Browse Source

Merge pull request #24 from Bobholamovic/refactor_res

[Refactor] Refactor Models for Restoration Tasks
cc 2 years ago
parent
commit
6b02cf875f
48 changed files with 2038 additions and 1460 deletions
  1. 4 0
      docs/apis/data.md
  2. 7 3
      docs/apis/infer.md
  3. 11 1
      docs/apis/train.md
  4. 3 3
      docs/dev/dev_guide.md
  5. 1 1
      paddlers/datasets/__init__.py
  6. 5 5
      paddlers/datasets/cd_dataset.py
  7. 83 0
      paddlers/datasets/res_dataset.py
  8. 1 1
      paddlers/datasets/seg_dataset.py
  9. 0 99
      paddlers/datasets/sr_dataset.py
  10. 18 1
      paddlers/deploy/predictor.py
  11. 1 0
      paddlers/models/__init__.py
  12. 1 1
      paddlers/models/ppdet/metrics/json_results.py
  13. 11 10
      paddlers/rs_models/res/generators/param_init.py
  14. 34 17
      paddlers/rs_models/res/generators/rcan.py
  15. 0 106
      paddlers/rs_models/res/rcan_model.py
  16. 1 1
      paddlers/tasks/__init__.py
  17. 24 32
      paddlers/tasks/base.py
  18. 18 11
      paddlers/tasks/change_detector.py
  19. 64 41
      paddlers/tasks/classifier.py
  20. 0 786
      paddlers/tasks/image_restorer.py
  21. 36 15
      paddlers/tasks/object_detector.py
  22. 934 0
      paddlers/tasks/restorer.py
  23. 17 12
      paddlers/tasks/segmenter.py
  24. 28 11
      paddlers/tasks/utils/infer_nets.py
  25. 132 0
      paddlers/tasks/utils/res_adapters.py
  26. 4 0
      paddlers/transforms/functions.py
  27. 102 15
      paddlers/transforms/operators.py
  28. 1 1
      paddlers/utils/__init__.py
  29. 32 1
      paddlers/utils/utils.py
  30. 28 26
      tests/data/data_utils.py
  31. 60 7
      tests/deploy/test_predictor.py
  32. 1 3
      tests/rs_models/test_cd_models.py
  33. 3 0
      tests/rs_models/test_det_models.py
  34. 46 0
      tests/rs_models/test_res_models.py
  35. 5 3
      tests/rs_models/test_seg_models.py
  36. 43 0
      tests/transforms/test_operators.py
  37. 5 5
      tutorials/train/README.md
  38. 1 1
      tutorials/train/change_detection/changeformer.py
  39. 1 1
      tutorials/train/classification/hrnet.py
  40. 1 1
      tutorials/train/classification/mobilenetv3.py
  41. 1 1
      tutorials/train/classification/resnet50_vd.py
  42. 3 0
      tutorials/train/image_restoration/data/.gitignore
  43. 89 0
      tutorials/train/image_restoration/drn.py
  44. 0 80
      tutorials/train/image_restoration/drn_train.py
  45. 89 0
      tutorials/train/image_restoration/esrgan.py
  46. 0 80
      tutorials/train/image_restoration/esrgan_train.py
  47. 89 0
      tutorials/train/image_restoration/lesrcnn.py
  48. 0 78
      tutorials/train/image_restoration/lesrcnn_train.py

+ 4 - 0
docs/apis/data.md

@@ -84,6 +84,9 @@
 
 
 - file list中的每一行应该包含2个以空格分隔的项,依次表示输入影像相对`data_dir`的路径以及[Pascal VOC格式](http://host.robots.ox.ac.uk/pascal/VOC/)标注文件相对`data_dir`的路径。
 - file list中的每一行应该包含2个以空格分隔的项,依次表示输入影像相对`data_dir`的路径以及[Pascal VOC格式](http://host.robots.ox.ac.uk/pascal/VOC/)标注文件相对`data_dir`的路径。
 
 
+### 图像复原数据集`ResDataset`
+
+
 ### 图像分割数据集`SegDataset`
 ### 图像分割数据集`SegDataset`
 
 
 `SegDataset`定义在:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/seg_dataset.py
 `SegDataset`定义在:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/seg_dataset.py
@@ -143,6 +146,7 @@
 |`'aux_masks'`|图像分割/变化检测任务中的辅助标签路径或数据。|
 |`'aux_masks'`|图像分割/变化检测任务中的辅助标签路径或数据。|
 |`'gt_bbox'`|目标检测任务中的检测框标注数据。|
 |`'gt_bbox'`|目标检测任务中的检测框标注数据。|
 |`'gt_poly'`|目标检测任务中的多边形标注数据。|
 |`'gt_poly'`|目标检测任务中的多边形标注数据。|
+|`'target'`|图像复原中的目标影像路径或数据。|
 
 
 ### 组合数据变换算子
 ### 组合数据变换算子
 
 

+ 7 - 3
docs/apis/infer.md

@@ -26,7 +26,7 @@ def predict(self, img_file, transforms=None):
 若`img_file`是一个元组,则返回对象为包含下列键值对的字典:
 若`img_file`是一个元组,则返回对象为包含下列键值对的字典:
 
 
 ```
 ```
-{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
+{"label_map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
 ```
 ```
 
 
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
@@ -51,7 +51,7 @@ def predict(self, img_file, transforms=None):
 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
 
 
 ```
 ```
-{"label map": 输出类别标签,
+{"label_map": 输出类别标签,
  "scores_map": 输出类别概率,
  "scores_map": 输出类别概率,
  "label_names_map": 输出类别名称}
  "label_names_map": 输出类别名称}
 ```
 ```
@@ -87,6 +87,10 @@ def predict(self, img_file, transforms=None):
 
 
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个由字典(键值对如上所示)构成的列表,顺序对应`img_file`中的每个元素。
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个由字典(键值对如上所示)构成的列表,顺序对应`img_file`中的每个元素。
 
 
+#### `BaseRestorer.predict()`
+
+
+
 #### `BaseSegmenter.predict()`
 #### `BaseSegmenter.predict()`
 
 
 接口形式:
 接口形式:
@@ -107,7 +111,7 @@ def predict(self, img_file, transforms=None):
 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
 
 
 ```
 ```
-{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
+{"label_map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
 ```
 ```
 
 
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。

+ 11 - 1
docs/apis/train.md

@@ -18,11 +18,15 @@
 - `use_mixed_loss`参将在未来被弃用,因此不建议使用。
 - `use_mixed_loss`参将在未来被弃用,因此不建议使用。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/clas)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py)。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/clas)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py)。
 
 
-### 初始化`Baseetector`子类对象
+### 初始化`BaseDetector`子类对象
 
 
 - 一般支持设置`num_classes`和`backbone`参数,分别表示模型输出类别数以及所用的骨干网络类型。相比其它任务,目标检测任务的训练器支持设置的初始化参数较多,囊括网络结构、损失函数、后处理策略等方面。
 - 一般支持设置`num_classes`和`backbone`参数,分别表示模型输出类别数以及所用的骨干网络类型。相比其它任务,目标检测任务的训练器支持设置的初始化参数较多,囊括网络结构、损失函数、后处理策略等方面。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/det)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/object_detector.py)。
 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/det)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/object_detector.py)。
 
 
+### 初始化`BaseRestorer`子类对象
+
+
+
 ### 初始化`BaseSegmenter`子类对象
 ### 初始化`BaseSegmenter`子类对象
 
 
 - 一般支持设置`in_channels`、`num_classes`以及`use_mixed_loss`参数,分别表示输入通道数、输出类别数以及是否使用预置的混合损失。部分模型如`FarSeg`暂不支持对`in_channels`参数的设置。
 - 一般支持设置`in_channels`、`num_classes`以及`use_mixed_loss`参数,分别表示输入通道数、输出类别数以及是否使用预置的混合损失。部分模型如`FarSeg`暂不支持对`in_channels`参数的设置。
@@ -170,6 +174,9 @@ def train(self,
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
 
 
+### `BaseRestorer.train()`
+
+
 ### `BaseSegmenter.train()`
 ### `BaseSegmenter.train()`
 
 
 接口形式:
 接口形式:
@@ -311,6 +318,9 @@ def evaluate(self,
  "mask": 预测得到的掩模图信息}
  "mask": 预测得到的掩模图信息}
 ```
 ```
 
 
+### `BaseRestorer.evaluate()`
+
+
 ### `BaseSegmenter.evaluate()`
 ### `BaseSegmenter.evaluate()`
 
 
 接口形式:
 接口形式:

+ 3 - 3
docs/dev/dev_guide.md

@@ -22,7 +22,7 @@
 
 
 在子目录中新建文件,以`{模型名称小写}.py`命名。在文件中编写完整的模型定义。
 在子目录中新建文件,以`{模型名称小写}.py`命名。在文件中编写完整的模型定义。
 
 
-新模型必须是`paddle.nn.Layer`的子类。对于图像分割、目标检测和场景分类任务,分别需要遵循[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)、[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)套件中制定的相关规范。**对于变化检测、场景分类和图像分割任务,模型构造时必须传入`num_classes`参数以指定输出的类别数目**对于变化检测任务,模型定义需遵循的规范与分割模型类似,但有以下几点不同:
+新模型必须是`paddle.nn.Layer`的子类。对于图像分割、目标检测、场景分类和图像复原任务,分别需要遵循[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)、[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)和[PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN)套件中制定的相关规范。**对于变化检测、场景分类和图像分割任务,模型构造时必须传入`num_classes`参数以指定输出的类别数目。对于图像复原任务,模型构造时必须传入`rs_factor`参数以指定超分辨率缩放倍数(对于非超分辨率模型,将此参数设置为`None`)。**对于变化检测任务,模型定义需遵循的规范与分割模型类似,但有以下几点不同:
 
 
 - `forward()`方法接受3个输入参数,分别是`self`、`t1`和`t2`,其中`t1`和`t2`分别表示前、后两个时相的输入影像。
 - `forward()`方法接受3个输入参数,分别是`self`、`t1`和`t2`,其中`t1`和`t2`分别表示前、后两个时相的输入影像。
 - 对于多任务变化检测模型(例如模型同时输出变化检测结果与两个时相的建筑物提取结果),需要指定类的`USE_MULTITASK_DECODER`属性为`True`,同时在`OUT_TYPES`属性中设置模型前向输出的列表中每一个元素对应的标签类型。可参考`ChangeStar`模型的定义。
 - 对于多任务变化检测模型(例如模型同时输出变化检测结果与两个时相的建筑物提取结果),需要指定类的`USE_MULTITASK_DECODER`属性为`True`,同时在`OUT_TYPES`属性中设置模型前向输出的列表中每一个元素对应的标签类型。可参考`ChangeStar`模型的定义。
@@ -64,7 +64,7 @@ Args:
 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。
 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。
 
 
 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下:
 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下:
-    - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失。
+    - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数
     - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。
     - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。
     - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。
     - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。
 
 
@@ -78,7 +78,7 @@ Args:
 
 
 ### 2.2 新增数据预处理/数据增强算子
 ### 2.2 新增数据预处理/数据增强算子
 
 
-在`paddlers/transforms/operators.py`中定义新算子,所有算子均继承自`paddlers.transforms.Transform`类。算子的`apply()`方法接收一个字典`sample`作为输入,取出其中存储的相关对象,处理后对字典进行in-place修改,最后返回修改后的字典。在定义算子时,只有极少数的情况需要重写`apply()`方法。大多数情况下,只需要重写`apply_im()`、`apply_mask()`、`apply_bbox()`和`apply_segm()`方法就分别可以实现对输入图像、分割标签、目标框以及目标多边形的处理。
+在`paddlers/transforms/operators.py`中定义新算子,所有算子均继承自`paddlers.transforms.Transform`类。算子的`apply()`方法接收一个字典`sample`作为输入,取出其中存储的相关对象,处理后对字典进行in-place修改,最后返回修改后的字典。在定义算子时,只有极少数的情况需要重写`apply()`方法。大多数情况下,只需要重写`apply_im()`、`apply_mask()`、`apply_bbox()`和`apply_segm()`方法就分别可以实现对图像、分割标签、目标框以及目标多边形的处理。
 
 
 如果处理逻辑较为复杂,建议先封装为函数,添加到`paddlers/transforms/functions.py`中,然后在算子的`apply*()`方法中调用函数。
 如果处理逻辑较为复杂,建议先封装为函数,添加到`paddlers/transforms/functions.py`中,然后在算子的`apply*()`方法中调用函数。
 
 

+ 1 - 1
paddlers/datasets/__init__.py

@@ -17,4 +17,4 @@ from .coco import COCODetDataset
 from .seg_dataset import SegDataset
 from .seg_dataset import SegDataset
 from .cd_dataset import CDDataset
 from .cd_dataset import CDDataset
 from .clas_dataset import ClasDataset
 from .clas_dataset import ClasDataset
-from .sr_dataset import SRdataset, ComposeTrans
+from .res_dataset import ResDataset

+ 5 - 5
paddlers/datasets/cd_dataset.py

@@ -95,23 +95,23 @@ class CDDataset(BaseDataset):
                                      full_path_label))):
                                      full_path_label))):
                     continue
                     continue
                 if not osp.exists(full_path_im_t1):
                 if not osp.exists(full_path_im_t1):
-                    raise IOError('Image file {} does not exist!'.format(
+                    raise IOError("Image file {} does not exist!".format(
                         full_path_im_t1))
                         full_path_im_t1))
                 if not osp.exists(full_path_im_t2):
                 if not osp.exists(full_path_im_t2):
-                    raise IOError('Image file {} does not exist!'.format(
+                    raise IOError("Image file {} does not exist!".format(
                         full_path_im_t2))
                         full_path_im_t2))
                 if not osp.exists(full_path_label):
                 if not osp.exists(full_path_label):
-                    raise IOError('Label file {} does not exist!'.format(
+                    raise IOError("Label file {} does not exist!".format(
                         full_path_label))
                         full_path_label))
 
 
                 if with_seg_labels:
                 if with_seg_labels:
                     full_path_seg_label_t1 = osp.join(data_dir, items[3])
                     full_path_seg_label_t1 = osp.join(data_dir, items[3])
                     full_path_seg_label_t2 = osp.join(data_dir, items[4])
                     full_path_seg_label_t2 = osp.join(data_dir, items[4])
                     if not osp.exists(full_path_seg_label_t1):
                     if not osp.exists(full_path_seg_label_t1):
-                        raise IOError('Label file {} does not exist!'.format(
+                        raise IOError("Label file {} does not exist!".format(
                             full_path_seg_label_t1))
                             full_path_seg_label_t1))
                     if not osp.exists(full_path_seg_label_t2):
                     if not osp.exists(full_path_seg_label_t2):
-                        raise IOError('Label file {} does not exist!'.format(
+                        raise IOError("Label file {} does not exist!".format(
                             full_path_seg_label_t2))
                             full_path_seg_label_t2))
 
 
                 item_dict = dict(
                 item_dict = dict(

+ 83 - 0
paddlers/datasets/res_dataset.py

@@ -0,0 +1,83 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import copy
+
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
+
+
+class ResDataset(BaseDataset):
+    """
+    Dataset for image restoration tasks.
+
+    Args:
+        data_dir (str): Root directory of the dataset.
+        file_list (str): Path of the file that contains relative paths of source and target image files.
+        transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
+        num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
+            the number of workers will be automatically determined according to the number of CPU cores: If 
+            there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half 
+            the number of CPU cores. Defaults: 'auto'.
+        shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
+        sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image 
+            restoration tasks. Defaults to None.
+    """
+
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 transforms,
+                 num_workers='auto',
+                 shuffle=False,
+                 sr_factor=None):
+        super(ResDataset, self).__init__(data_dir, None, transforms,
+                                         num_workers, shuffle)
+        self.batch_transforms = None
+        self.file_list = list()
+
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                items = line.strip().split()
+                if len(items) > 2:
+                    raise ValueError(
+                        "A space is defined as the delimiter to separate the source and target image path, " \
+                        "so the space cannot be in the source image or target image path, but the line[{}] of " \
+                        " file_list[{}] has a space in the two paths.".format(line, file_list))
+                items[0] = norm_path(items[0])
+                items[1] = norm_path(items[1])
+                full_path_im = osp.join(data_dir, items[0])
+                full_path_tar = osp.join(data_dir, items[1])
+                if not is_pic(full_path_im) or not is_pic(full_path_tar):
+                    continue
+                if not osp.exists(full_path_im):
+                    raise IOError("Source image file {} does not exist!".format(
+                        full_path_im))
+                if not osp.exists(full_path_tar):
+                    raise IOError("Target image file {} does not exist!".format(
+                        full_path_tar))
+                sample = {
+                    'image': full_path_im,
+                    'target': full_path_tar,
+                }
+                if sr_factor is not None:
+                    sample['sr_factor'] = sr_factor
+                self.file_list.append(sample)
+        self.num_samples = len(self.file_list)
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+
+    def __len__(self):
+        return len(self.file_list)

+ 1 - 1
paddlers/datasets/seg_dataset.py

@@ -44,7 +44,7 @@ class SegDataset(BaseDataset):
                  shuffle=False):
                  shuffle=False):
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
                                          num_workers, shuffle)
                                          num_workers, shuffle)
-        # TODO batch padding
+        # TODO: batch padding
         self.batch_transforms = None
         self.batch_transforms = None
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()

+ 0 - 99
paddlers/datasets/sr_dataset.py

@@ -1,99 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# 超分辨率数据集定义
-class SRdataset(object):
-    def __init__(self,
-                 mode,
-                 gt_floder,
-                 lq_floder,
-                 transforms,
-                 scale,
-                 num_workers=4,
-                 batch_size=8):
-        if mode == 'train':
-            preprocess = []
-            preprocess.append({
-                'name': 'LoadImageFromFile',
-                'key': 'lq'
-            })  # 加载方式
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
-            preprocess.append(transforms)  # 变换方式
-            self.dataset = {
-                'name': 'SRDataset',
-                'gt_folder': gt_floder,
-                'lq_folder': lq_floder,
-                'num_workers': num_workers,
-                'batch_size': batch_size,
-                'scale': scale,
-                'preprocess': preprocess
-            }
-
-        if mode == "test":
-            preprocess = []
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'})
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
-            preprocess.append(transforms)
-            self.dataset = {
-                'name': 'SRDataset',
-                'gt_folder': gt_floder,
-                'lq_folder': lq_floder,
-                'scale': scale,
-                'preprocess': preprocess
-            }
-
-    def __call__(self):
-        return self.dataset
-
-
-# 对定义的transforms处理方式组合,返回字典
-class ComposeTrans(object):
-    def __init__(self, input_keys, output_keys, pipelines):
-        if not isinstance(pipelines, list):
-            raise TypeError(
-                'Type of transforms is invalid. Must be List, but received is {}'
-                .format(type(pipelines)))
-        if len(pipelines) < 1:
-            raise ValueError(
-                'Length of transforms must not be less than 1, but received is {}'
-                .format(len(pipelines)))
-        self.transforms = pipelines
-        self.output_length = len(output_keys)  # 当output_keys的长度为3时,是DRN训练
-        self.input_keys = input_keys
-        self.output_keys = output_keys
-
-    def __call__(self):
-        pipeline = []
-        for op in self.transforms:
-            if op['name'] == 'SRPairedRandomCrop':
-                op['keys'] = ['image'] * 2
-            else:
-                op['keys'] = ['image'] * self.output_length
-            pipeline.append(op)
-        if self.output_length == 2:
-            transform_dict = {
-                'name': 'Transforms',
-                'input_keys': self.input_keys,
-                'pipeline': pipeline
-            }
-        else:
-            transform_dict = {
-                'name': 'Transforms',
-                'input_keys': self.input_keys,
-                'output_keys': self.output_keys,
-                'pipeline': pipeline
-            }
-
-        return transform_dict

+ 18 - 1
paddlers/deploy/predictor.py

@@ -163,13 +163,23 @@ class Predictor(object):
                 'image2': preprocessed_samples[1],
                 'image2': preprocessed_samples[1],
                 'ori_shape': preprocessed_samples[2]
                 'ori_shape': preprocessed_samples[2]
             }
             }
+        elif self._model.model_type == 'restorer':
+            preprocessed_samples = {
+                'image': preprocessed_samples[0],
+                'tar_shape': preprocessed_samples[1]
+            }
         else:
         else:
             logging.error(
             logging.error(
                 "Invalid model type {}".format(self._model.model_type),
                 "Invalid model type {}".format(self._model.model_type),
                 exit=True)
                 exit=True)
         return preprocessed_samples
         return preprocessed_samples
 
 
-    def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None):
+    def postprocess(self,
+                    net_outputs,
+                    topk=1,
+                    ori_shape=None,
+                    tar_shape=None,
+                    transforms=None):
         if self._model.model_type == 'classifier':
         if self._model.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
             if self._model.postprocess is None:
@@ -201,6 +211,12 @@ class Predictor(object):
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
             }
             }
             preds = self._model.postprocess(net_outputs)
             preds = self._model.postprocess(net_outputs)
+        elif self._model.model_type == 'restorer':
+            res_maps = self._model.postprocess(
+                net_outputs[0],
+                batch_tar_shape=tar_shape,
+                transforms=transforms.transforms)
+            preds = [{'res_map': res_map} for res_map in res_maps]
         else:
         else:
             logging.error(
             logging.error(
                 "Invalid model type {}.".format(self._model.model_type),
                 "Invalid model type {}.".format(self._model.model_type),
@@ -244,6 +260,7 @@ class Predictor(object):
             net_outputs,
             net_outputs,
             topk,
             topk,
             ori_shape=preprocessed_input.get('ori_shape', None),
             ori_shape=preprocessed_input.get('ori_shape', None),
+            tar_shape=preprocessed_input.get('tar_shape', None),
             transforms=transforms)
             transforms=transforms)
         self.timer.postprocess_time_s.end(iter_num=len(images))
         self.timer.postprocess_time_s.end(iter_num=len(images))
 
 

+ 1 - 0
paddlers/models/__init__.py

@@ -16,3 +16,4 @@ from . import ppcls, ppdet, ppseg, ppgan
 import paddlers.models.ppseg.models.losses as seg_losses
 import paddlers.models.ppseg.models.losses as seg_losses
 import paddlers.models.ppdet.modeling.losses as det_losses
 import paddlers.models.ppdet.modeling.losses as det_losses
 import paddlers.models.ppcls.loss as clas_losses
 import paddlers.models.ppcls.loss as clas_losses
+import paddlers.models.ppgan.models.criterions as res_losses

+ 1 - 1
paddlers/models/ppdet/metrics/json_results.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from .rcan_model import RCANModel
+from .generators import *

+ 11 - 10
paddlers/rs_models/res/generators/builder.py → paddlers/rs_models/res/generators/param_init.py

@@ -1,10 +1,10 @@
-#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 #
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 # You may obtain a copy of the License at
 #
 #
-#     http://www.apache.org/licenses/LICENSE-2.0
+#    http://www.apache.org/licenses/LICENSE-2.0
 #
 #
 # Unless required by applicable law or agreed to in writing, software
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,15 +12,16 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import copy
+import paddle
+import paddle.nn as nn
 
 
-from ....models.ppgan.utils.registry import Registry
+from paddlers.models.ppgan.modules.init import reset_parameters
 
 
-GENERATORS = Registry("GENERATOR")
 
 
+def init_sr_weight(net):
+    def reset_func(m):
+        if hasattr(m, 'weight') and (
+                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
+            reset_parameters(m)
 
 
-def build_generator(cfg):
+    net.apply(reset_func)
-    cfg_copy = copy.deepcopy(cfg)
-    name = cfg_copy.pop('name')
-    generator = GENERATORS.get(name)(**cfg_copy)
-    return generator

+ 34 - 17
paddlers/rs_models/res/generators/rcan.py

@@ -1,10 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 # Based on https://github.com/kongdebug/RCAN-Paddle
 # Based on https://github.com/kongdebug/RCAN-Paddle
+
 import math
 import math
 
 
 import paddle
 import paddle
 import paddle.nn as nn
 import paddle.nn as nn
 
 
-from .builder import GENERATORS
+from .param_init import init_sr_weight
 
 
 
 
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
@@ -63,8 +78,10 @@ class RCAB(nn.Layer):
                  bias=True,
                  bias=True,
                  bn=False,
                  bn=False,
                  act=nn.ReLU(),
                  act=nn.ReLU(),
-                 res_scale=1):
+                 res_scale=1,
+                 use_init_weight=False):
         super(RCAB, self).__init__()
         super(RCAB, self).__init__()
+
         modules_body = []
         modules_body = []
         for i in range(2):
         for i in range(2):
             modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
             modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
@@ -74,6 +91,9 @@ class RCAB(nn.Layer):
         self.body = nn.Sequential(*modules_body)
         self.body = nn.Sequential(*modules_body)
         self.res_scale = res_scale
         self.res_scale = res_scale
 
 
+        if use_init_weight:
+            init_sr_weight(self)
+
     def forward(self, x):
     def forward(self, x):
         res = self.body(x)
         res = self.body(x)
         res += x
         res += x
@@ -128,21 +148,19 @@ class Upsampler(nn.Sequential):
         super(Upsampler, self).__init__(*m)
         super(Upsampler, self).__init__(*m)
 
 
 
 
-@GENERATORS.register()
 class RCAN(nn.Layer):
 class RCAN(nn.Layer):
-    def __init__(
+    def __init__(self,
-            self,
+                 sr_factor=4,
-            scale,
+                 n_resgroups=10,
-            n_resgroups,
+                 n_resblocks=20,
-            n_resblocks,
+                 n_feats=64,
-            n_feats=64,
+                 n_colors=3,
-            n_colors=3,
+                 rgb_range=255,
-            rgb_range=255,
+                 kernel_size=3,
-            kernel_size=3,
+                 reduction=16,
-            reduction=16,
+                 conv=default_conv):
-            conv=default_conv, ):
         super(RCAN, self).__init__()
         super(RCAN, self).__init__()
-        self.scale = scale
+        self.scale = sr_factor
         act = nn.ReLU()
         act = nn.ReLU()
 
 
         n_resgroups = n_resgroups
         n_resgroups = n_resgroups
@@ -150,7 +168,6 @@ class RCAN(nn.Layer):
         n_feats = n_feats
         n_feats = n_feats
         kernel_size = kernel_size
         kernel_size = kernel_size
         reduction = reduction
         reduction = reduction
-        scale = scale
         act = nn.ReLU()
         act = nn.ReLU()
 
 
         rgb_mean = (0.4488, 0.4371, 0.4040)
         rgb_mean = (0.4488, 0.4371, 0.4040)
@@ -171,7 +188,7 @@ class RCAN(nn.Layer):
         # Define tail module
         # Define tail module
         modules_tail = [
         modules_tail = [
             Upsampler(
             Upsampler(
-                conv, scale, n_feats, act=False),
+                conv, self.scale, n_feats, act=False),
             conv(n_feats, n_colors, kernel_size)
             conv(n_feats, n_colors, kernel_size)
         ]
         ]
 
 

+ 0 - 106
paddlers/rs_models/res/rcan_model.py

@@ -1,106 +0,0 @@
-#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import paddle
-import paddle.nn as nn
-
-from .generators.builder import build_generator
-from ...models.ppgan.models.criterions.builder import build_criterion
-from ...models.ppgan.models.base_model import BaseModel
-from ...models.ppgan.models.builder import MODELS
-from ...models.ppgan.utils.visual import tensor2img
-from ...models.ppgan.modules.init import reset_parameters
-
-
-@MODELS.register()
-class RCANModel(BaseModel):
-    """
-    Base SR model for single image super-resolution.
-    """
-
-    def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
-        """
-        Args:
-            generator (dict): config of generator.
-            pixel_criterion (dict): config of pixel criterion.
-        """
-        super(RCANModel, self).__init__()
-
-        self.nets['generator'] = build_generator(generator)
-        self.error_last = 1e8
-        self.batch = 0
-        if pixel_criterion:
-            self.pixel_criterion = build_criterion(pixel_criterion)
-        if use_init_weight:
-            init_sr_weight(self.nets['generator'])
-
-    def setup_input(self, input):
-        self.lq = paddle.to_tensor(input['lq'])
-        self.visual_items['lq'] = self.lq
-        if 'gt' in input:
-            self.gt = paddle.to_tensor(input['gt'])
-            self.visual_items['gt'] = self.gt
-        self.image_paths = input['lq_path']
-
-    def forward(self):
-        pass
-
-    def train_iter(self, optims=None):
-        optims['optim'].clear_grad()
-
-        self.output = self.nets['generator'](self.lq)
-        self.visual_items['output'] = self.output
-        # pixel loss
-        loss_pixel = self.pixel_criterion(self.output, self.gt)
-        self.losses['loss_pixel'] = loss_pixel
-
-        skip_threshold = 1e6
-
-        if loss_pixel.item() < skip_threshold * self.error_last:
-            loss_pixel.backward()
-            optims['optim'].step()
-        else:
-            print('Skip this batch {}! (Loss: {})'.format(self.batch + 1,
-                                                          loss_pixel.item()))
-        self.batch += 1
-
-        if self.batch % 1000 == 0:
-            self.error_last = loss_pixel.item() / 1000
-            print("update error_last:{}".format(self.error_last))
-
-    def test_iter(self, metrics=None):
-        self.nets['generator'].eval()
-        with paddle.no_grad():
-            self.output = self.nets['generator'](self.lq)
-            self.visual_items['output'] = self.output
-        self.nets['generator'].train()
-
-        out_img = []
-        gt_img = []
-        for out_tensor, gt_tensor in zip(self.output, self.gt):
-            out_img.append(tensor2img(out_tensor, (0., 255.)))
-            gt_img.append(tensor2img(gt_tensor, (0., 255.)))
-
-        if metrics is not None:
-            for metric in metrics.values():
-                metric.update(out_img, gt_img)
-
-
-def init_sr_weight(net):
-    def reset_func(m):
-        if hasattr(m, 'weight') and (
-                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
-            reset_parameters(m)
-
-    net.apply(reset_func)

+ 1 - 1
paddlers/tasks/__init__.py

@@ -16,7 +16,7 @@ import paddlers.tasks.object_detector as detector
 import paddlers.tasks.segmenter as segmenter
 import paddlers.tasks.segmenter as segmenter
 import paddlers.tasks.change_detector as change_detector
 import paddlers.tasks.change_detector as change_detector
 import paddlers.tasks.classifier as classifier
 import paddlers.tasks.classifier as classifier
-import paddlers.tasks.image_restorer as restorer
+import paddlers.tasks.restorer as restorer
 from .load_model import load_model
 from .load_model import load_model
 
 
 # Shorter aliases
 # Shorter aliases

+ 24 - 32
paddlers/tasks/base.py

@@ -30,12 +30,11 @@ from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 
 
 import paddlers
 import paddlers
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
-from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
+from paddlers.utils import (
-                            get_pretrain_weights, load_pretrain_weights,
+    seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights,
-                            load_checkpoint, SmoothedValue, TrainingStats,
+    load_pretrain_weights, load_checkpoint, SmoothedValue, TrainingStats,
-                            _get_shared_memory_size_in_M, EarlyStop)
+    _get_shared_memory_size_in_M, EarlyStop, to_data_parallel, scheduler_step)
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
-from .utils.infer_nets import InferNet, InferCDNet
 
 
 
 
 class ModelMeta(type):
 class ModelMeta(type):
@@ -268,7 +267,7 @@ class BaseModel(metaclass=ModelMeta):
                 'The volume of dataset({}) must be larger than batch size({}).'
                 'The volume of dataset({}) must be larger than batch size({}).'
                 .format(dataset.num_samples, batch_size))
                 .format(dataset.num_samples, batch_size))
         batch_size_each_card = get_single_card_bs(batch_size=batch_size)
         batch_size_each_card = get_single_card_bs(batch_size=batch_size)
-        # TODO detection eval阶段需做判断
+
         batch_sampler = DistributedBatchSampler(
         batch_sampler = DistributedBatchSampler(
             dataset,
             dataset,
             batch_size=batch_size_each_card,
             batch_size=batch_size_each_card,
@@ -308,7 +307,7 @@ class BaseModel(metaclass=ModelMeta):
                    use_vdl=True):
                    use_vdl=True):
         self._check_transforms(train_dataset.transforms, 'train')
         self._check_transforms(train_dataset.transforms, 'train')
 
 
-        if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
+        if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
                 train_dataset.file_list):
             nranks = 1
             nranks = 1
         else:
         else:
@@ -321,10 +320,10 @@ class BaseModel(metaclass=ModelMeta):
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
-                ddp_net = paddle.DataParallel(
+                ddp_net = to_data_parallel(
                     self.net, find_unused_parameters=find_unused_parameters)
                     self.net, find_unused_parameters=find_unused_parameters)
             else:
             else:
-                ddp_net = paddle.DataParallel(
+                ddp_net = to_data_parallel(
                     self.net, find_unused_parameters=find_unused_parameters)
                     self.net, find_unused_parameters=find_unused_parameters)
 
 
         if use_vdl:
         if use_vdl:
@@ -365,24 +364,14 @@ class BaseModel(metaclass=ModelMeta):
 
 
             for step, data in enumerate(self.train_data_loader()):
             for step, data in enumerate(self.train_data_loader()):
                 if nranks > 1:
                 if nranks > 1:
-                    outputs = self.run(ddp_net, data, mode='train')
+                    outputs = self.train_step(step, data, ddp_net)
                 else:
                 else:
-                    outputs = self.run(self.net, data, mode='train')
+                    outputs = self.train_step(step, data, self.net)
-                loss = outputs['loss']
+
-                loss.backward()
+                scheduler_step(self.optimizer)
-                self.optimizer.step()
-                self.optimizer.clear_grad()
-                lr = self.optimizer.get_lr()
-                if isinstance(self.optimizer._learning_rate,
-                              paddle.optimizer.lr.LRScheduler):
-                    # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
-                    if isinstance(self.optimizer._learning_rate,
-                                  paddle.optimizer.lr.ReduceOnPlateau):
-                        self.optimizer._learning_rate.step(loss.item())
-                    else:
-                        self.optimizer._learning_rate.step()
 
 
                 train_avg_metrics.update(outputs)
                 train_avg_metrics.update(outputs)
+                lr = self.optimizer.get_lr()
                 outputs['lr'] = lr
                 outputs['lr'] = lr
                 if ema is not None:
                 if ema is not None:
                     ema.update(self.net)
                     ema.update(self.net)
@@ -622,14 +611,7 @@ class BaseModel(metaclass=ModelMeta):
         return pipeline_info
         return pipeline_info
 
 
     def _build_inference_net(self):
     def _build_inference_net(self):
-        if self.model_type in ('classifier', 'detector'):
+        raise NotImplementedError
-            infer_net = self.net
-        elif self.model_type == 'change_detector':
-            infer_net = InferCDNet(self.net)
-        else:
-            infer_net = InferNet(self.net, self.model_type)
-        infer_net.eval()
-        return infer_net
 
 
     def _export_inference_model(self, save_dir, image_shape=None):
     def _export_inference_model(self, save_dir, image_shape=None):
         self.test_inputs = self._get_test_inputs(image_shape)
         self.test_inputs = self._get_test_inputs(image_shape)
@@ -674,6 +656,16 @@ class BaseModel(metaclass=ModelMeta):
         logging.info("The inference model for deployment is saved in {}.".
         logging.info("The inference model for deployment is saved in {}.".
                      format(save_dir))
                      format(save_dir))
 
 
+    def train_step(self, step, data, net):
+        outputs = self.run(net, data, mode='train')
+
+        loss = outputs['loss']
+        loss.backward()
+        self.optimizer.step()
+        self.optimizer.clear_grad()
+
+        return outputs
+
     def _check_transforms(self, transforms, mode):
     def _check_transforms(self, transforms, mode):
         # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
         # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
         if not isinstance(transforms, paddlers.transforms.Compose):
         if not isinstance(transforms, paddlers.transforms.Compose):

+ 18 - 11
paddlers/tasks/change_detector.py

@@ -30,10 +30,11 @@ import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image
 from paddlers.transforms import Resize, decode_image
-from paddlers.utils import get_single_card_bs, DisablePrint
+from paddlers.utils import get_single_card_bs
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
+from .utils.infer_nets import InferCDNet
 
 
 __all__ = [
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -69,6 +70,11 @@ class BaseChangeDetector(BaseModel):
                                              **params)
                                              **params)
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = InferCDNet(self.net)
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
             if self.test_transforms is not None:
@@ -399,7 +405,8 @@ class BaseChangeDetector(BaseModel):
                 Defaults to False.
                 Defaults to False.
 
 
         Returns:
         Returns:
-            collections.OrderedDict with key-value pairs:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
                 For binary change detection (number of classes == 2), the key-value 
                 For binary change detection (number of classes == 2), the key-value 
                     pairs are like:
                     pairs are like:
                     {"iou": `intersection over union for the change class`,
                     {"iou": `intersection over union for the change class`,
@@ -527,12 +534,12 @@ class BaseChangeDetector(BaseModel):
 
 
         Returns:
         Returns:
             If `img_file` is a tuple of string or np.array, the result is a dict with 
             If `img_file` is a tuple of string or np.array, the result is a dict with 
-                key-value pairs:
+                the following key-value pairs:
-                {"label map": `label map`, "score_map": `score map`}.
+                label_map (np.ndarray): Predicted label map (HW).
+                score_map (np.ndarray): Prediction score map (HWC).
+
             If `img_file` is a list, the result is a list composed of dicts with the 
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
+                above keys.
-                label_map (np.ndarray): the predicted label map (HW)
-                score_map (np.ndarray): the prediction score map (HWC)
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -787,11 +794,11 @@ class BaseChangeDetector(BaseModel):
                 elif item[0] == 'padding':
                 elif item[0] == 'padding':
                     x, y = item[2]
                     x, y = item[2]
                     if isinstance(label_map, np.ndarray):
                     if isinstance(label_map, np.ndarray):
-                        label_map = label_map[..., y:y + h, x:x + w]
+                        label_map = label_map[y:y + h, x:x + w]
-                        score_map = score_map[..., y:y + h, x:x + w]
+                        score_map = score_map[y:y + h, x:x + w]
                     else:
                     else:
-                        label_map = label_map[:, :, y:y + h, x:x + w]
+                        label_map = label_map[:, y:y + h, x:x + w, :]
-                        score_map = score_map[:, :, y:y + h, x:x + w]
+                        score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
                 else:
                     pass
                     pass
             label_map = label_map.squeeze()
             label_map = label_map.squeeze()

+ 64 - 41
paddlers/tasks/classifier.py

@@ -83,6 +83,11 @@ class BaseClassifier(BaseModel):
                 self.in_channels = 3
                 self.in_channels = 3
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
             if self.test_transforms is not None:
@@ -373,7 +378,8 @@ class BaseClassifier(BaseModel):
                 Defaults to False.
                 Defaults to False.
 
 
         Returns:
         Returns:
-            collections.OrderedDict with key-value pairs:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
                 {"top1": `acc of top1`,
                 {"top1": `acc of top1`,
                  "top5": `acc of top5`}.
                  "top5": `acc of top5`}.
         """
         """
@@ -389,38 +395,37 @@ class BaseClassifier(BaseModel):
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
 
 
-        batch_size_each_card = get_single_card_bs(batch_size)
+        if batch_size > 1:
-        if batch_size_each_card > 1:
-            batch_size_each_card = 1
-            batch_size = batch_size_each_card * paddlers.env_info['num']
             logging.warning(
             logging.warning(
-                "Classifier only supports batch_size=1 for each gpu/cpu card " \
+                "Classifier only supports single card evaluation with batch_size=1 "
-                "during evaluation, so batch_size " \
+                "during evaluation, so batch_size is forcibly set to 1.")
-                "is forcibly set to {}.".format(batch_size))
+            batch_size = 1
-        self.eval_data_loader = self.build_data_loader(
+
-            eval_dataset, batch_size=batch_size, mode='eval')
+        if nranks < 2 or local_rank == 0:
-
+            self.eval_data_loader = self.build_data_loader(
-        logging.info(
+                eval_dataset, batch_size=batch_size, mode='eval')
-            "Start to evaluate(total_samples={}, total_steps={})...".format(
+            logging.info(
-                eval_dataset.num_samples,
+                "Start to evaluate(total_samples={}, total_steps={})...".format(
-                math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
+                    eval_dataset.num_samples, eval_dataset.num_samples))
-
+
-        top1s = []
+            top1s = []
-        top5s = []
+            top5s = []
-        with paddle.no_grad():
+            with paddle.no_grad():
-            for step, data in enumerate(self.eval_data_loader):
+                for step, data in enumerate(self.eval_data_loader):
-                data.append(eval_dataset.transforms.transforms)
+                    data.append(eval_dataset.transforms.transforms)
-                outputs = self.run(self.net, data, 'eval')
+                    outputs = self.run(self.net, data, 'eval')
-                top1s.append(outputs["top1"])
+                    top1s.append(outputs["top1"])
-                top5s.append(outputs["top5"])
+                    top5s.append(outputs["top5"])
-
+
-        top1 = np.mean(top1s)
+            top1 = np.mean(top1s)
-        top5 = np.mean(top5s)
+            top5 = np.mean(top5s)
-        eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
+            eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
-        if return_details:
+
-            # TODO: add details
+            if return_details:
-            return eval_metrics, None
+                # TODO: Add details
-        return eval_metrics
+                return eval_metrics, None
+
+            return eval_metrics
 
 
     def predict(self, img_file, transforms=None):
     def predict(self, img_file, transforms=None):
         """
         """
@@ -435,16 +440,14 @@ class BaseClassifier(BaseModel):
                 Defaults to None.
                 Defaults to None.
 
 
         Returns:
         Returns:
-            If `img_file` is a string or np.array, the result is a dict with key-value 
+            If `img_file` is a string or np.array, the result is a dict with the 
-                pairs:
+                following key-value pairs:
-                {"label map": `class_ids_map`, 
+                class_ids_map (np.ndarray): IDs of predicted classes.
-                 "scores_map": `scores_map`, 
+                scores_map (np.ndarray): Scores of predicted classes.
-                 "label_names_map": `label_names_map`}.
+                label_names_map (np.ndarray): Names of predicted classes.
+            
             If `img_file` is a list, the result is a list composed of dicts with the 
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
+                above keys.
-                class_ids_map (np.ndarray): class_ids
-                scores_map (np.ndarray): scores
-                label_names_map (np.ndarray): label_names
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -555,6 +558,26 @@ class BaseClassifier(BaseModel):
             raise TypeError(
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeClassifier object.")
                 "`transforms.arrange` must be an ArrangeClassifier object.")
 
 
+    def build_data_loader(self, dataset, batch_size, mode='train'):
+        if dataset.num_samples < batch_size:
+            raise ValueError(
+                'The volume of dataset({}) must be larger than batch size({}).'
+                .format(dataset.num_samples, batch_size))
+
+        if mode != 'train':
+            return paddle.io.DataLoader(
+                dataset,
+                batch_size=batch_size,
+                shuffle=dataset.shuffle,
+                drop_last=False,
+                collate_fn=dataset.batch_transforms,
+                num_workers=dataset.num_workers,
+                return_list=True,
+                use_shared_memory=False)
+        else:
+            return super(BaseClassifier, self).build_data_loader(
+                dataset, batch_size, mode)
+
 
 
 class ResNet50_vd(BaseClassifier):
 class ResNet50_vd(BaseClassifier):
     def __init__(self,
     def __init__(self,

+ 0 - 786
paddlers/tasks/image_restorer.py

@@ -1,786 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-import datetime
-
-import paddle
-from paddle.distributed import ParallelEnv
-
-from ..models.ppgan.datasets.builder import build_dataloader
-from ..models.ppgan.models.builder import build_model
-from ..models.ppgan.utils.visual import tensor2img, save_image
-from ..models.ppgan.utils.filesystem import makedirs, save, load
-from ..models.ppgan.utils.timer import TimeAverager
-from ..models.ppgan.utils.profiler import add_profiler_step
-from ..models.ppgan.utils.logger import setup_logger
-
-
-# 定义AttrDict类实现动态属性
-class AttrDict(dict):
-    def __getattr__(self, key):
-        try:
-            return self[key]
-        except KeyError:
-            raise AttributeError(key)
-
-    def __setattr__(self, key, value):
-        if key in self.__dict__:
-            self.__dict__[key] = value
-        else:
-            self[key] = value
-
-
-# 创建AttrDict类
-def create_attr_dict(config_dict):
-    from ast import literal_eval
-    for key, value in config_dict.items():
-        if type(value) is dict:
-            config_dict[key] = value = AttrDict(value)
-        if isinstance(value, str):
-            try:
-                value = literal_eval(value)
-            except BaseException:
-                pass
-        if isinstance(value, AttrDict):
-            create_attr_dict(config_dict[key])
-        else:
-            config_dict[key] = value
-
-
-# 数据加载类
-class IterLoader:
-    def __init__(self, dataloader):
-        self._dataloader = dataloader
-        self.iter_loader = iter(self._dataloader)
-        self._epoch = 1
-
-    @property
-    def epoch(self):
-        return self._epoch
-
-    def __next__(self):
-        try:
-            data = next(self.iter_loader)
-        except StopIteration:
-            self._epoch += 1
-            self.iter_loader = iter(self._dataloader)
-            data = next(self.iter_loader)
-
-        return data
-
-    def __len__(self):
-        return len(self._dataloader)
-
-
-# 基础训练类
-class Restorer:
-    """
-    # trainer calling logic:
-    #
-    #                build_model                               ||    model(BaseModel)
-    #                     |                                    ||
-    #               build_dataloader                           ||    dataloader
-    #                     |                                    ||
-    #               model.setup_lr_schedulers                  ||    lr_scheduler
-    #                     |                                    ||
-    #               model.setup_optimizers                     ||    optimizers
-    #                     |                                    ||
-    #     train loop (model.setup_input + model.train_iter)    ||    train loop
-    #                     |                                    ||
-    #         print log (model.get_current_losses)             ||
-    #                     |                                    ||
-    #         save checkpoint (model.nets)                     \/
-    """
-
-    def __init__(self, cfg, logger):
-        # base config
-        # self.logger = logging.getLogger(__name__)
-        self.logger = logger
-        self.cfg = cfg
-        self.output_dir = cfg.output_dir
-        self.max_eval_steps = cfg.model.get('max_eval_steps', None)
-
-        self.local_rank = ParallelEnv().local_rank
-        self.world_size = ParallelEnv().nranks
-        self.log_interval = cfg.log_config.interval
-        self.visual_interval = cfg.log_config.visiual_interval
-        self.weight_interval = cfg.snapshot_config.interval
-
-        self.start_epoch = 1
-        self.current_epoch = 1
-        self.current_iter = 1
-        self.inner_iter = 1
-        self.batch_id = 0
-        self.global_steps = 0
-
-        # build model
-        self.model = build_model(cfg.model)
-        # multiple gpus prepare
-        if ParallelEnv().nranks > 1:
-            self.distributed_data_parallel()
-
-        # build metrics
-        self.metrics = None
-        self.is_save_img = True
-        validate_cfg = cfg.get('validate', None)
-        if validate_cfg and 'metrics' in validate_cfg:
-            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
-        if validate_cfg and 'save_img' in validate_cfg:
-            self.is_save_img = validate_cfg['save_img']
-
-        self.enable_visualdl = cfg.get('enable_visualdl', False)
-        if self.enable_visualdl:
-            import visualdl
-            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
-
-        # evaluate only
-        if not cfg.is_train:
-            return
-
-        # build train dataloader
-        self.train_dataloader = build_dataloader(cfg.dataset.train)
-        self.iters_per_epoch = len(self.train_dataloader)
-
-        # build lr scheduler
-        # TODO: has a better way?
-        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
-            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
-        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
-
-        # build optimizers
-        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
-                                                      cfg.optimizer)
-
-        self.epochs = cfg.get('epochs', None)
-        if self.epochs:
-            self.total_iters = self.epochs * self.iters_per_epoch
-            self.by_epoch = True
-        else:
-            self.by_epoch = False
-            self.total_iters = cfg.total_iters
-
-        if self.by_epoch:
-            self.weight_interval *= self.iters_per_epoch
-
-        self.validate_interval = -1
-        if cfg.get('validate', None) is not None:
-            self.validate_interval = cfg.validate.get('interval', -1)
-
-        self.time_count = {}
-        self.best_metric = {}
-        self.model.set_total_iter(self.total_iters)
-        self.profiler_options = cfg.profiler_options
-
-    def distributed_data_parallel(self):
-        paddle.distributed.init_parallel_env()
-        find_unused_parameters = self.cfg.get('find_unused_parameters', False)
-        for net_name, net in self.model.nets.items():
-            self.model.nets[net_name] = paddle.DataParallel(
-                net, find_unused_parameters=find_unused_parameters)
-
-    def learning_rate_scheduler_step(self):
-        if isinstance(self.model.lr_scheduler, dict):
-            for lr_scheduler in self.model.lr_scheduler.values():
-                lr_scheduler.step()
-        elif isinstance(self.model.lr_scheduler,
-                        paddle.optimizer.lr.LRScheduler):
-            self.model.lr_scheduler.step()
-        else:
-            raise ValueError(
-                'lr schedulter must be a dict or an instance of LRScheduler')
-
-    def train(self):
-        reader_cost_averager = TimeAverager()
-        batch_cost_averager = TimeAverager()
-
-        iter_loader = IterLoader(self.train_dataloader)
-
-        # set model.is_train = True
-        self.model.setup_train_mode(is_train=True)
-        while self.current_iter < (self.total_iters + 1):
-            self.current_epoch = iter_loader.epoch
-            self.inner_iter = self.current_iter % self.iters_per_epoch
-
-            add_profiler_step(self.profiler_options)
-
-            start_time = step_start_time = time.time()
-            data = next(iter_loader)
-            reader_cost_averager.record(time.time() - step_start_time)
-            # unpack data from dataset and apply preprocessing
-            # data input should be dict
-            self.model.setup_input(data)
-            self.model.train_iter(self.optimizers)
-
-            batch_cost_averager.record(
-                time.time() - step_start_time,
-                num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
-
-            step_start_time = time.time()
-
-            if self.current_iter % self.log_interval == 0:
-                self.data_time = reader_cost_averager.get_average()
-                self.step_time = batch_cost_averager.get_average()
-                self.ips = batch_cost_averager.get_ips_average()
-                self.print_log()
-
-                reader_cost_averager.reset()
-                batch_cost_averager.reset()
-
-            if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
-                self.visual('visual_train')
-
-            self.learning_rate_scheduler_step()
-
-            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
-                self.test()
-
-            if self.current_iter % self.weight_interval == 0:
-                self.save(self.current_iter, 'weight', keep=-1)
-                self.save(self.current_iter)
-
-            self.current_iter += 1
-
-    def test(self):
-        if not hasattr(self, 'test_dataloader'):
-            self.test_dataloader = build_dataloader(
-                self.cfg.dataset.test, is_train=False)
-        iter_loader = IterLoader(self.test_dataloader)
-        if self.max_eval_steps is None:
-            self.max_eval_steps = len(self.test_dataloader)
-
-        if self.metrics:
-            for metric in self.metrics.values():
-                metric.reset()
-
-        # set model.is_train = False
-        self.model.setup_train_mode(is_train=False)
-
-        for i in range(self.max_eval_steps):
-            if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
-                self.logger.info('Test iter: [%d/%d]' % (
-                    i * self.world_size, self.max_eval_steps * self.world_size))
-
-            data = next(iter_loader)
-            self.model.setup_input(data)
-            self.model.test_iter(metrics=self.metrics)
-
-            if self.is_save_img:
-                visual_results = {}
-                current_paths = self.model.get_image_paths()
-                current_visuals = self.model.get_current_visuals()
-
-                if len(current_visuals) > 0 and list(current_visuals.values())[
-                        0].shape == 4:
-                    num_samples = list(current_visuals.values())[0].shape[0]
-                else:
-                    num_samples = 1
-
-                for j in range(num_samples):
-                    if j < len(current_paths):
-                        short_path = os.path.basename(current_paths[j])
-                        basename = os.path.splitext(short_path)[0]
-                    else:
-                        basename = '{:04d}_{:04d}'.format(i, j)
-                    for k, img_tensor in current_visuals.items():
-                        name = '%s_%s' % (basename, k)
-                        if len(img_tensor.shape) == 4:
-                            visual_results.update({name: img_tensor[j]})
-                        else:
-                            visual_results.update({name: img_tensor})
-
-                self.visual(
-                    'visual_test',
-                    visual_results=visual_results,
-                    step=self.batch_id,
-                    is_save_image=True)
-
-        if self.metrics:
-            for metric_name, metric in self.metrics.items():
-                self.logger.info("Metric {}: {:.4f}".format(
-                    metric_name, metric.accumulate()))
-
-    def print_log(self):
-        losses = self.model.get_current_losses()
-
-        message = ''
-        if self.by_epoch:
-            message += 'Epoch: %d/%d, iter: %d/%d ' % (
-                self.current_epoch, self.epochs, self.inner_iter,
-                self.iters_per_epoch)
-        else:
-            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
-
-        message += f'lr: {self.current_learning_rate:.3e} '
-
-        for k, v in losses.items():
-            message += '%s: %.3f ' % (k, v)
-            if self.enable_visualdl:
-                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
-
-        if hasattr(self, 'step_time'):
-            message += 'batch_cost: %.5f sec ' % self.step_time
-
-        if hasattr(self, 'data_time'):
-            message += 'reader_cost: %.5f sec ' % self.data_time
-
-        if hasattr(self, 'ips'):
-            message += 'ips: %.5f images/s ' % self.ips
-
-        if hasattr(self, 'step_time'):
-            eta = self.step_time * (self.total_iters - self.current_iter)
-            eta = eta if eta > 0 else 0
-
-            eta_str = str(datetime.timedelta(seconds=int(eta)))
-            message += f'eta: {eta_str}'
-
-        # print the message
-        self.logger.info(message)
-
-    @property
-    def current_learning_rate(self):
-        for optimizer in self.model.optimizers.values():
-            return optimizer.get_lr()
-
-    def visual(self,
-               results_dir,
-               visual_results=None,
-               step=None,
-               is_save_image=False):
-        """
-        visual the images, use visualdl or directly write to the directory
-        Parameters:
-            results_dir (str)     --  directory name which contains saved images
-            visual_results (dict) --  the results images dict
-            step (int)            --  global steps, used in visualdl
-            is_save_image (bool)  --  weather write to the directory or visualdl
-        """
-        self.model.compute_visuals()
-
-        if visual_results is None:
-            visual_results = self.model.get_current_visuals()
-
-        min_max = self.cfg.get('min_max', None)
-        if min_max is None:
-            min_max = (-1., 1.)
-
-        image_num = self.cfg.get('image_num', None)
-        if (image_num is None) or (not self.enable_visualdl):
-            image_num = 1
-        for label, image in visual_results.items():
-            image_numpy = tensor2img(image, min_max, image_num)
-            if (not is_save_image) and self.enable_visualdl:
-                self.vdl_logger.add_image(
-                    results_dir + '/' + label,
-                    image_numpy,
-                    step=step if step else self.global_steps,
-                    dataformats="HWC" if image_num == 1 else "NCHW")
-            else:
-                if self.cfg.is_train:
-                    if self.by_epoch:
-                        msg = 'epoch%.3d_' % self.current_epoch
-                    else:
-                        msg = 'iter%.3d_' % self.current_iter
-                else:
-                    msg = ''
-                makedirs(os.path.join(self.output_dir, results_dir))
-                img_path = os.path.join(self.output_dir, results_dir,
-                                        msg + '%s.png' % (label))
-                save_image(image_numpy, img_path)
-
-    def save(self, epoch, name='checkpoint', keep=1):
-        if self.local_rank != 0:
-            return
-
-        assert name in ['checkpoint', 'weight']
-
-        state_dicts = {}
-        if self.by_epoch:
-            save_filename = 'epoch_%s_%s.pdparams' % (
-                epoch // self.iters_per_epoch, name)
-        else:
-            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
-
-        os.makedirs(self.output_dir, exist_ok=True)
-        save_path = os.path.join(self.output_dir, save_filename)
-        for net_name, net in self.model.nets.items():
-            state_dicts[net_name] = net.state_dict()
-
-        if name == 'weight':
-            save(state_dicts, save_path)
-            return
-
-        state_dicts['epoch'] = epoch
-
-        for opt_name, opt in self.model.optimizers.items():
-            state_dicts[opt_name] = opt.state_dict()
-
-        save(state_dicts, save_path)
-
-        if keep > 0:
-            try:
-                if self.by_epoch:
-                    checkpoint_name_to_be_removed = os.path.join(
-                        self.output_dir, 'epoch_%s_%s.pdparams' % (
-                            (epoch - keep * self.weight_interval) //
-                            self.iters_per_epoch, name))
-                else:
-                    checkpoint_name_to_be_removed = os.path.join(
-                        self.output_dir, 'iter_%s_%s.pdparams' %
-                        (epoch - keep * self.weight_interval, name))
-
-                if os.path.exists(checkpoint_name_to_be_removed):
-                    os.remove(checkpoint_name_to_be_removed)
-
-            except Exception as e:
-                self.logger.info('remove old checkpoints error: {}'.format(e))
-
-    def resume(self, checkpoint_path):
-        state_dicts = load(checkpoint_path)
-        if state_dicts.get('epoch', None) is not None:
-            self.start_epoch = state_dicts['epoch'] + 1
-            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
-
-            self.current_iter = state_dicts['epoch'] + 1
-
-        for net_name, net in self.model.nets.items():
-            net.set_state_dict(state_dicts[net_name])
-
-        for opt_name, opt in self.model.optimizers.items():
-            opt.set_state_dict(state_dicts[opt_name])
-
-    def load(self, weight_path):
-        state_dicts = load(weight_path)
-
-        for net_name, net in self.model.nets.items():
-            if net_name in state_dicts:
-                net.set_state_dict(state_dicts[net_name])
-                self.logger.info('Loaded pretrained weight for net {}'.format(
-                    net_name))
-            else:
-                self.logger.warning(
-                    'Can not find state dict of net {}. Skip load pretrained weight for net {}'
-                    .format(net_name, net_name))
-
-    def close(self):
-        """
-        when finish the training need close file handler or other.
-        """
-        if self.enable_visualdl:
-            self.vdl_logger.close()
-
-
-# 基础超分模型训练类
-class BasicSRNet:
-    def __init__(self):
-        self.model = {}
-        self.optimizer = {}
-        self.lr_scheduler = {}
-        self.min_max = ''
-
-    def train(
-            self,
-            total_iters,
-            train_dataset,
-            test_dataset,
-            output_dir,
-            validate,
-            snapshot,
-            log,
-            lr_rate,
-            evaluate_weights='',
-            resume='',
-            pretrain_weights='',
-            periods=[100000],
-            restart_weights=[1], ):
-        self.lr_scheduler['learning_rate'] = lr_rate
-
-        if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR':
-            self.lr_scheduler['periods'] = periods
-            self.lr_scheduler['restart_weights'] = restart_weights
-
-        validate = {
-            'interval': validate,
-            'save_img': False,
-            'metrics': {
-                'psnr': {
-                    'name': 'PSNR',
-                    'crop_border': 4,
-                    'test_y_channel': True
-                },
-                'ssim': {
-                    'name': 'SSIM',
-                    'crop_border': 4,
-                    'test_y_channel': True
-                }
-            }
-        }
-        log_config = {'interval': log, 'visiual_interval': 500}
-        snapshot_config = {'interval': snapshot}
-
-        cfg = {
-            'total_iters': total_iters,
-            'output_dir': output_dir,
-            'min_max': self.min_max,
-            'model': self.model,
-            'dataset': {
-                'train': train_dataset,
-                'test': test_dataset
-            },
-            'lr_scheduler': self.lr_scheduler,
-            'optimizer': self.optimizer,
-            'validate': validate,
-            'log_config': log_config,
-            'snapshot_config': snapshot_config
-        }
-
-        cfg = AttrDict(cfg)
-        create_attr_dict(cfg)
-
-        cfg.is_train = True
-        cfg.profiler_options = None
-        cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
-
-        if cfg.model.name == 'BaseSRModel':
-            floderModelName = cfg.model.generator.name
-        else:
-            floderModelName = cfg.model.name
-        cfg.output_dir = os.path.join(cfg.output_dir,
-                                      floderModelName + cfg.timestamp)
-
-        logger_cfg = setup_logger(cfg.output_dir)
-        logger_cfg.info('Configs: {}'.format(cfg))
-
-        if paddle.is_compiled_with_cuda():
-            paddle.set_device('gpu')
-        else:
-            paddle.set_device('cpu')
-
-        # build trainer
-        trainer = Restorer(cfg, logger_cfg)
-
-        # continue train or evaluate, checkpoint need contain epoch and optimizer info
-        if len(resume) > 0:
-            trainer.resume(resume)
-        # evaluate or finute, only load generator weights
-        elif len(pretrain_weights) > 0:
-            trainer.load(pretrain_weights)
-        if len(evaluate_weights) > 0:
-            trainer.load(evaluate_weights)
-            trainer.test()
-            return
-        # training, when keyboard interrupt save weights
-        try:
-            trainer.train()
-        except KeyboardInterrupt as e:
-            trainer.save(trainer.current_epoch)
-
-        trainer.close()
-
-
-# DRN模型训练
-class DRNet(BasicSRNet):
-    def __init__(self,
-                 n_blocks=30,
-                 n_feats=16,
-                 n_colors=3,
-                 rgb_range=255,
-                 negval=0.2):
-        super(DRNet, self).__init__()
-        self.min_max = '(0., 255.)'
-        self.generator = {
-            'name': 'DRNGenerator',
-            'scale': (2, 4),
-            'n_blocks': n_blocks,
-            'n_feats': n_feats,
-            'n_colors': n_colors,
-            'rgb_range': rgb_range,
-            'negval': negval
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'DRN',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'optimG': {
-                'name': 'Adam',
-                'net_names': ['generator'],
-                'weight_decay': 0.0,
-                'beta1': 0.9,
-                'beta2': 0.999
-            },
-            'optimD': {
-                'name': 'Adam',
-                'net_names': ['dual_model_0', 'dual_model_1'],
-                'weight_decay': 0.0,
-                'beta1': 0.9,
-                'beta2': 0.999
-            }
-        }
-        self.lr_scheduler = {
-            'name': 'CosineAnnealingRestartLR',
-            'eta_min': 1e-07
-        }
-
-
-# 轻量化超分模型LESRCNN训练
-class LESRCNNet(BasicSRNet):
-    def __init__(self, scale=4, multi_scale=False, group=1):
-        super(LESRCNNet, self).__init__()
-        self.min_max = '(0., 1.)'
-        self.generator = {
-            'name': 'LESRCNNGenerator',
-            'scale': scale,
-            'multi_scale': False,
-            'group': 1
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'BaseSRModel',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'name': 'Adam',
-            'net_names': ['generator'],
-            'beta1': 0.9,
-            'beta2': 0.99
-        }
-        self.lr_scheduler = {
-            'name': 'CosineAnnealingRestartLR',
-            'eta_min': 1e-07
-        }
-
-
-# ESRGAN模型训练
-# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
-# 若loss_type = 'pixel' 只使用像素损失
-class ESRGANet(BasicSRNet):
-    def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23):
-        super(ESRGANet, self).__init__()
-        self.min_max = '(0., 1.)'
-        self.generator = {
-            'name': 'RRDBNet',
-            'in_nc': in_nc,
-            'out_nc': out_nc,
-            'nf': nf,
-            'nb': nb
-        }
-
-        if loss_type == 'gan':
-            # 定义损失函数
-            self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01}
-            self.discriminator = {
-                'name': 'VGGDiscriminator128',
-                'in_channels': 3,
-                'num_feat': 64
-            }
-            self.perceptual_criterion = {
-                'name': 'PerceptualLoss',
-                'layer_weights': {
-                    '34': 1.0
-                },
-                'perceptual_weight': 1.0,
-                'style_weight': 0.0,
-                'norm_img': False
-            }
-            self.gan_criterion = {
-                'name': 'GANLoss',
-                'gan_mode': 'vanilla',
-                'loss_weight': 0.005
-            }
-            # 定义模型 
-            self.model = {
-                'name': 'ESRGAN',
-                'generator': self.generator,
-                'discriminator': self.discriminator,
-                'pixel_criterion': self.pixel_criterion,
-                'perceptual_criterion': self.perceptual_criterion,
-                'gan_criterion': self.gan_criterion
-            }
-            self.optimizer = {
-                'optimG': {
-                    'name': 'Adam',
-                    'net_names': ['generator'],
-                    'weight_decay': 0.0,
-                    'beta1': 0.9,
-                    'beta2': 0.99
-                },
-                'optimD': {
-                    'name': 'Adam',
-                    'net_names': ['discriminator'],
-                    'weight_decay': 0.0,
-                    'beta1': 0.9,
-                    'beta2': 0.99
-                }
-            }
-            self.lr_scheduler = {
-                'name': 'MultiStepDecay',
-                'milestones': [50000, 100000, 200000, 300000],
-                'gamma': 0.5
-            }
-        else:
-            self.pixel_criterion = {'name': 'L1Loss'}
-            self.model = {
-                'name': 'BaseSRModel',
-                'generator': self.generator,
-                'pixel_criterion': self.pixel_criterion
-            }
-            self.optimizer = {
-                'name': 'Adam',
-                'net_names': ['generator'],
-                'beta1': 0.9,
-                'beta2': 0.99
-            }
-            self.lr_scheduler = {
-                'name': 'CosineAnnealingRestartLR',
-                'eta_min': 1e-07
-            }
-
-
-# RCAN模型训练
-class RCANet(BasicSRNet):
-    def __init__(
-            self,
-            scale=2,
-            n_resgroups=10,
-            n_resblocks=20, ):
-        super(RCANet, self).__init__()
-        self.min_max = '(0., 255.)'
-        self.generator = {
-            'name': 'RCAN',
-            'scale': scale,
-            'n_resgroups': n_resgroups,
-            'n_resblocks': n_resblocks
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'RCANModel',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'name': 'Adam',
-            'net_names': ['generator'],
-            'beta1': 0.9,
-            'beta2': 0.99
-        }
-        self.lr_scheduler = {
-            'name': 'MultiStepDecay',
-            'milestones': [250000, 500000, 750000, 1000000],
-            'gamma': 0.5
-        }

+ 36 - 15
paddlers/tasks/object_detector.py

@@ -61,6 +61,11 @@ class BaseDetector(BaseModel):
             net = ppdet.modeling.__dict__[self.model_name](**params)
             net = ppdet.modeling.__dict__[self.model_name](**params)
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
 
 
@@ -485,7 +490,7 @@ class BaseDetector(BaseModel):
                 Defaults to False.
                 Defaults to False.
 
 
         Returns:
         Returns:
-            collections.OrderedDict with key-value pairs: 
+            If `return_details` is False, return collections.OrderedDict with key-value pairs: 
                 {"bbox_mmap":`mean average precision (0.50, 11point)`}.
                 {"bbox_mmap":`mean average precision (0.50, 11point)`}.
         """
         """
 
 
@@ -584,21 +589,17 @@ class BaseDetector(BaseModel):
 
 
         Returns:
         Returns:
             If `img_file` is a string or np.array, the result is a list of dict with 
             If `img_file` is a string or np.array, the result is a list of dict with 
-                key-value pairs:
+                the following key-value pairs:
-                {"category_id": `category_id`, 
+                category_id (int): Predicted category ID. 0 represents the first 
-                 "category": `category`, 
-                 "bbox": `[x, y, w, h]`, 
-                 "score": `score`, 
-                 "mask": `mask`}.
-            If `img_file` is a list, the result is a list composed of list of dicts 
-                with the corresponding fields:
-                category_id(int): the predicted category ID. 0 represents the first 
                     category in the dataset, and so on.
                     category in the dataset, and so on.
-                category(str): category name
+                category (str): Category name.
-                bbox(list): bounding box in [x, y, w, h] format
+                bbox (list): Bounding box in [x, y, w, h] format.
-                score(str): confidence
+                score (str): Confidence.
-                mask(dict): Only for instance segmentation task. Mask of the object in 
+                mask (dict): Only for instance segmentation task. Mask of the object in 
-                    RLE format
+                    RLE format.
+
+            If `img_file` is a list, the result is a list composed of list of dicts 
+                with the above keys.
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -926,6 +927,26 @@ class PicoDet(BaseDetector):
             in_args['optimizer'] = optimizer
             in_args['optimizer'] = optimizer
         return in_args
         return in_args
 
 
+    def build_data_loader(self, dataset, batch_size, mode='train'):
+        if dataset.num_samples < batch_size:
+            raise ValueError(
+                'The volume of dataset({}) must be larger than batch size({}).'
+                .format(dataset.num_samples, batch_size))
+
+        if mode != 'train':
+            return paddle.io.DataLoader(
+                dataset,
+                batch_size=batch_size,
+                shuffle=dataset.shuffle,
+                drop_last=False,
+                collate_fn=dataset.batch_transforms,
+                num_workers=dataset.num_workers,
+                return_list=True,
+                use_shared_memory=False)
+        else:
+            return super(BaseDetector, self).build_data_loader(dataset,
+                                                               batch_size, mode)
+
 
 
 class YOLOv3(BaseDetector):
 class YOLOv3(BaseDetector):
     def __init__(self,
     def __init__(self,

+ 934 - 0
paddlers/tasks/restorer.py

@@ -0,0 +1,934 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import numpy as np
+import cv2
+import paddle
+import paddle.nn.functional as F
+from paddle.static import InputSpec
+
+import paddlers
+import paddlers.models.ppgan as ppgan
+import paddlers.rs_models.res as cmres
+import paddlers.models.ppgan.metrics as metrics
+import paddlers.utils.logging as logging
+from paddlers.models import res_losses
+from paddlers.transforms import Resize, decode_image
+from paddlers.transforms.functions import calc_hr_shape
+from paddlers.utils import get_single_card_bs
+from .base import BaseModel
+from .utils.res_adapters import GANAdapter, OptimizerAdapter
+from .utils.infer_nets import InferResNet
+
+__all__ = ["DRN", "LESRCNN", "ESRGAN"]
+
+
+class BaseRestorer(BaseModel):
+    MIN_MAX = (0., 1.)
+    TEST_OUT_KEY = None
+
+    def __init__(self, model_name, losses=None, sr_factor=None, **params):
+        self.init_params = locals()
+        if 'with_net' in self.init_params:
+            del self.init_params['with_net']
+        super(BaseRestorer, self).__init__('restorer')
+        self.model_name = model_name
+        self.losses = losses
+        self.sr_factor = sr_factor
+        if params.get('with_net', True):
+            params.pop('with_net', None)
+            self.net = self.build_net(**params)
+        self.find_unused_parameters = True
+
+    def build_net(self, **params):
+        # Currently, only use models from cmres.
+        if not hasattr(cmres, self.model_name):
+            raise ValueError("ERROR: There is no model named {}.".format(
+                model_name))
+        net = dict(**cmres.__dict__)[self.model_name](**params)
+        return net
+
+    def _build_inference_net(self):
+        # For GAN models, only the generator will be used for inference.
+        if isinstance(self.net, GANAdapter):
+            infer_net = InferResNet(
+                self.net.generator, out_key=self.TEST_OUT_KEY)
+        else:
+            infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY)
+        infer_net.eval()
+        return infer_net
+
+    def _fix_transforms_shape(self, image_shape):
+        if hasattr(self, 'test_transforms'):
+            if self.test_transforms is not None:
+                has_resize_op = False
+                resize_op_idx = -1
+                normalize_op_idx = len(self.test_transforms.transforms)
+                for idx, op in enumerate(self.test_transforms.transforms):
+                    name = op.__class__.__name__
+                    if name == 'Normalize':
+                        normalize_op_idx = idx
+                    if 'Resize' in name:
+                        has_resize_op = True
+                        resize_op_idx = idx
+
+                if not has_resize_op:
+                    self.test_transforms.transforms.insert(
+                        normalize_op_idx, Resize(target_size=image_shape))
+                else:
+                    self.test_transforms.transforms[resize_op_idx] = Resize(
+                        target_size=image_shape)
+
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            if len(image_shape) == 2:
+                image_shape = [1, 3] + image_shape
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+        self.fixed_input_shape = image_shape
+        input_spec = [
+            InputSpec(
+                shape=image_shape, name='image', dtype='float32')
+        ]
+        return input_spec
+
+    def run(self, net, inputs, mode):
+        outputs = OrderedDict()
+
+        if mode == 'test':
+            tar_shape = inputs[1]
+            if self.status == 'Infer':
+                net_out = net(inputs[0])
+                res_map_list = self.postprocess(
+                    net_out, tar_shape, transforms=inputs[2])
+            else:
+                if isinstance(net, GANAdapter):
+                    net_out = net.generator(inputs[0])
+                else:
+                    net_out = net(inputs[0])
+                if self.TEST_OUT_KEY is not None:
+                    net_out = net_out[self.TEST_OUT_KEY]
+                pred = self.postprocess(
+                    net_out, tar_shape, transforms=inputs[2])
+                res_map_list = []
+                for res_map in pred:
+                    res_map = self._tensor_to_images(res_map)
+                    res_map_list.append(res_map)
+            outputs['res_map'] = res_map_list
+
+        if mode == 'eval':
+            if isinstance(net, GANAdapter):
+                net_out = net.generator(inputs[0])
+            else:
+                net_out = net(inputs[0])
+            if self.TEST_OUT_KEY is not None:
+                net_out = net_out[self.TEST_OUT_KEY]
+            tar = inputs[1]
+            tar_shape = [tar.shape[-2:]]
+            pred = self.postprocess(
+                net_out, tar_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self._tensor_to_images(pred)
+            outputs['pred'] = pred
+            tar = self._tensor_to_images(tar)
+            outputs['tar'] = tar
+
+        if mode == 'train':
+            # This is used by non-GAN models.
+            # For GAN models, self.run_gan() should be used.
+            net_out = net(inputs[0])
+            loss = self.losses(net_out, inputs[1])
+            outputs['loss'] = loss
+        return outputs
+
+    def run_gan(self, net, inputs, mode, gan_mode):
+        raise NotImplementedError
+
+    def default_loss(self):
+        return res_losses.L1Loss()
+
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          num_epochs,
+                          num_steps_each_epoch,
+                          lr_decay_power=0.9):
+        decay_step = num_epochs * num_steps_each_epoch
+        lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+            learning_rate, decay_step, end_lr=0, power=lr_decay_power)
+        optimizer = paddle.optimizer.Momentum(
+            learning_rate=lr_scheduler,
+            parameters=parameters,
+            momentum=0.9,
+            weight_decay=4e-5)
+        return optimizer
+
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=2,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=2,
+              save_dir='output',
+              pretrain_weights=None,
+              learning_rate=0.01,
+              lr_decay_power=0.9,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+
+        Args:
+            num_epochs (int): Number of epochs.
+            train_dataset (paddlers.datasets.ResDataset): Training dataset.
+            train_batch_size (int, optional): Total batch size among all cards used in 
+                training. Defaults to 2.
+            eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. 
+                If None, the model will not be evaluated during training process. 
+                Defaults to None.
+            optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in 
+                training. If None, a default optimizer will be used. Defaults to None.
+            save_interval_epochs (int, optional): Epoch interval for saving the model. 
+                Defaults to 1.
+            log_interval_steps (int, optional): Step interval for printing training 
+                information. Defaults to 2.
+            save_dir (str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights (str|None, optional): None or name/path of pretrained 
+                weights. If None, no pretrained weights will be loaded. 
+                Defaults to None.
+            learning_rate (float, optional): Learning rate for training. Defaults to .01.
+            lr_decay_power (float, optional): Learning decay power. Defaults to .9.
+            early_stop (bool, optional): Whether to adopt early stop strategy. Defaults 
+                to False.
+            early_stop_patience (int, optional): Early stop patience. Defaults to 5.
+            use_vdl (bool, optional): Whether to use VisualDL to monitor the training 
+                process. Defaults to True.
+            resume_checkpoint (str|None, optional): Path of the checkpoint to resume
+                training from. If None, no training checkpoint will be resumed. At most
+                Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
+                Defaults to None.
+        """
+
+        if self.status == 'Infer':
+            logging.error(
+                "Exported inference model does not support training.",
+                exit=True)
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
+
+        if self.losses is None:
+            self.losses = self.default_loss()
+
+        if optimizer is None:
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            if isinstance(self.net, GANAdapter):
+                parameters = {'params_g': [], 'params_d': []}
+                for net_g in self.net.generators:
+                    parameters['params_g'].append(net_g.parameters())
+                for net_d in self.net.discriminators:
+                    parameters['params_d'].append(net_d.parameters())
+            else:
+                parameters = self.net.parameters()
+            self.optimizer = self.default_optimizer(
+                parameters, learning_rate, num_epochs, num_steps_each_epoch,
+                lr_decay_power)
+        else:
+            self.optimizer = optimizer
+
+        if pretrain_weights is not None and not osp.exists(pretrain_weights):
+            logging.warning("Path of pretrain_weights('{}') does not exist!".
+                            format(pretrain_weights))
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
+        pretrained_dir = osp.join(save_dir, 'pretrain')
+        is_backbone_weights = pretrain_weights == 'IMAGENET'
+        self.net_initialize(
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint,
+            is_backbone_weights=is_backbone_weights)
+
+        self.train_loop(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=2,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=2,
+                          save_dir='output',
+                          learning_rate=0.0001,
+                          lr_decay_power=0.9,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          resume_checkpoint=None,
+                          quant_config=None):
+        """
+        Quantization-aware training.
+
+        Args:
+            num_epochs (int): Number of epochs.
+            train_dataset (paddlers.datasets.ResDataset): Training dataset.
+            train_batch_size (int, optional): Total batch size among all cards used in 
+                training. Defaults to 2.
+            eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset.
+                If None, the model will not be evaluated during training process. 
+                Defaults to None.
+            optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in 
+                training. If None, a default optimizer will be used. Defaults to None.
+            save_interval_epochs (int, optional): Epoch interval for saving the model. 
+                Defaults to 1.
+            log_interval_steps (int, optional): Step interval for printing training 
+                information. Defaults to 2.
+            save_dir (str, optional): Directory to save the model. Defaults to 'output'.
+            learning_rate (float, optional): Learning rate for training. 
+                Defaults to .0001.
+            lr_decay_power (float, optional): Learning decay power. Defaults to .9.
+            early_stop (bool, optional): Whether to adopt early stop strategy. 
+                Defaults to False.
+            early_stop_patience (int, optional): Early stop patience. Defaults to 5.
+            use_vdl (bool, optional): Whether to use VisualDL to monitor the training 
+                process. Defaults to True.
+            quant_config (dict|None, optional): Quantization configuration. If None, 
+                a default rule of thumb configuration will be used. Defaults to None.
+            resume_checkpoint (str|None, optional): Path of the checkpoint to resume
+                quantization-aware training from. If None, no training checkpoint will
+                be resumed. Defaults to None.
+        """
+
+        self._prepare_qat(quant_config)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=None,
+            learning_rate=learning_rate,
+            lr_decay_power=lr_decay_power,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
+
+    def evaluate(self, eval_dataset, batch_size=1, return_details=False):
+        """
+        Evaluate the model.
+
+        Args:
+            eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset.
+            batch_size (int, optional): Total batch size among all cards used for 
+                evaluation. Defaults to 1.
+            return_details (bool, optional): Whether to return evaluation details. 
+                Defaults to False.
+
+        Returns:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
+                {"psnr": `peak signal-to-noise ratio`,
+                 "ssim": `structural similarity`}.
+
+        """
+
+        self._check_transforms(eval_dataset.transforms, 'eval')
+
+        self.net.eval()
+        nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
+        if nranks > 1:
+            # Initialize parallel environment if not done.
+            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+            ):
+                paddle.distributed.init_parallel_env()
+
+        # TODO: Distributed evaluation
+        if batch_size > 1:
+            logging.warning(
+                "Restorer only supports single card evaluation with batch_size=1 "
+                "during evaluation, so batch_size is forcibly set to 1.")
+            batch_size = 1
+
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset, batch_size=batch_size, mode='eval')
+            # XXX: Hard-code crop_border and test_y_channel
+            psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
+            ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
+            logging.info(
+                "Start to evaluate(total_samples={}, total_steps={})...".format(
+                    eval_dataset.num_samples, eval_dataset.num_samples))
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    data.append(eval_dataset.transforms.transforms)
+                    outputs = self.run(self.net, data, 'eval')
+                    psnr.update(outputs['pred'], outputs['tar'])
+                    ssim.update(outputs['pred'], outputs['tar'])
+
+            # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
+            assert len(psnr.results) > 0
+            assert len(ssim.results) > 0
+            eval_metrics = OrderedDict(
+                zip(['psnr', 'ssim'],
+                    [np.mean(psnr.results), np.mean(ssim.results)]))
+
+            if return_details:
+                # TODO: Add details
+                return eval_metrics, None
+
+            return eval_metrics
+
+    def predict(self, img_file, transforms=None):
+        """
+        Do inference.
+
+        Args:
+            img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded 
+                image data, which also could constitute a list, meaning all images to be 
+                predicted as a mini-batch.
+            transforms (paddlers.transforms.Compose|None, optional): Transforms for 
+                inputs. If None, the transforms for evaluation process will be used. 
+                Defaults to None.
+
+        Returns:
+            If `img_file` is a tuple of string or np.array, the result is a dict with 
+                the following key-value pairs:
+                res_map (np.ndarray): Restored image (HWC).
+
+            If `img_file` is a list, the result is a list composed of dicts with the 
+                above keys.
+        """
+
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise ValueError("transforms need to be defined, now is None.")
+        if transforms is None:
+            transforms = self.test_transforms
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            images = img_file
+        batch_im, batch_tar_shape = self.preprocess(images, transforms,
+                                                    self.model_type)
+        self.net.eval()
+        data = (batch_im, batch_tar_shape, transforms.transforms)
+        outputs = self.run(self.net, data, 'test')
+        res_map_list = outputs['res_map']
+        if isinstance(img_file, list):
+            prediction = [{'res_map': m} for m in res_map_list]
+        else:
+            prediction = {'res_map': res_map_list[0]}
+        return prediction
+
+    def preprocess(self, images, transforms, to_tensor=True):
+        self._check_transforms(transforms, 'test')
+        batch_im = list()
+        batch_tar_shape = list()
+        for im in images:
+            if isinstance(im, str):
+                im = decode_image(im, to_rgb=False)
+            ori_shape = im.shape[:2]
+            sample = {'image': im}
+            im = transforms(sample)[0]
+            batch_im.append(im)
+            batch_tar_shape.append(self._get_target_shape(ori_shape))
+        if to_tensor:
+            batch_im = paddle.to_tensor(batch_im)
+        else:
+            batch_im = np.asarray(batch_im)
+
+        return batch_im, batch_tar_shape
+
+    def _get_target_shape(self, ori_shape):
+        if self.sr_factor is None:
+            return ori_shape
+        else:
+            return calc_hr_shape(ori_shape, self.sr_factor)
+
+    @staticmethod
+    def get_transforms_shape_info(batch_tar_shape, transforms):
+        batch_restore_list = list()
+        for tar_shape in batch_tar_shape:
+            restore_list = list()
+            h, w = tar_shape[0], tar_shape[1]
+            for op in transforms:
+                if op.__class__.__name__ == 'Resize':
+                    restore_list.append(('resize', (h, w)))
+                    h, w = op.target_size
+                elif op.__class__.__name__ == 'ResizeByShort':
+                    restore_list.append(('resize', (h, w)))
+                    im_short_size = min(h, w)
+                    im_long_size = max(h, w)
+                    scale = float(op.short_size) / float(im_short_size)
+                    if 0 < op.max_size < np.round(scale * im_long_size):
+                        scale = float(op.max_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'ResizeByLong':
+                    restore_list.append(('resize', (h, w)))
+                    im_long_size = max(h, w)
+                    scale = float(op.long_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'Pad':
+                    if op.target_size:
+                        target_h, target_w = op.target_size
+                    else:
+                        target_h = int(
+                            (np.ceil(h / op.size_divisor) * op.size_divisor))
+                        target_w = int(
+                            (np.ceil(w / op.size_divisor) * op.size_divisor))
+
+                    if op.pad_mode == -1:
+                        offsets = op.offsets
+                    elif op.pad_mode == 0:
+                        offsets = [0, 0]
+                    elif op.pad_mode == 1:
+                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
+                    else:
+                        offsets = [target_h - h, target_w - w]
+                    restore_list.append(('padding', (h, w), offsets))
+                    h, w = target_h, target_w
+
+            batch_restore_list.append(restore_list)
+        return batch_restore_list
+
+    def postprocess(self, batch_pred, batch_tar_shape, transforms):
+        batch_restore_list = BaseRestorer.get_transforms_shape_info(
+            batch_tar_shape, transforms)
+        if self.status == 'Infer':
+            return self._infer_postprocess(
+                batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
+        results = []
+        if batch_pred.dtype == paddle.float32:
+            mode = 'bilinear'
+        else:
+            mode = 'nearest'
+        for pred, restore_list in zip(batch_pred, batch_restore_list):
+            pred = paddle.unsqueeze(pred, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    pred = F.interpolate(
+                        pred, (h, w), mode=mode, data_format='NCHW')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    pred = pred[:, :, y:y + h, x:x + w]
+                else:
+                    pass
+            results.append(pred)
+        return results
+
+    def _infer_postprocess(self, batch_res_map, batch_restore_list):
+        res_maps = []
+        for res_map, restore_list in zip(batch_res_map, batch_restore_list):
+            if not isinstance(res_map, np.ndarray):
+                res_map = paddle.unsqueeze(res_map, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    if isinstance(res_map, np.ndarray):
+                        res_map = cv2.resize(
+                            res_map, (w, h), interpolation=cv2.INTER_LINEAR)
+                    else:
+                        res_map = F.interpolate(
+                            res_map, (h, w),
+                            mode='bilinear',
+                            data_format='NHWC')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    if isinstance(res_map, np.ndarray):
+                        res_map = res_map[y:y + h, x:x + w]
+                    else:
+                        res_map = res_map[:, y:y + h, x:x + w, :]
+                else:
+                    pass
+            res_map = res_map.squeeze()
+            if not isinstance(res_map, np.ndarray):
+                res_map = res_map.numpy()
+            res_map = self._normalize(res_map)
+            res_maps.append(res_map.squeeze())
+        return res_maps
+
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeRestorer):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeRestorer object.")
+
+    def build_data_loader(self, dataset, batch_size, mode='train'):
+        if dataset.num_samples < batch_size:
+            raise ValueError(
+                'The volume of dataset({}) must be larger than batch size({}).'
+                .format(dataset.num_samples, batch_size))
+
+        if mode != 'train':
+            return paddle.io.DataLoader(
+                dataset,
+                batch_size=batch_size,
+                shuffle=dataset.shuffle,
+                drop_last=False,
+                collate_fn=dataset.batch_transforms,
+                num_workers=dataset.num_workers,
+                return_list=True,
+                use_shared_memory=False)
+        else:
+            return super(BaseRestorer, self).build_data_loader(dataset,
+                                                               batch_size, mode)
+
+    def set_losses(self, losses):
+        self.losses = losses
+
+    def _tensor_to_images(self,
+                          tensor,
+                          transpose=True,
+                          squeeze=True,
+                          quantize=True):
+        if transpose:
+            tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1])  # NHWC
+        if squeeze:
+            tensor = tensor.squeeze()
+        images = tensor.numpy().astype('float32')
+        images = self._normalize(
+            images, copy=True, clip=True, quantize=quantize)
+        return images
+
+    def _normalize(self, im, copy=False, clip=True, quantize=True):
+        if copy:
+            im = im.copy()
+        if clip:
+            im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
+        im -= im.min()
+        im /= im.max() + 1e-32
+        if quantize:
+            im *= 255
+            im = im.astype('uint8')
+        return im
+
+
+class DRN(BaseRestorer):
+    TEST_OUT_KEY = -1
+
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 scale=(2, 4),
+                 n_blocks=30,
+                 n_feats=16,
+                 n_colors=3,
+                 rgb_range=1.0,
+                 negval=0.2,
+                 lq_loss_weight=0.1,
+                 dual_loss_weight=0.1,
+                 **params):
+        if sr_factor != max(scale):
+            raise ValueError(f"`sr_factor` must be equal to `max(scale)`.")
+        params.update({
+            'scale': scale,
+            'n_blocks': n_blocks,
+            'n_feats': n_feats,
+            'n_colors': n_colors,
+            'rgb_range': rgb_range,
+            'negval': negval
+        })
+        self.lq_loss_weight = lq_loss_weight
+        self.dual_loss_weight = dual_loss_weight
+        super(DRN, self).__init__(
+            model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        from ppgan.modules.init import init_weights
+        generators = [ppgan.models.generators.DRNGenerator(**params)]
+        init_weights(generators[-1])
+        for scale in params['scale']:
+            dual_model = ppgan.models.generators.drn.DownBlock(
+                params['negval'], params['n_feats'], params['n_colors'], 2)
+            generators.append(dual_model)
+            init_weights(generators[-1])
+        return GANAdapter(generators, [])
+
+    def default_optimizer(self, parameters, *args, **kwargs):
+        optims_g = [
+            super(DRN, self).default_optimizer(params_g, *args, **kwargs)
+            for params_g in parameters['params_g']
+        ]
+        return OptimizerAdapter(*optims_g)
+
+    def run_gan(self, net, inputs, mode, gan_mode='forward_primary'):
+        if mode != 'train':
+            raise ValueError("`mode` is not 'train'.")
+        outputs = OrderedDict()
+        if gan_mode == 'forward_primary':
+            sr = net.generator(inputs[0])
+            lr = [inputs[0]]
+            lr.extend([
+                F.interpolate(
+                    inputs[0], scale_factor=s, mode='bicubic')
+                for s in net.generator.scale[:-1]
+            ])
+            loss = self.losses(sr[-1], inputs[1])
+            for i in range(1, len(sr)):
+                if self.lq_loss_weight > 0:
+                    loss += self.losses(sr[i - 1 - len(sr)],
+                                        lr[i - len(sr)]) * self.lq_loss_weight
+            outputs['loss_prim'] = loss
+            outputs['sr'] = sr
+            outputs['lr'] = lr
+        elif gan_mode == 'forward_dual':
+            sr, lr = inputs[0], inputs[1]
+            sr2lr = []
+            n_scales = len(net.generator.scale)
+            for i in range(n_scales):
+                sr2lr_i = net.generators[1 + i](sr[i - n_scales])
+                sr2lr.append(sr2lr_i)
+            loss = self.losses(sr2lr[0], lr[0])
+            for i in range(1, n_scales):
+                if self.dual_loss_weight > 0.0:
+                    loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight
+            outputs['loss_dual'] = loss
+        else:
+            raise ValueError("Invalid `gan_mode`!")
+        return outputs
+
+    def train_step(self, step, data, net):
+        outputs = self.run_gan(
+            net, data, mode='train', gan_mode='forward_primary')
+        outputs.update(
+            self.run_gan(
+                net, (outputs['sr'], outputs['lr']),
+                mode='train',
+                gan_mode='forward_dual'))
+        self.optimizer.clear_grad()
+        (outputs['loss_prim'] + outputs['loss_dual']).backward()
+        self.optimizer.step()
+        return {
+            'loss_prim': outputs['loss_prim'],
+            'loss_dual': outputs['loss_dual']
+        }
+
+
+class LESRCNN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 multi_scale=False,
+                 group=1,
+                 **params):
+        params.update({
+            'scale': sr_factor,
+            'multi_scale': multi_scale,
+            'group': group
+        })
+        super(LESRCNN, self).__init__(
+            model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        net = ppgan.models.generators.LESRCNNGenerator(**params)
+        return net
+
+
+class ESRGAN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 use_gan=True,
+                 in_channels=3,
+                 out_channels=3,
+                 nf=64,
+                 nb=23,
+                 **params):
+        if sr_factor != 4:
+            raise ValueError("`sr_factor` must be 4.")
+        params.update({
+            'in_nc': in_channels,
+            'out_nc': out_channels,
+            'nf': nf,
+            'nb': nb
+        })
+        self.use_gan = use_gan
+        super(ESRGAN, self).__init__(
+            model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        from ppgan.modules.init import init_weights
+        generator = ppgan.models.generators.RRDBNet(**params)
+        init_weights(generator)
+        if self.use_gan:
+            discriminator = ppgan.models.discriminators.VGGDiscriminator128(
+                in_channels=params['out_nc'], num_feat=64)
+            net = GANAdapter(
+                generators=[generator], discriminators=[discriminator])
+        else:
+            net = generator
+        return net
+
+    def default_loss(self):
+        if self.use_gan:
+            return {
+                'pixel': res_losses.L1Loss(loss_weight=0.01),
+                'perceptual': res_losses.PerceptualLoss(
+                    layer_weights={'34': 1.0},
+                    perceptual_weight=1.0,
+                    style_weight=0.0,
+                    norm_img=False),
+                'gan': res_losses.GANLoss(
+                    gan_mode='vanilla', loss_weight=0.005)
+            }
+        else:
+            return res_losses.L1Loss()
+
+    def default_optimizer(self, parameters, *args, **kwargs):
+        if self.use_gan:
+            optim_g = super(ESRGAN, self).default_optimizer(
+                parameters['params_g'][0], *args, **kwargs)
+            optim_d = super(ESRGAN, self).default_optimizer(
+                parameters['params_d'][0], *args, **kwargs)
+            return OptimizerAdapter(optim_g, optim_d)
+        else:
+            return super(ESRGAN, self).default_optimizer(parameters, *args,
+                                                         **kwargs)
+
+    def run_gan(self, net, inputs, mode, gan_mode='forward_g'):
+        if mode != 'train':
+            raise ValueError("`mode` is not 'train'.")
+        outputs = OrderedDict()
+        if gan_mode == 'forward_g':
+            loss_g = 0
+            g_pred = net.generator(inputs[0])
+            loss_pix = self.losses['pixel'](g_pred, inputs[1])
+            loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1])
+            loss_g += loss_pix
+            if loss_perc is not None:
+                loss_g += loss_perc
+            if loss_sty is not None:
+                loss_g += loss_sty
+            self._set_requires_grad(net.discriminator, False)
+            real_d_pred = net.discriminator(inputs[1]).detach()
+            fake_g_pred = net.discriminator(g_pred)
+            loss_g_real = self.losses['gan'](
+                real_d_pred - paddle.mean(fake_g_pred), False,
+                is_disc=False) * 0.5
+            loss_g_fake = self.losses['gan'](
+                fake_g_pred - paddle.mean(real_d_pred), True,
+                is_disc=False) * 0.5
+            loss_g_gan = loss_g_real + loss_g_fake
+            outputs['g_pred'] = g_pred.detach()
+            outputs['loss_g_pps'] = loss_g
+            outputs['loss_g_gan'] = loss_g_gan
+        elif gan_mode == 'forward_d':
+            self._set_requires_grad(net.discriminator, True)
+            # Real
+            fake_d_pred = net.discriminator(inputs[0]).detach()
+            real_d_pred = net.discriminator(inputs[1])
+            loss_d_real = self.losses['gan'](
+                real_d_pred - paddle.mean(fake_d_pred), True,
+                is_disc=True) * 0.5
+            # Fake
+            fake_d_pred = net.discriminator(inputs[0].detach())
+            loss_d_fake = self.losses['gan'](
+                fake_d_pred - paddle.mean(real_d_pred.detach()),
+                False,
+                is_disc=True) * 0.5
+            outputs['loss_d'] = loss_d_real + loss_d_fake
+        else:
+            raise ValueError("Invalid `gan_mode`!")
+        return outputs
+
+    def train_step(self, step, data, net):
+        if self.use_gan:
+            optim_g, optim_d = self.optimizer
+
+            outputs = self.run_gan(
+                net, data, mode='train', gan_mode='forward_g')
+            optim_g.clear_grad()
+            (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
+            optim_g.step()
+
+            outputs.update(
+                self.run_gan(
+                    net, (outputs['g_pred'], data[1]),
+                    mode='train',
+                    gan_mode='forward_d'))
+            optim_d.clear_grad()
+            outputs['loss_d'].backward()
+            optim_d.step()
+
+            outputs['loss'] = outputs['loss_g_pps'] + outputs[
+                'loss_g_gan'] + outputs['loss_d']
+
+            return {
+                'loss': outputs['loss'],
+                'loss_g_pps': outputs['loss_g_pps'],
+                'loss_g_gan': outputs['loss_g_gan'],
+                'loss_d': outputs['loss_d']
+            }
+        else:
+            return super(ESRGAN, self).train_step(step, data, net)
+
+    def _set_requires_grad(self, net, requires_grad):
+        for p in net.parameters():
+            p.trainable = requires_grad
+
+
+class RCAN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 n_resgroups=10,
+                 n_resblocks=20,
+                 n_feats=64,
+                 n_colors=3,
+                 rgb_range=1.0,
+                 kernel_size=3,
+                 reduction=16,
+                 **params):
+        params.update({
+            'n_resgroups': n_resgroups,
+            'n_resblocks': n_resblocks,
+            'n_feats': n_feats,
+            'n_colors': n_colors,
+            'rgb_range': rgb_range,
+            'kernel_size': kernel_size,
+            'reduction': reduction
+        })
+        super(RCAN, self).__init__(
+            model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)

+ 17 - 12
paddlers/tasks/segmenter.py

@@ -33,6 +33,7 @@ from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
+from .utils.infer_nets import InferSegNet
 
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
 
@@ -64,11 +65,16 @@ class BaseSegmenter(BaseModel):
 
 
     def build_net(self, **params):
     def build_net(self, **params):
         # TODO: when using paddle.utils.unique_name.guard,
         # TODO: when using paddle.utils.unique_name.guard,
-        # DeepLabv3p and HRNet will raise a error
+        # DeepLabv3p and HRNet will raise an error.
         net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
         net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
             num_classes=self.num_classes, **params)
             num_classes=self.num_classes, **params)
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = InferSegNet(self.net)
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
             if self.test_transforms is not None:
@@ -472,7 +478,6 @@ class BaseSegmenter(BaseModel):
                     conf_mat_all.append(conf_mat)
                     conf_mat_all.append(conf_mat)
         class_iou, miou = ppseg.utils.metrics.mean_iou(
         class_iou, miou = ppseg.utils.metrics.mean_iou(
             intersect_area_all, pred_area_all, label_area_all)
             intersect_area_all, pred_area_all, label_area_all)
-        # TODO 确认是按oacc还是macc
         class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
         class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
                                                        pred_area_all)
                                                        pred_area_all)
         kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
         kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
@@ -504,13 +509,13 @@ class BaseSegmenter(BaseModel):
                 Defaults to None.
                 Defaults to None.
 
 
         Returns:
         Returns:
-            If `img_file` is a string or np.array, the result is a dict with key-value 
+            If `img_file` is a tuple of string or np.array, the result is a dict with 
-                pairs:
+                the following key-value pairs:
-                {"label map": `label map`, "score_map": `score map`}.
+                label_map (np.ndarray): Predicted label map (HW).
+                score_map (np.ndarray): Prediction score map (HWC).
+
             If `img_file` is a list, the result is a list composed of dicts with the 
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
+                above keys.
-                label_map (np.ndarray): the predicted label map (HW)
-                score_map (np.ndarray): the prediction score map (HWC)
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -750,11 +755,11 @@ class BaseSegmenter(BaseModel):
                 elif item[0] == 'padding':
                 elif item[0] == 'padding':
                     x, y = item[2]
                     x, y = item[2]
                     if isinstance(label_map, np.ndarray):
                     if isinstance(label_map, np.ndarray):
-                        label_map = label_map[..., y:y + h, x:x + w]
+                        label_map = label_map[y:y + h, x:x + w]
-                        score_map = score_map[..., y:y + h, x:x + w]
+                        score_map = score_map[y:y + h, x:x + w]
                     else:
                     else:
-                        label_map = label_map[:, :, y:y + h, x:x + w]
+                        label_map = label_map[:, y:y + h, x:x + w, :]
-                        score_map = score_map[:, :, y:y + h, x:x + w]
+                        score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
                 else:
                     pass
                     pass
             label_map = label_map.squeeze()
             label_map = label_map.squeeze()

+ 28 - 11
paddlers/tasks/utils/infer_nets.py

@@ -15,30 +15,36 @@
 import paddle
 import paddle
 
 
 
 
-class PostProcessor(paddle.nn.Layer):
+class SegPostProcessor(paddle.nn.Layer):
-    def __init__(self, model_type):
-        super(PostProcessor, self).__init__()
-        self.model_type = model_type
-
     def forward(self, net_outputs):
     def forward(self, net_outputs):
         # label_map [NHW], score_map [NHWC]
         # label_map [NHW], score_map [NHWC]
         logit = net_outputs[0]
         logit = net_outputs[0]
         outputs = paddle.argmax(logit, axis=1, keepdim=False, dtype='int32'), \
         outputs = paddle.argmax(logit, axis=1, keepdim=False, dtype='int32'), \
                     paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1])
                     paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1])
+        return outputs
+
+
+class ResPostProcessor(paddle.nn.Layer):
+    def __init__(self, out_key=None):
+        super(ResPostProcessor, self).__init__()
+        self.out_key = out_key
 
 
+    def forward(self, net_outputs):
+        if self.out_key is not None:
+            net_outputs = net_outputs[self.out_key]
+        outputs = paddle.transpose(net_outputs, perm=[0, 2, 3, 1])
         return outputs
         return outputs
 
 
 
 
-class InferNet(paddle.nn.Layer):
+class InferSegNet(paddle.nn.Layer):
-    def __init__(self, net, model_type):
+    def __init__(self, net):
-        super(InferNet, self).__init__()
+        super(InferSegNet, self).__init__()
         self.net = net
         self.net = net
-        self.postprocessor = PostProcessor(model_type)
+        self.postprocessor = SegPostProcessor()
 
 
     def forward(self, x):
     def forward(self, x):
         net_outputs = self.net(x)
         net_outputs = self.net(x)
         outputs = self.postprocessor(net_outputs)
         outputs = self.postprocessor(net_outputs)
-
         return outputs
         return outputs
 
 
 
 
@@ -46,10 +52,21 @@ class InferCDNet(paddle.nn.Layer):
     def __init__(self, net):
     def __init__(self, net):
         super(InferCDNet, self).__init__()
         super(InferCDNet, self).__init__()
         self.net = net
         self.net = net
-        self.postprocessor = PostProcessor('change_detector')
+        self.postprocessor = SegPostProcessor()
 
 
     def forward(self, x1, x2):
     def forward(self, x1, x2):
         net_outputs = self.net(x1, x2)
         net_outputs = self.net(x1, x2)
         outputs = self.postprocessor(net_outputs)
         outputs = self.postprocessor(net_outputs)
+        return outputs
+
+
+class InferResNet(paddle.nn.Layer):
+    def __init__(self, net, out_key=None):
+        super(InferResNet, self).__init__()
+        self.net = net
+        self.postprocessor = ResPostProcessor(out_key=out_key)
 
 
+    def forward(self, x):
+        net_outputs = self.net(x)
+        outputs = self.postprocessor(net_outputs)
         return outputs
         return outputs

+ 132 - 0
paddlers/tasks/utils/res_adapters.py

@@ -0,0 +1,132 @@
+from functools import wraps
+from inspect import isfunction, isgeneratorfunction, getmembers
+from collections.abc import Sequence
+from abc import ABC
+
+import paddle
+import paddle.nn as nn
+
+__all__ = ['GANAdapter', 'OptimizerAdapter']
+
+
+class _AttrDesc:
+    def __init__(self, key):
+        self.key = key
+
+    def __get__(self, instance, owner):
+        return tuple(getattr(ele, self.key) for ele in instance)
+
+    def __set__(self, instance, value):
+        for ele in instance:
+            setattr(ele, self.key, value)
+
+
+def _func_deco(cls, func_name):
+    @wraps(getattr(cls.__ducktype__, func_name))
+    def _wrapper(self, *args, **kwargs):
+        return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
+
+    return _wrapper
+
+
+def _generator_deco(cls, func_name):
+    @wraps(getattr(cls.__ducktype__, func_name))
+    def _wrapper(self, *args, **kwargs):
+        for ele in self:
+            yield from getattr(ele, func_name)(*args, **kwargs)
+
+    return _wrapper
+
+
+class Adapter(Sequence, ABC):
+    __ducktype__ = object
+    __ava__ = ()
+
+    def __init__(self, *args):
+        if not all(map(self._check, args)):
+            raise TypeError("Please check the input type.")
+        self._seq = tuple(args)
+
+    def __getitem__(self, key):
+        return self._seq[key]
+
+    def __len__(self):
+        return len(self._seq)
+
+    def __repr__(self):
+        return repr(self._seq)
+
+    @classmethod
+    def _check(cls, obj):
+        for attr in cls.__ava__:
+            try:
+                getattr(obj, attr)
+                # TODO: Check function signature
+            except AttributeError:
+                return False
+        return True
+
+
+def make_adapter(cls):
+    members = dict(getmembers(cls.__ducktype__))
+    for k in cls.__ava__:
+        if hasattr(cls, k):
+            continue
+        if k in members:
+            v = members[k]
+            if isgeneratorfunction(v):
+                setattr(cls, k, _generator_deco(cls, k))
+            elif isfunction(v):
+                setattr(cls, k, _func_deco(cls, k))
+            else:
+                setattr(cls, k, _AttrDesc(k))
+    return cls
+
+
+class GANAdapter(nn.Layer):
+    __ducktype__ = nn.Layer
+    __ava__ = ('state_dict', 'set_state_dict', 'train', 'eval')
+
+    def __init__(self, generators, discriminators):
+        super(GANAdapter, self).__init__()
+        self.generators = nn.LayerList(generators)
+        self.discriminators = nn.LayerList(discriminators)
+        self._m = [*generators, *discriminators]
+
+    def __len__(self):
+        return len(self._m)
+
+    def __getitem__(self, key):
+        return self._m[key]
+
+    def __contains__(self, m):
+        return m in self._m
+
+    def __repr__(self):
+        return repr(self._m)
+
+    @property
+    def generator(self):
+        return self.generators[0]
+
+    @property
+    def discriminator(self):
+        return self.discriminators[0]
+
+
+Adapter.register(GANAdapter)
+
+
+@make_adapter
+class OptimizerAdapter(Adapter):
+    __ducktype__ = paddle.optimizer.Optimizer
+    __ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr')
+
+    def set_state_dict(self, state_dicts):
+        # Special dispatching rule
+        for optim, state_dict in zip(self, state_dicts):
+            optim.set_state_dict(state_dict)
+
+    def get_lr(self):
+        # Return the lr of the first optimizer
+        return self[0].get_lr()

+ 4 - 0
paddlers/transforms/functions.py

@@ -638,3 +638,7 @@ def decode_seg_mask(mask_path):
     mask = np.asarray(Image.open(mask_path))
     mask = np.asarray(Image.open(mask_path))
     mask = mask.astype('int64')
     mask = mask.astype('int64')
     return mask
     return mask
+
+
+def calc_hr_shape(lr_shape, sr_factor):
+    return tuple(int(s * sr_factor) for s in lr_shape)

+ 102 - 15
paddlers/transforms/operators.py

@@ -35,7 +35,7 @@ from .functions import (
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
     vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
     vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
     resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
     resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
-    img_flip, img_simple_rotate, decode_seg_mask)
+    img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape)
 
 
 __all__ = [
 __all__ = [
     "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
     "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
@@ -44,7 +44,7 @@ __all__ = [
     "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
     "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
     "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
     "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
     "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
     "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
-    "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask"
+    "ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask"
 ]
 ]
 
 
 interp_dict = {
 interp_dict = {
@@ -154,6 +154,8 @@ class Transform(object):
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
+        if 'target' in sample:
+            sample['target'] = self.apply_im(sample['target'])
 
 
         return sample
         return sample
 
 
@@ -336,6 +338,14 @@ class DecodeImg(Transform):
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
             # TODO: check the shape of auxiliary masks
             # TODO: check the shape of auxiliary masks
 
 
+        if 'target' in sample:
+            if self.read_geo_info:
+                target, geo_info_dict = self.apply_im(sample['target'])
+                sample['target'] = target
+                sample['geo_info_dict_tar'] = geo_info_dict
+            else:
+                sample['target'] = self.apply_im(sample['target'])
+
         sample['im_shape'] = np.array(
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
@@ -457,6 +467,17 @@ class Resize(Transform):
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
                 sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
+        if 'target' in sample:
+            if 'sr_factor' in sample:
+                # For SR tasks
+                sample['target'] = self.apply_im(
+                    sample['target'], interp,
+                    calc_hr_shape(target_size, sample['sr_factor']))
+            else:
+                # For non-SR tasks
+                sample['target'] = self.apply_im(sample['target'], interp,
+                                                 target_size)
+
         sample['im_shape'] = np.asarray(
         sample['im_shape'] = np.asarray(
             sample['image'].shape[:2], dtype=np.float32)
             sample['image'].shape[:2], dtype=np.float32)
         if 'scale_factor' in sample:
         if 'scale_factor' in sample:
@@ -730,6 +751,9 @@ class RandomFlipOrRotate(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                                                     True)
                                                     True)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'], mode_id,
+                                                 True)
         elif p_m < self.probs[1]:
         elif p_m < self.probs[1]:
             mode_p = random.random()
             mode_p = random.random()
             mode_id = self.judge_probs_range(mode_p, self.probsr)
             mode_id = self.judge_probs_range(mode_p, self.probsr)
@@ -750,6 +774,9 @@ class RandomFlipOrRotate(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                                                     False)
                                                     False)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'], mode_id,
+                                                 False)
 
 
         return sample
         return sample
 
 
@@ -809,6 +836,8 @@ class RandomHorizontalFlip(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                                                     im_w)
                                                     im_w)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -867,6 +896,8 @@ class RandomVerticalFlip(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                                                     im_w)
                                                     im_w)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -884,15 +915,18 @@ class Normalize(Transform):
             image(s). Defaults to [0.229, 0.224, 0.225].
             image(s). Defaults to [0.229, 0.224, 0.225].
         min_val (list[float] | tuple[float], optional): Minimum value of input 
         min_val (list[float] | tuple[float], optional): Minimum value of input 
             image(s). If None, use 0 for all channels. Defaults to None.
             image(s). If None, use 0 for all channels. Defaults to None.
-        max_val (list[float] | tuple[float], optional): Max value of input image(s). 
+        max_val (list[float] | tuple[float], optional): Maximum value of input 
-            If None, use 255. for all channels. Defaults to None.
+            image(s). If None, use 255. for all channels. Defaults to None.
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
     """
 
 
     def __init__(self,
     def __init__(self,
                  mean=[0.485, 0.456, 0.406],
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225],
                  std=[0.229, 0.224, 0.225],
                  min_val=None,
                  min_val=None,
-                 max_val=None):
+                 max_val=None,
+                 apply_to_tar=True):
         super(Normalize, self).__init__()
         super(Normalize, self).__init__()
         channel = len(mean)
         channel = len(mean)
         if min_val is None:
         if min_val is None:
@@ -914,6 +948,7 @@ class Normalize(Transform):
         self.std = std
         self.std = std
         self.min_val = min_val
         self.min_val = min_val
         self.max_val = max_val
         self.max_val = max_val
+        self.apply_to_tar = apply_to_tar
 
 
     def apply_im(self, image):
     def apply_im(self, image):
         image = image.astype(np.float32)
         image = image.astype(np.float32)
@@ -927,6 +962,8 @@ class Normalize(Transform):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
 
 
         return sample
         return sample
 
 
@@ -964,6 +1001,8 @@ class CenterCrop(Transform):
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
+        if 'target' in sample:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -1165,6 +1204,14 @@ class RandomCrop(Transform):
                         self.apply_mask, crop=crop_box),
                         self.apply_mask, crop=crop_box),
                         sample['aux_masks']))
                         sample['aux_masks']))
 
 
+            if 'target' in sample:
+                if 'sr_factor' in sample:
+                    sample['target'] = self.apply_im(
+                        sample['target'],
+                        calc_hr_shape(crop_box, sample['sr_factor']))
+                else:
+                    sample['target'] = self.apply_im(sample['image'], crop_box)
+
         if self.crop_size is not None:
         if self.crop_size is not None:
             sample = Resize(self.crop_size)(sample)
             sample = Resize(self.crop_size)(sample)
 
 
@@ -1266,6 +1313,7 @@ class Pad(Transform):
             pad_mode (int, optional): Pad mode. Currently only four modes are supported:
             pad_mode (int, optional): Pad mode. Currently only four modes are supported:
                 [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
                 [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
                 If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
                 If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
+            offsets (list[int]|None, optional): Padding offsets. Defaults to None.
             im_padding_value (list[float] | tuple[float]): RGB value of padded area. 
             im_padding_value (list[float] | tuple[float]): RGB value of padded area. 
                 Defaults to (127.5, 127.5, 127.5).
                 Defaults to (127.5, 127.5, 127.5).
             label_padding_value (int, optional): Filling value for the mask. 
             label_padding_value (int, optional): Filling value for the mask. 
@@ -1332,6 +1380,17 @@ class Pad(Transform):
                     expand_rle(segm, x, y, height, width, h, w))
                     expand_rle(segm, x, y, height, width, h, w))
         return expanded_segms
         return expanded_segms
 
 
+    def _get_offsets(self, im_h, im_w, h, w):
+        if self.pad_mode == -1:
+            offsets = self.offsets
+        elif self.pad_mode == 0:
+            offsets = [0, 0]
+        elif self.pad_mode == 1:
+            offsets = [(w - im_w) // 2, (h - im_h) // 2]
+        else:
+            offsets = [w - im_w, h - im_h]
+        return offsets
+
     def apply(self, sample):
     def apply(self, sample):
         im_h, im_w = sample['image'].shape[:2]
         im_h, im_w = sample['image'].shape[:2]
         if self.target_size:
         if self.target_size:
@@ -1349,14 +1408,7 @@ class Pad(Transform):
         if h == im_h and w == im_w:
         if h == im_h and w == im_w:
             return sample
             return sample
 
 
-        if self.pad_mode == -1:
+        offsets = self._get_offsets(im_h, im_w, h, w)
-            offsets = self.offsets
-        elif self.pad_mode == 0:
-            offsets = [0, 0]
-        elif self.pad_mode == 1:
-            offsets = [(w - im_w) // 2, (h - im_h) // 2]
-        else:
-            offsets = [w - im_w, h - im_h]
 
 
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         if 'image2' in sample:
         if 'image2' in sample:
@@ -1373,6 +1425,16 @@ class Pad(Transform):
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
                 sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
+        if 'target' in sample:
+            if 'sr_factor' in sample:
+                hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
+                hr_offsets = self._get_offsets(*sample['target'].shape[:2],
+                                               *hr_shape)
+                sample['target'] = self.apply_im(sample['target'], hr_offsets,
+                                                 hr_shape)
+            else:
+                sample['target'] = self.apply_im(sample['target'], offsets,
+                                                 (h, w))
         return sample
         return sample
 
 
 
 
@@ -1688,15 +1750,18 @@ class ReduceDim(Transform):
 
 
     Args: 
     Args: 
         joblib_path (str): Path of *.joblib file of PCA.
         joblib_path (str): Path of *.joblib file of PCA.
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
     """
 
 
-    def __init__(self, joblib_path):
+    def __init__(self, joblib_path, apply_to_tar=True):
         super(ReduceDim, self).__init__()
         super(ReduceDim, self).__init__()
         ext = joblib_path.split(".")[-1]
         ext = joblib_path.split(".")[-1]
         if ext != "joblib":
         if ext != "joblib":
             raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
             raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
                 ext))
                 ext))
         self.pca = load(joblib_path)
         self.pca = load(joblib_path)
+        self.apply_to_tar = apply_to_tar
 
 
     def apply_im(self, image):
     def apply_im(self, image):
         H, W, C = image.shape
         H, W, C = image.shape
@@ -1709,6 +1774,8 @@ class ReduceDim(Transform):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -1719,11 +1786,14 @@ class SelectBand(Transform):
     Args: 
     Args: 
         band_list (list, optional): Bands to select (band index starts from 1). 
         band_list (list, optional): Bands to select (band index starts from 1). 
             Defaults to [1, 2, 3].
             Defaults to [1, 2, 3].
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
     """
 
 
-    def __init__(self, band_list=[1, 2, 3]):
+    def __init__(self, band_list=[1, 2, 3], apply_to_tar=True):
         super(SelectBand, self).__init__()
         super(SelectBand, self).__init__()
         self.band_list = band_list
         self.band_list = band_list
+        self.apply_to_tar = apply_to_tar
 
 
     def apply_im(self, image):
     def apply_im(self, image):
         image = select_bands(image, self.band_list)
         image = select_bands(image, self.band_list)
@@ -1733,6 +1803,8 @@ class SelectBand(Transform):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -1820,6 +1892,8 @@ class _Permute(Transform):
         sample['image'] = permute(sample['image'], False)
         sample['image'] = permute(sample['image'], False)
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = permute(sample['image2'], False)
             sample['image2'] = permute(sample['image2'], False)
+        if 'target' in sample:
+            sample['target'] = permute(sample['target'], False)
         return sample
         return sample
 
 
 
 
@@ -1915,3 +1989,16 @@ class ArrangeDetector(Arrange):
         if self.mode == 'eval' and 'gt_poly' in sample:
         if self.mode == 'eval' and 'gt_poly' in sample:
             del sample['gt_poly']
             del sample['gt_poly']
         return sample
         return sample
+
+
+class ArrangeRestorer(Arrange):
+    def apply(self, sample):
+        if 'target' in sample:
+            target = permute(sample['target'], False)
+        image = permute(sample['image'], False)
+        if self.mode == 'train':
+            return image, target
+        if self.mode == 'eval':
+            return image, target
+        if self.mode == 'test':
+            return image,

+ 1 - 1
paddlers/utils/__init__.py

@@ -16,7 +16,7 @@ from . import logging
 from . import utils
 from . import utils
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
                     EarlyStop, norm_path, is_pic, MyEncoder, DisablePrint,
                     EarlyStop, norm_path, is_pic, MyEncoder, DisablePrint,
-                    Timer)
+                    Timer, to_data_parallel, scheduler_step)
 from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress
 from .download import download_and_decompress, decompress

+ 32 - 1
paddlers/utils/utils.py

@@ -20,11 +20,12 @@ import math
 import imghdr
 import imghdr
 import chardet
 import chardet
 import json
 import json
+import platform
 
 
 import numpy as np
 import numpy as np
+import paddle
 
 
 from . import logging
 from . import logging
-import platform
 import paddlers
 import paddlers
 
 
 
 
@@ -237,3 +238,33 @@ class Timer(Times):
         self.postprocess_time_s.reset()
         self.postprocess_time_s.reset()
         self.img_num = 0
         self.img_num = 0
         self.repeats = 0
         self.repeats = 0
+
+
+def to_data_parallel(layers, *args, **kwargs):
+    from paddlers.tasks.utils.res_adapters import GANAdapter
+    if isinstance(layers, GANAdapter):
+        # Inplace modification for efficiency
+        layers.generators = [
+            paddle.DataParallel(g, *args, **kwargs) for g in layers.generators
+        ]
+        layers.discriminators = [
+            paddle.DataParallel(d, *args, **kwargs)
+            for d in layers.discriminators
+        ]
+    else:
+        layers = paddle.DataParallel(layers, *args, **kwargs)
+    return layers
+
+
+def scheduler_step(optimizer):
+    from paddlers.tasks.utils.res_adapters import OptimizerAdapter
+    if not isinstance(optimizer, OptimizerAdapter):
+        optimizer = [optimizer]
+    for optim in optimizer:
+        if isinstance(optim._learning_rate, paddle.optimizer.lr.LRScheduler):
+            # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
+            if isinstance(optim._learning_rate,
+                          paddle.optimizer.lr.ReduceOnPlateau):
+                optim._learning_rate.step(loss.item())
+            else:
+                optim._learning_rate.step()

+ 28 - 26
tests/data/data_utils.py

@@ -14,7 +14,6 @@
 
 
 import os.path as osp
 import os.path as osp
 import re
 import re
-import imghdr
 import platform
 import platform
 from collections import OrderedDict
 from collections import OrderedDict
 from functools import partial, wraps
 from functools import partial, wraps
@@ -34,20 +33,6 @@ def norm_path(path):
     return path
     return path
 
 
 
 
-def is_pic(im_path):
-    valid_suffix = [
-        'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy'
-    ]
-    suffix = im_path.split('.')[-1]
-    if suffix in valid_suffix:
-        return True
-    im_format = imghdr.what(im_path)
-    _, ext = osp.splitext(im_path)
-    if im_format == 'tiff' or ext == '.img':
-        return True
-    return False
-
-
 def get_full_path(p, prefix=''):
 def get_full_path(p, prefix=''):
     p = norm_path(p)
     p = norm_path(p)
     return osp.join(prefix, p)
     return osp.join(prefix, p)
@@ -323,15 +308,34 @@ class ConstrDetSample(ConstrSample):
         return samples
         return samples
 
 
 
 
-def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
+class ConstrResSample(ConstrSample):
+    def __init__(self, prefix, label_list, sr_factor=None):
+        super().__init__(prefix, label_list)
+        self.sr_factor = sr_factor
+
+    def __call__(self, src_path, tar_path):
+        sample = {
+            'image': self.get_full_path(src_path),
+            'target': self.get_full_path(tar_path)
+        }
+        if self.sr_factor is not None:
+            sample['sr_factor'] = self.sr_factor
+        return sample
+
+
+def build_input_from_file(file_list,
+                          prefix='',
+                          task='auto',
+                          label_list=None,
+                          **kwargs):
     """
     """
     Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
     Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
 
 
     Args:
     Args:
-        file_list (str): Path of file_list.
+        file_list (str): Path of file list.
         prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
         prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
-        task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input. 
+        task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', 'res', and 'auto'. When `task` is set to 'auto', 
-            Default: 'auto'.
+            automatically determine the task based on the input. Default: 'auto'.
         label_list (str|None, optional): Path of label_list. Default: None.
         label_list (str|None, optional): Path of label_list. Default: None.
 
 
     Returns:
     Returns:
@@ -339,22 +343,21 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
     """
     """
 
 
     def _determine_task(parts):
     def _determine_task(parts):
+        task = 'unknown'
         if len(parts) in (3, 5):
         if len(parts) in (3, 5):
             task = 'cd'
             task = 'cd'
         elif len(parts) == 2:
         elif len(parts) == 2:
             if parts[1].isdigit():
             if parts[1].isdigit():
                 task = 'clas'
                 task = 'clas'
-            elif is_pic(osp.join(prefix, parts[1])):
+            elif parts[1].endswith('.xml'):
-                task = 'seg'
-            else:
                 task = 'det'
                 task = 'det'
-        else:
+        if task == 'unknown':
             raise RuntimeError(
             raise RuntimeError(
                 "Cannot automatically determine the task type. Please specify `task` manually."
                 "Cannot automatically determine the task type. Please specify `task` manually."
             )
             )
         return task
         return task
 
 
-    if task not in ('seg', 'det', 'cd', 'clas', 'auto'):
+    if task not in ('seg', 'det', 'cd', 'clas', 'res', 'auto'):
         raise ValueError("Invalid value of `task`")
         raise ValueError("Invalid value of `task`")
 
 
     samples = []
     samples = []
@@ -366,9 +369,8 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
             if task == 'auto':
             if task == 'auto':
                 task = _determine_task(parts)
                 task = _determine_task(parts)
             if ctor is None:
             if ctor is None:
-                # Select and build sample constructor
                 ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
                 ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
-                ctor = ctor_class(prefix, label_list)
+                ctor = ctor_class(prefix, label_list, **kwargs)
             sample = ctor(*parts)
             sample = ctor(*parts)
             if isinstance(sample, list):
             if isinstance(sample, list):
                 samples.extend(sample)
                 samples.extend(sample)

+ 60 - 7
tests/deploy/test_predictor.py

@@ -24,7 +24,7 @@ from testing_utils import CommonTest, run_script
 
 
 __all__ = [
 __all__ = [
     'TestCDPredictor', 'TestClasPredictor', 'TestDetPredictor',
     'TestCDPredictor', 'TestClasPredictor', 'TestDetPredictor',
-    'TestSegPredictor'
+    'TestResPredictor', 'TestSegPredictor'
 ]
 ]
 
 
 
 
@@ -105,7 +105,7 @@ class TestPredictor(CommonTest):
                     dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
                     dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
 
 
 
 
-@TestPredictor.add_tests
+# @TestPredictor.add_tests
 class TestCDPredictor(TestPredictor):
 class TestCDPredictor(TestPredictor):
     MODULE = pdrs.tasks.change_detector
     MODULE = pdrs.tasks.change_detector
     TRAINER_NAME_TO_EXPORT_OPTS = {
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -177,7 +177,7 @@ class TestCDPredictor(TestPredictor):
         self.assertEqual(len(out_multi_array_t), num_inputs)
         self.assertEqual(len(out_multi_array_t), num_inputs)
 
 
 
 
-@TestPredictor.add_tests
+# @TestPredictor.add_tests
 class TestClasPredictor(TestPredictor):
 class TestClasPredictor(TestPredictor):
     MODULE = pdrs.tasks.classifier
     MODULE = pdrs.tasks.classifier
     TRAINER_NAME_TO_EXPORT_OPTS = {
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -185,7 +185,7 @@ class TestClasPredictor(TestPredictor):
     }
     }
 
 
     def check_predictor(self, predictor, trainer):
     def check_predictor(self, predictor, trainer):
-        single_input = "data/ssmt/optical_t1.bmp"
+        single_input = "data/ssst/optical.bmp"
         num_inputs = 2
         num_inputs = 2
         transforms = pdrs.transforms.Compose([
         transforms = pdrs.transforms.Compose([
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
@@ -242,7 +242,7 @@ class TestClasPredictor(TestPredictor):
         self.check_dict_equal(out_multi_array_p, out_multi_array_t)
         self.check_dict_equal(out_multi_array_p, out_multi_array_t)
 
 
 
 
-@TestPredictor.add_tests
+# @TestPredictor.add_tests
 class TestDetPredictor(TestPredictor):
 class TestDetPredictor(TestPredictor):
     MODULE = pdrs.tasks.object_detector
     MODULE = pdrs.tasks.object_detector
     TRAINER_NAME_TO_EXPORT_OPTS = {
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -253,7 +253,7 @@ class TestDetPredictor(TestPredictor):
         # For detection tasks, do NOT ensure the consistence of bboxes.
         # For detection tasks, do NOT ensure the consistence of bboxes.
         # This is because the coordinates of bboxes were observed to be very sensitive to numeric errors, 
         # This is because the coordinates of bboxes were observed to be very sensitive to numeric errors, 
         # given that the network is (partially?) randomly initialized.
         # given that the network is (partially?) randomly initialized.
-        single_input = "data/ssmt/optical_t1.bmp"
+        single_input = "data/ssst/optical.bmp"
         num_inputs = 2
         num_inputs = 2
         transforms = pdrs.transforms.Compose([
         transforms = pdrs.transforms.Compose([
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
@@ -303,6 +303,59 @@ class TestDetPredictor(TestPredictor):
 
 
 
 
 @TestPredictor.add_tests
 @TestPredictor.add_tests
+class TestResPredictor(TestPredictor):
+    MODULE = pdrs.tasks.restorer
+
+    def check_predictor(self, predictor, trainer):
+        # For restoration tasks, do NOT ensure the consistence of numeric values, 
+        # because the output is of uint8 type.
+        single_input = "data/ssst/optical.bmp"
+        num_inputs = 2
+        transforms = pdrs.transforms.Compose([
+            pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
+            pdrs.transforms.ArrangeRestorer('test')
+        ])
+
+        # Single input (file path)
+        input_ = single_input
+        predictor.predict(input_, transforms=transforms)
+        trainer.predict(input_, transforms=transforms)
+        out_single_file_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_p), 1)
+        out_single_file_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_t), 1)
+
+        # Single input (ndarray)
+        input_ = decode_image(
+            single_input, to_rgb=False)  # Reuse the name `input_`
+        predictor.predict(input_, transforms=transforms)
+        trainer.predict(input_, transforms=transforms)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_t), 1)
+
+        # Multiple inputs (file paths)
+        input_ = [single_input] * num_inputs  # Reuse the name `input_`
+        out_multi_file_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), num_inputs)
+        out_multi_file_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_t), num_inputs)
+
+        # Multiple inputs (ndarrays)
+        input_ = [decode_image(
+            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+        out_multi_array_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), num_inputs)
+        out_multi_array_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_t), num_inputs)
+
+
+# @TestPredictor.add_tests
 class TestSegPredictor(TestPredictor):
 class TestSegPredictor(TestPredictor):
     MODULE = pdrs.tasks.segmenter
     MODULE = pdrs.tasks.segmenter
     TRAINER_NAME_TO_EXPORT_OPTS = {
     TRAINER_NAME_TO_EXPORT_OPTS = {
@@ -310,7 +363,7 @@ class TestSegPredictor(TestPredictor):
     }
     }
 
 
     def check_predictor(self, predictor, trainer):
     def check_predictor(self, predictor, trainer):
-        single_input = "data/ssmt/optical_t1.bmp"
+        single_input = "data/ssst/optical.bmp"
         num_inputs = 2
         num_inputs = 2
         transforms = pdrs.transforms.Compose([
         transforms = pdrs.transforms.Compose([
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
             pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),

+ 1 - 3
tests/rs_models/test_cd_models.py

@@ -33,9 +33,7 @@ class TestCDModel(TestModel):
         self.check_output_equal(len(output), len(target))
         self.check_output_equal(len(output), len(target))
         for o, t in zip(output, target):
         for o, t in zip(output, target):
             o = o.numpy()
             o = o.numpy()
-            self.check_output_equal(o.shape[0], t.shape[0])
+            self.check_output_equal(o.shape, t.shape)
-            self.check_output_equal(len(o.shape), 4)
-            self.check_output_equal(o.shape[2:], t.shape[2:])
 
 
     def set_inputs(self):
     def set_inputs(self):
         if self.EF_MODE == 'Concat':
         if self.EF_MODE == 'Concat':

+ 3 - 0
tests/rs_models/test_det_models.py

@@ -32,3 +32,6 @@ class TestDetModel(TestModel):
 
 
     def set_inputs(self):
     def set_inputs(self):
         self.inputs = cycle([self.get_randn_tensor(3)])
         self.inputs = cycle([self.get_randn_tensor(3)])
+
+    def set_targets(self):
+        self.targets = cycle([None])

+ 46 - 0
tests/rs_models/test_res_models.py

@@ -0,0 +1,46 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddlers
+from rs_models.test_model import TestModel
+
+__all__ = []
+
+
+class TestResModel(TestModel):
+    def check_output(self, output, target):
+        output = output.numpy()
+        self.check_output_equal(output.shape, target.shape)
+
+    def set_inputs(self):
+        def _gen_data(specs):
+            for spec in specs:
+                c = spec.get('in_channels', 3)
+                yield self.get_randn_tensor(c)
+
+        self.inputs = _gen_data(self.specs)
+
+    def set_targets(self):
+        def _gen_data(specs):
+            for spec in specs:
+                # XXX: Hard coding
+                if 'out_channels' in spec:
+                    c = spec['out_channels']
+                elif 'in_channels' in spec:
+                    c = spec['in_channels']
+                else:
+                    c = 3
+                yield [self.get_zeros_array(c)]
+
+        self.targets = _gen_data(self.specs)

+ 5 - 3
tests/rs_models/test_seg_models.py

@@ -26,9 +26,7 @@ class TestSegModel(TestModel):
         self.check_output_equal(len(output), len(target))
         self.check_output_equal(len(output), len(target))
         for o, t in zip(output, target):
         for o, t in zip(output, target):
             o = o.numpy()
             o = o.numpy()
-            self.check_output_equal(o.shape[0], t.shape[0])
+            self.check_output_equal(o.shape, t.shape)
-            self.check_output_equal(len(o.shape), 4)
-            self.check_output_equal(o.shape[2:], t.shape[2:])
 
 
     def set_inputs(self):
     def set_inputs(self):
         def _gen_data(specs):
         def _gen_data(specs):
@@ -54,3 +52,7 @@ class TestFarSegModel(TestSegModel):
         self.specs = [
         self.specs = [
             dict(), dict(num_classes=20), dict(encoder_pretrained=False)
             dict(), dict(num_classes=20), dict(encoder_pretrained=False)
         ]
         ]
+
+    def set_targets(self):
+        self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(20)],
+                        [self.get_zeros_array(16)]]

+ 43 - 0
tests/transforms/test_operators.py

@@ -164,12 +164,15 @@ class TestTransform(CpuCommonTest):
                 prefix="./data/ssst"),
                 prefix="./data/ssst"),
             build_input_from_file(
             build_input_from_file(
                 "data/ssst/test_optical_seg.txt",
                 "data/ssst/test_optical_seg.txt",
+                task='seg',
                 prefix="./data/ssst"),
                 prefix="./data/ssst"),
             build_input_from_file(
             build_input_from_file(
                 "data/ssst/test_sar_seg.txt",
                 "data/ssst/test_sar_seg.txt",
+                task='seg',
                 prefix="./data/ssst"),
                 prefix="./data/ssst"),
             build_input_from_file(
             build_input_from_file(
                 "data/ssst/test_multispectral_seg.txt",
                 "data/ssst/test_multispectral_seg.txt",
+                task='seg',
                 prefix="./data/ssst"),
                 prefix="./data/ssst"),
             build_input_from_file(
             build_input_from_file(
                 "data/ssst/test_optical_det.txt",
                 "data/ssst/test_optical_det.txt",
@@ -185,7 +188,23 @@ class TestTransform(CpuCommonTest):
                 label_list="data/ssst/labels_det.txt"),
                 label_list="data/ssst/labels_det.txt"),
             build_input_from_file(
             build_input_from_file(
                 "data/ssst/test_det_coco.txt",
                 "data/ssst/test_det_coco.txt",
+                task='det',
                 prefix="./data/ssst"),
                 prefix="./data/ssst"),
+            build_input_from_file(
+                "data/ssst/test_optical_res.txt",
+                task='res',
+                prefix="./data/ssst",
+                sr_factor=4),
+            build_input_from_file(
+                "data/ssst/test_sar_res.txt",
+                task='res',
+                prefix="./data/ssst",
+                sr_factor=4),
+            build_input_from_file(
+                "data/ssst/test_multispectral_res.txt",
+                task='res',
+                prefix="./data/ssst",
+                sr_factor=4),
             build_input_from_file(
             build_input_from_file(
                 "data/ssmt/test_mixed_binary.txt",
                 "data/ssmt/test_mixed_binary.txt",
                 prefix="./data/ssmt"),
                 prefix="./data/ssmt"),
@@ -227,6 +246,8 @@ class TestTransform(CpuCommonTest):
                 self.aux_mask_values = [
                 self.aux_mask_values = [
                     set(aux_mask.ravel()) for aux_mask in sample['aux_masks']
                     set(aux_mask.ravel()) for aux_mask in sample['aux_masks']
                 ]
                 ]
+            if 'target' in sample:
+                self.target_shape = sample['target'].shape
             return sample
             return sample
 
 
         def _out_hook_not_keep_ratio(sample):
         def _out_hook_not_keep_ratio(sample):
@@ -243,6 +264,21 @@ class TestTransform(CpuCommonTest):
                 for aux_mask, amv in zip(sample['aux_masks'],
                 for aux_mask, amv in zip(sample['aux_masks'],
                                          self.aux_mask_values):
                                          self.aux_mask_values):
                     self.assertLessEqual(set(aux_mask.ravel()), amv)
                     self.assertLessEqual(set(aux_mask.ravel()), amv)
+            if 'target' in sample:
+                if 'sr_factor' in sample:
+                    self.check_output_equal(
+                        sample['target'].shape[:2],
+                        T.functions.calc_hr_shape(TARGET_SIZE,
+                                                  sample['sr_factor']))
+                else:
+                    self.check_output_equal(sample['target'].shape[:2],
+                                            TARGET_SIZE)
+                self.check_output_equal(
+                    sample['target'].shape[0] / self.target_shape[0],
+                    sample['image'].shape[0] / self.image_shape[0])
+                self.check_output_equal(
+                    sample['target'].shape[1] / self.target_shape[1],
+                    sample['image'].shape[1] / self.image_shape[1])
             # TODO: Test gt_bbox and gt_poly
             # TODO: Test gt_bbox and gt_poly
             return sample
             return sample
 
 
@@ -260,6 +296,13 @@ class TestTransform(CpuCommonTest):
                 for aux_mask, ori_aux_mask_shape in zip(sample['aux_masks'],
                 for aux_mask, ori_aux_mask_shape in zip(sample['aux_masks'],
                                                         self.aux_mask_shapes):
                                                         self.aux_mask_shapes):
                     __check_ratio(aux_mask.shape, ori_aux_mask_shape)
                     __check_ratio(aux_mask.shape, ori_aux_mask_shape)
+            if 'target' in sample:
+                self.check_output_equal(
+                    sample['target'].shape[0] / self.target_shape[0],
+                    sample['image'].shape[0] / self.image_shape[0])
+                self.check_output_equal(
+                    sample['target'].shape[1] / self.target_shape[1],
+                    sample['image'].shape[1] / self.image_shape[1])
             # TODO: Test gt_bbox and gt_poly
             # TODO: Test gt_bbox and gt_poly
             return sample
             return sample
 
 

+ 5 - 5
tutorials/train/README.md

@@ -9,17 +9,17 @@
 |change_detection/changeformer.py | 变化检测 | ChangeFormer |
 |change_detection/changeformer.py | 变化检测 | ChangeFormer |
 |change_detection/dsamnet.py | 变化检测 | DSAMNet |
 |change_detection/dsamnet.py | 变化检测 | DSAMNet |
 |change_detection/dsifn.py | 变化检测 | DSIFN |
 |change_detection/dsifn.py | 变化检测 | DSIFN |
-|change_detection/snunet.py | 变化检测 | SNUNet |
-|change_detection/stanet.py | 变化检测 | STANet |
 |change_detection/fc_ef.py | 变化检测 | FC-EF |
 |change_detection/fc_ef.py | 变化检测 | FC-EF |
 |change_detection/fc_siam_conc.py | 变化检测 | FC-Siam-conc |
 |change_detection/fc_siam_conc.py | 变化检测 | FC-Siam-conc |
 |change_detection/fc_siam_diff.py | 变化检测 | FC-Siam-diff |
 |change_detection/fc_siam_diff.py | 变化检测 | FC-Siam-diff |
+|change_detection/snunet.py | 变化检测 | SNUNet |
+|change_detection/stanet.py | 变化检测 | STANet |
 |classification/hrnet.py | 场景分类 | HRNet |
 |classification/hrnet.py | 场景分类 | HRNet |
 |classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
 |classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
 |classification/resnet50_vd.py | 场景分类 | ResNet50-vd |
 |classification/resnet50_vd.py | 场景分类 | ResNet50-vd |
-|image_restoration/drn.py | 超分辨率 | DRN |
+|image_restoration/drn.py | 图像复原 | DRN |
-|image_restoration/esrgan.py | 超分辨率 | ESRGAN |
+|image_restoration/esrgan.py | 图像复原 | ESRGAN |
-|image_restoration/lesrcnn.py | 超分辨率 | LESRCNN |
+|image_restoration/lesrcnn.py | 图像复原 | LESRCNN |
 |object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN |
 |object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN |
 |object_detection/ppyolo.py | 目标检测 | PP-YOLO |
 |object_detection/ppyolo.py | 目标检测 | PP-YOLO |
 |object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny |
 |object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny |

+ 1 - 1
tutorials/train/change_detection/changeformer.py

@@ -72,7 +72,7 @@ eval_dataset = pdrs.datasets.CDDataset(
     binarize_labels=True)
     binarize_labels=True)
 
 
 # 使用默认参数构建ChangeFormer模型
 # 使用默认参数构建ChangeFormer模型
-# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.cd.ChangeFormer()
 model = pdrs.tasks.cd.ChangeFormer()
 
 

+ 1 - 1
tutorials/train/classification/hrnet.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     num_workers=0,
     shuffle=False)
     shuffle=False)
 
 
-# 使用默认参数构建HRNet模型
+# 构建HRNet模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.HRNet_W18_C(num_classes=len(train_dataset.labels))
 model = pdrs.tasks.clas.HRNet_W18_C(num_classes=len(train_dataset.labels))

+ 1 - 1
tutorials/train/classification/mobilenetv3.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     num_workers=0,
     shuffle=False)
     shuffle=False)
 
 
-# 使用默认参数构建MobileNetV3模型
+# 构建MobileNetV3模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.MobileNetV3_small_x1_0(
 model = pdrs.tasks.clas.MobileNetV3_small_x1_0(

+ 1 - 1
tutorials/train/classification/resnet50_vd.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     num_workers=0,
     shuffle=False)
     shuffle=False)
 
 
-# 使用默认参数构建ResNet50-vd模型
+# 构建ResNet50-vd模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.ResNet50_vd(num_classes=len(train_dataset.labels))
 model = pdrs.tasks.clas.ResNet50_vd(num_classes=len(train_dataset.labels))

+ 3 - 0
tutorials/train/image_restoration/data/.gitignore

@@ -0,0 +1,3 @@
+*.zip
+*.tar.gz
+rssr/

+ 89 - 0
tutorials/train/image_restoration/drn.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+
+# 图像复原模型DRN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/drn/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 从输入影像中裁剪96x96大小的影像块
+    T.RandomCrop(crop_size=96),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=4)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=4)
+
+# 使用默认参数构建DRN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.DRN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 80
tutorials/train/image_restoration/drn_train.py

@@ -1,80 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddle
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'lqx2', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 192,
-        'scale': 4,
-        'scale_list': True
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [1.0, 1.0, 1.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [1.0, 1.0, 1.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 4
-batch_size = 8
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-train_dict = train_dataset()
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-model = pdrs.tasks.res.DRNet(
-    n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2)
-
-model.train(
-    total_iters=100000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    lr_rate=0.0001,
-    log=10)

+ 89 - 0
tutorials/train/image_restoration/esrgan.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+
+# 图像复原模型ESRGAN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/esrgan/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 从输入影像中裁剪32x32大小的影像块
+    T.RandomCrop(crop_size=32),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=4)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=4)
+
+# 使用默认参数构建ESRGAN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.ESRGAN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 80
tutorials/train/image_restoration/esrgan_train.py

@@ -1,80 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 128,
-        'scale': 4
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 6
-batch_size = 32
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
-# 若loss_type = 'pixel' 只使用像素损失
-model = pdrs.tasks.res.ESRGANet(loss_type='pixel')
-
-model.train(
-    total_iters=1000000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    log=100,
-    lr_rate=0.0001,
-    periods=[250000, 250000, 250000, 250000],
-    restart_weights=[1, 1, 1, 1])

+ 89 - 0
tutorials/train/image_restoration/lesrcnn.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+
+# 图像复原模型LESRCNN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/lesrcnn/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 从输入影像中裁剪32x32大小的影像块
+    T.RandomCrop(crop_size=32),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=4)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=4)
+
+# 使用默认参数构建LESRCNN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.LESRCNN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 78
tutorials/train/image_restoration/lesrcnn_train.py

@@ -1,78 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 192,
-        'scale': 4
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 4
-batch_size = 16
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-model = pdrs.tasks.res.LESRCNNet(scale=4, multi_scale=False, group=1)
-
-model.train(
-    total_iters=1000000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    log=100,
-    lr_rate=0.0001,
-    periods=[250000, 250000, 250000, 250000],
-    restart_weights=[1, 1, 1, 1])