瀏覽代碼

Fix restoration bugs

Bobholamovic 2 年之前
父節點
當前提交
d3c1499a87

+ 1 - 0
docs/apis/train.md

@@ -30,6 +30,7 @@
 
 - 一般支持设置`sr_factor`参数,表示超分辨率倍数;对于不支持超分辨率重建任务的模型,`sr_factor`设置为`None`。
 - 可通过`losses`参数指定模型训练时使用的损失函数,传入实参需为可调用对象或字典。手动指定的`losses`与子类的`default_loss()`方法返回值必须具有相同的格式。
+- 可通过`min_max`参数指定模型输入、输出的数值范围;若为`None`,则使用类默认的数值范围。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/res)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py)。
 
 ### 初始化`BaseSegmenter`子类对象

+ 1 - 1
docs/dev/dev_guide.md

@@ -64,7 +64,7 @@ Args:
 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。
 
 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下:
-    - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数。
+    - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数;第3个参数是`min_max`,表示输入、输出影像的数值范围
     - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。
     - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。
 

+ 4 - 2
paddlers/tasks/base.py

@@ -90,7 +90,8 @@ class BaseModel(metaclass=ModelMeta):
                        pretrain_weights=None,
                        save_dir='.',
                        resume_checkpoint=None,
-                       is_backbone_weights=False):
+                       is_backbone_weights=False,
+                       load_optim_state=True):
         if pretrain_weights is not None and \
                 not osp.exists(pretrain_weights):
             if not osp.isdir(save_dir):
@@ -148,7 +149,8 @@ class BaseModel(metaclass=ModelMeta):
                 self.net,
                 self.optimizer,
                 model_name=self.model_name,
-                checkpoint=resume_checkpoint)
+                checkpoint=resume_checkpoint,
+                load_optim_state=load_optim_state)
 
     def get_model_info(self, get_raw_params=False, inplace=True):
         if inplace:

+ 1 - 0
paddlers/tasks/change_detector.py

@@ -528,6 +528,7 @@ class BaseChangeDetector(BaseModel):
             return eval_metrics, eval_details
         return eval_metrics
 
+    @paddle.no_grad()
     def predict(self, img_file, transforms=None):
         """
         Do inference.

+ 1 - 0
paddlers/tasks/classifier.py

@@ -432,6 +432,7 @@ class BaseClassifier(BaseModel):
 
             return eval_metrics
 
+    @paddle.no_grad()
     def predict(self, img_file, transforms=None):
         """
         Do inference.

+ 1 - 0
paddlers/tasks/object_detector.py

@@ -567,6 +567,7 @@ class BaseDetector(BaseModel):
                 return scores, self.eval_details
             return scores
 
+    @paddle.no_grad()
     def predict(self, img_file, transforms=None):
         """
         Do inference.

+ 39 - 9
paddlers/tasks/restorer.py

@@ -43,7 +43,12 @@ class BaseRestorer(BaseModel):
     MIN_MAX = (0., 1.)
     TEST_OUT_KEY = None
 
-    def __init__(self, model_name, losses=None, sr_factor=None, **params):
+    def __init__(self,
+                 model_name,
+                 losses=None,
+                 sr_factor=None,
+                 min_max=None,
+                 **params):
         self.init_params = locals()
         if 'with_net' in self.init_params:
             del self.init_params['with_net']
@@ -55,6 +60,8 @@ class BaseRestorer(BaseModel):
             params.pop('with_net', None)
             self.net = self.build_net(**params)
         self.find_unused_parameters = True
+        if min_max is None:
+            self.min_max = self.MIN_MAX
 
     def build_net(self, **params):
         # Currently, only use models from cmres.
@@ -283,11 +290,13 @@ class BaseRestorer(BaseModel):
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
+        # XXX: Currently, do not load optimizer state dict.
         self.initialize_net(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
-            is_backbone_weights=is_backbone_weights)
+            is_backbone_weights=is_backbone_weights,
+            load_optim_state=False)
 
         self.train_loop(
             num_epochs=num_epochs,
@@ -434,6 +443,7 @@ class BaseRestorer(BaseModel):
 
             return eval_metrics
 
+    @paddle.no_grad()
     def predict(self, img_file, transforms=None):
         """
         Do inference.
@@ -653,9 +663,9 @@ class BaseRestorer(BaseModel):
         if copy:
             im = im.copy()
         if clip:
-            im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
-        im -= im.min()
-        im /= im.max() + 1e-32
+            im = np.clip(im, self.min_max[0], self.min_max[1])
+        im -= self.min_max[0]
+        im /= self.min_max[1] - self.min_max[0]
         if quantize:
             im *= 255
             im = im.astype('uint8')
@@ -668,6 +678,7 @@ class DRN(BaseRestorer):
     def __init__(self,
                  losses=None,
                  sr_factor=4,
+                 min_max=None,
                  scales=(2, 4),
                  n_blocks=30,
                  n_feats=16,
@@ -691,7 +702,11 @@ class DRN(BaseRestorer):
         self.dual_loss_weight = dual_loss_weight
         self.scales = scales
         super(DRN, self).__init__(
-            model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
+            model_name='DRN',
+            losses=losses,
+            sr_factor=sr_factor,
+            min_max=min_max,
+            **params)
 
     def build_net(self, **params):
         from ppgan.modules.init import init_weights
@@ -769,6 +784,7 @@ class LESRCNN(BaseRestorer):
     def __init__(self,
                  losses=None,
                  sr_factor=4,
+                 min_max=None,
                  multi_scale=False,
                  group=1,
                  **params):
@@ -778,7 +794,11 @@ class LESRCNN(BaseRestorer):
             'group': group
         })
         super(LESRCNN, self).__init__(
-            model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
+            model_name='LESRCNN',
+            losses=losses,
+            sr_factor=sr_factor,
+            min_max=min_max,
+            **params)
 
     def build_net(self, **params):
         net = ppgan.models.generators.LESRCNNGenerator(**params)
@@ -789,6 +809,7 @@ class ESRGAN(BaseRestorer):
     def __init__(self,
                  losses=None,
                  sr_factor=4,
+                 min_max=None,
                  use_gan=True,
                  in_channels=3,
                  out_channels=3,
@@ -805,7 +826,11 @@ class ESRGAN(BaseRestorer):
         })
         self.use_gan = use_gan
         super(ESRGAN, self).__init__(
-            model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
+            model_name='ESRGAN',
+            losses=losses,
+            sr_factor=sr_factor,
+            min_max=min_max,
+            **params)
 
     def build_net(self, **params):
         from ppgan.modules.init import init_weights
@@ -932,6 +957,7 @@ class RCAN(BaseRestorer):
     def __init__(self,
                  losses=None,
                  sr_factor=4,
+                 min_max=None,
                  n_resgroups=10,
                  n_resblocks=20,
                  n_feats=64,
@@ -950,4 +976,8 @@ class RCAN(BaseRestorer):
             'reduction': reduction
         })
         super(RCAN, self).__init__(
-            model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)
+            model_name='RCAN',
+            losses=losses,
+            sr_factor=sr_factor,
+            min_max=min_max,
+            **params)

+ 1 - 0
paddlers/tasks/segmenter.py

@@ -497,6 +497,7 @@ class BaseSegmenter(BaseModel):
             return eval_metrics, eval_details
         return eval_metrics
 
+    @paddle.no_grad()
     def predict(self, img_file, transforms=None):
         """
         Do inference.

+ 8 - 3
paddlers/utils/checkpoint.py

@@ -527,11 +527,16 @@ def load_optimizer(optimizer, state_dict_path):
     optimizer.set_state_dict(optim_state_dict)
 
 
-def load_checkpoint(model, optimizer, model_name, checkpoint):
+def load_checkpoint(model,
+                    optimizer,
+                    model_name,
+                    checkpoint,
+                    load_optim_state=True):
     logging.info("Loading checkpoint from {}".format(checkpoint))
     load_pretrain_weights(
         model,
         pretrain_weights=osp.join(checkpoint, 'model.pdparams'),
         model_name=model_name)
-    load_optimizer(
-        optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt"))
+    if load_optim_state:
+        load_optimizer(
+            optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt"))