Эх сурвалжийг харах

Merge pull request #24 from Bobholamovic/refactor_res

[Refactor] Refactor Models for Restoration Tasks
cc 2 жил өмнө
parent
commit
6b02cf875f
48 өөрчлөгдсөн 2038 нэмэгдсэн , 1460 устгасан
  1. 4 0
      docs/apis/data.md
  2. 7 3
      docs/apis/infer.md
  3. 11 1
      docs/apis/train.md
  4. 3 3
      docs/dev/dev_guide.md
  5. 1 1
      paddlers/datasets/__init__.py
  6. 5 5
      paddlers/datasets/cd_dataset.py
  7. 83 0
      paddlers/datasets/res_dataset.py
  8. 1 1
      paddlers/datasets/seg_dataset.py
  9. 0 99
      paddlers/datasets/sr_dataset.py
  10. 18 1
      paddlers/deploy/predictor.py
  11. 1 0
      paddlers/models/__init__.py
  12. 1 1
      paddlers/models/ppdet/metrics/json_results.py
  13. 11 10
      paddlers/rs_models/res/generators/param_init.py
  14. 34 17
      paddlers/rs_models/res/generators/rcan.py
  15. 0 106
      paddlers/rs_models/res/rcan_model.py
  16. 1 1
      paddlers/tasks/__init__.py
  17. 24 32
      paddlers/tasks/base.py
  18. 18 11
      paddlers/tasks/change_detector.py
  19. 64 41
      paddlers/tasks/classifier.py
  20. 0 786
      paddlers/tasks/image_restorer.py
  21. 36 15
      paddlers/tasks/object_detector.py
  22. 934 0
      paddlers/tasks/restorer.py
  23. 17 12
      paddlers/tasks/segmenter.py
  24. 28 11
      paddlers/tasks/utils/infer_nets.py
  25. 132 0
      paddlers/tasks/utils/res_adapters.py
  26. 4 0
      paddlers/transforms/functions.py
  27. 102 15
      paddlers/transforms/operators.py
  28. 1 1
      paddlers/utils/__init__.py
  29. 32 1
      paddlers/utils/utils.py
  30. 28 26
      tests/data/data_utils.py
  31. 60 7
      tests/deploy/test_predictor.py
  32. 1 3
      tests/rs_models/test_cd_models.py
  33. 3 0
      tests/rs_models/test_det_models.py
  34. 46 0
      tests/rs_models/test_res_models.py
  35. 5 3
      tests/rs_models/test_seg_models.py
  36. 43 0
      tests/transforms/test_operators.py
  37. 5 5
      tutorials/train/README.md
  38. 1 1
      tutorials/train/change_detection/changeformer.py
  39. 1 1
      tutorials/train/classification/hrnet.py
  40. 1 1
      tutorials/train/classification/mobilenetv3.py
  41. 1 1
      tutorials/train/classification/resnet50_vd.py
  42. 3 0
      tutorials/train/image_restoration/data/.gitignore
  43. 89 0
      tutorials/train/image_restoration/drn.py
  44. 0 80
      tutorials/train/image_restoration/drn_train.py
  45. 89 0
      tutorials/train/image_restoration/esrgan.py
  46. 0 80
      tutorials/train/image_restoration/esrgan_train.py
  47. 89 0
      tutorials/train/image_restoration/lesrcnn.py
  48. 0 78
      tutorials/train/image_restoration/lesrcnn_train.py

+ 4 - 0
docs/apis/data.md

@@ -84,6 +84,9 @@
 
 - file list中的每一行应该包含2个以空格分隔的项,依次表示输入影像相对`data_dir`的路径以及[Pascal VOC格式](http://host.robots.ox.ac.uk/pascal/VOC/)标注文件相对`data_dir`的路径。
 
+### 图像复原数据集`ResDataset`
+
+
 ### 图像分割数据集`SegDataset`
 
 `SegDataset`定义在:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/seg_dataset.py
@@ -143,6 +146,7 @@
 |`'aux_masks'`|图像分割/变化检测任务中的辅助标签路径或数据。|
 |`'gt_bbox'`|目标检测任务中的检测框标注数据。|
 |`'gt_poly'`|目标检测任务中的多边形标注数据。|
+|`'target'`|图像复原中的目标影像路径或数据。|
 
 ### 组合数据变换算子
 

+ 7 - 3
docs/apis/infer.md

@@ -26,7 +26,7 @@ def predict(self, img_file, transforms=None):
 若`img_file`是一个元组,则返回对象为包含下列键值对的字典:
 
 ```
-{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
+{"label_map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
 ```
 
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
@@ -51,7 +51,7 @@ def predict(self, img_file, transforms=None):
 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
 
 ```
-{"label map": 输出类别标签,
+{"label_map": 输出类别标签,
  "scores_map": 输出类别概率,
  "label_names_map": 输出类别名称}
 ```
@@ -87,6 +87,10 @@ def predict(self, img_file, transforms=None):
 
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个由字典(键值对如上所示)构成的列表,顺序对应`img_file`中的每个元素。
 
+#### `BaseRestorer.predict()`
+
+
+
 #### `BaseSegmenter.predict()`
 
 接口形式:
@@ -107,7 +111,7 @@ def predict(self, img_file, transforms=None):
 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
 
 ```
-{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
+{"label_map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
 ```
 
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。

+ 11 - 1
docs/apis/train.md

@@ -18,11 +18,15 @@
 - `use_mixed_loss`参将在未来被弃用,因此不建议使用。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/clas)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py)。
 
-### 初始化`Baseetector`子类对象
+### 初始化`BaseDetector`子类对象
 
 - 一般支持设置`num_classes`和`backbone`参数,分别表示模型输出类别数以及所用的骨干网络类型。相比其它任务,目标检测任务的训练器支持设置的初始化参数较多,囊括网络结构、损失函数、后处理策略等方面。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/det)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/object_detector.py)。
 
+### 初始化`BaseRestorer`子类对象
+
+
+
 ### 初始化`BaseSegmenter`子类对象
 
 - 一般支持设置`in_channels`、`num_classes`以及`use_mixed_loss`参数,分别表示输入通道数、输出类别数以及是否使用预置的混合损失。部分模型如`FarSeg`暂不支持对`in_channels`参数的设置。
@@ -170,6 +174,9 @@ def train(self,
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
 
+### `BaseRestorer.train()`
+
+
 ### `BaseSegmenter.train()`
 
 接口形式:
@@ -311,6 +318,9 @@ def evaluate(self,
  "mask": 预测得到的掩模图信息}
 ```
 
+### `BaseRestorer.evaluate()`
+
+
 ### `BaseSegmenter.evaluate()`
 
 接口形式:

+ 3 - 3
docs/dev/dev_guide.md

@@ -22,7 +22,7 @@
 
 在子目录中新建文件,以`{模型名称小写}.py`命名。在文件中编写完整的模型定义。
 
-新模型必须是`paddle.nn.Layer`的子类。对于图像分割、目标检测和场景分类任务,分别需要遵循[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)、[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)套件中制定的相关规范。**对于变化检测、场景分类和图像分割任务,模型构造时必须传入`num_classes`参数以指定输出的类别数目**对于变化检测任务,模型定义需遵循的规范与分割模型类似,但有以下几点不同:
+新模型必须是`paddle.nn.Layer`的子类。对于图像分割、目标检测、场景分类和图像复原任务,分别需要遵循[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)、[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)和[PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN)套件中制定的相关规范。**对于变化检测、场景分类和图像分割任务,模型构造时必须传入`num_classes`参数以指定输出的类别数目。对于图像复原任务,模型构造时必须传入`rs_factor`参数以指定超分辨率缩放倍数(对于非超分辨率模型,将此参数设置为`None`)。**对于变化检测任务,模型定义需遵循的规范与分割模型类似,但有以下几点不同:
 
 - `forward()`方法接受3个输入参数,分别是`self`、`t1`和`t2`,其中`t1`和`t2`分别表示前、后两个时相的输入影像。
 - 对于多任务变化检测模型(例如模型同时输出变化检测结果与两个时相的建筑物提取结果),需要指定类的`USE_MULTITASK_DECODER`属性为`True`,同时在`OUT_TYPES`属性中设置模型前向输出的列表中每一个元素对应的标签类型。可参考`ChangeStar`模型的定义。
@@ -64,7 +64,7 @@ Args:
 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。
 
 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下:
-    - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失。
+    - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数
     - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。
     - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。
 
@@ -78,7 +78,7 @@ Args:
 
 ### 2.2 新增数据预处理/数据增强算子
 
-在`paddlers/transforms/operators.py`中定义新算子,所有算子均继承自`paddlers.transforms.Transform`类。算子的`apply()`方法接收一个字典`sample`作为输入,取出其中存储的相关对象,处理后对字典进行in-place修改,最后返回修改后的字典。在定义算子时,只有极少数的情况需要重写`apply()`方法。大多数情况下,只需要重写`apply_im()`、`apply_mask()`、`apply_bbox()`和`apply_segm()`方法就分别可以实现对输入图像、分割标签、目标框以及目标多边形的处理。
+在`paddlers/transforms/operators.py`中定义新算子,所有算子均继承自`paddlers.transforms.Transform`类。算子的`apply()`方法接收一个字典`sample`作为输入,取出其中存储的相关对象,处理后对字典进行in-place修改,最后返回修改后的字典。在定义算子时,只有极少数的情况需要重写`apply()`方法。大多数情况下,只需要重写`apply_im()`、`apply_mask()`、`apply_bbox()`和`apply_segm()`方法就分别可以实现对图像、分割标签、目标框以及目标多边形的处理。
 
 如果处理逻辑较为复杂,建议先封装为函数,添加到`paddlers/transforms/functions.py`中,然后在算子的`apply*()`方法中调用函数。
 

+ 1 - 1
paddlers/datasets/__init__.py

@@ -17,4 +17,4 @@ from .coco import COCODetDataset
 from .seg_dataset import SegDataset
 from .cd_dataset import CDDataset
 from .clas_dataset import ClasDataset
-from .sr_dataset import SRdataset, ComposeTrans
+from .res_dataset import ResDataset

+ 5 - 5
paddlers/datasets/cd_dataset.py

@@ -95,23 +95,23 @@ class CDDataset(BaseDataset):
                                      full_path_label))):
                     continue
                 if not osp.exists(full_path_im_t1):
-                    raise IOError('Image file {} does not exist!'.format(
+                    raise IOError("Image file {} does not exist!".format(
                         full_path_im_t1))
                 if not osp.exists(full_path_im_t2):
-                    raise IOError('Image file {} does not exist!'.format(
+                    raise IOError("Image file {} does not exist!".format(
                         full_path_im_t2))
                 if not osp.exists(full_path_label):
-                    raise IOError('Label file {} does not exist!'.format(
+                    raise IOError("Label file {} does not exist!".format(
                         full_path_label))
 
                 if with_seg_labels:
                     full_path_seg_label_t1 = osp.join(data_dir, items[3])
                     full_path_seg_label_t2 = osp.join(data_dir, items[4])
                     if not osp.exists(full_path_seg_label_t1):
-                        raise IOError('Label file {} does not exist!'.format(
+                        raise IOError("Label file {} does not exist!".format(
                             full_path_seg_label_t1))
                     if not osp.exists(full_path_seg_label_t2):
-                        raise IOError('Label file {} does not exist!'.format(
+                        raise IOError("Label file {} does not exist!".format(
                             full_path_seg_label_t2))
 
                 item_dict = dict(

+ 83 - 0
paddlers/datasets/res_dataset.py

@@ -0,0 +1,83 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import copy
+
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
+
+
+class ResDataset(BaseDataset):
+    """
+    Dataset for image restoration tasks.
+
+    Args:
+        data_dir (str): Root directory of the dataset.
+        file_list (str): Path of the file that contains relative paths of source and target image files.
+        transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
+        num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
+            the number of workers will be automatically determined according to the number of CPU cores: If 
+            there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half 
+            the number of CPU cores. Defaults: 'auto'.
+        shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
+        sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image 
+            restoration tasks. Defaults to None.
+    """
+
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 transforms,
+                 num_workers='auto',
+                 shuffle=False,
+                 sr_factor=None):
+        super(ResDataset, self).__init__(data_dir, None, transforms,
+                                         num_workers, shuffle)
+        self.batch_transforms = None
+        self.file_list = list()
+
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                items = line.strip().split()
+                if len(items) > 2:
+                    raise ValueError(
+                        "A space is defined as the delimiter to separate the source and target image path, " \
+                        "so the space cannot be in the source image or target image path, but the line[{}] of " \
+                        " file_list[{}] has a space in the two paths.".format(line, file_list))
+                items[0] = norm_path(items[0])
+                items[1] = norm_path(items[1])
+                full_path_im = osp.join(data_dir, items[0])
+                full_path_tar = osp.join(data_dir, items[1])
+                if not is_pic(full_path_im) or not is_pic(full_path_tar):
+                    continue
+                if not osp.exists(full_path_im):
+                    raise IOError("Source image file {} does not exist!".format(
+                        full_path_im))
+                if not osp.exists(full_path_tar):
+                    raise IOError("Target image file {} does not exist!".format(
+                        full_path_tar))
+                sample = {
+                    'image': full_path_im,
+                    'target': full_path_tar,
+                }
+                if sr_factor is not None:
+                    sample['sr_factor'] = sr_factor
+                self.file_list.append(sample)
+        self.num_samples = len(self.file_list)
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+
+    def __len__(self):
+        return len(self.file_list)

+ 1 - 1
paddlers/datasets/seg_dataset.py

@@ -44,7 +44,7 @@ class SegDataset(BaseDataset):
                  shuffle=False):
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
                                          num_workers, shuffle)
-        # TODO batch padding
+        # TODO: batch padding
         self.batch_transforms = None
         self.file_list = list()
         self.labels = list()

+ 0 - 99
paddlers/datasets/sr_dataset.py

@@ -1,99 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# 超分辨率数据集定义
-class SRdataset(object):
-    def __init__(self,
-                 mode,
-                 gt_floder,
-                 lq_floder,
-                 transforms,
-                 scale,
-                 num_workers=4,
-                 batch_size=8):
-        if mode == 'train':
-            preprocess = []
-            preprocess.append({
-                'name': 'LoadImageFromFile',
-                'key': 'lq'
-            })  # 加载方式
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
-            preprocess.append(transforms)  # 变换方式
-            self.dataset = {
-                'name': 'SRDataset',
-                'gt_folder': gt_floder,
-                'lq_folder': lq_floder,
-                'num_workers': num_workers,
-                'batch_size': batch_size,
-                'scale': scale,
-                'preprocess': preprocess
-            }
-
-        if mode == "test":
-            preprocess = []
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'})
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
-            preprocess.append(transforms)
-            self.dataset = {
-                'name': 'SRDataset',
-                'gt_folder': gt_floder,
-                'lq_folder': lq_floder,
-                'scale': scale,
-                'preprocess': preprocess
-            }
-
-    def __call__(self):
-        return self.dataset
-
-
-# 对定义的transforms处理方式组合,返回字典
-class ComposeTrans(object):
-    def __init__(self, input_keys, output_keys, pipelines):
-        if not isinstance(pipelines, list):
-            raise TypeError(
-                'Type of transforms is invalid. Must be List, but received is {}'
-                .format(type(pipelines)))
-        if len(pipelines) < 1:
-            raise ValueError(
-                'Length of transforms must not be less than 1, but received is {}'
-                .format(len(pipelines)))
-        self.transforms = pipelines
-        self.output_length = len(output_keys)  # 当output_keys的长度为3时,是DRN训练
-        self.input_keys = input_keys
-        self.output_keys = output_keys
-
-    def __call__(self):
-        pipeline = []
-        for op in self.transforms:
-            if op['name'] == 'SRPairedRandomCrop':
-                op['keys'] = ['image'] * 2
-            else:
-                op['keys'] = ['image'] * self.output_length
-            pipeline.append(op)
-        if self.output_length == 2:
-            transform_dict = {
-                'name': 'Transforms',
-                'input_keys': self.input_keys,
-                'pipeline': pipeline
-            }
-        else:
-            transform_dict = {
-                'name': 'Transforms',
-                'input_keys': self.input_keys,
-                'output_keys': self.output_keys,
-                'pipeline': pipeline
-            }
-
-        return transform_dict

+ 18 - 1
paddlers/deploy/predictor.py

@@ -163,13 +163,23 @@ class Predictor(object):
                 'image2': preprocessed_samples[1],
                 'ori_shape': preprocessed_samples[2]
             }
+        elif self._model.model_type == 'restorer':
+            preprocessed_samples = {
+                'image': preprocessed_samples[0],
+                'tar_shape': preprocessed_samples[1]
+            }
         else:
             logging.error(
                 "Invalid model type {}".format(self._model.model_type),
                 exit=True)
         return preprocessed_samples
 
-    def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None):
+    def postprocess(self,
+                    net_outputs,
+                    topk=1,
+                    ori_shape=None,
+                    tar_shape=None,
+                    transforms=None):
         if self._model.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
@@ -201,6 +211,12 @@ class Predictor(object):
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
             }
             preds = self._model.postprocess(net_outputs)
+        elif self._model.model_type == 'restorer':
+            res_maps = self._model.postprocess(
+                net_outputs[0],
+                batch_tar_shape=tar_shape,
+                transforms=transforms.transforms)
+            preds = [{'res_map': res_map} for res_map in res_maps]
         else:
             logging.error(
                 "Invalid model type {}.".format(self._model.model_type),
@@ -244,6 +260,7 @@ class Predictor(object):
             net_outputs,
             topk,
             ori_shape=preprocessed_input.get('ori_shape', None),
+            tar_shape=preprocessed_input.get('tar_shape', None),
             transforms=transforms)
         self.timer.postprocess_time_s.end(iter_num=len(images))
 

+ 1 - 0
paddlers/models/__init__.py

@@ -16,3 +16,4 @@ from . import ppcls, ppdet, ppseg, ppgan
 import paddlers.models.ppseg.models.losses as seg_losses
 import paddlers.models.ppdet.modeling.losses as det_losses
 import paddlers.models.ppcls.loss as clas_losses
+import paddlers.models.ppgan.models.criterions as res_losses

+ 1 - 1
paddlers/models/ppdet/metrics/json_results.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .rcan_model import RCANModel
+from .generators import *

+ 11 - 10
paddlers/rs_models/res/generators/builder.py → paddlers/rs_models/res/generators/param_init.py

@@ -1,10 +1,10 @@
-#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-#     http://www.apache.org/licenses/LICENSE-2.0
+#    http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,15 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
+import paddle
+import paddle.nn as nn
 
-from ....models.ppgan.utils.registry import Registry
+from paddlers.models.ppgan.modules.init import reset_parameters
 
-GENERATORS = Registry("GENERATOR")
 
+def init_sr_weight(net):
+    def reset_func(m):
+        if hasattr(m, 'weight') and (
+                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
+            reset_parameters(m)
 
-def build_generator(cfg):
-    cfg_copy = copy.deepcopy(cfg)
-    name = cfg_copy.pop('name')
-    generator = GENERATORS.get(name)(**cfg_copy)
-    return generator
+    net.apply(reset_func)

+ 34 - 17
paddlers/rs_models/res/generators/rcan.py

@@ -1,10 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 # Based on https://github.com/kongdebug/RCAN-Paddle
+
 import math
 
 import paddle
 import paddle.nn as nn
 
-from .builder import GENERATORS
+from .param_init import init_sr_weight
 
 
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
@@ -63,8 +78,10 @@ class RCAB(nn.Layer):
                  bias=True,
                  bn=False,
                  act=nn.ReLU(),
-                 res_scale=1):
+                 res_scale=1,
+                 use_init_weight=False):
         super(RCAB, self).__init__()
+
         modules_body = []
         for i in range(2):
             modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
@@ -74,6 +91,9 @@ class RCAB(nn.Layer):
         self.body = nn.Sequential(*modules_body)
         self.res_scale = res_scale
 
+        if use_init_weight:
+            init_sr_weight(self)
+
     def forward(self, x):
         res = self.body(x)
         res += x
@@ -128,21 +148,19 @@ class Upsampler(nn.Sequential):
         super(Upsampler, self).__init__(*m)
 
 
-@GENERATORS.register()
 class RCAN(nn.Layer):
-    def __init__(
-            self,
-            scale,
-            n_resgroups,
-            n_resblocks,
-            n_feats=64,
-            n_colors=3,
-            rgb_range=255,
-            kernel_size=3,
-            reduction=16,
-            conv=default_conv, ):
+    def __init__(self,
+                 sr_factor=4,
+                 n_resgroups=10,
+                 n_resblocks=20,
+                 n_feats=64,
+                 n_colors=3,
+                 rgb_range=255,
+                 kernel_size=3,
+                 reduction=16,
+                 conv=default_conv):
         super(RCAN, self).__init__()
-        self.scale = scale
+        self.scale = sr_factor
         act = nn.ReLU()
 
         n_resgroups = n_resgroups
@@ -150,7 +168,6 @@ class RCAN(nn.Layer):
         n_feats = n_feats
         kernel_size = kernel_size
         reduction = reduction
-        scale = scale
         act = nn.ReLU()
 
         rgb_mean = (0.4488, 0.4371, 0.4040)
@@ -171,7 +188,7 @@ class RCAN(nn.Layer):
         # Define tail module
         modules_tail = [
             Upsampler(
-                conv, scale, n_feats, act=False),
+                conv, self.scale, n_feats, act=False),
             conv(n_feats, n_colors, kernel_size)
         ]
 

+ 0 - 106
paddlers/rs_models/res/rcan_model.py

@@ -1,106 +0,0 @@
-#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import paddle
-import paddle.nn as nn
-
-from .generators.builder import build_generator
-from ...models.ppgan.models.criterions.builder import build_criterion
-from ...models.ppgan.models.base_model import BaseModel
-from ...models.ppgan.models.builder import MODELS
-from ...models.ppgan.utils.visual import tensor2img
-from ...models.ppgan.modules.init import reset_parameters
-
-
-@MODELS.register()
-class RCANModel(BaseModel):
-    """
-    Base SR model for single image super-resolution.
-    """
-
-    def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
-        """
-        Args:
-            generator (dict): config of generator.
-            pixel_criterion (dict): config of pixel criterion.
-        """
-        super(RCANModel, self).__init__()
-
-        self.nets['generator'] = build_generator(generator)
-        self.error_last = 1e8
-        self.batch = 0
-        if pixel_criterion:
-            self.pixel_criterion = build_criterion(pixel_criterion)
-        if use_init_weight:
-            init_sr_weight(self.nets['generator'])
-
-    def setup_input(self, input):
-        self.lq = paddle.to_tensor(input['lq'])
-        self.visual_items['lq'] = self.lq
-        if 'gt' in input:
-            self.gt = paddle.to_tensor(input['gt'])
-            self.visual_items['gt'] = self.gt
-        self.image_paths = input['lq_path']
-
-    def forward(self):
-        pass
-
-    def train_iter(self, optims=None):
-        optims['optim'].clear_grad()
-
-        self.output = self.nets['generator'](self.lq)
-        self.visual_items['output'] = self.output
-        # pixel loss
-        loss_pixel = self.pixel_criterion(self.output, self.gt)
-        self.losses['loss_pixel'] = loss_pixel
-
-        skip_threshold = 1e6
-
-        if loss_pixel.item() < skip_threshold * self.error_last:
-            loss_pixel.backward()
-            optims['optim'].step()
-        else:
-            print('Skip this batch {}! (Loss: {})'.format(self.batch + 1,
-                                                          loss_pixel.item()))
-        self.batch += 1
-
-        if self.batch % 1000 == 0:
-            self.error_last = loss_pixel.item() / 1000
-            print("update error_last:{}".format(self.error_last))
-
-    def test_iter(self, metrics=None):
-        self.nets['generator'].eval()
-        with paddle.no_grad():
-            self.output = self.nets['generator'](self.lq)
-            self.visual_items['output'] = self.output
-        self.nets['generator'].train()
-
-        out_img = []
-        gt_img = []
-        for out_tensor, gt_tensor in zip(self.output, self.gt):
-            out_img.append(tensor2img(out_tensor, (0., 255.)))
-            gt_img.append(tensor2img(gt_tensor, (0., 255.)))
-
-        if metrics is not None:
-            for metric in metrics.values():
-                metric.update(out_img, gt_img)
-
-
-def init_sr_weight(net):
-    def reset_func(m):
-        if hasattr(m, 'weight') and (
-                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
-            reset_parameters(m)
-
-    net.apply(reset_func)

+ 1 - 1
paddlers/tasks/__init__.py

@@ -16,7 +16,7 @@ import paddlers.tasks.object_detector as detector
 import paddlers.tasks.segmenter as segmenter
 import paddlers.tasks.change_detector as change_detector
 import paddlers.tasks.classifier as classifier
-import paddlers.tasks.image_restorer as restorer
+import paddlers.tasks.restorer as restorer
 from .load_model import load_model
 
 # Shorter aliases

+ 24 - 32
paddlers/tasks/base.py

@@ -30,12 +30,11 @@ from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 
 import paddlers
 import paddlers.utils.logging as logging
-from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
-                            get_pretrain_weights, load_pretrain_weights,
-                            load_checkpoint, SmoothedValue, TrainingStats,
-                            _get_shared_memory_size_in_M, EarlyStop)
+from paddlers.utils import (
+    seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights,
+    load_pretrain_weights, load_checkpoint, SmoothedValue, TrainingStats,
+    _get_shared_memory_size_in_M, EarlyStop, to_data_parallel, scheduler_step)
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
-from .utils.infer_nets import InferNet, InferCDNet
 
 
 class ModelMeta(type):
@@ -268,7 +267,7 @@ class BaseModel(metaclass=ModelMeta):
                 'The volume of dataset({}) must be larger than batch size({}).'
                 .format(dataset.num_samples, batch_size))
         batch_size_each_card = get_single_card_bs(batch_size=batch_size)
-        # TODO detection eval阶段需做判断
+
         batch_sampler = DistributedBatchSampler(
             dataset,
             batch_size=batch_size_each_card,
@@ -308,7 +307,7 @@ class BaseModel(metaclass=ModelMeta):
                    use_vdl=True):
         self._check_transforms(train_dataset.transforms, 'train')
 
-        if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
+        if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
             nranks = 1
         else:
@@ -321,10 +320,10 @@ class BaseModel(metaclass=ModelMeta):
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
                 paddle.distributed.init_parallel_env()
-                ddp_net = paddle.DataParallel(
+                ddp_net = to_data_parallel(
                     self.net, find_unused_parameters=find_unused_parameters)
             else:
-                ddp_net = paddle.DataParallel(
+                ddp_net = to_data_parallel(
                     self.net, find_unused_parameters=find_unused_parameters)
 
         if use_vdl:
@@ -365,24 +364,14 @@ class BaseModel(metaclass=ModelMeta):
 
             for step, data in enumerate(self.train_data_loader()):
                 if nranks > 1:
-                    outputs = self.run(ddp_net, data, mode='train')
+                    outputs = self.train_step(step, data, ddp_net)
                 else:
-                    outputs = self.run(self.net, data, mode='train')
-                loss = outputs['loss']
-                loss.backward()
-                self.optimizer.step()
-                self.optimizer.clear_grad()
-                lr = self.optimizer.get_lr()
-                if isinstance(self.optimizer._learning_rate,
-                              paddle.optimizer.lr.LRScheduler):
-                    # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
-                    if isinstance(self.optimizer._learning_rate,
-                                  paddle.optimizer.lr.ReduceOnPlateau):
-                        self.optimizer._learning_rate.step(loss.item())
-                    else:
-                        self.optimizer._learning_rate.step()
+                    outputs = self.train_step(step, data, self.net)
+
+                scheduler_step(self.optimizer)
 
                 train_avg_metrics.update(outputs)
+                lr = self.optimizer.get_lr()
                 outputs['lr'] = lr
                 if ema is not None:
                     ema.update(self.net)
@@ -622,14 +611,7 @@ class BaseModel(metaclass=ModelMeta):
         return pipeline_info
 
     def _build_inference_net(self):
-        if self.model_type in ('classifier', 'detector'):
-            infer_net = self.net
-        elif self.model_type == 'change_detector':
-            infer_net = InferCDNet(self.net)
-        else:
-            infer_net = InferNet(self.net, self.model_type)
-        infer_net.eval()
-        return infer_net
+        raise NotImplementedError
 
     def _export_inference_model(self, save_dir, image_shape=None):
         self.test_inputs = self._get_test_inputs(image_shape)
@@ -674,6 +656,16 @@ class BaseModel(metaclass=ModelMeta):
         logging.info("The inference model for deployment is saved in {}.".
                      format(save_dir))
 
+    def train_step(self, step, data, net):
+        outputs = self.run(net, data, mode='train')
+
+        loss = outputs['loss']
+        loss.backward()
+        self.optimizer.step()
+        self.optimizer.clear_grad()
+
+        return outputs
+
     def _check_transforms(self, transforms, mode):
         # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
         if not isinstance(transforms, paddlers.transforms.Compose):

+ 18 - 11
paddlers/tasks/change_detector.py

@@ -30,10 +30,11 @@ import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image
-from paddlers.utils import get_single_card_bs, DisablePrint
+from paddlers.utils import get_single_card_bs
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
+from .utils.infer_nets import InferCDNet
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -69,6 +70,11 @@ class BaseChangeDetector(BaseModel):
                                              **params)
         return net
 
+    def _build_inference_net(self):
+        infer_net = InferCDNet(self.net)
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
@@ -399,7 +405,8 @@ class BaseChangeDetector(BaseModel):
                 Defaults to False.
 
         Returns:
-            collections.OrderedDict with key-value pairs:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
                 For binary change detection (number of classes == 2), the key-value 
                     pairs are like:
                     {"iou": `intersection over union for the change class`,
@@ -527,12 +534,12 @@ class BaseChangeDetector(BaseModel):
 
         Returns:
             If `img_file` is a tuple of string or np.array, the result is a dict with 
-                key-value pairs:
-                {"label map": `label map`, "score_map": `score map`}.
+                the following key-value pairs:
+                label_map (np.ndarray): Predicted label map (HW).
+                score_map (np.ndarray): Prediction score map (HWC).
+
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
-                label_map (np.ndarray): the predicted label map (HW)
-                score_map (np.ndarray): the prediction score map (HWC)
+                above keys.
         """
 
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -787,11 +794,11 @@ class BaseChangeDetector(BaseModel):
                 elif item[0] == 'padding':
                     x, y = item[2]
                     if isinstance(label_map, np.ndarray):
-                        label_map = label_map[..., y:y + h, x:x + w]
-                        score_map = score_map[..., y:y + h, x:x + w]
+                        label_map = label_map[y:y + h, x:x + w]
+                        score_map = score_map[y:y + h, x:x + w]
                     else:
-                        label_map = label_map[:, :, y:y + h, x:x + w]
-                        score_map = score_map[:, :, y:y + h, x:x + w]
+                        label_map = label_map[:, y:y + h, x:x + w, :]
+                        score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
                     pass
             label_map = label_map.squeeze()

+ 64 - 41
paddlers/tasks/classifier.py

@@ -83,6 +83,11 @@ class BaseClassifier(BaseModel):
                 self.in_channels = 3
         return net
 
+    def _build_inference_net(self):
+        infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
@@ -373,7 +378,8 @@ class BaseClassifier(BaseModel):
                 Defaults to False.
 
         Returns:
-            collections.OrderedDict with key-value pairs:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
                 {"top1": `acc of top1`,
                  "top5": `acc of top5`}.
         """
@@ -389,38 +395,37 @@ class BaseClassifier(BaseModel):
             ):
                 paddle.distributed.init_parallel_env()
 
-        batch_size_each_card = get_single_card_bs(batch_size)
-        if batch_size_each_card > 1:
-            batch_size_each_card = 1
-            batch_size = batch_size_each_card * paddlers.env_info['num']
+        if batch_size > 1:
             logging.warning(
-                "Classifier only supports batch_size=1 for each gpu/cpu card " \
-                "during evaluation, so batch_size " \
-                "is forcibly set to {}.".format(batch_size))
-        self.eval_data_loader = self.build_data_loader(
-            eval_dataset, batch_size=batch_size, mode='eval')
-
-        logging.info(
-            "Start to evaluate(total_samples={}, total_steps={})...".format(
-                eval_dataset.num_samples,
-                math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
-
-        top1s = []
-        top5s = []
-        with paddle.no_grad():
-            for step, data in enumerate(self.eval_data_loader):
-                data.append(eval_dataset.transforms.transforms)
-                outputs = self.run(self.net, data, 'eval')
-                top1s.append(outputs["top1"])
-                top5s.append(outputs["top5"])
-
-        top1 = np.mean(top1s)
-        top5 = np.mean(top5s)
-        eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
-        if return_details:
-            # TODO: add details
-            return eval_metrics, None
-        return eval_metrics
+                "Classifier only supports single card evaluation with batch_size=1 "
+                "during evaluation, so batch_size is forcibly set to 1.")
+            batch_size = 1
+
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset, batch_size=batch_size, mode='eval')
+            logging.info(
+                "Start to evaluate(total_samples={}, total_steps={})...".format(
+                    eval_dataset.num_samples, eval_dataset.num_samples))
+
+            top1s = []
+            top5s = []
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    data.append(eval_dataset.transforms.transforms)
+                    outputs = self.run(self.net, data, 'eval')
+                    top1s.append(outputs["top1"])
+                    top5s.append(outputs["top5"])
+
+            top1 = np.mean(top1s)
+            top5 = np.mean(top5s)
+            eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
+
+            if return_details:
+                # TODO: Add details
+                return eval_metrics, None
+
+            return eval_metrics
 
     def predict(self, img_file, transforms=None):
         """
@@ -435,16 +440,14 @@ class BaseClassifier(BaseModel):
                 Defaults to None.
 
         Returns:
-            If `img_file` is a string or np.array, the result is a dict with key-value 
-                pairs:
-                {"label map": `class_ids_map`, 
-                 "scores_map": `scores_map`, 
-                 "label_names_map": `label_names_map`}.
+            If `img_file` is a string or np.array, the result is a dict with the 
+                following key-value pairs:
+                class_ids_map (np.ndarray): IDs of predicted classes.
+                scores_map (np.ndarray): Scores of predicted classes.
+                label_names_map (np.ndarray): Names of predicted classes.
+            
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
-                class_ids_map (np.ndarray): class_ids
-                scores_map (np.ndarray): scores
-                label_names_map (np.ndarray): label_names
+                above keys.
         """
 
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -555,6 +558,26 @@ class BaseClassifier(BaseModel):
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeClassifier object.")
 
+    def build_data_loader(self, dataset, batch_size, mode='train'):
+        if dataset.num_samples < batch_size:
+            raise ValueError(
+                'The volume of dataset({}) must be larger than batch size({}).'
+                .format(dataset.num_samples, batch_size))
+
+        if mode != 'train':
+            return paddle.io.DataLoader(
+                dataset,
+                batch_size=batch_size,
+                shuffle=dataset.shuffle,
+                drop_last=False,
+                collate_fn=dataset.batch_transforms,
+                num_workers=dataset.num_workers,
+                return_list=True,
+                use_shared_memory=False)
+        else:
+            return super(BaseClassifier, self).build_data_loader(
+                dataset, batch_size, mode)
+
 
 class ResNet50_vd(BaseClassifier):
     def __init__(self,

+ 0 - 786
paddlers/tasks/image_restorer.py

@@ -1,786 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-import datetime
-
-import paddle
-from paddle.distributed import ParallelEnv
-
-from ..models.ppgan.datasets.builder import build_dataloader
-from ..models.ppgan.models.builder import build_model
-from ..models.ppgan.utils.visual import tensor2img, save_image
-from ..models.ppgan.utils.filesystem import makedirs, save, load
-from ..models.ppgan.utils.timer import TimeAverager
-from ..models.ppgan.utils.profiler import add_profiler_step
-from ..models.ppgan.utils.logger import setup_logger
-
-
-# 定义AttrDict类实现动态属性
-class AttrDict(dict):
-    def __getattr__(self, key):
-        try:
-            return self[key]
-        except KeyError:
-            raise AttributeError(key)
-
-    def __setattr__(self, key, value):
-        if key in self.__dict__:
-            self.__dict__[key] = value
-        else:
-            self[key] = value
-
-
-# 创建AttrDict类
-def create_attr_dict(config_dict):
-    from ast import literal_eval
-    for key, value in config_dict.items():
-        if type(value) is dict:
-            config_dict[key] = value = AttrDict(value)
-        if isinstance(value, str):
-            try:
-                value = literal_eval(value)
-            except BaseException:
-                pass
-        if isinstance(value, AttrDict):
-            create_attr_dict(config_dict[key])
-        else:
-            config_dict[key] = value
-
-
-# 数据加载类
-class IterLoader:
-    def __init__(self, dataloader):
-        self._dataloader = dataloader
-        self.iter_loader = iter(self._dataloader)
-        self._epoch = 1
-
-    @property
-    def epoch(self):
-        return self._epoch
-
-    def __next__(self):
-        try:
-            data = next(self.iter_loader)
-        except StopIteration:
-            self._epoch += 1
-            self.iter_loader = iter(self._dataloader)
-            data = next(self.iter_loader)
-
-        return data
-
-    def __len__(self):
-        return len(self._dataloader)
-
-
-# 基础训练类
-class Restorer:
-    """
-    # trainer calling logic:
-    #
-    #                build_model                               ||    model(BaseModel)
-    #                     |                                    ||
-    #               build_dataloader                           ||    dataloader
-    #                     |                                    ||
-    #               model.setup_lr_schedulers                  ||    lr_scheduler
-    #                     |                                    ||
-    #               model.setup_optimizers                     ||    optimizers
-    #                     |                                    ||
-    #     train loop (model.setup_input + model.train_iter)    ||    train loop
-    #                     |                                    ||
-    #         print log (model.get_current_losses)             ||
-    #                     |                                    ||
-    #         save checkpoint (model.nets)                     \/
-    """
-
-    def __init__(self, cfg, logger):
-        # base config
-        # self.logger = logging.getLogger(__name__)
-        self.logger = logger
-        self.cfg = cfg
-        self.output_dir = cfg.output_dir
-        self.max_eval_steps = cfg.model.get('max_eval_steps', None)
-
-        self.local_rank = ParallelEnv().local_rank
-        self.world_size = ParallelEnv().nranks
-        self.log_interval = cfg.log_config.interval
-        self.visual_interval = cfg.log_config.visiual_interval
-        self.weight_interval = cfg.snapshot_config.interval
-
-        self.start_epoch = 1
-        self.current_epoch = 1
-        self.current_iter = 1
-        self.inner_iter = 1
-        self.batch_id = 0
-        self.global_steps = 0
-
-        # build model
-        self.model = build_model(cfg.model)
-        # multiple gpus prepare
-        if ParallelEnv().nranks > 1:
-            self.distributed_data_parallel()
-
-        # build metrics
-        self.metrics = None
-        self.is_save_img = True
-        validate_cfg = cfg.get('validate', None)
-        if validate_cfg and 'metrics' in validate_cfg:
-            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
-        if validate_cfg and 'save_img' in validate_cfg:
-            self.is_save_img = validate_cfg['save_img']
-
-        self.enable_visualdl = cfg.get('enable_visualdl', False)
-        if self.enable_visualdl:
-            import visualdl
-            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
-
-        # evaluate only
-        if not cfg.is_train:
-            return
-
-        # build train dataloader
-        self.train_dataloader = build_dataloader(cfg.dataset.train)
-        self.iters_per_epoch = len(self.train_dataloader)
-
-        # build lr scheduler
-        # TODO: has a better way?
-        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
-            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
-        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
-
-        # build optimizers
-        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
-                                                      cfg.optimizer)
-
-        self.epochs = cfg.get('epochs', None)
-        if self.epochs:
-            self.total_iters = self.epochs * self.iters_per_epoch
-            self.by_epoch = True
-        else:
-            self.by_epoch = False
-            self.total_iters = cfg.total_iters
-
-        if self.by_epoch:
-            self.weight_interval *= self.iters_per_epoch
-
-        self.validate_interval = -1
-        if cfg.get('validate', None) is not None:
-            self.validate_interval = cfg.validate.get('interval', -1)
-
-        self.time_count = {}
-        self.best_metric = {}
-        self.model.set_total_iter(self.total_iters)
-        self.profiler_options = cfg.profiler_options
-
-    def distributed_data_parallel(self):
-        paddle.distributed.init_parallel_env()
-        find_unused_parameters = self.cfg.get('find_unused_parameters', False)
-        for net_name, net in self.model.nets.items():
-            self.model.nets[net_name] = paddle.DataParallel(
-                net, find_unused_parameters=find_unused_parameters)
-
-    def learning_rate_scheduler_step(self):
-        if isinstance(self.model.lr_scheduler, dict):
-            for lr_scheduler in self.model.lr_scheduler.values():
-                lr_scheduler.step()
-        elif isinstance(self.model.lr_scheduler,
-                        paddle.optimizer.lr.LRScheduler):
-            self.model.lr_scheduler.step()
-        else:
-            raise ValueError(
-                'lr schedulter must be a dict or an instance of LRScheduler')
-
-    def train(self):
-        reader_cost_averager = TimeAverager()
-        batch_cost_averager = TimeAverager()
-
-        iter_loader = IterLoader(self.train_dataloader)
-
-        # set model.is_train = True
-        self.model.setup_train_mode(is_train=True)
-        while self.current_iter < (self.total_iters + 1):
-            self.current_epoch = iter_loader.epoch
-            self.inner_iter = self.current_iter % self.iters_per_epoch
-
-            add_profiler_step(self.profiler_options)
-
-            start_time = step_start_time = time.time()
-            data = next(iter_loader)
-            reader_cost_averager.record(time.time() - step_start_time)
-            # unpack data from dataset and apply preprocessing
-            # data input should be dict
-            self.model.setup_input(data)
-            self.model.train_iter(self.optimizers)
-
-            batch_cost_averager.record(
-                time.time() - step_start_time,
-                num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
-
-            step_start_time = time.time()
-
-            if self.current_iter % self.log_interval == 0:
-                self.data_time = reader_cost_averager.get_average()
-                self.step_time = batch_cost_averager.get_average()
-                self.ips = batch_cost_averager.get_ips_average()
-                self.print_log()
-
-                reader_cost_averager.reset()
-                batch_cost_averager.reset()
-
-            if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
-                self.visual('visual_train')
-
-            self.learning_rate_scheduler_step()
-
-            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
-                self.test()
-
-            if self.current_iter % self.weight_interval == 0:
-                self.save(self.current_iter, 'weight', keep=-1)
-                self.save(self.current_iter)
-
-            self.current_iter += 1
-
-    def test(self):
-        if not hasattr(self, 'test_dataloader'):
-            self.test_dataloader = build_dataloader(
-                self.cfg.dataset.test, is_train=False)
-        iter_loader = IterLoader(self.test_dataloader)
-        if self.max_eval_steps is None:
-            self.max_eval_steps = len(self.test_dataloader)
-
-        if self.metrics:
-            for metric in self.metrics.values():
-                metric.reset()
-
-        # set model.is_train = False
-        self.model.setup_train_mode(is_train=False)
-
-        for i in range(self.max_eval_steps):
-            if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
-                self.logger.info('Test iter: [%d/%d]' % (
-                    i * self.world_size, self.max_eval_steps * self.world_size))
-
-            data = next(iter_loader)
-            self.model.setup_input(data)
-            self.model.test_iter(metrics=self.metrics)
-
-            if self.is_save_img:
-                visual_results = {}
-                current_paths = self.model.get_image_paths()
-                current_visuals = self.model.get_current_visuals()
-
-                if len(current_visuals) > 0 and list(current_visuals.values())[
-                        0].shape == 4:
-                    num_samples = list(current_visuals.values())[0].shape[0]
-                else:
-                    num_samples = 1
-
-                for j in range(num_samples):
-                    if j < len(current_paths):
-                        short_path = os.path.basename(current_paths[j])
-                        basename = os.path.splitext(short_path)[0]
-                    else:
-                        basename = '{:04d}_{:04d}'.format(i, j)
-                    for k, img_tensor in current_visuals.items():
-                        name = '%s_%s' % (basename, k)
-                        if len(img_tensor.shape) == 4:
-                            visual_results.update({name: img_tensor[j]})
-                        else:
-                            visual_results.update({name: img_tensor})
-
-                self.visual(
-                    'visual_test',
-                    visual_results=visual_results,
-                    step=self.batch_id,
-                    is_save_image=True)
-
-        if self.metrics:
-            for metric_name, metric in self.metrics.items():
-                self.logger.info("Metric {}: {:.4f}".format(
-                    metric_name, metric.accumulate()))
-
-    def print_log(self):
-        losses = self.model.get_current_losses()
-
-        message = ''
-        if self.by_epoch:
-            message += 'Epoch: %d/%d, iter: %d/%d ' % (
-                self.current_epoch, self.epochs, self.inner_iter,
-                self.iters_per_epoch)
-        else:
-            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
-
-        message += f'lr: {self.current_learning_rate:.3e} '
-
-        for k, v in losses.items():
-            message += '%s: %.3f ' % (k, v)
-            if self.enable_visualdl:
-                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
-
-        if hasattr(self, 'step_time'):
-            message += 'batch_cost: %.5f sec ' % self.step_time
-
-        if hasattr(self, 'data_time'):
-            message += 'reader_cost: %.5f sec ' % self.data_time
-
-        if hasattr(self, 'ips'):
-            message += 'ips: %.5f images/s ' % self.ips
-
-        if hasattr(self, 'step_time'):
-            eta = self.step_time * (self.total_iters - self.current_iter)
-            eta = eta if eta > 0 else 0
-
-            eta_str = str(datetime.timedelta(seconds=int(eta)))
-            message += f'eta: {eta_str}'
-
-        # print the message
-        self.logger.info(message)
-
-    @property
-    def current_learning_rate(self):
-        for optimizer in self.model.optimizers.values():
-            return optimizer.get_lr()
-
-    def visual(self,
-               results_dir,
-               visual_results=None,
-               step=None,
-               is_save_image=False):
-        """
-        visual the images, use visualdl or directly write to the directory
-        Parameters:
-            results_dir (str)     --  directory name which contains saved images
-            visual_results (dict) --  the results images dict
-            step (int)            --  global steps, used in visualdl
-            is_save_image (bool)  --  weather write to the directory or visualdl
-        """
-        self.model.compute_visuals()
-
-        if visual_results is None:
-            visual_results = self.model.get_current_visuals()
-
-        min_max = self.cfg.get('min_max', None)
-        if min_max is None:
-            min_max = (-1., 1.)
-
-        image_num = self.cfg.get('image_num', None)
-        if (image_num is None) or (not self.enable_visualdl):
-            image_num = 1
-        for label, image in visual_results.items():
-            image_numpy = tensor2img(image, min_max, image_num)
-            if (not is_save_image) and self.enable_visualdl:
-                self.vdl_logger.add_image(
-                    results_dir + '/' + label,
-                    image_numpy,
-                    step=step if step else self.global_steps,
-                    dataformats="HWC" if image_num == 1 else "NCHW")
-            else:
-                if self.cfg.is_train:
-                    if self.by_epoch:
-                        msg = 'epoch%.3d_' % self.current_epoch
-                    else:
-                        msg = 'iter%.3d_' % self.current_iter
-                else:
-                    msg = ''
-                makedirs(os.path.join(self.output_dir, results_dir))
-                img_path = os.path.join(self.output_dir, results_dir,
-                                        msg + '%s.png' % (label))
-                save_image(image_numpy, img_path)
-
-    def save(self, epoch, name='checkpoint', keep=1):
-        if self.local_rank != 0:
-            return
-
-        assert name in ['checkpoint', 'weight']
-
-        state_dicts = {}
-        if self.by_epoch:
-            save_filename = 'epoch_%s_%s.pdparams' % (
-                epoch // self.iters_per_epoch, name)
-        else:
-            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
-
-        os.makedirs(self.output_dir, exist_ok=True)
-        save_path = os.path.join(self.output_dir, save_filename)
-        for net_name, net in self.model.nets.items():
-            state_dicts[net_name] = net.state_dict()
-
-        if name == 'weight':
-            save(state_dicts, save_path)
-            return
-
-        state_dicts['epoch'] = epoch
-
-        for opt_name, opt in self.model.optimizers.items():
-            state_dicts[opt_name] = opt.state_dict()
-
-        save(state_dicts, save_path)
-
-        if keep > 0:
-            try:
-                if self.by_epoch:
-                    checkpoint_name_to_be_removed = os.path.join(
-                        self.output_dir, 'epoch_%s_%s.pdparams' % (
-                            (epoch - keep * self.weight_interval) //
-                            self.iters_per_epoch, name))
-                else:
-                    checkpoint_name_to_be_removed = os.path.join(
-                        self.output_dir, 'iter_%s_%s.pdparams' %
-                        (epoch - keep * self.weight_interval, name))
-
-                if os.path.exists(checkpoint_name_to_be_removed):
-                    os.remove(checkpoint_name_to_be_removed)
-
-            except Exception as e:
-                self.logger.info('remove old checkpoints error: {}'.format(e))
-
-    def resume(self, checkpoint_path):
-        state_dicts = load(checkpoint_path)
-        if state_dicts.get('epoch', None) is not None:
-            self.start_epoch = state_dicts['epoch'] + 1
-            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
-
-            self.current_iter = state_dicts['epoch'] + 1
-
-        for net_name, net in self.model.nets.items():
-            net.set_state_dict(state_dicts[net_name])
-
-        for opt_name, opt in self.model.optimizers.items():
-            opt.set_state_dict(state_dicts[opt_name])
-
-    def load(self, weight_path):
-        state_dicts = load(weight_path)
-
-        for net_name, net in self.model.nets.items():
-            if net_name in state_dicts:
-                net.set_state_dict(state_dicts[net_name])
-                self.logger.info('Loaded pretrained weight for net {}'.format(
-                    net_name))
-            else:
-                self.logger.warning(
-                    'Can not find state dict of net {}. Skip load pretrained weight for net {}'
-                    .format(net_name, net_name))
-
-    def close(self):
-        """
-        when finish the training need close file handler or other.
-        """
-        if self.enable_visualdl:
-            self.vdl_logger.close()
-
-
-# 基础超分模型训练类
-class BasicSRNet:
-    def __init__(self):
-        self.model = {}
-        self.optimizer = {}
-        self.lr_scheduler = {}
-        self.min_max = ''
-
-    def train(
-            self,
-            total_iters,
-            train_dataset,
-            test_dataset,
-            output_dir,
-            validate,
-            snapshot,
-            log,
-            lr_rate,
-            evaluate_weights='',
-            resume='',
-            pretrain_weights='',
-            periods=[100000],
-            restart_weights=[1], ):
-        self.lr_scheduler['learning_rate'] = lr_rate
-
-        if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR':
-            self.lr_scheduler['periods'] = periods
-            self.lr_scheduler['restart_weights'] = restart_weights
-
-        validate = {
-            'interval': validate,
-            'save_img': False,
-            'metrics': {
-                'psnr': {
-                    'name': 'PSNR',
-                    'crop_border': 4,
-                    'test_y_channel': True
-                },
-                'ssim': {
-                    'name': 'SSIM',
-                    'crop_border': 4,
-                    'test_y_channel': True
-                }
-            }
-        }
-        log_config = {'interval': log, 'visiual_interval': 500}
-        snapshot_config = {'interval': snapshot}
-
-        cfg = {
-            'total_iters': total_iters,
-            'output_dir': output_dir,
-            'min_max': self.min_max,
-            'model': self.model,
-            'dataset': {
-                'train': train_dataset,
-                'test': test_dataset
-            },
-            'lr_scheduler': self.lr_scheduler,
-            'optimizer': self.optimizer,
-            'validate': validate,
-            'log_config': log_config,
-            'snapshot_config': snapshot_config
-        }
-
-        cfg = AttrDict(cfg)
-        create_attr_dict(cfg)
-
-        cfg.is_train = True
-        cfg.profiler_options = None
-        cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
-
-        if cfg.model.name == 'BaseSRModel':
-            floderModelName = cfg.model.generator.name
-        else:
-            floderModelName = cfg.model.name
-        cfg.output_dir = os.path.join(cfg.output_dir,
-                                      floderModelName + cfg.timestamp)
-
-        logger_cfg = setup_logger(cfg.output_dir)
-        logger_cfg.info('Configs: {}'.format(cfg))
-
-        if paddle.is_compiled_with_cuda():
-            paddle.set_device('gpu')
-        else:
-            paddle.set_device('cpu')
-
-        # build trainer
-        trainer = Restorer(cfg, logger_cfg)
-
-        # continue train or evaluate, checkpoint need contain epoch and optimizer info
-        if len(resume) > 0:
-            trainer.resume(resume)
-        # evaluate or finute, only load generator weights
-        elif len(pretrain_weights) > 0:
-            trainer.load(pretrain_weights)
-        if len(evaluate_weights) > 0:
-            trainer.load(evaluate_weights)
-            trainer.test()
-            return
-        # training, when keyboard interrupt save weights
-        try:
-            trainer.train()
-        except KeyboardInterrupt as e:
-            trainer.save(trainer.current_epoch)
-
-        trainer.close()
-
-
-# DRN模型训练
-class DRNet(BasicSRNet):
-    def __init__(self,
-                 n_blocks=30,
-                 n_feats=16,
-                 n_colors=3,
-                 rgb_range=255,
-                 negval=0.2):
-        super(DRNet, self).__init__()
-        self.min_max = '(0., 255.)'
-        self.generator = {
-            'name': 'DRNGenerator',
-            'scale': (2, 4),
-            'n_blocks': n_blocks,
-            'n_feats': n_feats,
-            'n_colors': n_colors,
-            'rgb_range': rgb_range,
-            'negval': negval
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'DRN',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'optimG': {
-                'name': 'Adam',
-                'net_names': ['generator'],
-                'weight_decay': 0.0,
-                'beta1': 0.9,
-                'beta2': 0.999
-            },
-            'optimD': {
-                'name': 'Adam',
-                'net_names': ['dual_model_0', 'dual_model_1'],
-                'weight_decay': 0.0,
-                'beta1': 0.9,
-                'beta2': 0.999
-            }
-        }
-        self.lr_scheduler = {
-            'name': 'CosineAnnealingRestartLR',
-            'eta_min': 1e-07
-        }
-
-
-# 轻量化超分模型LESRCNN训练
-class LESRCNNet(BasicSRNet):
-    def __init__(self, scale=4, multi_scale=False, group=1):
-        super(LESRCNNet, self).__init__()
-        self.min_max = '(0., 1.)'
-        self.generator = {
-            'name': 'LESRCNNGenerator',
-            'scale': scale,
-            'multi_scale': False,
-            'group': 1
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'BaseSRModel',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'name': 'Adam',
-            'net_names': ['generator'],
-            'beta1': 0.9,
-            'beta2': 0.99
-        }
-        self.lr_scheduler = {
-            'name': 'CosineAnnealingRestartLR',
-            'eta_min': 1e-07
-        }
-
-
-# ESRGAN模型训练
-# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
-# 若loss_type = 'pixel' 只使用像素损失
-class ESRGANet(BasicSRNet):
-    def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23):
-        super(ESRGANet, self).__init__()
-        self.min_max = '(0., 1.)'
-        self.generator = {
-            'name': 'RRDBNet',
-            'in_nc': in_nc,
-            'out_nc': out_nc,
-            'nf': nf,
-            'nb': nb
-        }
-
-        if loss_type == 'gan':
-            # 定义损失函数
-            self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01}
-            self.discriminator = {
-                'name': 'VGGDiscriminator128',
-                'in_channels': 3,
-                'num_feat': 64
-            }
-            self.perceptual_criterion = {
-                'name': 'PerceptualLoss',
-                'layer_weights': {
-                    '34': 1.0
-                },
-                'perceptual_weight': 1.0,
-                'style_weight': 0.0,
-                'norm_img': False
-            }
-            self.gan_criterion = {
-                'name': 'GANLoss',
-                'gan_mode': 'vanilla',
-                'loss_weight': 0.005
-            }
-            # 定义模型 
-            self.model = {
-                'name': 'ESRGAN',
-                'generator': self.generator,
-                'discriminator': self.discriminator,
-                'pixel_criterion': self.pixel_criterion,
-                'perceptual_criterion': self.perceptual_criterion,
-                'gan_criterion': self.gan_criterion
-            }
-            self.optimizer = {
-                'optimG': {
-                    'name': 'Adam',
-                    'net_names': ['generator'],
-                    'weight_decay': 0.0,
-                    'beta1': 0.9,
-                    'beta2': 0.99
-                },
-                'optimD': {
-                    'name': 'Adam',
-                    'net_names': ['discriminator'],
-                    'weight_decay': 0.0,
-                    'beta1': 0.9,
-                    'beta2': 0.99
-                }
-            }
-            self.lr_scheduler = {
-                'name': 'MultiStepDecay',
-                'milestones': [50000, 100000, 200000, 300000],
-                'gamma': 0.5
-            }
-        else:
-            self.pixel_criterion = {'name': 'L1Loss'}
-            self.model = {
-                'name': 'BaseSRModel',
-                'generator': self.generator,
-                'pixel_criterion': self.pixel_criterion
-            }
-            self.optimizer = {
-                'name': 'Adam',
-                'net_names': ['generator'],
-                'beta1': 0.9,
-                'beta2': 0.99
-            }
-            self.lr_scheduler = {
-                'name': 'CosineAnnealingRestartLR',
-                'eta_min': 1e-07
-            }
-
-
-# RCAN模型训练
-class RCANet(BasicSRNet):
-    def __init__(
-            self,
-            scale=2,
-            n_resgroups=10,
-            n_resblocks=20, ):
-        super(RCANet, self).__init__()
-        self.min_max = '(0., 255.)'
-        self.generator = {
-            'name': 'RCAN',
-            'scale': scale,
-            'n_resgroups': n_resgroups,
-            'n_resblocks': n_resblocks
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'RCANModel',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'name': 'Adam',
-            'net_names': ['generator'],
-            'beta1': 0.9,
-            'beta2': 0.99
-        }
-        self.lr_scheduler = {
-            'name': 'MultiStepDecay',
-            'milestones': [250000, 500000, 750000, 1000000],
-            'gamma': 0.5
-        }

+ 36 - 15
paddlers/tasks/object_detector.py

@@ -61,6 +61,11 @@ class BaseDetector(BaseModel):
             net = ppdet.modeling.__dict__[self.model_name](**params)
         return net
 
+    def _build_inference_net(self):
+        infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
 
@@ -485,7 +490,7 @@ class BaseDetector(BaseModel):
                 Defaults to False.
 
         Returns:
-            collections.OrderedDict with key-value pairs: 
+            If `return_details` is False, return collections.OrderedDict with key-value pairs: 
                 {"bbox_mmap":`mean average precision (0.50, 11point)`}.
         """
 
@@ -584,21 +589,17 @@ class BaseDetector(BaseModel):
 
         Returns:
             If `img_file` is a string or np.array, the result is a list of dict with 
-                key-value pairs:
-                {"category_id": `category_id`, 
-                 "category": `category`, 
-                 "bbox": `[x, y, w, h]`, 
-                 "score": `score`, 
-                 "mask": `mask`}.
-            If `img_file` is a list, the result is a list composed of list of dicts 
-                with the corresponding fields:
-                category_id(int): the predicted category ID. 0 represents the first 
+                the following key-value pairs:
+                category_id (int): Predicted category ID. 0 represents the first 
                     category in the dataset, and so on.
-                category(str): category name
-                bbox(list): bounding box in [x, y, w, h] format
-                score(str): confidence
-                mask(dict): Only for instance segmentation task. Mask of the object in 
-                    RLE format
+                category (str): Category name.
+                bbox (list): Bounding box in [x, y, w, h] format.
+                score (str): Confidence.
+                mask (dict): Only for instance segmentation task. Mask of the object in 
+                    RLE format.
+
+            If `img_file` is a list, the result is a list composed of list of dicts 
+                with the above keys.
         """
 
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -926,6 +927,26 @@ class PicoDet(BaseDetector):
             in_args['optimizer'] = optimizer
         return in_args
 
+    def build_data_loader(self, dataset, batch_size, mode='train'):
+        if dataset.num_samples < batch_size:
+            raise ValueError(
+                'The volume of dataset({}) must be larger than batch size({}).'
+                .format(dataset.num_samples, batch_size))
+
+        if mode != 'train':
+            return paddle.io.DataLoader(
+                dataset,
+                batch_size=batch_size,
+                shuffle=dataset.shuffle,
+                drop_last=False,
+                collate_fn=dataset.batch_transforms,
+                num_workers=dataset.num_workers,
+                return_list=True,
+                use_shared_memory=False)
+        else:
+            return super(BaseDetector, self).build_data_loader(dataset,
+                                                               batch_size, mode)
+
 
 class YOLOv3(BaseDetector):
     def __init__(self,

+ 934 - 0
paddlers/tasks/restorer.py

@@ -0,0 +1,934 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import numpy as np
+import cv2
+import paddle
+import paddle.nn.functional as F
+from paddle.static import InputSpec
+
+import paddlers
+import paddlers.models.ppgan as ppgan
+import paddlers.rs_models.res as cmres
+import paddlers.models.ppgan.metrics as metrics
+import paddlers.utils.logging as logging
+from paddlers.models import res_losses
+from paddlers.transforms import Resize, decode_image
+from paddlers.transforms.functions import calc_hr_shape
+from paddlers.utils import get_single_card_bs
+from .base import BaseModel
+from .utils.res_adapters import GANAdapter, OptimizerAdapter
+from .utils.infer_nets import InferResNet
+
+__all__ = ["DRN", "LESRCNN", "ESRGAN"]
+
+
+class BaseRestorer(BaseModel):
+    MIN_MAX = (0., 1.)
+    TEST_OUT_KEY = None
+
+    def __init__(self, model_name, losses=None, sr_factor=None, **params):
+        self.init_params = locals()
+        if 'with_net' in self.init_params:
+            del self.init_params['with_net']
+        super(BaseRestorer, self).__init__('restorer')
+        self.model_name = model_name
+        self.losses = losses
+        self.sr_factor = sr_factor
+        if params.get('with_net', True):
+            params.pop('with_net', None)
+            self.net = self.build_net(**params)
+        self.find_unused_parameters = True
+
+    def build_net(self, **params):
+        # Currently, only use models from cmres.
+        if not hasattr(cmres, self.model_name):
+            raise ValueError("ERROR: There is no model named {}.".format(
+                model_name))
+        net = dict(**cmres.__dict__)[self.model_name](**params)
+        return net
+
+    def _build_inference_net(self):
+        # For GAN models, only the generator will be used for inference.
+        if isinstance(self.net, GANAdapter):
+            infer_net = InferResNet(
+                self.net.generator, out_key=self.TEST_OUT_KEY)
+        else:
+            infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY)
+        infer_net.eval()
+        return infer_net
+
+    def _fix_transforms_shape(self, image_shape):
+        if hasattr(self, 'test_transforms'):
+            if self.test_transforms is not None:
+                has_resize_op = False
+                resize_op_idx = -1
+                normalize_op_idx = len(self.test_transforms.transforms)
+                for idx, op in enumerate(self.test_transforms.transforms):
+                    name = op.__class__.__name__
+                    if name == 'Normalize':
+                        normalize_op_idx = idx
+                    if 'Resize' in name:
+                        has_resize_op = True
+                        resize_op_idx = idx
+
+                if not has_resize_op:
+                    self.test_transforms.transforms.insert(
+                        normalize_op_idx, Resize(target_size=image_shape))
+                else:
+                    self.test_transforms.transforms[resize_op_idx] = Resize(
+                        target_size=image_shape)
+
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            if len(image_shape) == 2:
+                image_shape = [1, 3] + image_shape
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+        self.fixed_input_shape = image_shape
+        input_spec = [
+            InputSpec(
+                shape=image_shape, name='image', dtype='float32')
+        ]
+        return input_spec
+
+    def run(self, net, inputs, mode):
+        outputs = OrderedDict()
+
+        if mode == 'test':
+            tar_shape = inputs[1]
+            if self.status == 'Infer':
+                net_out = net(inputs[0])
+                res_map_list = self.postprocess(
+                    net_out, tar_shape, transforms=inputs[2])
+            else:
+                if isinstance(net, GANAdapter):
+                    net_out = net.generator(inputs[0])
+                else:
+                    net_out = net(inputs[0])
+                if self.TEST_OUT_KEY is not None:
+                    net_out = net_out[self.TEST_OUT_KEY]
+                pred = self.postprocess(
+                    net_out, tar_shape, transforms=inputs[2])
+                res_map_list = []
+                for res_map in pred:
+                    res_map = self._tensor_to_images(res_map)
+                    res_map_list.append(res_map)
+            outputs['res_map'] = res_map_list
+
+        if mode == 'eval':
+            if isinstance(net, GANAdapter):
+                net_out = net.generator(inputs[0])
+            else:
+                net_out = net(inputs[0])
+            if self.TEST_OUT_KEY is not None:
+                net_out = net_out[self.TEST_OUT_KEY]
+            tar = inputs[1]
+            tar_shape = [tar.shape[-2:]]
+            pred = self.postprocess(
+                net_out, tar_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self._tensor_to_images(pred)
+            outputs['pred'] = pred
+            tar = self._tensor_to_images(tar)
+            outputs['tar'] = tar
+
+        if mode == 'train':
+            # This is used by non-GAN models.
+            # For GAN models, self.run_gan() should be used.
+            net_out = net(inputs[0])
+            loss = self.losses(net_out, inputs[1])
+            outputs['loss'] = loss
+        return outputs
+
+    def run_gan(self, net, inputs, mode, gan_mode):
+        raise NotImplementedError
+
+    def default_loss(self):
+        return res_losses.L1Loss()
+
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          num_epochs,
+                          num_steps_each_epoch,
+                          lr_decay_power=0.9):
+        decay_step = num_epochs * num_steps_each_epoch
+        lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+            learning_rate, decay_step, end_lr=0, power=lr_decay_power)
+        optimizer = paddle.optimizer.Momentum(
+            learning_rate=lr_scheduler,
+            parameters=parameters,
+            momentum=0.9,
+            weight_decay=4e-5)
+        return optimizer
+
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=2,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=2,
+              save_dir='output',
+              pretrain_weights=None,
+              learning_rate=0.01,
+              lr_decay_power=0.9,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+
+        Args:
+            num_epochs (int): Number of epochs.
+            train_dataset (paddlers.datasets.ResDataset): Training dataset.
+            train_batch_size (int, optional): Total batch size among all cards used in 
+                training. Defaults to 2.
+            eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. 
+                If None, the model will not be evaluated during training process. 
+                Defaults to None.
+            optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in 
+                training. If None, a default optimizer will be used. Defaults to None.
+            save_interval_epochs (int, optional): Epoch interval for saving the model. 
+                Defaults to 1.
+            log_interval_steps (int, optional): Step interval for printing training 
+                information. Defaults to 2.
+            save_dir (str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights (str|None, optional): None or name/path of pretrained 
+                weights. If None, no pretrained weights will be loaded. 
+                Defaults to None.
+            learning_rate (float, optional): Learning rate for training. Defaults to .01.
+            lr_decay_power (float, optional): Learning decay power. Defaults to .9.
+            early_stop (bool, optional): Whether to adopt early stop strategy. Defaults 
+                to False.
+            early_stop_patience (int, optional): Early stop patience. Defaults to 5.
+            use_vdl (bool, optional): Whether to use VisualDL to monitor the training 
+                process. Defaults to True.
+            resume_checkpoint (str|None, optional): Path of the checkpoint to resume
+                training from. If None, no training checkpoint will be resumed. At most
+                Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
+                Defaults to None.
+        """
+
+        if self.status == 'Infer':
+            logging.error(
+                "Exported inference model does not support training.",
+                exit=True)
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
+
+        if self.losses is None:
+            self.losses = self.default_loss()
+
+        if optimizer is None:
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            if isinstance(self.net, GANAdapter):
+                parameters = {'params_g': [], 'params_d': []}
+                for net_g in self.net.generators:
+                    parameters['params_g'].append(net_g.parameters())
+                for net_d in self.net.discriminators:
+                    parameters['params_d'].append(net_d.parameters())
+            else:
+                parameters = self.net.parameters()
+            self.optimizer = self.default_optimizer(
+                parameters, learning_rate, num_epochs, num_steps_each_epoch,
+                lr_decay_power)
+        else:
+            self.optimizer = optimizer
+
+        if pretrain_weights is not None and not osp.exists(pretrain_weights):
+            logging.warning("Path of pretrain_weights('{}') does not exist!".
+                            format(pretrain_weights))
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
+        pretrained_dir = osp.join(save_dir, 'pretrain')
+        is_backbone_weights = pretrain_weights == 'IMAGENET'
+        self.net_initialize(
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint,
+            is_backbone_weights=is_backbone_weights)
+
+        self.train_loop(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=2,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=2,
+                          save_dir='output',
+                          learning_rate=0.0001,
+                          lr_decay_power=0.9,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          resume_checkpoint=None,
+                          quant_config=None):
+        """
+        Quantization-aware training.
+
+        Args:
+            num_epochs (int): Number of epochs.
+            train_dataset (paddlers.datasets.ResDataset): Training dataset.
+            train_batch_size (int, optional): Total batch size among all cards used in 
+                training. Defaults to 2.
+            eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset.
+                If None, the model will not be evaluated during training process. 
+                Defaults to None.
+            optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in 
+                training. If None, a default optimizer will be used. Defaults to None.
+            save_interval_epochs (int, optional): Epoch interval for saving the model. 
+                Defaults to 1.
+            log_interval_steps (int, optional): Step interval for printing training 
+                information. Defaults to 2.
+            save_dir (str, optional): Directory to save the model. Defaults to 'output'.
+            learning_rate (float, optional): Learning rate for training. 
+                Defaults to .0001.
+            lr_decay_power (float, optional): Learning decay power. Defaults to .9.
+            early_stop (bool, optional): Whether to adopt early stop strategy. 
+                Defaults to False.
+            early_stop_patience (int, optional): Early stop patience. Defaults to 5.
+            use_vdl (bool, optional): Whether to use VisualDL to monitor the training 
+                process. Defaults to True.
+            quant_config (dict|None, optional): Quantization configuration. If None, 
+                a default rule of thumb configuration will be used. Defaults to None.
+            resume_checkpoint (str|None, optional): Path of the checkpoint to resume
+                quantization-aware training from. If None, no training checkpoint will
+                be resumed. Defaults to None.
+        """
+
+        self._prepare_qat(quant_config)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=None,
+            learning_rate=learning_rate,
+            lr_decay_power=lr_decay_power,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
+
+    def evaluate(self, eval_dataset, batch_size=1, return_details=False):
+        """
+        Evaluate the model.
+
+        Args:
+            eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset.
+            batch_size (int, optional): Total batch size among all cards used for 
+                evaluation. Defaults to 1.
+            return_details (bool, optional): Whether to return evaluation details. 
+                Defaults to False.
+
+        Returns:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
+                {"psnr": `peak signal-to-noise ratio`,
+                 "ssim": `structural similarity`}.
+
+        """
+
+        self._check_transforms(eval_dataset.transforms, 'eval')
+
+        self.net.eval()
+        nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
+        if nranks > 1:
+            # Initialize parallel environment if not done.
+            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+            ):
+                paddle.distributed.init_parallel_env()
+
+        # TODO: Distributed evaluation
+        if batch_size > 1:
+            logging.warning(
+                "Restorer only supports single card evaluation with batch_size=1 "
+                "during evaluation, so batch_size is forcibly set to 1.")
+            batch_size = 1
+
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset, batch_size=batch_size, mode='eval')
+            # XXX: Hard-code crop_border and test_y_channel
+            psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
+            ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
+            logging.info(
+                "Start to evaluate(total_samples={}, total_steps={})...".format(
+                    eval_dataset.num_samples, eval_dataset.num_samples))
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    data.append(eval_dataset.transforms.transforms)
+                    outputs = self.run(self.net, data, 'eval')
+                    psnr.update(outputs['pred'], outputs['tar'])
+                    ssim.update(outputs['pred'], outputs['tar'])
+
+            # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
+            assert len(psnr.results) > 0
+            assert len(ssim.results) > 0
+            eval_metrics = OrderedDict(
+                zip(['psnr', 'ssim'],
+                    [np.mean(psnr.results), np.mean(ssim.results)]))
+
+            if return_details:
+                # TODO: Add details
+                return eval_metrics, None
+
+            return eval_metrics
+
+    def predict(self, img_file, transforms=None):
+        """
+        Do inference.
+
+        Args:
+            img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded 
+                image data, which also could constitute a list, meaning all images to be 
+                predicted as a mini-batch.
+            transforms (paddlers.transforms.Compose|None, optional): Transforms for 
+                inputs. If None, the transforms for evaluation process will be used. 
+                Defaults to None.
+
+        Returns:
+            If `img_file` is a tuple of string or np.array, the result is a dict with 
+                the following key-value pairs:
+                res_map (np.ndarray): Restored image (HWC).
+
+            If `img_file` is a list, the result is a list composed of dicts with the 
+                above keys.
+        """
+
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise ValueError("transforms need to be defined, now is None.")
+        if transforms is None:
+            transforms = self.test_transforms
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            images = img_file
+        batch_im, batch_tar_shape = self.preprocess(images, transforms,
+                                                    self.model_type)
+        self.net.eval()
+        data = (batch_im, batch_tar_shape, transforms.transforms)
+        outputs = self.run(self.net, data, 'test')
+        res_map_list = outputs['res_map']
+        if isinstance(img_file, list):
+            prediction = [{'res_map': m} for m in res_map_list]
+        else:
+            prediction = {'res_map': res_map_list[0]}
+        return prediction
+
+    def preprocess(self, images, transforms, to_tensor=True):
+        self._check_transforms(transforms, 'test')
+        batch_im = list()
+        batch_tar_shape = list()
+        for im in images:
+            if isinstance(im, str):
+                im = decode_image(im, to_rgb=False)
+            ori_shape = im.shape[:2]
+            sample = {'image': im}
+            im = transforms(sample)[0]
+            batch_im.append(im)
+            batch_tar_shape.append(self._get_target_shape(ori_shape))
+        if to_tensor:
+            batch_im = paddle.to_tensor(batch_im)
+        else:
+            batch_im = np.asarray(batch_im)
+
+        return batch_im, batch_tar_shape
+
+    def _get_target_shape(self, ori_shape):
+        if self.sr_factor is None:
+            return ori_shape
+        else:
+            return calc_hr_shape(ori_shape, self.sr_factor)
+
+    @staticmethod
+    def get_transforms_shape_info(batch_tar_shape, transforms):
+        batch_restore_list = list()
+        for tar_shape in batch_tar_shape:
+            restore_list = list()
+            h, w = tar_shape[0], tar_shape[1]
+            for op in transforms:
+                if op.__class__.__name__ == 'Resize':
+                    restore_list.append(('resize', (h, w)))
+                    h, w = op.target_size
+                elif op.__class__.__name__ == 'ResizeByShort':
+                    restore_list.append(('resize', (h, w)))
+                    im_short_size = min(h, w)
+                    im_long_size = max(h, w)
+                    scale = float(op.short_size) / float(im_short_size)
+                    if 0 < op.max_size < np.round(scale * im_long_size):
+                        scale = float(op.max_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'ResizeByLong':
+                    restore_list.append(('resize', (h, w)))
+                    im_long_size = max(h, w)
+                    scale = float(op.long_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'Pad':
+                    if op.target_size:
+                        target_h, target_w = op.target_size
+                    else:
+                        target_h = int(
+                            (np.ceil(h / op.size_divisor) * op.size_divisor))
+                        target_w = int(
+                            (np.ceil(w / op.size_divisor) * op.size_divisor))
+
+                    if op.pad_mode == -1:
+                        offsets = op.offsets
+                    elif op.pad_mode == 0:
+                        offsets = [0, 0]
+                    elif op.pad_mode == 1:
+                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
+                    else:
+                        offsets = [target_h - h, target_w - w]
+                    restore_list.append(('padding', (h, w), offsets))
+                    h, w = target_h, target_w
+
+            batch_restore_list.append(restore_list)
+        return batch_restore_list
+
+    def postprocess(self, batch_pred, batch_tar_shape, transforms):
+        batch_restore_list = BaseRestorer.get_transforms_shape_info(
+            batch_tar_shape, transforms)
+        if self.status == 'Infer':
+            return self._infer_postprocess(
+                batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
+        results = []
+        if batch_pred.dtype == paddle.float32:
+            mode = 'bilinear'
+        else:
+            mode = 'nearest'
+        for pred, restore_list in zip(batch_pred, batch_restore_list):
+            pred = paddle.unsqueeze(pred, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    pred = F.interpolate(
+                        pred, (h, w), mode=mode, data_format='NCHW')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    pred = pred[:, :, y:y + h, x:x + w]
+                else:
+                    pass
+            results.append(pred)
+        return results
+
+    def _infer_postprocess(self, batch_res_map, batch_restore_list):
+        res_maps = []
+        for res_map, restore_list in zip(batch_res_map, batch_restore_list):
+            if not isinstance(res_map, np.ndarray):
+                res_map = paddle.unsqueeze(res_map, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    if isinstance(res_map, np.ndarray):
+                        res_map = cv2.resize(
+                            res_map, (w, h), interpolation=cv2.INTER_LINEAR)
+                    else:
+                        res_map = F.interpolate(
+                            res_map, (h, w),
+                            mode='bilinear',
+                            data_format='NHWC')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    if isinstance(res_map, np.ndarray):
+                        res_map = res_map[y:y + h, x:x + w]
+                    else:
+                        res_map = res_map[:, y:y + h, x:x + w, :]
+                else:
+                    pass
+            res_map = res_map.squeeze()
+            if not isinstance(res_map, np.ndarray):
+                res_map = res_map.numpy()
+            res_map = self._normalize(res_map)
+            res_maps.append(res_map.squeeze())
+        return res_maps
+
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeRestorer):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeRestorer object.")
+
+    def build_data_loader(self, dataset, batch_size, mode='train'):
+        if dataset.num_samples < batch_size:
+            raise ValueError(
+                'The volume of dataset({}) must be larger than batch size({}).'
+                .format(dataset.num_samples, batch_size))
+
+        if mode != 'train':
+            return paddle.io.DataLoader(
+                dataset,
+                batch_size=batch_size,
+                shuffle=dataset.shuffle,
+                drop_last=False,
+                collate_fn=dataset.batch_transforms,
+                num_workers=dataset.num_workers,
+                return_list=True,
+                use_shared_memory=False)
+        else:
+            return super(BaseRestorer, self).build_data_loader(dataset,
+                                                               batch_size, mode)
+
+    def set_losses(self, losses):
+        self.losses = losses
+
+    def _tensor_to_images(self,
+                          tensor,
+                          transpose=True,
+                          squeeze=True,
+                          quantize=True):
+        if transpose:
+            tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1])  # NHWC
+        if squeeze:
+            tensor = tensor.squeeze()
+        images = tensor.numpy().astype('float32')
+        images = self._normalize(
+            images, copy=True, clip=True, quantize=quantize)
+        return images
+
+    def _normalize(self, im, copy=False, clip=True, quantize=True):
+        if copy:
+            im = im.copy()
+        if clip:
+            im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
+        im -= im.min()
+        im /= im.max() + 1e-32
+        if quantize:
+            im *= 255
+            im = im.astype('uint8')
+        return im
+
+
+class DRN(BaseRestorer):
+    TEST_OUT_KEY = -1
+
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 scale=(2, 4),
+                 n_blocks=30,
+                 n_feats=16,
+                 n_colors=3,
+                 rgb_range=1.0,
+                 negval=0.2,
+                 lq_loss_weight=0.1,
+                 dual_loss_weight=0.1,
+                 **params):
+        if sr_factor != max(scale):
+            raise ValueError(f"`sr_factor` must be equal to `max(scale)`.")
+        params.update({
+            'scale': scale,
+            'n_blocks': n_blocks,
+            'n_feats': n_feats,
+            'n_colors': n_colors,
+            'rgb_range': rgb_range,
+            'negval': negval
+        })
+        self.lq_loss_weight = lq_loss_weight
+        self.dual_loss_weight = dual_loss_weight
+        super(DRN, self).__init__(
+            model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        from ppgan.modules.init import init_weights
+        generators = [ppgan.models.generators.DRNGenerator(**params)]
+        init_weights(generators[-1])
+        for scale in params['scale']:
+            dual_model = ppgan.models.generators.drn.DownBlock(
+                params['negval'], params['n_feats'], params['n_colors'], 2)
+            generators.append(dual_model)
+            init_weights(generators[-1])
+        return GANAdapter(generators, [])
+
+    def default_optimizer(self, parameters, *args, **kwargs):
+        optims_g = [
+            super(DRN, self).default_optimizer(params_g, *args, **kwargs)
+            for params_g in parameters['params_g']
+        ]
+        return OptimizerAdapter(*optims_g)
+
+    def run_gan(self, net, inputs, mode, gan_mode='forward_primary'):
+        if mode != 'train':
+            raise ValueError("`mode` is not 'train'.")
+        outputs = OrderedDict()
+        if gan_mode == 'forward_primary':
+            sr = net.generator(inputs[0])
+            lr = [inputs[0]]
+            lr.extend([
+                F.interpolate(
+                    inputs[0], scale_factor=s, mode='bicubic')
+                for s in net.generator.scale[:-1]
+            ])
+            loss = self.losses(sr[-1], inputs[1])
+            for i in range(1, len(sr)):
+                if self.lq_loss_weight > 0:
+                    loss += self.losses(sr[i - 1 - len(sr)],
+                                        lr[i - len(sr)]) * self.lq_loss_weight
+            outputs['loss_prim'] = loss
+            outputs['sr'] = sr
+            outputs['lr'] = lr
+        elif gan_mode == 'forward_dual':
+            sr, lr = inputs[0], inputs[1]
+            sr2lr = []
+            n_scales = len(net.generator.scale)
+            for i in range(n_scales):
+                sr2lr_i = net.generators[1 + i](sr[i - n_scales])
+                sr2lr.append(sr2lr_i)
+            loss = self.losses(sr2lr[0], lr[0])
+            for i in range(1, n_scales):
+                if self.dual_loss_weight > 0.0:
+                    loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight
+            outputs['loss_dual'] = loss
+        else:
+            raise ValueError("Invalid `gan_mode`!")
+        return outputs
+
+    def train_step(self, step, data, net):
+        outputs = self.run_gan(
+            net, data, mode='train', gan_mode='forward_primary')
+        outputs.update(
+            self.run_gan(
+                net, (outputs['sr'], outputs['lr']),
+                mode='train',
+                gan_mode='forward_dual'))
+        self.optimizer.clear_grad()
+        (outputs['loss_prim'] + outputs['loss_dual']).backward()
+        self.optimizer.step()
+        return {
+            'loss_prim': outputs['loss_prim'],
+            'loss_dual': outputs['loss_dual']
+        }
+
+
+class LESRCNN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 multi_scale=False,
+                 group=1,
+                 **params):
+        params.update({
+            'scale': sr_factor,
+            'multi_scale': multi_scale,
+            'group': group
+        })
+        super(LESRCNN, self).__init__(
+            model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        net = ppgan.models.generators.LESRCNNGenerator(**params)
+        return net
+
+
+class ESRGAN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 use_gan=True,
+                 in_channels=3,
+                 out_channels=3,
+                 nf=64,
+                 nb=23,
+                 **params):
+        if sr_factor != 4:
+            raise ValueError("`sr_factor` must be 4.")
+        params.update({
+            'in_nc': in_channels,
+            'out_nc': out_channels,
+            'nf': nf,
+            'nb': nb
+        })
+        self.use_gan = use_gan
+        super(ESRGAN, self).__init__(
+            model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        from ppgan.modules.init import init_weights
+        generator = ppgan.models.generators.RRDBNet(**params)
+        init_weights(generator)
+        if self.use_gan:
+            discriminator = ppgan.models.discriminators.VGGDiscriminator128(
+                in_channels=params['out_nc'], num_feat=64)
+            net = GANAdapter(
+                generators=[generator], discriminators=[discriminator])
+        else:
+            net = generator
+        return net
+
+    def default_loss(self):
+        if self.use_gan:
+            return {
+                'pixel': res_losses.L1Loss(loss_weight=0.01),
+                'perceptual': res_losses.PerceptualLoss(
+                    layer_weights={'34': 1.0},
+                    perceptual_weight=1.0,
+                    style_weight=0.0,
+                    norm_img=False),
+                'gan': res_losses.GANLoss(
+                    gan_mode='vanilla', loss_weight=0.005)
+            }
+        else:
+            return res_losses.L1Loss()
+
+    def default_optimizer(self, parameters, *args, **kwargs):
+        if self.use_gan:
+            optim_g = super(ESRGAN, self).default_optimizer(
+                parameters['params_g'][0], *args, **kwargs)
+            optim_d = super(ESRGAN, self).default_optimizer(
+                parameters['params_d'][0], *args, **kwargs)
+            return OptimizerAdapter(optim_g, optim_d)
+        else:
+            return super(ESRGAN, self).default_optimizer(parameters, *args,
+                                                         **kwargs)
+
+    def run_gan(self, net, inputs, mode, gan_mode='forward_g'):
+        if mode != 'train':
+            raise ValueError("`mode` is not 'train'.")
+        outputs = OrderedDict()
+        if gan_mode == 'forward_g':
+            loss_g = 0
+            g_pred = net.generator(inputs[0])
+            loss_pix = self.losses['pixel'](g_pred, inputs[1])
+            loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1])
+            loss_g += loss_pix
+            if loss_perc is not None:
+                loss_g += loss_perc
+            if loss_sty is not None:
+                loss_g += loss_sty
+            self._set_requires_grad(net.discriminator, False)
+            real_d_pred = net.discriminator(inputs[1]).detach()
+            fake_g_pred = net.discriminator(g_pred)
+            loss_g_real = self.losses['gan'](
+                real_d_pred - paddle.mean(fake_g_pred), False,
+                is_disc=False) * 0.5
+            loss_g_fake = self.losses['gan'](
+                fake_g_pred - paddle.mean(real_d_pred), True,
+                is_disc=False) * 0.5
+            loss_g_gan = loss_g_real + loss_g_fake
+            outputs['g_pred'] = g_pred.detach()
+            outputs['loss_g_pps'] = loss_g
+            outputs['loss_g_gan'] = loss_g_gan
+        elif gan_mode == 'forward_d':
+            self._set_requires_grad(net.discriminator, True)
+            # Real
+            fake_d_pred = net.discriminator(inputs[0]).detach()
+            real_d_pred = net.discriminator(inputs[1])
+            loss_d_real = self.losses['gan'](
+                real_d_pred - paddle.mean(fake_d_pred), True,
+                is_disc=True) * 0.5
+            # Fake
+            fake_d_pred = net.discriminator(inputs[0].detach())
+            loss_d_fake = self.losses['gan'](
+                fake_d_pred - paddle.mean(real_d_pred.detach()),
+                False,
+                is_disc=True) * 0.5
+            outputs['loss_d'] = loss_d_real + loss_d_fake
+        else:
+            raise ValueError("Invalid `gan_mode`!")
+        return outputs
+
+    def train_step(self, step, data, net):
+        if self.use_gan:
+            optim_g, optim_d = self.optimizer
+
+            outputs = self.run_gan(
+                net, data, mode='train', gan_mode='forward_g')
+            optim_g.clear_grad()
+            (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
+            optim_g.step()
+
+            outputs.update(
+                self.run_gan(
+                    net, (outputs['g_pred'], data[1]),
+                    mode='train',
+                    gan_mode='forward_d'))
+            optim_d.clear_grad()
+            outputs['loss_d'].backward()
+            optim_d.step()
+
+            outputs['loss'] = outputs['loss_g_pps'] + outputs[
+                'loss_g_gan'] + outputs['loss_d']
+
+            return {
+                'loss': outputs['loss'],
+                'loss_g_pps': outputs['loss_g_pps'],
+                'loss_g_gan': outputs['loss_g_gan'],
+                'loss_d': outputs['loss_d']
+            }
+        else:
+            return super(ESRGAN, self).train_step(step, data, net)
+
+    def _set_requires_grad(self, net, requires_grad):
+        for p in net.parameters():
+            p.trainable = requires_grad
+
+
+class RCAN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 n_resgroups=10,
+                 n_resblocks=20,
+                 n_feats=64,
+                 n_colors=3,
+                 rgb_range=1.0,
+                 kernel_size=3,
+                 reduction=16,
+                 **params):
+        params.update({
+            'n_resgroups': n_resgroups,
+            'n_resblocks': n_resblocks,
+            'n_feats': n_feats,
+            'n_colors': n_colors,
+            'rgb_range': rgb_range,
+            'kernel_size': kernel_size,
+            'reduction': reduction
+        })
+        super(RCAN, self).__init__(
+            model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)

+ 17 - 12
paddlers/tasks/segmenter.py

@@ -33,6 +33,7 @@ from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
+from .utils.infer_nets import InferSegNet
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
@@ -64,11 +65,16 @@ class BaseSegmenter(BaseModel):
 
     def build_net(self, **params):
         # TODO: when using paddle.utils.unique_name.guard,
-        # DeepLabv3p and HRNet will raise a error
+        # DeepLabv3p and HRNet will raise an error.
         net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
             num_classes=self.num_classes, **params)
         return net
 
+    def _build_inference_net(self):
+        infer_net = InferSegNet(self.net)
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
@@ -472,7 +478,6 @@ class BaseSegmenter(BaseModel):
                     conf_mat_all.append(conf_mat)
         class_iou, miou = ppseg.utils.metrics.mean_iou(
             intersect_area_all, pred_area_all, label_area_all)
-        # TODO 确认是按oacc还是macc
         class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
                                                        pred_area_all)
         kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
@@ -504,13 +509,13 @@ class BaseSegmenter(BaseModel):
                 Defaults to None.
 
         Returns:
-            If `img_file` is a string or np.array, the result is a dict with key-value 
-                pairs:
-                {"label map": `label map`, "score_map": `score map`}.
+            If `img_file` is a tuple of string or np.array, the result is a dict with 
+                the following key-value pairs:
+                label_map (np.ndarray): Predicted label map (HW).
+                score_map (np.ndarray): Prediction score map (HWC).
+
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
-                label_map (np.ndarray): the predicted label map (HW)
-                score_map (np.ndarray): the prediction score map (HWC)
+                above keys.
         """
 
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -750,11 +755,11 @@ class BaseSegmenter(BaseModel):
                 elif item[0] == 'padding':
                     x, y = item[2]
                     if isinstance(label_map, np.ndarray):
-                        label_map = label_map[..., y:y + h, x:x + w]
-                        score_map = score_map[..., y:y + h, x:x + w]
+                        label_map = label_map[y:y + h, x:x + w]
+                        score_map = score_map[y:y + h, x:x + w]
                     else:
-                        label_map = label_map[:, :, y:y + h, x:x + w]
-                        score_map = score_map[:, :, y:y + h, x:x + w]
+                        label_map = label_map[:, y:y + h, x:x + w, :]
+                        score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
                     pass
             label_map = label_map.squeeze()

+ 28 - 11
paddlers/tasks/utils/infer_nets.py

@@ -15,30 +15,36 @@
 import paddle
 
 
-class PostProcessor(paddle.nn.Layer):
-    def __init__(self, model_type):
-        super(PostProcessor, self).__init__()
-        self.model_type = model_type
-
+class SegPostProcessor(paddle.nn.Layer):
     def forward(self, net_outputs):
         # label_map [NHW], score_map [NHWC]
         logit = net_outputs[0]
         outputs = paddle.argmax(logit, axis=1, keepdim=False, dtype='int32'), \
                     paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1])
+        return outputs
+
+
+class ResPostProcessor(paddle.nn.Layer):
+    def __init__(self, out_key=None):
+        super(ResPostProcessor, self).__init__()
+        self.out_key = out_key
 
+    def forward(self, net_outputs):
+        if self.out_key is not None:
+            net_outputs = net_outputs[self.out_key]
+        outputs = paddle.transpose(net_outputs, perm=[0, 2, 3, 1])
         return outputs
 
 
-class InferNet(paddle.nn.Layer):
-    def __init__(self, net, model_type):
-        super(InferNet, self).__init__()
+class InferSegNet(paddle.nn.Layer):
+    def __init__(self, net):
+        super(InferSegNet, self).__init__()
         self.net = net
-        self.postprocessor = PostProcessor(model_type)
+        self.postprocessor = SegPostProcessor()
 
     def forward(self, x):
         net_outputs = self.net(x)
         outputs = self.postprocessor(net_outputs)
-
         return outputs
 
 
@@ -46,10 +52,21 @@ class InferCDNet(paddle.nn.Layer):
     def __init__(self, net):
         super(InferCDNet, self).__init__()
         self.net = net
-        self.postprocessor = PostProcessor('change_detector')
+        self.postprocessor = SegPostProcessor()
 
     def forward(self, x1, x2):
         net_outputs = self.net(x1, x2)
         outputs = self.postprocessor(net_outputs)
+        return outputs
+
+
+class InferResNet(paddle.nn.Layer):
+    def __init__(self, net, out_key=None):
+        super(InferResNet, self).__init__()
+        self.net = net
+        self.postprocessor = ResPostProcessor(out_key=out_key)
 
+    def forward(self, x):
+        net_outputs = self.net(x)
+        outputs = self.postprocessor(net_outputs)
         return outputs

+ 132 - 0
paddlers/tasks/utils/res_adapters.py

@@ -0,0 +1,132 @@
+from functools import wraps
+from inspect import isfunction, isgeneratorfunction, getmembers
+from collections.abc import Sequence
+from abc import ABC
+
+import paddle
+import paddle.nn as nn
+
+__all__ = ['GANAdapter', 'OptimizerAdapter']
+
+
+class _AttrDesc:
+    def __init__(self, key):
+        self.key = key
+
+    def __get__(self, instance, owner):
+        return tuple(getattr(ele, self.key) for ele in instance)
+
+    def __set__(self, instance, value):
+        for ele in instance:
+            setattr(ele, self.key, value)
+
+
+def _func_deco(cls, func_name):
+    @wraps(getattr(cls.__ducktype__, func_name))
+    def _wrapper(self, *args, **kwargs):
+        return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
+
+    return _wrapper
+
+
+def _generator_deco(cls, func_name):
+    @wraps(getattr(cls.__ducktype__, func_name))
+    def _wrapper(self, *args, **kwargs):
+        for ele in self:
+            yield from getattr(ele, func_name)(*args, **kwargs)
+
+    return _wrapper
+
+
+class Adapter(Sequence, ABC):
+    __ducktype__ = object
+    __ava__ = ()
+
+    def __init__(self, *args):
+        if not all(map(self._check, args)):
+            raise TypeError("Please check the input type.")
+        self._seq = tuple(args)
+
+    def __getitem__(self, key):
+        return self._seq[key]
+
+    def __len__(self):
+        return len(self._seq)
+
+    def __repr__(self):
+        return repr(self._seq)
+
+    @classmethod
+    def _check(cls, obj):
+        for attr in cls.__ava__:
+            try:
+                getattr(obj, attr)
+                # TODO: Check function signature
+            except AttributeError:
+                return False
+        return True
+
+
+def make_adapter(cls):
+    members = dict(getmembers(cls.__ducktype__))
+    for k in cls.__ava__:
+        if hasattr(cls, k):
+            continue
+        if k in members:
+            v = members[k]
+            if isgeneratorfunction(v):
+                setattr(cls, k, _generator_deco(cls, k))
+            elif isfunction(v):
+                setattr(cls, k, _func_deco(cls, k))
+            else:
+                setattr(cls, k, _AttrDesc(k))
+    return cls
+
+
+class GANAdapter(nn.Layer):
+    __ducktype__ = nn.Layer
+    __ava__ = ('state_dict', 'set_state_dict', 'train', 'eval')
+
+    def __init__(self, generators, discriminators):
+        super(GANAdapter, self).__init__()
+        self.generators = nn.LayerList(generators)
+        self.discriminators = nn.LayerList(discriminators)
+        self._m = [*generators, *discriminators]
+
+    def __len__(self):
+        return len(self._m)
+
+    def __getitem__(self, key):
+        return self._m[key]
+
+    def __contains__(self, m):
+        return m in self._m
+
+    def __repr__(self):
+        return repr(self._m)
+
+    @property
+    def generator(self):
+        return self.generators[0]
+
+    @property
+    def discriminator(self):
+        return self.discriminators[0]
+
+
+Adapter.register(GANAdapter)
+
+
+@make_adapter
+class OptimizerAdapter(Adapter):
+    __ducktype__ = paddle.optimizer.Optimizer
+    __ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr')
+
+    def set_state_dict(self, state_dicts):
+        # Special dispatching rule
+        for optim, state_dict in zip(self, state_dicts):
+            optim.set_state_dict(state_dict)
+
+    def get_lr(self):
+        # Return the lr of the first optimizer
+        return self[0].get_lr()

+ 4 - 0
paddlers/transforms/functions.py

@@ -638,3 +638,7 @@ def decode_seg_mask(mask_path):
     mask = np.asarray(Image.open(mask_path))
     mask = mask.astype('int64')
     return mask
+
+
+def calc_hr_shape(lr_shape, sr_factor):
+    return tuple(int(s * sr_factor) for s in lr_shape)

+ 102 - 15
paddlers/transforms/operators.py

@@ -35,7 +35,7 @@ from .functions import (
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
     vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
     resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
-    img_flip, img_simple_rotate, decode_seg_mask)
+    img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape)
 
 __all__ = [
     "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
@@ -44,7 +44,7 @@ __all__ = [
     "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
     "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
     "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
-    "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask"
+    "ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask"
 ]
 
 interp_dict = {
@@ -154,6 +154,8 @@ class Transform(object):
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
+        if 'target' in sample:
+            sample['target'] = self.apply_im(sample['target'])
 
         return sample
 
@@ -336,6 +338,14 @@ class DecodeImg(Transform):
                 map(self.apply_mask, sample['aux_masks']))
             # TODO: check the shape of auxiliary masks
 
+        if 'target' in sample:
+            if self.read_geo_info:
+                target, geo_info_dict = self.apply_im(sample['target'])
+                sample['target'] = target
+                sample['geo_info_dict_tar'] = geo_info_dict
+            else:
+                sample['target'] = self.apply_im(sample['target'])
+
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
@@ -457,6 +467,17 @@ class Resize(Transform):
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
+        if 'target' in sample:
+            if 'sr_factor' in sample:
+                # For SR tasks
+                sample['target'] = self.apply_im(
+                    sample['target'], interp,
+                    calc_hr_shape(target_size, sample['sr_factor']))
+            else:
+                # For non-SR tasks
+                sample['target'] = self.apply_im(sample['target'], interp,
+                                                 target_size)
+
         sample['im_shape'] = np.asarray(
             sample['image'].shape[:2], dtype=np.float32)
         if 'scale_factor' in sample:
@@ -730,6 +751,9 @@ class RandomFlipOrRotate(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                                                     True)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'], mode_id,
+                                                 True)
         elif p_m < self.probs[1]:
             mode_p = random.random()
             mode_id = self.judge_probs_range(mode_p, self.probsr)
@@ -750,6 +774,9 @@ class RandomFlipOrRotate(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                                                     False)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'], mode_id,
+                                                 False)
 
         return sample
 
@@ -809,6 +836,8 @@ class RandomHorizontalFlip(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                                                     im_w)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'])
         return sample
 
 
@@ -867,6 +896,8 @@ class RandomVerticalFlip(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                                                     im_w)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'])
         return sample
 
 
@@ -884,15 +915,18 @@ class Normalize(Transform):
             image(s). Defaults to [0.229, 0.224, 0.225].
         min_val (list[float] | tuple[float], optional): Minimum value of input 
             image(s). If None, use 0 for all channels. Defaults to None.
-        max_val (list[float] | tuple[float], optional): Max value of input image(s). 
-            If None, use 255. for all channels. Defaults to None.
+        max_val (list[float] | tuple[float], optional): Maximum value of input 
+            image(s). If None, use 255. for all channels. Defaults to None.
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
 
     def __init__(self,
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225],
                  min_val=None,
-                 max_val=None):
+                 max_val=None,
+                 apply_to_tar=True):
         super(Normalize, self).__init__()
         channel = len(mean)
         if min_val is None:
@@ -914,6 +948,7 @@ class Normalize(Transform):
         self.std = std
         self.min_val = min_val
         self.max_val = max_val
+        self.apply_to_tar = apply_to_tar
 
     def apply_im(self, image):
         image = image.astype(np.float32)
@@ -927,6 +962,8 @@ class Normalize(Transform):
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
 
         return sample
 
@@ -964,6 +1001,8 @@ class CenterCrop(Transform):
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
+        if 'target' in sample:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
 
 
@@ -1165,6 +1204,14 @@ class RandomCrop(Transform):
                         self.apply_mask, crop=crop_box),
                         sample['aux_masks']))
 
+            if 'target' in sample:
+                if 'sr_factor' in sample:
+                    sample['target'] = self.apply_im(
+                        sample['target'],
+                        calc_hr_shape(crop_box, sample['sr_factor']))
+                else:
+                    sample['target'] = self.apply_im(sample['image'], crop_box)
+
         if self.crop_size is not None:
             sample = Resize(self.crop_size)(sample)
 
@@ -1266,6 +1313,7 @@ class Pad(Transform):
             pad_mode (int, optional): Pad mode. Currently only four modes are supported:
                 [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
                 If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
+            offsets (list[int]|None, optional): Padding offsets. Defaults to None.
             im_padding_value (list[float] | tuple[float]): RGB value of padded area. 
                 Defaults to (127.5, 127.5, 127.5).
             label_padding_value (int, optional): Filling value for the mask. 
@@ -1332,6 +1380,17 @@ class Pad(Transform):
                     expand_rle(segm, x, y, height, width, h, w))
         return expanded_segms
 
+    def _get_offsets(self, im_h, im_w, h, w):
+        if self.pad_mode == -1:
+            offsets = self.offsets
+        elif self.pad_mode == 0:
+            offsets = [0, 0]
+        elif self.pad_mode == 1:
+            offsets = [(w - im_w) // 2, (h - im_h) // 2]
+        else:
+            offsets = [w - im_w, h - im_h]
+        return offsets
+
     def apply(self, sample):
         im_h, im_w = sample['image'].shape[:2]
         if self.target_size:
@@ -1349,14 +1408,7 @@ class Pad(Transform):
         if h == im_h and w == im_w:
             return sample
 
-        if self.pad_mode == -1:
-            offsets = self.offsets
-        elif self.pad_mode == 0:
-            offsets = [0, 0]
-        elif self.pad_mode == 1:
-            offsets = [(w - im_w) // 2, (h - im_h) // 2]
-        else:
-            offsets = [w - im_w, h - im_h]
+        offsets = self._get_offsets(im_h, im_w, h, w)
 
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         if 'image2' in sample:
@@ -1373,6 +1425,16 @@ class Pad(Transform):
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
+        if 'target' in sample:
+            if 'sr_factor' in sample:
+                hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
+                hr_offsets = self._get_offsets(*sample['target'].shape[:2],
+                                               *hr_shape)
+                sample['target'] = self.apply_im(sample['target'], hr_offsets,
+                                                 hr_shape)
+            else:
+                sample['target'] = self.apply_im(sample['target'], offsets,
+                                                 (h, w))
         return sample
 
 
@@ -1688,15 +1750,18 @@ class ReduceDim(Transform):
 
     Args: 
         joblib_path (str): Path of *.joblib file of PCA.
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
 
-    def __init__(self, joblib_path):
+    def __init__(self, joblib_path, apply_to_tar=True):
         super(ReduceDim, self).__init__()
         ext = joblib_path.split(".")[-1]
         if ext != "joblib":
             raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
                 ext))
         self.pca = load(joblib_path)
+        self.apply_to_tar = apply_to_tar
 
     def apply_im(self, image):
         H, W, C = image.shape
@@ -1709,6 +1774,8 @@ class ReduceDim(Transform):
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
 
 
@@ -1719,11 +1786,14 @@ class SelectBand(Transform):
     Args: 
         band_list (list, optional): Bands to select (band index starts from 1). 
             Defaults to [1, 2, 3].
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
 
-    def __init__(self, band_list=[1, 2, 3]):
+    def __init__(self, band_list=[1, 2, 3], apply_to_tar=True):
         super(SelectBand, self).__init__()
         self.band_list = band_list
+        self.apply_to_tar = apply_to_tar
 
     def apply_im(self, image):
         image = select_bands(image, self.band_list)
@@ -1733,6 +1803,8 @@ class SelectBand(Transform):
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
 
 
@@ -1820,6 +1892,8 @@ class _Permute(Transform):
         sample['image'] = permute(sample['image'], False)
         if 'image2' in sample:
             sample['image2'] = permute(sample['image2'], False)
+        if 'target' in sample:
+            sample['target'] = permute(sample['target'], False)
         return sample
 
 
@@ -1915,3 +1989,16 @@ class ArrangeDetector(Arrange):
         if self.mode == 'eval' and 'gt_poly' in sample:
             del sample['gt_poly']
         return sample
+
+
+class ArrangeRestorer(Arrange):
+    def apply(self, sample):
+        if 'target' in sample:
+            target = permute(sample['target'], False)
+        image = permute(sample['image'], False)
+        if self.mode == 'train':
+            return image, target
+        if self.mode == 'eval':
+            return image, target
+        if self.mode == 'test':
+            return image,

+ 1 - 1
paddlers/utils/__init__.py

@@ -16,7 +16,7 @@ from . import logging
 from . import utils
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
                     EarlyStop, norm_path, is_pic, MyEncoder, DisablePrint,
-                    Timer)
+                    Timer, to_data_parallel, scheduler_step)
 from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress

+ 32 - 1
paddlers/utils/utils.py

@@ -20,11 +20,12 @@ import math
 import imghdr
 import chardet
 import json
+import platform
 
 import numpy as np
+import paddle
 
 from . import logging
-import platform
 import paddlers
 
 
@@ -237,3 +238,33 @@ class Timer(Times):
         self.postprocess_time_s.reset()
         self.img_num = 0
         self.repeats = 0
+
+
+def to_data_parallel(layers, *args, **kwargs):
+    from paddlers.tasks.utils.res_adapters import GANAdapter
+    if isinstance(layers, GANAdapter):
+        # Inplace modification for efficiency
+        layers.generators = [
+            paddle.DataParallel(g, *args, **kwargs) for g in layers.generators
+        ]
+        layers.discriminators = [
+            paddle.DataParallel(d, *args, **kwargs)
+            for d in layers.discriminators
+        ]
+    else:
+        layers = paddle.DataParallel(layers, *args, **kwargs)
+    return layers
+
+
+def scheduler_step(optimizer):
+    from paddlers.tasks.utils.res_adapters import OptimizerAdapter
+    if not isinstance(optimizer, OptimizerAdapter):
+        optimizer = [optimizer]
+    for optim in optimizer:
+        if isinstance(optim._learning_rate, paddle.optimizer.lr.LRScheduler):
+            # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
+            if isinstance(optim._learning_rate,
+                          paddle.optimizer.lr.ReduceOnPlateau):
+                optim._learning_rate.step(loss.item())
+            else:
+                optim._learning_rate.step()

+ 28 - 26
tests/data/data_utils.py

@@ -14,7 +14,6 @@
 
 import os.path as osp
 import re
-import imghdr
 import platform
 from collections import OrderedDict
 from functools import partial, wraps
@@ -34,20 +33,6 @@ def norm_path(path):
     return path
 
 
-def is_pic(im_path):
-    valid_suffix = [
-        'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy'
-    ]
-    suffix = im_path.split('.')[-1]
-    if suffix in valid_suffix:
-        return True
-    im_format = imghdr.what(im_path)
-    _, ext = osp.splitext(im_path)
-    if im_format == 'tiff' or ext == '.img':
-        return True
-    return False
-
-
 def get_full_path(p, prefix=''):
     p = norm_path(p)
     return osp.join(prefix, p)
@@ -323,15 +308,34 @@ class ConstrDetSample(ConstrSample):
         return samples
 
 
-def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
+class ConstrResSample(ConstrSample):
+    def __init__(self, prefix, label_list, sr_factor=None):
+        super().__init__(prefix, label_list)
+        self.sr_factor = sr_factor
+
+    def __call__(self, src_path, tar_path):
+        sample = {
+            'image': self.get_full_path(src_path),
+            'target': self.get_full_path(tar_path)
+        }
+        if self.sr_factor is not None:
+            sample['sr_factor'] = self.sr_factor
+        return sample
+
+
+def build_input_from_file(file_list,
+                          prefix='',
+                          task='auto',
+                          label_list=None,
+                          **kwargs):
     """
     Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
 
     Args:
-        file_list (str): Path of file_list.
+        file_list (str): Path of file list.
         prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
-        task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input. 
-            Default: 'auto'.
+        task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', 'res', and 'auto'. When `task` is set to 'auto', 
+            automatically determine the task based on the input. Default: 'auto'.
         label_list (str|None, optional): Path of label_list. Default: None.
 
     Returns:
@@ -339,22 +343,21 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
     """
 
     def _determine_task(parts):
+        task = 'unknown'
         if len(parts) in (3, 5):
             task = 'cd'
         elif len(parts) == 2:
             if parts[1].isdigit():
                 task = 'clas'
-            elif is_pic(osp.join(prefix, parts[1])):
-                task = 'seg'
-            else:
+            elif parts[1].endswith('.xml'):
                 task = 'det'
-        else:
+        if task == 'unknown':
             raise RuntimeError(
                 "Cannot automatically determine the task type. Please specify `task` manually."
             )
         return task
 
-    if task not in ('seg', 'det', 'cd', 'clas', 'auto'):
+    if task not in ('seg', 'det', 'cd', 'clas', 'res', 'auto'):
         raise ValueError("Invalid value of `task`")
 
     samples = []
@@ -366,9 +369,8 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
             if task == 'auto':
                 task = _determine_task(parts)
             if ctor is None:
-                # Select and build sample constructor
                 ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
-                ctor = ctor_class(prefix, label_list)
+                ctor = ctor_class(prefix, label_list, **kwargs)
             sample = ctor(*parts)
             if isinstance(sample, list):
                 samples.extend(sample)

+ 60 - 7
tests/deploy/test_predictor.py

@@ -24,7 +24,7 @@ from testing_utils import CommonTest, run_script
 
 __all__ = [
     'TestCDPredictor', 'TestClasPredictor', 'TestDetPredictor',
-    'TestSegPredictor'
+    'TestResPredictor', 'TestSegPredictor'
 ]
 
 
@@ -105,7 +105,7 @@ class TestPredictor(CommonTest):
                     dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
 
 
-@TestPredictor.add_tests
+# @TestPredictor.add_tests
 class TestCDPredictor(TestPredictor):
     MODULE = pdrs.tasks.change_detector
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -177,7 +177,7 @@ class TestCDPredictor(TestPredictor):
         self.assertEqual(len(out_multi_array_t), num_inputs)
 
 
-@TestPredictor.add_tests
+# @TestPredictor.add_tests
 class TestClasPredictor(TestPredictor):
     MODULE = pdrs.tasks.classifier
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -185,7 +185,7 @@ class TestClasPredictor(TestPredictor):
     }
 
     def check_predictor(self, predictor, trainer):
-        single_input = "data/ssmt/optical_t1.bmp"
+        single_input = "data/ssst/optical.bmp"
         num_inputs = 2
         transforms = pdrs.transforms.Compose([
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
@@ -242,7 +242,7 @@ class TestClasPredictor(TestPredictor):
         self.check_dict_equal(out_multi_array_p, out_multi_array_t)
 
 
-@TestPredictor.add_tests
+# @TestPredictor.add_tests
 class TestDetPredictor(TestPredictor):
     MODULE = pdrs.tasks.object_detector
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -253,7 +253,7 @@ class TestDetPredictor(TestPredictor):
         # For detection tasks, do NOT ensure the consistence of bboxes.
         # This is because the coordinates of bboxes were observed to be very sensitive to numeric errors, 
         # given that the network is (partially?) randomly initialized.
-        single_input = "data/ssmt/optical_t1.bmp"
+        single_input = "data/ssst/optical.bmp"
         num_inputs = 2
         transforms = pdrs.transforms.Compose([
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
@@ -303,6 +303,59 @@ class TestDetPredictor(TestPredictor):
 
 
 @TestPredictor.add_tests
+class TestResPredictor(TestPredictor):
+    MODULE = pdrs.tasks.restorer
+
+    def check_predictor(self, predictor, trainer):
+        # For restoration tasks, do NOT ensure the consistence of numeric values, 
+        # because the output is of uint8 type.
+        single_input = "data/ssst/optical.bmp"
+        num_inputs = 2
+        transforms = pdrs.transforms.Compose([
+            pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
+            pdrs.transforms.ArrangeRestorer('test')
+        ])
+
+        # Single input (file path)
+        input_ = single_input
+        predictor.predict(input_, transforms=transforms)
+        trainer.predict(input_, transforms=transforms)
+        out_single_file_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_p), 1)
+        out_single_file_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_t), 1)
+
+        # Single input (ndarray)
+        input_ = decode_image(
+            single_input, to_rgb=False)  # Reuse the name `input_`
+        predictor.predict(input_, transforms=transforms)
+        trainer.predict(input_, transforms=transforms)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_t), 1)
+
+        # Multiple inputs (file paths)
+        input_ = [single_input] * num_inputs  # Reuse the name `input_`
+        out_multi_file_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), num_inputs)
+        out_multi_file_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_t), num_inputs)
+
+        # Multiple inputs (ndarrays)
+        input_ = [decode_image(
+            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+        out_multi_array_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), num_inputs)
+        out_multi_array_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_t), num_inputs)
+
+
+# @TestPredictor.add_tests
 class TestSegPredictor(TestPredictor):
     MODULE = pdrs.tasks.segmenter
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -310,7 +363,7 @@ class TestSegPredictor(TestPredictor):
     }
 
     def check_predictor(self, predictor, trainer):
-        single_input = "data/ssmt/optical_t1.bmp"
+        single_input = "data/ssst/optical.bmp"
         num_inputs = 2
         transforms = pdrs.transforms.Compose([
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),

+ 1 - 3
tests/rs_models/test_cd_models.py

@@ -33,9 +33,7 @@ class TestCDModel(TestModel):
         self.check_output_equal(len(output), len(target))
         for o, t in zip(output, target):
             o = o.numpy()
-            self.check_output_equal(o.shape[0], t.shape[0])
-            self.check_output_equal(len(o.shape), 4)
-            self.check_output_equal(o.shape[2:], t.shape[2:])
+            self.check_output_equal(o.shape, t.shape)
 
     def set_inputs(self):
         if self.EF_MODE == 'Concat':

+ 3 - 0
tests/rs_models/test_det_models.py

@@ -32,3 +32,6 @@ class TestDetModel(TestModel):
 
     def set_inputs(self):
         self.inputs = cycle([self.get_randn_tensor(3)])
+
+    def set_targets(self):
+        self.targets = cycle([None])

+ 46 - 0
tests/rs_models/test_res_models.py

@@ -0,0 +1,46 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddlers
+from rs_models.test_model import TestModel
+
+__all__ = []
+
+
+class TestResModel(TestModel):
+    def check_output(self, output, target):
+        output = output.numpy()
+        self.check_output_equal(output.shape, target.shape)
+
+    def set_inputs(self):
+        def _gen_data(specs):
+            for spec in specs:
+                c = spec.get('in_channels', 3)
+                yield self.get_randn_tensor(c)
+
+        self.inputs = _gen_data(self.specs)
+
+    def set_targets(self):
+        def _gen_data(specs):
+            for spec in specs:
+                # XXX: Hard coding
+                if 'out_channels' in spec:
+                    c = spec['out_channels']
+                elif 'in_channels' in spec:
+                    c = spec['in_channels']
+                else:
+                    c = 3
+                yield [self.get_zeros_array(c)]
+
+        self.targets = _gen_data(self.specs)

+ 5 - 3
tests/rs_models/test_seg_models.py

@@ -26,9 +26,7 @@ class TestSegModel(TestModel):
         self.check_output_equal(len(output), len(target))
         for o, t in zip(output, target):
             o = o.numpy()
-            self.check_output_equal(o.shape[0], t.shape[0])
-            self.check_output_equal(len(o.shape), 4)
-            self.check_output_equal(o.shape[2:], t.shape[2:])
+            self.check_output_equal(o.shape, t.shape)
 
     def set_inputs(self):
         def _gen_data(specs):
@@ -54,3 +52,7 @@ class TestFarSegModel(TestSegModel):
         self.specs = [
             dict(), dict(num_classes=20), dict(encoder_pretrained=False)
         ]
+
+    def set_targets(self):
+        self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(20)],
+                        [self.get_zeros_array(16)]]

+ 43 - 0
tests/transforms/test_operators.py

@@ -164,12 +164,15 @@ class TestTransform(CpuCommonTest):
                 prefix="./data/ssst"),
             build_input_from_file(
                 "data/ssst/test_optical_seg.txt",
+                task='seg',
                 prefix="./data/ssst"),
             build_input_from_file(
                 "data/ssst/test_sar_seg.txt",
+                task='seg',
                 prefix="./data/ssst"),
             build_input_from_file(
                 "data/ssst/test_multispectral_seg.txt",
+                task='seg',
                 prefix="./data/ssst"),
             build_input_from_file(
                 "data/ssst/test_optical_det.txt",
@@ -185,7 +188,23 @@ class TestTransform(CpuCommonTest):
                 label_list="data/ssst/labels_det.txt"),
             build_input_from_file(
                 "data/ssst/test_det_coco.txt",
+                task='det',
                 prefix="./data/ssst"),
+            build_input_from_file(
+                "data/ssst/test_optical_res.txt",
+                task='res',
+                prefix="./data/ssst",
+                sr_factor=4),
+            build_input_from_file(
+                "data/ssst/test_sar_res.txt",
+                task='res',
+                prefix="./data/ssst",
+                sr_factor=4),
+            build_input_from_file(
+                "data/ssst/test_multispectral_res.txt",
+                task='res',
+                prefix="./data/ssst",
+                sr_factor=4),
             build_input_from_file(
                 "data/ssmt/test_mixed_binary.txt",
                 prefix="./data/ssmt"),
@@ -227,6 +246,8 @@ class TestTransform(CpuCommonTest):
                 self.aux_mask_values = [
                     set(aux_mask.ravel()) for aux_mask in sample['aux_masks']
                 ]
+            if 'target' in sample:
+                self.target_shape = sample['target'].shape
             return sample
 
         def _out_hook_not_keep_ratio(sample):
@@ -243,6 +264,21 @@ class TestTransform(CpuCommonTest):
                 for aux_mask, amv in zip(sample['aux_masks'],
                                          self.aux_mask_values):
                     self.assertLessEqual(set(aux_mask.ravel()), amv)
+            if 'target' in sample:
+                if 'sr_factor' in sample:
+                    self.check_output_equal(
+                        sample['target'].shape[:2],
+                        T.functions.calc_hr_shape(TARGET_SIZE,
+                                                  sample['sr_factor']))
+                else:
+                    self.check_output_equal(sample['target'].shape[:2],
+                                            TARGET_SIZE)
+                self.check_output_equal(
+                    sample['target'].shape[0] / self.target_shape[0],
+                    sample['image'].shape[0] / self.image_shape[0])
+                self.check_output_equal(
+                    sample['target'].shape[1] / self.target_shape[1],
+                    sample['image'].shape[1] / self.image_shape[1])
             # TODO: Test gt_bbox and gt_poly
             return sample
 
@@ -260,6 +296,13 @@ class TestTransform(CpuCommonTest):
                 for aux_mask, ori_aux_mask_shape in zip(sample['aux_masks'],
                                                         self.aux_mask_shapes):
                     __check_ratio(aux_mask.shape, ori_aux_mask_shape)
+            if 'target' in sample:
+                self.check_output_equal(
+                    sample['target'].shape[0] / self.target_shape[0],
+                    sample['image'].shape[0] / self.image_shape[0])
+                self.check_output_equal(
+                    sample['target'].shape[1] / self.target_shape[1],
+                    sample['image'].shape[1] / self.image_shape[1])
             # TODO: Test gt_bbox and gt_poly
             return sample
 

+ 5 - 5
tutorials/train/README.md

@@ -9,17 +9,17 @@
 |change_detection/changeformer.py | 变化检测 | ChangeFormer |
 |change_detection/dsamnet.py | 变化检测 | DSAMNet |
 |change_detection/dsifn.py | 变化检测 | DSIFN |
-|change_detection/snunet.py | 变化检测 | SNUNet |
-|change_detection/stanet.py | 变化检测 | STANet |
 |change_detection/fc_ef.py | 变化检测 | FC-EF |
 |change_detection/fc_siam_conc.py | 变化检测 | FC-Siam-conc |
 |change_detection/fc_siam_diff.py | 变化检测 | FC-Siam-diff |
+|change_detection/snunet.py | 变化检测 | SNUNet |
+|change_detection/stanet.py | 变化检测 | STANet |
 |classification/hrnet.py | 场景分类 | HRNet |
 |classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
 |classification/resnet50_vd.py | 场景分类 | ResNet50-vd |
-|image_restoration/drn.py | 超分辨率 | DRN |
-|image_restoration/esrgan.py | 超分辨率 | ESRGAN |
-|image_restoration/lesrcnn.py | 超分辨率 | LESRCNN |
+|image_restoration/drn.py | 图像复原 | DRN |
+|image_restoration/esrgan.py | 图像复原 | ESRGAN |
+|image_restoration/lesrcnn.py | 图像复原 | LESRCNN |
 |object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN |
 |object_detection/ppyolo.py | 目标检测 | PP-YOLO |
 |object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny |

+ 1 - 1
tutorials/train/change_detection/changeformer.py

@@ -72,7 +72,7 @@ eval_dataset = pdrs.datasets.CDDataset(
     binarize_labels=True)
 
 # 使用默认参数构建ChangeFormer模型
-# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.cd.ChangeFormer()
 

+ 1 - 1
tutorials/train/classification/hrnet.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     shuffle=False)
 
-# 使用默认参数构建HRNet模型
+# 构建HRNet模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.HRNet_W18_C(num_classes=len(train_dataset.labels))

+ 1 - 1
tutorials/train/classification/mobilenetv3.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     shuffle=False)
 
-# 使用默认参数构建MobileNetV3模型
+# 构建MobileNetV3模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.MobileNetV3_small_x1_0(

+ 1 - 1
tutorials/train/classification/resnet50_vd.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     shuffle=False)
 
-# 使用默认参数构建ResNet50-vd模型
+# 构建ResNet50-vd模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.ResNet50_vd(num_classes=len(train_dataset.labels))

+ 3 - 0
tutorials/train/image_restoration/data/.gitignore

@@ -0,0 +1,3 @@
+*.zip
+*.tar.gz
+rssr/

+ 89 - 0
tutorials/train/image_restoration/drn.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+
+# 图像复原模型DRN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/drn/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 从输入影像中裁剪96x96大小的影像块
+    T.RandomCrop(crop_size=96),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=4)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=4)
+
+# 使用默认参数构建DRN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.DRN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 80
tutorials/train/image_restoration/drn_train.py

@@ -1,80 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddle
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'lqx2', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 192,
-        'scale': 4,
-        'scale_list': True
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [1.0, 1.0, 1.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [1.0, 1.0, 1.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 4
-batch_size = 8
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-train_dict = train_dataset()
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-model = pdrs.tasks.res.DRNet(
-    n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2)
-
-model.train(
-    total_iters=100000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    lr_rate=0.0001,
-    log=10)

+ 89 - 0
tutorials/train/image_restoration/esrgan.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+
+# 图像复原模型ESRGAN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/esrgan/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 从输入影像中裁剪32x32大小的影像块
+    T.RandomCrop(crop_size=32),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=4)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=4)
+
+# 使用默认参数构建ESRGAN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.ESRGAN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 80
tutorials/train/image_restoration/esrgan_train.py

@@ -1,80 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 128,
-        'scale': 4
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 6
-batch_size = 32
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
-# 若loss_type = 'pixel' 只使用像素损失
-model = pdrs.tasks.res.ESRGANet(loss_type='pixel')
-
-model.train(
-    total_iters=1000000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    log=100,
-    lr_rate=0.0001,
-    periods=[250000, 250000, 250000, 250000],
-    restart_weights=[1, 1, 1, 1])

+ 89 - 0
tutorials/train/image_restoration/lesrcnn.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+
+# 图像复原模型LESRCNN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/lesrcnn/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 从输入影像中裁剪32x32大小的影像块
+    T.RandomCrop(crop_size=32),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=4)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=4)
+
+# 使用默认参数构建LESRCNN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.LESRCNN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 78
tutorials/train/image_restoration/lesrcnn_train.py

@@ -1,78 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 192,
-        'scale': 4
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 4
-batch_size = 16
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-model = pdrs.tasks.res.LESRCNNet(scale=4, multi_scale=False, group=1)
-
-model.train(
-    total_iters=1000000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    log=100,
-    lr_rate=0.0001,
-    periods=[250000, 250000, 250000, 250000],
-    restart_weights=[1, 1, 1, 1])