Pārlūkot izejas kodu

[Refactor] Refactor image restoration

Bobholamovic 2 gadi atpakaļ
vecāks
revīzija
b86f9d435c
33 mainītis faili ar 1589 papildinājumiem un 1364 dzēšanām
  1. 1 1
      paddlers/datasets/__init__.py
  2. 5 5
      paddlers/datasets/cd_dataset.py
  3. 83 0
      paddlers/datasets/res_dataset.py
  4. 1 1
      paddlers/datasets/seg_dataset.py
  5. 0 99
      paddlers/datasets/sr_dataset.py
  6. 1 0
      paddlers/models/__init__.py
  7. 1 1
      paddlers/rs_models/res/__init__.py
  8. 0 26
      paddlers/rs_models/res/generators/builder.py
  9. 11 15
      paddlers/rs_models/res/generators/rcan.py
  10. 0 106
      paddlers/rs_models/res/rcan_model.py
  11. 1 1
      paddlers/tasks/__init__.py
  12. 24 25
      paddlers/tasks/base.py
  13. 14 7
      paddlers/tasks/change_detector.py
  14. 15 11
      paddlers/tasks/classifier.py
  15. 0 786
      paddlers/tasks/image_restorer.py
  16. 16 15
      paddlers/tasks/object_detector.py
  17. 818 0
      paddlers/tasks/restorer.py
  18. 13 8
      paddlers/tasks/segmenter.py
  19. 128 0
      paddlers/tasks/utils/res_adapters.py
  20. 4 0
      paddlers/transforms/functions.py
  21. 99 13
      paddlers/transforms/operators.py
  22. 4 3
      tutorials/train/README.md
  23. 1 1
      tutorials/train/classification/hrnet.py
  24. 1 1
      tutorials/train/classification/mobilenetv3.py
  25. 1 1
      tutorials/train/classification/resnet50_vd.py
  26. 3 0
      tutorials/train/image_restoration/data/.gitignore
  27. 86 0
      tutorials/train/image_restoration/drn.py
  28. 0 80
      tutorials/train/image_restoration/drn_train.py
  29. 86 0
      tutorials/train/image_restoration/esrgan.py
  30. 0 80
      tutorials/train/image_restoration/esrgan_train.py
  31. 86 0
      tutorials/train/image_restoration/lesrcnn.py
  32. 0 78
      tutorials/train/image_restoration/lesrcnn_train.py
  33. 86 0
      tutorials/train/image_restoration/rcan.py

+ 1 - 1
paddlers/datasets/__init__.py

@@ -17,4 +17,4 @@ from .coco import COCODetDataset
 from .seg_dataset import SegDataset
 from .seg_dataset import SegDataset
 from .cd_dataset import CDDataset
 from .cd_dataset import CDDataset
 from .clas_dataset import ClasDataset
 from .clas_dataset import ClasDataset
-from .sr_dataset import SRdataset, ComposeTrans
+from .res_dataset import ResDataset

+ 5 - 5
paddlers/datasets/cd_dataset.py

@@ -95,23 +95,23 @@ class CDDataset(BaseDataset):
                                      full_path_label))):
                                      full_path_label))):
                     continue
                     continue
                 if not osp.exists(full_path_im_t1):
                 if not osp.exists(full_path_im_t1):
-                    raise IOError('Image file {} does not exist!'.format(
+                    raise IOError("Image file {} does not exist!".format(
                         full_path_im_t1))
                         full_path_im_t1))
                 if not osp.exists(full_path_im_t2):
                 if not osp.exists(full_path_im_t2):
-                    raise IOError('Image file {} does not exist!'.format(
+                    raise IOError("Image file {} does not exist!".format(
                         full_path_im_t2))
                         full_path_im_t2))
                 if not osp.exists(full_path_label):
                 if not osp.exists(full_path_label):
-                    raise IOError('Label file {} does not exist!'.format(
+                    raise IOError("Label file {} does not exist!".format(
                         full_path_label))
                         full_path_label))
 
 
                 if with_seg_labels:
                 if with_seg_labels:
                     full_path_seg_label_t1 = osp.join(data_dir, items[3])
                     full_path_seg_label_t1 = osp.join(data_dir, items[3])
                     full_path_seg_label_t2 = osp.join(data_dir, items[4])
                     full_path_seg_label_t2 = osp.join(data_dir, items[4])
                     if not osp.exists(full_path_seg_label_t1):
                     if not osp.exists(full_path_seg_label_t1):
-                        raise IOError('Label file {} does not exist!'.format(
+                        raise IOError("Label file {} does not exist!".format(
                             full_path_seg_label_t1))
                             full_path_seg_label_t1))
                     if not osp.exists(full_path_seg_label_t2):
                     if not osp.exists(full_path_seg_label_t2):
-                        raise IOError('Label file {} does not exist!'.format(
+                        raise IOError("Label file {} does not exist!".format(
                             full_path_seg_label_t2))
                             full_path_seg_label_t2))
 
 
                 item_dict = dict(
                 item_dict = dict(

+ 83 - 0
paddlers/datasets/res_dataset.py

@@ -0,0 +1,83 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import copy
+
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
+
+
+class ResDataset(BaseDataset):
+    """
+    Dataset for image restoration tasks.
+
+    Args:
+        data_dir (str): Root directory of the dataset.
+        file_list (str): Path of the file that contains relative paths of source and target image files.
+        transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
+        num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
+            the number of workers will be automatically determined according to the number of CPU cores: If 
+            there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half 
+            the number of CPU cores. Defaults: 'auto'.
+        shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
+        sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image 
+            restoration tasks. Defaults to None.
+    """
+
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 transforms,
+                 num_workers='auto',
+                 shuffle=False,
+                 sr_factor=None):
+        super(ResDataset, self).__init__(data_dir, None, transforms,
+                                         num_workers, shuffle)
+        self.batch_transforms = None
+        self.file_list = list()
+
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                items = line.strip().split()
+                if len(items) > 2:
+                    raise ValueError(
+                        "A space is defined as the delimiter to separate the source and target image path, " \
+                        "so the space cannot be in the source image or target image path, but the line[{}] of " \
+                        " file_list[{}] has a space in the two paths.".format(line, file_list))
+                items[0] = norm_path(items[0])
+                items[1] = norm_path(items[1])
+                full_path_im = osp.join(data_dir, items[0])
+                full_path_tar = osp.join(data_dir, items[1])
+                if not is_pic(full_path_im) or not is_pic(full_path_tar):
+                    continue
+                if not osp.exists(full_path_im):
+                    raise IOError("Source image file {} does not exist!".format(
+                        full_path_im))
+                if not osp.exists(full_path_tar):
+                    raise IOError("Target image file {} does not exist!".format(
+                        full_path_tar))
+                sample = {
+                    'image': full_path_im,
+                    'target': full_path_tar,
+                }
+                if sr_factor is not None:
+                    sample['sr_factor'] = sr_factor
+                self.file_list.append(sample)
+        self.num_samples = len(self.file_list)
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+
+    def __len__(self):
+        return len(self.file_list)

+ 1 - 1
paddlers/datasets/seg_dataset.py

@@ -44,7 +44,7 @@ class SegDataset(BaseDataset):
                  shuffle=False):
                  shuffle=False):
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
                                          num_workers, shuffle)
                                          num_workers, shuffle)
-        # TODO batch padding
+        # TODO: batch padding
         self.batch_transforms = None
         self.batch_transforms = None
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()

+ 0 - 99
paddlers/datasets/sr_dataset.py

@@ -1,99 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# 超分辨率数据集定义
-class SRdataset(object):
-    def __init__(self,
-                 mode,
-                 gt_floder,
-                 lq_floder,
-                 transforms,
-                 scale,
-                 num_workers=4,
-                 batch_size=8):
-        if mode == 'train':
-            preprocess = []
-            preprocess.append({
-                'name': 'LoadImageFromFile',
-                'key': 'lq'
-            })  # 加载方式
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
-            preprocess.append(transforms)  # 变换方式
-            self.dataset = {
-                'name': 'SRDataset',
-                'gt_folder': gt_floder,
-                'lq_folder': lq_floder,
-                'num_workers': num_workers,
-                'batch_size': batch_size,
-                'scale': scale,
-                'preprocess': preprocess
-            }
-
-        if mode == "test":
-            preprocess = []
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'})
-            preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
-            preprocess.append(transforms)
-            self.dataset = {
-                'name': 'SRDataset',
-                'gt_folder': gt_floder,
-                'lq_folder': lq_floder,
-                'scale': scale,
-                'preprocess': preprocess
-            }
-
-    def __call__(self):
-        return self.dataset
-
-
-# 对定义的transforms处理方式组合,返回字典
-class ComposeTrans(object):
-    def __init__(self, input_keys, output_keys, pipelines):
-        if not isinstance(pipelines, list):
-            raise TypeError(
-                'Type of transforms is invalid. Must be List, but received is {}'
-                .format(type(pipelines)))
-        if len(pipelines) < 1:
-            raise ValueError(
-                'Length of transforms must not be less than 1, but received is {}'
-                .format(len(pipelines)))
-        self.transforms = pipelines
-        self.output_length = len(output_keys)  # 当output_keys的长度为3时,是DRN训练
-        self.input_keys = input_keys
-        self.output_keys = output_keys
-
-    def __call__(self):
-        pipeline = []
-        for op in self.transforms:
-            if op['name'] == 'SRPairedRandomCrop':
-                op['keys'] = ['image'] * 2
-            else:
-                op['keys'] = ['image'] * self.output_length
-            pipeline.append(op)
-        if self.output_length == 2:
-            transform_dict = {
-                'name': 'Transforms',
-                'input_keys': self.input_keys,
-                'pipeline': pipeline
-            }
-        else:
-            transform_dict = {
-                'name': 'Transforms',
-                'input_keys': self.input_keys,
-                'output_keys': self.output_keys,
-                'pipeline': pipeline
-            }
-
-        return transform_dict

+ 1 - 0
paddlers/models/__init__.py

@@ -16,3 +16,4 @@ from . import ppcls, ppdet, ppseg, ppgan
 import paddlers.models.ppseg.models.losses as seg_losses
 import paddlers.models.ppseg.models.losses as seg_losses
 import paddlers.models.ppdet.modeling.losses as det_losses
 import paddlers.models.ppdet.modeling.losses as det_losses
 import paddlers.models.ppcls.loss as clas_losses
 import paddlers.models.ppcls.loss as clas_losses
+import paddlers.models.ppgan.models.criterions as res_losses

+ 1 - 1
paddlers/rs_models/res/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from .rcan_model import RCANModel
+from .generators import *

+ 0 - 26
paddlers/rs_models/res/generators/builder.py

@@ -1,26 +0,0 @@
-#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import copy
-
-from ....models.ppgan.utils.registry import Registry
-
-GENERATORS = Registry("GENERATOR")
-
-
-def build_generator(cfg):
-    cfg_copy = copy.deepcopy(cfg)
-    name = cfg_copy.pop('name')
-    generator = GENERATORS.get(name)(**cfg_copy)
-    return generator

+ 11 - 15
paddlers/rs_models/res/generators/rcan.py

@@ -4,8 +4,6 @@ import math
 import paddle
 import paddle
 import paddle.nn as nn
 import paddle.nn as nn
 
 
-from .builder import GENERATORS
-
 
 
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
     weight_attr = paddle.ParamAttr(
     weight_attr = paddle.ParamAttr(
@@ -128,21 +126,19 @@ class Upsampler(nn.Sequential):
         super(Upsampler, self).__init__(*m)
         super(Upsampler, self).__init__(*m)
 
 
 
 
-@GENERATORS.register()
 class RCAN(nn.Layer):
 class RCAN(nn.Layer):
-    def __init__(
-            self,
-            scale,
-            n_resgroups,
-            n_resblocks,
-            n_feats=64,
-            n_colors=3,
-            rgb_range=255,
-            kernel_size=3,
-            reduction=16,
-            conv=default_conv, ):
+    def __init__(self,
+                 sr_factor=4,
+                 n_resgroups=10,
+                 n_resblocks=20,
+                 n_feats=64,
+                 n_colors=3,
+                 rgb_range=255,
+                 kernel_size=3,
+                 reduction=16,
+                 conv=default_conv):
         super(RCAN, self).__init__()
         super(RCAN, self).__init__()
-        self.scale = scale
+        self.scale = sr_factor
         act = nn.ReLU()
         act = nn.ReLU()
 
 
         n_resgroups = n_resgroups
         n_resgroups = n_resgroups

+ 0 - 106
paddlers/rs_models/res/rcan_model.py

@@ -1,106 +0,0 @@
-#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import paddle
-import paddle.nn as nn
-
-from .generators.builder import build_generator
-from ...models.ppgan.models.criterions.builder import build_criterion
-from ...models.ppgan.models.base_model import BaseModel
-from ...models.ppgan.models.builder import MODELS
-from ...models.ppgan.utils.visual import tensor2img
-from ...models.ppgan.modules.init import reset_parameters
-
-
-@MODELS.register()
-class RCANModel(BaseModel):
-    """
-    Base SR model for single image super-resolution.
-    """
-
-    def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
-        """
-        Args:
-            generator (dict): config of generator.
-            pixel_criterion (dict): config of pixel criterion.
-        """
-        super(RCANModel, self).__init__()
-
-        self.nets['generator'] = build_generator(generator)
-        self.error_last = 1e8
-        self.batch = 0
-        if pixel_criterion:
-            self.pixel_criterion = build_criterion(pixel_criterion)
-        if use_init_weight:
-            init_sr_weight(self.nets['generator'])
-
-    def setup_input(self, input):
-        self.lq = paddle.to_tensor(input['lq'])
-        self.visual_items['lq'] = self.lq
-        if 'gt' in input:
-            self.gt = paddle.to_tensor(input['gt'])
-            self.visual_items['gt'] = self.gt
-        self.image_paths = input['lq_path']
-
-    def forward(self):
-        pass
-
-    def train_iter(self, optims=None):
-        optims['optim'].clear_grad()
-
-        self.output = self.nets['generator'](self.lq)
-        self.visual_items['output'] = self.output
-        # pixel loss
-        loss_pixel = self.pixel_criterion(self.output, self.gt)
-        self.losses['loss_pixel'] = loss_pixel
-
-        skip_threshold = 1e6
-
-        if loss_pixel.item() < skip_threshold * self.error_last:
-            loss_pixel.backward()
-            optims['optim'].step()
-        else:
-            print('Skip this batch {}! (Loss: {})'.format(self.batch + 1,
-                                                          loss_pixel.item()))
-        self.batch += 1
-
-        if self.batch % 1000 == 0:
-            self.error_last = loss_pixel.item() / 1000
-            print("update error_last:{}".format(self.error_last))
-
-    def test_iter(self, metrics=None):
-        self.nets['generator'].eval()
-        with paddle.no_grad():
-            self.output = self.nets['generator'](self.lq)
-            self.visual_items['output'] = self.output
-        self.nets['generator'].train()
-
-        out_img = []
-        gt_img = []
-        for out_tensor, gt_tensor in zip(self.output, self.gt):
-            out_img.append(tensor2img(out_tensor, (0., 255.)))
-            gt_img.append(tensor2img(gt_tensor, (0., 255.)))
-
-        if metrics is not None:
-            for metric in metrics.values():
-                metric.update(out_img, gt_img)
-
-
-def init_sr_weight(net):
-    def reset_func(m):
-        if hasattr(m, 'weight') and (
-                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
-            reset_parameters(m)
-
-    net.apply(reset_func)

+ 1 - 1
paddlers/tasks/__init__.py

@@ -16,7 +16,7 @@ import paddlers.tasks.object_detector as detector
 import paddlers.tasks.segmenter as segmenter
 import paddlers.tasks.segmenter as segmenter
 import paddlers.tasks.change_detector as change_detector
 import paddlers.tasks.change_detector as change_detector
 import paddlers.tasks.classifier as classifier
 import paddlers.tasks.classifier as classifier
-import paddlers.tasks.image_restorer as restorer
+import paddlers.tasks.restorer as restorer
 from .load_model import load_model
 from .load_model import load_model
 
 
 # Shorter aliases
 # Shorter aliases

+ 24 - 25
paddlers/tasks/base.py

@@ -35,7 +35,6 @@ from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
                             load_checkpoint, SmoothedValue, TrainingStats,
                             load_checkpoint, SmoothedValue, TrainingStats,
                             _get_shared_memory_size_in_M, EarlyStop)
                             _get_shared_memory_size_in_M, EarlyStop)
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
-from .utils.infer_nets import InferNet, InferCDNet
 
 
 
 
 class ModelMeta(type):
 class ModelMeta(type):
@@ -268,7 +267,7 @@ class BaseModel(metaclass=ModelMeta):
                 'The volume of dataset({}) must be larger than batch size({}).'
                 'The volume of dataset({}) must be larger than batch size({}).'
                 .format(dataset.num_samples, batch_size))
                 .format(dataset.num_samples, batch_size))
         batch_size_each_card = get_single_card_bs(batch_size=batch_size)
         batch_size_each_card = get_single_card_bs(batch_size=batch_size)
-        # TODO detection eval阶段需做判断
+        # TODO: Make judgement in detection eval phase.
         batch_sampler = DistributedBatchSampler(
         batch_sampler = DistributedBatchSampler(
             dataset,
             dataset,
             batch_size=batch_size_each_card,
             batch_size=batch_size_each_card,
@@ -365,24 +364,12 @@ class BaseModel(metaclass=ModelMeta):
 
 
             for step, data in enumerate(self.train_data_loader()):
             for step, data in enumerate(self.train_data_loader()):
                 if nranks > 1:
                 if nranks > 1:
-                    outputs = self.run(ddp_net, data, mode='train')
+                    outputs = self.train_step(step, data, ddp_net)
                 else:
                 else:
-                    outputs = self.run(self.net, data, mode='train')
-                loss = outputs['loss']
-                loss.backward()
-                self.optimizer.step()
-                self.optimizer.clear_grad()
-                lr = self.optimizer.get_lr()
-                if isinstance(self.optimizer._learning_rate,
-                              paddle.optimizer.lr.LRScheduler):
-                    # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
-                    if isinstance(self.optimizer._learning_rate,
-                                  paddle.optimizer.lr.ReduceOnPlateau):
-                        self.optimizer._learning_rate.step(loss.item())
-                    else:
-                        self.optimizer._learning_rate.step()
+                    outputs = self.train_step(step, data, self.net)
 
 
                 train_avg_metrics.update(outputs)
                 train_avg_metrics.update(outputs)
+                lr = self.optimizer.get_lr()
                 outputs['lr'] = lr
                 outputs['lr'] = lr
                 if ema is not None:
                 if ema is not None:
                     ema.update(self.net)
                     ema.update(self.net)
@@ -622,14 +609,7 @@ class BaseModel(metaclass=ModelMeta):
         return pipeline_info
         return pipeline_info
 
 
     def _build_inference_net(self):
     def _build_inference_net(self):
-        if self.model_type in ('classifier', 'detector'):
-            infer_net = self.net
-        elif self.model_type == 'change_detector':
-            infer_net = InferCDNet(self.net)
-        else:
-            infer_net = InferNet(self.net, self.model_type)
-        infer_net.eval()
-        return infer_net
+        raise NotImplementedError
 
 
     def _export_inference_model(self, save_dir, image_shape=None):
     def _export_inference_model(self, save_dir, image_shape=None):
         self.test_inputs = self._get_test_inputs(image_shape)
         self.test_inputs = self._get_test_inputs(image_shape)
@@ -674,6 +654,25 @@ class BaseModel(metaclass=ModelMeta):
         logging.info("The inference model for deployment is saved in {}.".
         logging.info("The inference model for deployment is saved in {}.".
                      format(save_dir))
                      format(save_dir))
 
 
+    def train_step(self, step, data, net):
+        outputs = self.run(net, data, mode='train')
+
+        loss = outputs['loss']
+        loss.backward()
+        self.optimizer.step()
+        self.optimizer.clear_grad()
+
+        if isinstance(self.optimizer._learning_rate,
+                      paddle.optimizer.lr.LRScheduler):
+            # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
+            if isinstance(self.optimizer._learning_rate,
+                          paddle.optimizer.lr.ReduceOnPlateau):
+                self.optimizer._learning_rate.step(loss.item())
+            else:
+                self.optimizer._learning_rate.step()
+
+        return outputs
+
     def _check_transforms(self, transforms, mode):
     def _check_transforms(self, transforms, mode):
         # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
         # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
         if not isinstance(transforms, paddlers.transforms.Compose):
         if not isinstance(transforms, paddlers.transforms.Compose):

+ 14 - 7
paddlers/tasks/change_detector.py

@@ -30,10 +30,11 @@ import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image
 from paddlers.transforms import Resize, decode_image
-from paddlers.utils import get_single_card_bs, DisablePrint
+from paddlers.utils import get_single_card_bs
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
+from .utils.infer_nets import InferCDNet
 
 
 __all__ = [
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -71,6 +72,11 @@ class BaseChangeDetector(BaseModel):
                                              **params)
                                              **params)
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = InferCDNet(self.net)
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
             if self.test_transforms is not None:
@@ -401,7 +407,8 @@ class BaseChangeDetector(BaseModel):
                 Defaults to False.
                 Defaults to False.
 
 
         Returns:
         Returns:
-            collections.OrderedDict with key-value pairs:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
                 For binary change detection (number of classes == 2), the key-value 
                 For binary change detection (number of classes == 2), the key-value 
                     pairs are like:
                     pairs are like:
                     {"iou": `intersection over union for the change class`,
                     {"iou": `intersection over union for the change class`,
@@ -529,12 +536,12 @@ class BaseChangeDetector(BaseModel):
 
 
         Returns:
         Returns:
             If `img_file` is a tuple of string or np.array, the result is a dict with 
             If `img_file` is a tuple of string or np.array, the result is a dict with 
-                key-value pairs:
-                {"label map": `label map`, "score_map": `score map`}.
+                the following key-value pairs:
+                label_map (np.ndarray): Predicted label map (HW).
+                score_map (np.ndarray): Prediction score map (HWC).
+
             If `img_file` is a list, the result is a list composed of dicts with the 
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
-                label_map (np.ndarray): the predicted label map (HW)
-                score_map (np.ndarray): the prediction score map (HWC)
+                above keys.
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):

+ 15 - 11
paddlers/tasks/classifier.py

@@ -83,6 +83,11 @@ class BaseClassifier(BaseModel):
                 self.in_channels = 3
                 self.in_channels = 3
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
             if self.test_transforms is not None:
@@ -375,7 +380,8 @@ class BaseClassifier(BaseModel):
                 Defaults to False.
                 Defaults to False.
 
 
         Returns:
         Returns:
-            collections.OrderedDict with key-value pairs:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
                 {"top1": `acc of top1`,
                 {"top1": `acc of top1`,
                  "top5": `acc of top5`}.
                  "top5": `acc of top5`}.
         """
         """
@@ -420,7 +426,7 @@ class BaseClassifier(BaseModel):
         top5 = np.mean(top5s)
         top5 = np.mean(top5s)
         eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
         eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
         if return_details:
         if return_details:
-            # TODO: add details
+            # TODO: Add details
             return eval_metrics, None
             return eval_metrics, None
         return eval_metrics
         return eval_metrics
 
 
@@ -437,16 +443,14 @@ class BaseClassifier(BaseModel):
                 Defaults to None.
                 Defaults to None.
 
 
         Returns:
         Returns:
-            If `img_file` is a string or np.array, the result is a dict with key-value 
-                pairs:
-                {"label map": `class_ids_map`, 
-                 "scores_map": `scores_map`, 
-                 "label_names_map": `label_names_map`}.
+            If `img_file` is a string or np.array, the result is a dict with the 
+                following key-value pairs:
+                class_ids_map (np.ndarray): IDs of predicted classes.
+                scores_map (np.ndarray): Scores of predicted classes.
+                label_names_map (np.ndarray): Names of predicted classes.
+            
             If `img_file` is a list, the result is a list composed of dicts with the 
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
-                class_ids_map (np.ndarray): class_ids
-                scores_map (np.ndarray): scores
-                label_names_map (np.ndarray): label_names
+                above keys.
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):

+ 0 - 786
paddlers/tasks/image_restorer.py

@@ -1,786 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-import datetime
-
-import paddle
-from paddle.distributed import ParallelEnv
-
-from ..models.ppgan.datasets.builder import build_dataloader
-from ..models.ppgan.models.builder import build_model
-from ..models.ppgan.utils.visual import tensor2img, save_image
-from ..models.ppgan.utils.filesystem import makedirs, save, load
-from ..models.ppgan.utils.timer import TimeAverager
-from ..models.ppgan.utils.profiler import add_profiler_step
-from ..models.ppgan.utils.logger import setup_logger
-
-
-# 定义AttrDict类实现动态属性
-class AttrDict(dict):
-    def __getattr__(self, key):
-        try:
-            return self[key]
-        except KeyError:
-            raise AttributeError(key)
-
-    def __setattr__(self, key, value):
-        if key in self.__dict__:
-            self.__dict__[key] = value
-        else:
-            self[key] = value
-
-
-# 创建AttrDict类
-def create_attr_dict(config_dict):
-    from ast import literal_eval
-    for key, value in config_dict.items():
-        if type(value) is dict:
-            config_dict[key] = value = AttrDict(value)
-        if isinstance(value, str):
-            try:
-                value = literal_eval(value)
-            except BaseException:
-                pass
-        if isinstance(value, AttrDict):
-            create_attr_dict(config_dict[key])
-        else:
-            config_dict[key] = value
-
-
-# 数据加载类
-class IterLoader:
-    def __init__(self, dataloader):
-        self._dataloader = dataloader
-        self.iter_loader = iter(self._dataloader)
-        self._epoch = 1
-
-    @property
-    def epoch(self):
-        return self._epoch
-
-    def __next__(self):
-        try:
-            data = next(self.iter_loader)
-        except StopIteration:
-            self._epoch += 1
-            self.iter_loader = iter(self._dataloader)
-            data = next(self.iter_loader)
-
-        return data
-
-    def __len__(self):
-        return len(self._dataloader)
-
-
-# 基础训练类
-class Restorer:
-    """
-    # trainer calling logic:
-    #
-    #                build_model                               ||    model(BaseModel)
-    #                     |                                    ||
-    #               build_dataloader                           ||    dataloader
-    #                     |                                    ||
-    #               model.setup_lr_schedulers                  ||    lr_scheduler
-    #                     |                                    ||
-    #               model.setup_optimizers                     ||    optimizers
-    #                     |                                    ||
-    #     train loop (model.setup_input + model.train_iter)    ||    train loop
-    #                     |                                    ||
-    #         print log (model.get_current_losses)             ||
-    #                     |                                    ||
-    #         save checkpoint (model.nets)                     \/
-    """
-
-    def __init__(self, cfg, logger):
-        # base config
-        # self.logger = logging.getLogger(__name__)
-        self.logger = logger
-        self.cfg = cfg
-        self.output_dir = cfg.output_dir
-        self.max_eval_steps = cfg.model.get('max_eval_steps', None)
-
-        self.local_rank = ParallelEnv().local_rank
-        self.world_size = ParallelEnv().nranks
-        self.log_interval = cfg.log_config.interval
-        self.visual_interval = cfg.log_config.visiual_interval
-        self.weight_interval = cfg.snapshot_config.interval
-
-        self.start_epoch = 1
-        self.current_epoch = 1
-        self.current_iter = 1
-        self.inner_iter = 1
-        self.batch_id = 0
-        self.global_steps = 0
-
-        # build model
-        self.model = build_model(cfg.model)
-        # multiple gpus prepare
-        if ParallelEnv().nranks > 1:
-            self.distributed_data_parallel()
-
-        # build metrics
-        self.metrics = None
-        self.is_save_img = True
-        validate_cfg = cfg.get('validate', None)
-        if validate_cfg and 'metrics' in validate_cfg:
-            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
-        if validate_cfg and 'save_img' in validate_cfg:
-            self.is_save_img = validate_cfg['save_img']
-
-        self.enable_visualdl = cfg.get('enable_visualdl', False)
-        if self.enable_visualdl:
-            import visualdl
-            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
-
-        # evaluate only
-        if not cfg.is_train:
-            return
-
-        # build train dataloader
-        self.train_dataloader = build_dataloader(cfg.dataset.train)
-        self.iters_per_epoch = len(self.train_dataloader)
-
-        # build lr scheduler
-        # TODO: has a better way?
-        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
-            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
-        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
-
-        # build optimizers
-        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
-                                                      cfg.optimizer)
-
-        self.epochs = cfg.get('epochs', None)
-        if self.epochs:
-            self.total_iters = self.epochs * self.iters_per_epoch
-            self.by_epoch = True
-        else:
-            self.by_epoch = False
-            self.total_iters = cfg.total_iters
-
-        if self.by_epoch:
-            self.weight_interval *= self.iters_per_epoch
-
-        self.validate_interval = -1
-        if cfg.get('validate', None) is not None:
-            self.validate_interval = cfg.validate.get('interval', -1)
-
-        self.time_count = {}
-        self.best_metric = {}
-        self.model.set_total_iter(self.total_iters)
-        self.profiler_options = cfg.profiler_options
-
-    def distributed_data_parallel(self):
-        paddle.distributed.init_parallel_env()
-        find_unused_parameters = self.cfg.get('find_unused_parameters', False)
-        for net_name, net in self.model.nets.items():
-            self.model.nets[net_name] = paddle.DataParallel(
-                net, find_unused_parameters=find_unused_parameters)
-
-    def learning_rate_scheduler_step(self):
-        if isinstance(self.model.lr_scheduler, dict):
-            for lr_scheduler in self.model.lr_scheduler.values():
-                lr_scheduler.step()
-        elif isinstance(self.model.lr_scheduler,
-                        paddle.optimizer.lr.LRScheduler):
-            self.model.lr_scheduler.step()
-        else:
-            raise ValueError(
-                'lr schedulter must be a dict or an instance of LRScheduler')
-
-    def train(self):
-        reader_cost_averager = TimeAverager()
-        batch_cost_averager = TimeAverager()
-
-        iter_loader = IterLoader(self.train_dataloader)
-
-        # set model.is_train = True
-        self.model.setup_train_mode(is_train=True)
-        while self.current_iter < (self.total_iters + 1):
-            self.current_epoch = iter_loader.epoch
-            self.inner_iter = self.current_iter % self.iters_per_epoch
-
-            add_profiler_step(self.profiler_options)
-
-            start_time = step_start_time = time.time()
-            data = next(iter_loader)
-            reader_cost_averager.record(time.time() - step_start_time)
-            # unpack data from dataset and apply preprocessing
-            # data input should be dict
-            self.model.setup_input(data)
-            self.model.train_iter(self.optimizers)
-
-            batch_cost_averager.record(
-                time.time() - step_start_time,
-                num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
-
-            step_start_time = time.time()
-
-            if self.current_iter % self.log_interval == 0:
-                self.data_time = reader_cost_averager.get_average()
-                self.step_time = batch_cost_averager.get_average()
-                self.ips = batch_cost_averager.get_ips_average()
-                self.print_log()
-
-                reader_cost_averager.reset()
-                batch_cost_averager.reset()
-
-            if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
-                self.visual('visual_train')
-
-            self.learning_rate_scheduler_step()
-
-            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
-                self.test()
-
-            if self.current_iter % self.weight_interval == 0:
-                self.save(self.current_iter, 'weight', keep=-1)
-                self.save(self.current_iter)
-
-            self.current_iter += 1
-
-    def test(self):
-        if not hasattr(self, 'test_dataloader'):
-            self.test_dataloader = build_dataloader(
-                self.cfg.dataset.test, is_train=False)
-        iter_loader = IterLoader(self.test_dataloader)
-        if self.max_eval_steps is None:
-            self.max_eval_steps = len(self.test_dataloader)
-
-        if self.metrics:
-            for metric in self.metrics.values():
-                metric.reset()
-
-        # set model.is_train = False
-        self.model.setup_train_mode(is_train=False)
-
-        for i in range(self.max_eval_steps):
-            if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
-                self.logger.info('Test iter: [%d/%d]' % (
-                    i * self.world_size, self.max_eval_steps * self.world_size))
-
-            data = next(iter_loader)
-            self.model.setup_input(data)
-            self.model.test_iter(metrics=self.metrics)
-
-            if self.is_save_img:
-                visual_results = {}
-                current_paths = self.model.get_image_paths()
-                current_visuals = self.model.get_current_visuals()
-
-                if len(current_visuals) > 0 and list(current_visuals.values())[
-                        0].shape == 4:
-                    num_samples = list(current_visuals.values())[0].shape[0]
-                else:
-                    num_samples = 1
-
-                for j in range(num_samples):
-                    if j < len(current_paths):
-                        short_path = os.path.basename(current_paths[j])
-                        basename = os.path.splitext(short_path)[0]
-                    else:
-                        basename = '{:04d}_{:04d}'.format(i, j)
-                    for k, img_tensor in current_visuals.items():
-                        name = '%s_%s' % (basename, k)
-                        if len(img_tensor.shape) == 4:
-                            visual_results.update({name: img_tensor[j]})
-                        else:
-                            visual_results.update({name: img_tensor})
-
-                self.visual(
-                    'visual_test',
-                    visual_results=visual_results,
-                    step=self.batch_id,
-                    is_save_image=True)
-
-        if self.metrics:
-            for metric_name, metric in self.metrics.items():
-                self.logger.info("Metric {}: {:.4f}".format(
-                    metric_name, metric.accumulate()))
-
-    def print_log(self):
-        losses = self.model.get_current_losses()
-
-        message = ''
-        if self.by_epoch:
-            message += 'Epoch: %d/%d, iter: %d/%d ' % (
-                self.current_epoch, self.epochs, self.inner_iter,
-                self.iters_per_epoch)
-        else:
-            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
-
-        message += f'lr: {self.current_learning_rate:.3e} '
-
-        for k, v in losses.items():
-            message += '%s: %.3f ' % (k, v)
-            if self.enable_visualdl:
-                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
-
-        if hasattr(self, 'step_time'):
-            message += 'batch_cost: %.5f sec ' % self.step_time
-
-        if hasattr(self, 'data_time'):
-            message += 'reader_cost: %.5f sec ' % self.data_time
-
-        if hasattr(self, 'ips'):
-            message += 'ips: %.5f images/s ' % self.ips
-
-        if hasattr(self, 'step_time'):
-            eta = self.step_time * (self.total_iters - self.current_iter)
-            eta = eta if eta > 0 else 0
-
-            eta_str = str(datetime.timedelta(seconds=int(eta)))
-            message += f'eta: {eta_str}'
-
-        # print the message
-        self.logger.info(message)
-
-    @property
-    def current_learning_rate(self):
-        for optimizer in self.model.optimizers.values():
-            return optimizer.get_lr()
-
-    def visual(self,
-               results_dir,
-               visual_results=None,
-               step=None,
-               is_save_image=False):
-        """
-        visual the images, use visualdl or directly write to the directory
-        Parameters:
-            results_dir (str)     --  directory name which contains saved images
-            visual_results (dict) --  the results images dict
-            step (int)            --  global steps, used in visualdl
-            is_save_image (bool)  --  weather write to the directory or visualdl
-        """
-        self.model.compute_visuals()
-
-        if visual_results is None:
-            visual_results = self.model.get_current_visuals()
-
-        min_max = self.cfg.get('min_max', None)
-        if min_max is None:
-            min_max = (-1., 1.)
-
-        image_num = self.cfg.get('image_num', None)
-        if (image_num is None) or (not self.enable_visualdl):
-            image_num = 1
-        for label, image in visual_results.items():
-            image_numpy = tensor2img(image, min_max, image_num)
-            if (not is_save_image) and self.enable_visualdl:
-                self.vdl_logger.add_image(
-                    results_dir + '/' + label,
-                    image_numpy,
-                    step=step if step else self.global_steps,
-                    dataformats="HWC" if image_num == 1 else "NCHW")
-            else:
-                if self.cfg.is_train:
-                    if self.by_epoch:
-                        msg = 'epoch%.3d_' % self.current_epoch
-                    else:
-                        msg = 'iter%.3d_' % self.current_iter
-                else:
-                    msg = ''
-                makedirs(os.path.join(self.output_dir, results_dir))
-                img_path = os.path.join(self.output_dir, results_dir,
-                                        msg + '%s.png' % (label))
-                save_image(image_numpy, img_path)
-
-    def save(self, epoch, name='checkpoint', keep=1):
-        if self.local_rank != 0:
-            return
-
-        assert name in ['checkpoint', 'weight']
-
-        state_dicts = {}
-        if self.by_epoch:
-            save_filename = 'epoch_%s_%s.pdparams' % (
-                epoch // self.iters_per_epoch, name)
-        else:
-            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
-
-        os.makedirs(self.output_dir, exist_ok=True)
-        save_path = os.path.join(self.output_dir, save_filename)
-        for net_name, net in self.model.nets.items():
-            state_dicts[net_name] = net.state_dict()
-
-        if name == 'weight':
-            save(state_dicts, save_path)
-            return
-
-        state_dicts['epoch'] = epoch
-
-        for opt_name, opt in self.model.optimizers.items():
-            state_dicts[opt_name] = opt.state_dict()
-
-        save(state_dicts, save_path)
-
-        if keep > 0:
-            try:
-                if self.by_epoch:
-                    checkpoint_name_to_be_removed = os.path.join(
-                        self.output_dir, 'epoch_%s_%s.pdparams' % (
-                            (epoch - keep * self.weight_interval) //
-                            self.iters_per_epoch, name))
-                else:
-                    checkpoint_name_to_be_removed = os.path.join(
-                        self.output_dir, 'iter_%s_%s.pdparams' %
-                        (epoch - keep * self.weight_interval, name))
-
-                if os.path.exists(checkpoint_name_to_be_removed):
-                    os.remove(checkpoint_name_to_be_removed)
-
-            except Exception as e:
-                self.logger.info('remove old checkpoints error: {}'.format(e))
-
-    def resume(self, checkpoint_path):
-        state_dicts = load(checkpoint_path)
-        if state_dicts.get('epoch', None) is not None:
-            self.start_epoch = state_dicts['epoch'] + 1
-            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
-
-            self.current_iter = state_dicts['epoch'] + 1
-
-        for net_name, net in self.model.nets.items():
-            net.set_state_dict(state_dicts[net_name])
-
-        for opt_name, opt in self.model.optimizers.items():
-            opt.set_state_dict(state_dicts[opt_name])
-
-    def load(self, weight_path):
-        state_dicts = load(weight_path)
-
-        for net_name, net in self.model.nets.items():
-            if net_name in state_dicts:
-                net.set_state_dict(state_dicts[net_name])
-                self.logger.info('Loaded pretrained weight for net {}'.format(
-                    net_name))
-            else:
-                self.logger.warning(
-                    'Can not find state dict of net {}. Skip load pretrained weight for net {}'
-                    .format(net_name, net_name))
-
-    def close(self):
-        """
-        when finish the training need close file handler or other.
-        """
-        if self.enable_visualdl:
-            self.vdl_logger.close()
-
-
-# 基础超分模型训练类
-class BasicSRNet:
-    def __init__(self):
-        self.model = {}
-        self.optimizer = {}
-        self.lr_scheduler = {}
-        self.min_max = ''
-
-    def train(
-            self,
-            total_iters,
-            train_dataset,
-            test_dataset,
-            output_dir,
-            validate,
-            snapshot,
-            log,
-            lr_rate,
-            evaluate_weights='',
-            resume='',
-            pretrain_weights='',
-            periods=[100000],
-            restart_weights=[1], ):
-        self.lr_scheduler['learning_rate'] = lr_rate
-
-        if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR':
-            self.lr_scheduler['periods'] = periods
-            self.lr_scheduler['restart_weights'] = restart_weights
-
-        validate = {
-            'interval': validate,
-            'save_img': False,
-            'metrics': {
-                'psnr': {
-                    'name': 'PSNR',
-                    'crop_border': 4,
-                    'test_y_channel': True
-                },
-                'ssim': {
-                    'name': 'SSIM',
-                    'crop_border': 4,
-                    'test_y_channel': True
-                }
-            }
-        }
-        log_config = {'interval': log, 'visiual_interval': 500}
-        snapshot_config = {'interval': snapshot}
-
-        cfg = {
-            'total_iters': total_iters,
-            'output_dir': output_dir,
-            'min_max': self.min_max,
-            'model': self.model,
-            'dataset': {
-                'train': train_dataset,
-                'test': test_dataset
-            },
-            'lr_scheduler': self.lr_scheduler,
-            'optimizer': self.optimizer,
-            'validate': validate,
-            'log_config': log_config,
-            'snapshot_config': snapshot_config
-        }
-
-        cfg = AttrDict(cfg)
-        create_attr_dict(cfg)
-
-        cfg.is_train = True
-        cfg.profiler_options = None
-        cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
-
-        if cfg.model.name == 'BaseSRModel':
-            floderModelName = cfg.model.generator.name
-        else:
-            floderModelName = cfg.model.name
-        cfg.output_dir = os.path.join(cfg.output_dir,
-                                      floderModelName + cfg.timestamp)
-
-        logger_cfg = setup_logger(cfg.output_dir)
-        logger_cfg.info('Configs: {}'.format(cfg))
-
-        if paddle.is_compiled_with_cuda():
-            paddle.set_device('gpu')
-        else:
-            paddle.set_device('cpu')
-
-        # build trainer
-        trainer = Restorer(cfg, logger_cfg)
-
-        # continue train or evaluate, checkpoint need contain epoch and optimizer info
-        if len(resume) > 0:
-            trainer.resume(resume)
-        # evaluate or finute, only load generator weights
-        elif len(pretrain_weights) > 0:
-            trainer.load(pretrain_weights)
-        if len(evaluate_weights) > 0:
-            trainer.load(evaluate_weights)
-            trainer.test()
-            return
-        # training, when keyboard interrupt save weights
-        try:
-            trainer.train()
-        except KeyboardInterrupt as e:
-            trainer.save(trainer.current_epoch)
-
-        trainer.close()
-
-
-# DRN模型训练
-class DRNet(BasicSRNet):
-    def __init__(self,
-                 n_blocks=30,
-                 n_feats=16,
-                 n_colors=3,
-                 rgb_range=255,
-                 negval=0.2):
-        super(DRNet, self).__init__()
-        self.min_max = '(0., 255.)'
-        self.generator = {
-            'name': 'DRNGenerator',
-            'scale': (2, 4),
-            'n_blocks': n_blocks,
-            'n_feats': n_feats,
-            'n_colors': n_colors,
-            'rgb_range': rgb_range,
-            'negval': negval
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'DRN',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'optimG': {
-                'name': 'Adam',
-                'net_names': ['generator'],
-                'weight_decay': 0.0,
-                'beta1': 0.9,
-                'beta2': 0.999
-            },
-            'optimD': {
-                'name': 'Adam',
-                'net_names': ['dual_model_0', 'dual_model_1'],
-                'weight_decay': 0.0,
-                'beta1': 0.9,
-                'beta2': 0.999
-            }
-        }
-        self.lr_scheduler = {
-            'name': 'CosineAnnealingRestartLR',
-            'eta_min': 1e-07
-        }
-
-
-# 轻量化超分模型LESRCNN训练
-class LESRCNNet(BasicSRNet):
-    def __init__(self, scale=4, multi_scale=False, group=1):
-        super(LESRCNNet, self).__init__()
-        self.min_max = '(0., 1.)'
-        self.generator = {
-            'name': 'LESRCNNGenerator',
-            'scale': scale,
-            'multi_scale': False,
-            'group': 1
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'BaseSRModel',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'name': 'Adam',
-            'net_names': ['generator'],
-            'beta1': 0.9,
-            'beta2': 0.99
-        }
-        self.lr_scheduler = {
-            'name': 'CosineAnnealingRestartLR',
-            'eta_min': 1e-07
-        }
-
-
-# ESRGAN模型训练
-# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
-# 若loss_type = 'pixel' 只使用像素损失
-class ESRGANet(BasicSRNet):
-    def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23):
-        super(ESRGANet, self).__init__()
-        self.min_max = '(0., 1.)'
-        self.generator = {
-            'name': 'RRDBNet',
-            'in_nc': in_nc,
-            'out_nc': out_nc,
-            'nf': nf,
-            'nb': nb
-        }
-
-        if loss_type == 'gan':
-            # 定义损失函数
-            self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01}
-            self.discriminator = {
-                'name': 'VGGDiscriminator128',
-                'in_channels': 3,
-                'num_feat': 64
-            }
-            self.perceptual_criterion = {
-                'name': 'PerceptualLoss',
-                'layer_weights': {
-                    '34': 1.0
-                },
-                'perceptual_weight': 1.0,
-                'style_weight': 0.0,
-                'norm_img': False
-            }
-            self.gan_criterion = {
-                'name': 'GANLoss',
-                'gan_mode': 'vanilla',
-                'loss_weight': 0.005
-            }
-            # 定义模型 
-            self.model = {
-                'name': 'ESRGAN',
-                'generator': self.generator,
-                'discriminator': self.discriminator,
-                'pixel_criterion': self.pixel_criterion,
-                'perceptual_criterion': self.perceptual_criterion,
-                'gan_criterion': self.gan_criterion
-            }
-            self.optimizer = {
-                'optimG': {
-                    'name': 'Adam',
-                    'net_names': ['generator'],
-                    'weight_decay': 0.0,
-                    'beta1': 0.9,
-                    'beta2': 0.99
-                },
-                'optimD': {
-                    'name': 'Adam',
-                    'net_names': ['discriminator'],
-                    'weight_decay': 0.0,
-                    'beta1': 0.9,
-                    'beta2': 0.99
-                }
-            }
-            self.lr_scheduler = {
-                'name': 'MultiStepDecay',
-                'milestones': [50000, 100000, 200000, 300000],
-                'gamma': 0.5
-            }
-        else:
-            self.pixel_criterion = {'name': 'L1Loss'}
-            self.model = {
-                'name': 'BaseSRModel',
-                'generator': self.generator,
-                'pixel_criterion': self.pixel_criterion
-            }
-            self.optimizer = {
-                'name': 'Adam',
-                'net_names': ['generator'],
-                'beta1': 0.9,
-                'beta2': 0.99
-            }
-            self.lr_scheduler = {
-                'name': 'CosineAnnealingRestartLR',
-                'eta_min': 1e-07
-            }
-
-
-# RCAN模型训练
-class RCANet(BasicSRNet):
-    def __init__(
-            self,
-            scale=2,
-            n_resgroups=10,
-            n_resblocks=20, ):
-        super(RCANet, self).__init__()
-        self.min_max = '(0., 255.)'
-        self.generator = {
-            'name': 'RCAN',
-            'scale': scale,
-            'n_resgroups': n_resgroups,
-            'n_resblocks': n_resblocks
-        }
-        self.pixel_criterion = {'name': 'L1Loss'}
-        self.model = {
-            'name': 'RCANModel',
-            'generator': self.generator,
-            'pixel_criterion': self.pixel_criterion
-        }
-        self.optimizer = {
-            'name': 'Adam',
-            'net_names': ['generator'],
-            'beta1': 0.9,
-            'beta2': 0.99
-        }
-        self.lr_scheduler = {
-            'name': 'MultiStepDecay',
-            'milestones': [250000, 500000, 750000, 1000000],
-            'gamma': 0.5
-        }

+ 16 - 15
paddlers/tasks/object_detector.py

@@ -61,6 +61,11 @@ class BaseDetector(BaseModel):
             net = ppdet.modeling.__dict__[self.model_name](**params)
             net = ppdet.modeling.__dict__[self.model_name](**params)
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
         raise NotImplementedError("_fix_transforms_shape: not implemented!")
 
 
@@ -457,7 +462,7 @@ class BaseDetector(BaseModel):
                 Defaults to False.
                 Defaults to False.
 
 
         Returns:
         Returns:
-            collections.OrderedDict with key-value pairs: 
+            If `return_details` is False, return collections.OrderedDict with key-value pairs: 
                 {"bbox_mmap":`mean average precision (0.50, 11point)`}.
                 {"bbox_mmap":`mean average precision (0.50, 11point)`}.
         """
         """
 
 
@@ -556,21 +561,17 @@ class BaseDetector(BaseModel):
 
 
         Returns:
         Returns:
             If `img_file` is a string or np.array, the result is a list of dict with 
             If `img_file` is a string or np.array, the result is a list of dict with 
-                key-value pairs:
-                {"category_id": `category_id`, 
-                 "category": `category`, 
-                 "bbox": `[x, y, w, h]`, 
-                 "score": `score`, 
-                 "mask": `mask`}.
-            If `img_file` is a list, the result is a list composed of list of dicts 
-                with the corresponding fields:
-                category_id(int): the predicted category ID. 0 represents the first 
+                the following key-value pairs:
+                category_id (int): Predicted category ID. 0 represents the first 
                     category in the dataset, and so on.
                     category in the dataset, and so on.
-                category(str): category name
-                bbox(list): bounding box in [x, y, w, h] format
-                score(str): confidence
-                mask(dict): Only for instance segmentation task. Mask of the object in 
-                    RLE format
+                category (str): Category name.
+                bbox (list): Bounding box in [x, y, w, h] format.
+                score (str): Confidence.
+                mask (dict): Only for instance segmentation task. Mask of the object in 
+                    RLE format.
+
+            If `img_file` is a list, the result is a list composed of list of dicts 
+                with the above keys.
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):

+ 818 - 0
paddlers/tasks/restorer.py

@@ -0,0 +1,818 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import numpy as np
+import cv2
+import paddle
+import paddle.nn.functional as F
+from paddle.static import InputSpec
+
+import paddlers
+import paddlers.models.ppgan as ppgan
+import paddlers.rs_models.res as cmres
+import paddlers.utils.logging as logging
+from paddlers.models import res_losses
+from paddlers.transforms import Resize, decode_image
+from paddlers.transforms.functions import calc_hr_shape
+from paddlers.utils import get_single_card_bs
+from .base import BaseModel
+from .utils.res_adapters import GANAdapter, OptimizerAdapter
+
+__all__ = []
+
+
+class BaseRestorer(BaseModel):
+    MIN_MAX = (0., 255.)
+
+    def __init__(self, model_name, losses=None, sr_factor=None, **params):
+        self.init_params = locals()
+        if 'with_net' in self.init_params:
+            del self.init_params['with_net']
+        super(BaseRestorer, self).__init__('restorer')
+        self.model_name = model_name
+        self.losses = losses
+        self.sr_factor = sr_factor
+        if params.get('with_net', True):
+            params.pop('with_net', None)
+            self.net = self.build_net(**params)
+        self.find_unused_parameters = True
+
+    def build_net(self, **params):
+        # Currently, only use models from cmres.
+        if not hasattr(cmres, model_name):
+            raise ValueError("ERROR: There is no model named {}.".format(
+                model_name))
+        net = dict(**cmres.__dict__)[self.model_name](**params)
+        return net
+
+    def _build_inference_net(self):
+        # For GAN models, only the generator will be used for inference.
+        if isinstance(self.net, GANAdapter):
+            infer_net = self.net.generator
+        else:
+            infer_net = self.net
+        infer_net.eval()
+        return infer_net
+
+    def _fix_transforms_shape(self, image_shape):
+        if hasattr(self, 'test_transforms'):
+            if self.test_transforms is not None:
+                has_resize_op = False
+                resize_op_idx = -1
+                normalize_op_idx = len(self.test_transforms.transforms)
+                for idx, op in enumerate(self.test_transforms.transforms):
+                    name = op.__class__.__name__
+                    if name == 'Normalize':
+                        normalize_op_idx = idx
+                    if 'Resize' in name:
+                        has_resize_op = True
+                        resize_op_idx = idx
+
+                if not has_resize_op:
+                    self.test_transforms.transforms.insert(
+                        normalize_op_idx, Resize(target_size=image_shape))
+                else:
+                    self.test_transforms.transforms[resize_op_idx] = Resize(
+                        target_size=image_shape)
+
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            if len(image_shape) == 2:
+                image_shape = [1, 3] + image_shape
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+        self.fixed_input_shape = image_shape
+        input_spec = [
+            InputSpec(
+                shape=image_shape, name='image', dtype='float32')
+        ]
+        return input_spec
+
+    def run(self, net, inputs, mode):
+        outputs = OrderedDict()
+
+        if mode == 'test':
+            if isinstance(net, GANAdapter):
+                net_out = net.generator(inputs[0])
+            else:
+                net_out = net(inputs[0])
+            tar_shape = inputs[1]
+            if self.status == 'Infer':
+                res_map_list = self._postprocess(
+                    net_out, tar_shape, transforms=inputs[2])
+            else:
+                pred = self._postprocess(
+                    net_out, tar_shape, transforms=inputs[2])
+                res_map_list = []
+                for res_map in pred:
+                    res_map = self._tensor_to_images(res_map)
+                    res_map_list.append(res_map)
+            outputs['res_map'] = res_map_list
+
+        if mode == 'eval':
+            if isinstance(net, GANAdapter):
+                net_out = net.generator(inputs[0])
+            else:
+                net_out = net(inputs[0])
+            tar = inputs[1]
+            tar_shape = [tar.shape[-2:]]
+            pred = self._postprocess(
+                net_out, tar_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self._tensor_to_images(pred)
+            outputs['pred'] = pred
+            tar = self.tensor_to_images(tar)
+            outputs['tar'] = tar
+
+        if mode == 'train':
+            # This is used by non-GAN models.
+            # For GAN models, self.run_gan() should be used.
+            net_out = net(inputs[0])
+            loss = self.losses(net_out, inputs[1])
+            outputs['loss'] = loss
+        return outputs
+
+    def run_gan(self, net, inputs, mode, gan_mode):
+        raise NotImplementedError
+
+    def default_loss(self):
+        return res_losses.L1Loss()
+
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          num_epochs,
+                          num_steps_each_epoch,
+                          lr_decay_power=0.9):
+        decay_step = num_epochs * num_steps_each_epoch
+        lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+            learning_rate, decay_step, end_lr=0, power=lr_decay_power)
+        optimizer = paddle.optimizer.Momentum(
+            learning_rate=lr_scheduler,
+            parameters=parameters,
+            momentum=0.9,
+            weight_decay=4e-5)
+        return optimizer
+
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=2,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=2,
+              save_dir='output',
+              pretrain_weights=None,
+              learning_rate=0.01,
+              lr_decay_power=0.9,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+
+        Args:
+            num_epochs (int): Number of epochs.
+            train_dataset (paddlers.datasets.ResDataset): Training dataset.
+            train_batch_size (int, optional): Total batch size among all cards used in 
+                training. Defaults to 2.
+            eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. 
+                If None, the model will not be evaluated during training process. 
+                Defaults to None.
+            optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in 
+                training. If None, a default optimizer will be used. Defaults to None.
+            save_interval_epochs (int, optional): Epoch interval for saving the model. 
+                Defaults to 1.
+            log_interval_steps (int, optional): Step interval for printing training 
+                information. Defaults to 2.
+            save_dir (str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights (str|None, optional): None or name/path of pretrained 
+                weights. If None, no pretrained weights will be loaded. 
+                Defaults to None.
+            learning_rate (float, optional): Learning rate for training. Defaults to .01.
+            lr_decay_power (float, optional): Learning decay power. Defaults to .9.
+            early_stop (bool, optional): Whether to adopt early stop strategy. Defaults 
+                to False.
+            early_stop_patience (int, optional): Early stop patience. Defaults to 5.
+            use_vdl (bool, optional): Whether to use VisualDL to monitor the training 
+                process. Defaults to True.
+            resume_checkpoint (str|None, optional): Path of the checkpoint to resume
+                training from. If None, no training checkpoint will be resumed. At most
+                Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
+                Defaults to None.
+        """
+
+        if self.status == 'Infer':
+            logging.error(
+                "Exported inference model does not support training.",
+                exit=True)
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
+
+        if self.losses is None:
+            self.losses = self.default_loss()
+
+        if optimizer is None:
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            if isinstance(self.net, GANAdapter):
+                parameters = {'params_g': [], 'params_d': []}
+                for net_g in self.net.generators:
+                    parameters['params_g'].append(net_g.parameters())
+                for net_d in self.net.discriminators:
+                    parameters['params_d'].append(net_d.parameters())
+            else:
+                parameters = self.net.parameters()
+            self.optimizer = self.default_optimizer(
+                parameters, learning_rate, num_epochs, num_steps_each_epoch,
+                lr_decay_power)
+        else:
+            self.optimizer = optimizer
+
+        if pretrain_weights is not None and not osp.exists(pretrain_weights):
+            logging.warning("Path of pretrain_weights('{}') does not exist!".
+                            format(pretrain_weights))
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
+        pretrained_dir = osp.join(save_dir, 'pretrain')
+        is_backbone_weights = pretrain_weights == 'IMAGENET'
+        self.net_initialize(
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint,
+            is_backbone_weights=is_backbone_weights)
+
+        self.train_loop(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=2,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=2,
+                          save_dir='output',
+                          learning_rate=0.0001,
+                          lr_decay_power=0.9,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          resume_checkpoint=None,
+                          quant_config=None):
+        """
+        Quantization-aware training.
+
+        Args:
+            num_epochs (int): Number of epochs.
+            train_dataset (paddlers.datasets.ResDataset): Training dataset.
+            train_batch_size (int, optional): Total batch size among all cards used in 
+                training. Defaults to 2.
+            eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset.
+                If None, the model will not be evaluated during training process. 
+                Defaults to None.
+            optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in 
+                training. If None, a default optimizer will be used. Defaults to None.
+            save_interval_epochs (int, optional): Epoch interval for saving the model. 
+                Defaults to 1.
+            log_interval_steps (int, optional): Step interval for printing training 
+                information. Defaults to 2.
+            save_dir (str, optional): Directory to save the model. Defaults to 'output'.
+            learning_rate (float, optional): Learning rate for training. 
+                Defaults to .0001.
+            lr_decay_power (float, optional): Learning decay power. Defaults to .9.
+            early_stop (bool, optional): Whether to adopt early stop strategy. 
+                Defaults to False.
+            early_stop_patience (int, optional): Early stop patience. Defaults to 5.
+            use_vdl (bool, optional): Whether to use VisualDL to monitor the training 
+                process. Defaults to True.
+            quant_config (dict|None, optional): Quantization configuration. If None, 
+                a default rule of thumb configuration will be used. Defaults to None.
+            resume_checkpoint (str|None, optional): Path of the checkpoint to resume
+                quantization-aware training from. If None, no training checkpoint will
+                be resumed. Defaults to None.
+        """
+
+        self._prepare_qat(quant_config)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=None,
+            learning_rate=learning_rate,
+            lr_decay_power=lr_decay_power,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
+
+    def evaluate(self, eval_dataset, batch_size=1, return_details=False):
+        """
+        Evaluate the model.
+
+        Args:
+            eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset.
+            batch_size (int, optional): Total batch size among all cards used for 
+                evaluation. Defaults to 1.
+            return_details (bool, optional): Whether to return evaluation details. 
+                Defaults to False.
+
+        Returns:
+            If `return_details` is False, return collections.OrderedDict with 
+                key-value pairs:
+                {"psnr": `peak signal-to-noise ratio`,
+                 "ssim": `structural similarity`}.
+
+        """
+
+        self._check_transforms(eval_dataset.transforms, 'eval')
+
+        self.net.eval()
+        nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
+        if nranks > 1:
+            # Initialize parallel environment if not done.
+            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+            ):
+                paddle.distributed.init_parallel_env()
+
+        batch_size_each_card = get_single_card_bs(batch_size)
+        if batch_size_each_card > 1:
+            batch_size_each_card = 1
+            batch_size = batch_size_each_card * paddlers.env_info['num']
+            logging.warning(
+                "Restorer only supports batch_size=1 for each gpu/cpu card " \
+                "during evaluation, so batch_size " \
+                "is forcibly set to {}.".format(batch_size))
+
+        # TODO: Distributed evaluation
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset, batch_size=batch_size, mode='eval')
+            # XXX: Hard-code crop_border and test_y_channel
+            psnr = ppgan.metrics.PSNR(crop_border=4, test_y_channel=True)
+            ssim = ppgan.metrics.SSIM(crop_border=4, test_y_channel=True)
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    outputs = self.run(self.net, data, 'eval')
+                    psnr.update(outputs['pred'], outputs['tar'])
+                    ssim.update(outputs['pred'], outputs['tar'])
+
+        eval_metrics = OrderedDict(
+            zip(['psnr', 'ssim'], [psnr.accumulate(), ssim.accumulate()]))
+
+        if return_details:
+            # TODO: Add details
+            return eval_metrics, None
+
+        return eval_metrics
+
+    def predict(self, img_file, transforms=None):
+        """
+        Do inference.
+
+        Args:
+            img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded 
+                image data, which also could constitute a list, meaning all images to be 
+                predicted as a mini-batch.
+            transforms (paddlers.transforms.Compose|None, optional): Transforms for 
+                inputs. If None, the transforms for evaluation process will be used. 
+                Defaults to None.
+
+        Returns:
+            If `img_file` is a tuple of string or np.array, the result is a dict with 
+                the following key-value pairs:
+                res_map (np.ndarray): Restored image (HWC).
+
+            If `img_file` is a list, the result is a list composed of dicts with the 
+                above keys.
+        """
+
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise ValueError("transforms need to be defined, now is None.")
+        if transforms is None:
+            transforms = self.test_transforms
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            images = img_file
+        batch_im, batch_tar_shape = self._preprocess(images, transforms,
+                                                     self.model_type)
+        self.net.eval()
+        data = (batch_im, batch_tar_shape, transforms.transforms)
+        outputs = self.run(self.net, data, 'test')
+        res_map_list = outputs['res_map']
+        if isinstance(img_file, list):
+            prediction = [{'res_map': m} for m in res_map_list]
+        else:
+            prediction = {'res_map': res_map_list[0]}
+        return prediction
+
+    def _preprocess(self, images, transforms, to_tensor=True):
+        self._check_transforms(transforms, 'test')
+        batch_im = list()
+        batch_tar_shape = list()
+        for im in images:
+            if isinstance(im, str):
+                im = decode_image(im, to_rgb=False)
+            ori_shape = im.shape[:2]
+            sample = {'image': im}
+            im = transforms(sample)[0]
+            batch_im.append(im)
+            batch_tar_shape.append(self._get_target_shape(ori_shape))
+        if to_tensor:
+            batch_im = paddle.to_tensor(batch_im)
+        else:
+            batch_im = np.asarray(batch_im)
+
+        return batch_im, batch_tar_shape
+
+    def _get_target_shape(self, ori_shape):
+        if self.sr_factor is None:
+            return ori_shape
+        else:
+            return calc_hr_shape(ori_shape, self.sr_factor)
+
+    @staticmethod
+    def get_transforms_shape_info(batch_tar_shape, transforms):
+        batch_restore_list = list()
+        for tar_shape in batch_tar_shape:
+            restore_list = list()
+            h, w = tar_shape[0], tar_shape[1]
+            for op in transforms:
+                if op.__class__.__name__ == 'Resize':
+                    restore_list.append(('resize', (h, w)))
+                    h, w = op.target_size
+                elif op.__class__.__name__ == 'ResizeByShort':
+                    restore_list.append(('resize', (h, w)))
+                    im_short_size = min(h, w)
+                    im_long_size = max(h, w)
+                    scale = float(op.short_size) / float(im_short_size)
+                    if 0 < op.max_size < np.round(scale * im_long_size):
+                        scale = float(op.max_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'ResizeByLong':
+                    restore_list.append(('resize', (h, w)))
+                    im_long_size = max(h, w)
+                    scale = float(op.long_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'Pad':
+                    if op.target_size:
+                        target_h, target_w = op.target_size
+                    else:
+                        target_h = int(
+                            (np.ceil(h / op.size_divisor) * op.size_divisor))
+                        target_w = int(
+                            (np.ceil(w / op.size_divisor) * op.size_divisor))
+
+                    if op.pad_mode == -1:
+                        offsets = op.offsets
+                    elif op.pad_mode == 0:
+                        offsets = [0, 0]
+                    elif op.pad_mode == 1:
+                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
+                    else:
+                        offsets = [target_h - h, target_w - w]
+                    restore_list.append(('padding', (h, w), offsets))
+                    h, w = target_h, target_w
+
+            batch_restore_list.append(restore_list)
+        return batch_restore_list
+
+    def _postprocess(self, batch_pred, batch_tar_shape, transforms):
+        batch_restore_list = BaseRestorer.get_transforms_shape_info(
+            batch_tar_shape, transforms)
+        if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
+            return self._infer_postprocess(
+                batch_res_map=batch_pred[0],
+                batch_restore_list=batch_restore_list)
+        results = []
+        if batch_pred.dtype == paddle.float32:
+            mode = 'bilinear'
+        else:
+            mode = 'nearest'
+        for pred, restore_list in zip(batch_pred, batch_restore_list):
+            pred = paddle.unsqueeze(pred, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    pred = F.interpolate(
+                        pred, (h, w), mode=mode, data_format='NCHW')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    pred = pred[:, :, y:y + h, x:x + w]
+                else:
+                    pass
+            results.append(pred)
+        return results
+
+    def _infer_postprocess(self, batch_res_map, batch_restore_list):
+        res_maps = []
+        for score_map, restore_list in zip(batch_res_map, batch_restore_list):
+            if not isinstance(res_map, np.ndarray):
+                res_map = paddle.unsqueeze(res_map, axis=0)
+            for item in restore_list[::-1]:
+                h, w = item[1][0], item[1][1]
+                if item[0] == 'resize':
+                    if isinstance(res_map, np.ndarray):
+                        res_map = cv2.resize(
+                            res_map, (w, h), interpolation=cv2.INTER_LINEAR)
+                    else:
+                        res_map = F.interpolate(
+                            score_map, (h, w),
+                            mode='bilinear',
+                            data_format='NHWC')
+                elif item[0] == 'padding':
+                    x, y = item[2]
+                    if isinstance(res_map, np.ndarray):
+                        res_map = res_map[..., y:y + h, x:x + w]
+                    else:
+                        res_map = res_map[:, :, y:y + h, x:x + w]
+                else:
+                    pass
+            res_map = res_map.squeeze()
+            if not isinstance(res_map, np.ndarray):
+                res_map = res_map.numpy()
+            res_map = self._normalize(res_map)
+            res_maps.append(res_map.squeeze())
+        return res_maps
+
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeRestorer):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeRestorer object.")
+
+    def set_losses(self, losses):
+        self.losses = losses
+
+    def _tensor_to_images(self, tensor, squeeze=True, quantize=True):
+        tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1])  # NHWC
+        if squeeze:
+            tensor = tensor.squeeze()
+        images = tensor.numpy().astype('float32')
+        images = np.clip(images, self.MIN_MAX[0], self.MIN_MAX[1])
+        images = self._normalize(images, copy=True, quantize=quantize)
+        return images
+
+    def _normalize(self, im, copy=False, quantize=True):
+        if copy:
+            im = im.copy()
+        im -= im.min()
+        im /= im.max() + 1e-32
+        if quantize:
+            im *= 255
+            im = im.astype('uint8')
+        return im
+
+
+class RCAN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 n_resgroups=10,
+                 n_resblocks=20,
+                 n_feats=64,
+                 n_colors=3,
+                 rgb_range=255,
+                 kernel_size=3,
+                 reduction=16,
+                 **params):
+        params.update({
+            'factor': sr_factor,
+            'n_resgroups': n_resgroups,
+            'n_resblocks': n_resblocks,
+            'n_feats': n_feats,
+            'n_colors': n_colors,
+            'rgb_range': rgb_range,
+            'kernel_size': kernel_size,
+            'reduction': reduction
+        })
+        super(RCAN, self).__init__(
+            model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)
+
+
+class DRN(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 scale=(2, 4),
+                 n_blocks=30,
+                 n_feats=16,
+                 n_colors=3,
+                 rgb_range=255,
+                 negval=0.2,
+                 **params):
+        if sr_factor != max(scale):
+            raise ValueError(f"`sr_factor` must be equal to `max(scale)`.")
+        params.update({
+            'scale': scale,
+            'n_blocks': n_blocks,
+            'n_feats': n_feats,
+            'n_colors': n_colors,
+            'rgb_range': rgb_range,
+            'negval': negval
+        })
+        super(DRN, self).__init__(
+            model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        net = ppgan.models.generators.DRNGenerator(**params)
+        return net
+
+
+class LESRCNN(BaseRestorer):
+    def __init__(self, losses=None, sr_factor=4, multi_scale=False, group=1):
+        params.update({'scale': sr_factor, 'multi_scale': False, 'group': 1})
+        super(LESRCNN, self).__init__(
+            model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        net = ppgan.models.generators.LESRCNNGenerator(**params)
+        return net
+
+
+class ESRGAN(BaseRestorer):
+    MIN_MAX = (0., 1.)
+
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 use_gan=True,
+                 in_channels=3,
+                 out_channels=3,
+                 nf=64,
+                 nb=23):
+        params.update({
+            'scale': sr_factor,
+            'in_nc': in_channels,
+            'out_nc': out_channels,
+            'nf': nf,
+            'nb': nb
+        })
+        self.use_gan = use_gan
+        super(ESRGAN, self).__init__(
+            model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
+
+    def build_net(self, **params):
+        generator = ppgan.models.generators.RRDBNet(**params)
+        if self.use_gan:
+            discriminator = ppgan.models.discriminators.VGGDiscrinimator128(
+                in_channels=params['out_nc'], num_feat=64)
+            net = GANAdapter(
+                generators=[generator], discriminators=[discriminator])
+        else:
+            net = generator
+        return net
+
+    def default_loss(self):
+        if self.use_gan:
+            self.losses = {
+                'pixel': res_losses.L1Loss(loss_weight=0.01),
+                'perceptual':
+                res_losses.PerceptualLoss(layer_weights={'34': 1.0}),
+                'gan': res_losses.GANLoss(
+                    gan_mode='vanilla', loss_weight=0.005)
+            }
+        else:
+            return res_losses.L1Loss()
+
+    def default_optimizer(self, parameters, *args, **kwargs):
+        if self.use_gan:
+            optim_g = super(ESRGAN, self).default_optimizer(
+                parameters['optims_g'][0], *args, **kwargs)
+            optim_d = super(ESRGAN, self).default_optimizer(
+                parameters['optims_d'][0], *args, **kwargs)
+            return OptimizerAdapter(optim_g, optim_d)
+        else:
+            return super(ESRGAN, self).default_optimizer(params, *args,
+                                                         **kwargs)
+
+    def run_gan(self, net, inputs, mode, gan_mode='forward_g'):
+        if mode != 'train':
+            raise ValueError("`mode` is not 'train'.")
+        outputs = OrderedDict()
+        if gan_mode == 'forward_g':
+            loss_g = 0
+            g_pred = net.generator(inputs[0])
+            loss_pix = self.losses['pixel'](g_pred, tar)
+            loss_perc, loss_sty = self.losses['perceptual'](g_pred, tar)
+            loss_g += loss_pix
+            if loss_perc is not None:
+                loss_g += loss_perc
+            if loss_sty is not None:
+                loss_g += loss_sty
+            self._set_requires_grad(net.discriminator, False)
+            real_d_pred = net.discriminator(inputs[1]).detach()
+            fake_g_pred = net.discriminator(g_pred)
+            loss_g_real = self.losses['gan'](
+                real_d_pred - paddle.mean(fake_g_pred), False,
+                is_disc=False) * 0.5
+            loss_g_fake = self.losses['gan'](
+                fake_g_pred - paddle.mean(real_d_pred), True,
+                is_disc=False) * 0.5
+            loss_g_gan = loss_g_real + loss_g_fake
+            outputs['g_pred'] = g_pred.detach()
+            outputs['loss_g_pps'] = loss_g
+            outputs['loss_g_gan'] = loss_g_gan
+        elif gan_mode == 'forward_d':
+            self._set_requires_grad(net.discriminator, True)
+            # Real
+            fake_d_pred = net.discriminator(data[0]).detach()
+            real_d_pred = net.discriminator(data[1])
+            loss_d_real = self.losses['gan'](
+                real_d_pred - paddle.mean(fake_d_pred), True,
+                is_disc=True) * 0.5
+            # Fake
+            fake_d_pred = self.nets['discriminator'](self.output.detach())
+            loss_d_fake = self.gan_criterion(
+                fake_d_pred - paddle.mean(real_d_pred.detach()),
+                False,
+                is_disc=True) * 0.5
+            outputs['loss_d'] = loss_d_real + loss_d_fake
+        else:
+            raise ValueError("Invalid `gan_mode`!")
+        return outputs
+
+    def train_step(self, step, data, net):
+        if self.use_gan:
+            optim_g, optim_d = self.optimizer
+
+            outputs = self.run_gan(net, data, gan_mode='forward_g')
+            optim_g.clear_grad()
+            (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
+            optim_g.step()
+
+            outputs.update(
+                self.run_gan(
+                    net, (outputs['g_pred'], data[1]), gan_mode='forward_d'))
+            optim_d.clear_grad()
+            outputs['loss_d'].backward()
+            optim_d.step()
+
+            outputs['loss'] = outupts['loss_g_pps'] + outputs[
+                'loss_g_gan'] + outputs['loss_d']
+
+            if isinstance(optim_g._learning_rate,
+                          paddle.optimizer.lr.LRScheduler):
+                # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
+                if isinstance(optim_g._learning_rate,
+                              paddle.optimizer.lr.ReduceOnPlateau):
+                    optim_g._learning_rate.step(loss.item())
+                else:
+                    optim_g._learning_rate.step()
+
+            if isinstance(optim_d._learning_rate,
+                          paddle.optimizer.lr.LRScheduler):
+                if isinstance(optim_d._learning_rate,
+                              paddle.optimizer.lr.ReduceOnPlateau):
+                    optim_d._learning_rate.step(loss.item())
+                else:
+                    optim_d._learning_rate.step()
+
+            return outputs
+        else:
+            super(ESRGAN, self).train_step(step, data, net)
+
+    def _set_requires_grad(self, net, requires_grad):
+        for p in net.parameters():
+            p.trainable = requires_grad

+ 13 - 8
paddlers/tasks/segmenter.py

@@ -33,6 +33,7 @@ from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
+from .utils.infer_nets import InferNet
 
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
 
@@ -64,11 +65,16 @@ class BaseSegmenter(BaseModel):
 
 
     def build_net(self, **params):
     def build_net(self, **params):
         # TODO: when using paddle.utils.unique_name.guard,
         # TODO: when using paddle.utils.unique_name.guard,
-        # DeepLabv3p and HRNet will raise a error
+        # DeepLabv3p and HRNet will raise an error.
         net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
         net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
             num_classes=self.num_classes, **params)
             num_classes=self.num_classes, **params)
         return net
         return net
 
 
+    def _build_inference_net(self):
+        infer_net = InferNet(self.net, self.model_type)
+        infer_net.eval()
+        return infer_net
+
     def _fix_transforms_shape(self, image_shape):
     def _fix_transforms_shape(self, image_shape):
         if hasattr(self, 'test_transforms'):
         if hasattr(self, 'test_transforms'):
             if self.test_transforms is not None:
             if self.test_transforms is not None:
@@ -472,7 +478,6 @@ class BaseSegmenter(BaseModel):
                     conf_mat_all.append(conf_mat)
                     conf_mat_all.append(conf_mat)
         class_iou, miou = ppseg.utils.metrics.mean_iou(
         class_iou, miou = ppseg.utils.metrics.mean_iou(
             intersect_area_all, pred_area_all, label_area_all)
             intersect_area_all, pred_area_all, label_area_all)
-        # TODO 确认是按oacc还是macc
         class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
         class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
                                                        pred_area_all)
                                                        pred_area_all)
         kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
         kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
@@ -504,13 +509,13 @@ class BaseSegmenter(BaseModel):
                 Defaults to None.
                 Defaults to None.
 
 
         Returns:
         Returns:
-            If `img_file` is a string or np.array, the result is a dict with key-value 
-                pairs:
-                {"label map": `label map`, "score_map": `score map`}.
+            If `img_file` is a tuple of string or np.array, the result is a dict with 
+                the following key-value pairs:
+                label_map (np.ndarray): Predicted label map (HW).
+                score_map (np.ndarray): Prediction score map (HWC).
+
             If `img_file` is a list, the result is a list composed of dicts with the 
             If `img_file` is a list, the result is a list composed of dicts with the 
-                corresponding fields:
-                label_map (np.ndarray): the predicted label map (HW)
-                score_map (np.ndarray): the prediction score map (HWC)
+                above keys.
         """
         """
 
 
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):

+ 128 - 0
paddlers/tasks/utils/res_adapters.py

@@ -0,0 +1,128 @@
+from functools import wraps
+from inspect import isfunction, isgeneratorfunction, getmembers
+from collections.abc import Sequence
+from abc import ABC
+
+import paddle
+import paddle.nn as nn
+
+__all__ = ['GANAdapter', 'OptimizerAdapter']
+
+
+class _AttrDesc:
+    def __init__(self, key):
+        self.key = key
+
+    def __get__(self, instance, owner):
+        return tuple(getattr(ele, self.key) for ele in instance)
+
+    def __set__(self, instance, value):
+        for ele in instance:
+            setattr(ele, self.key, value)
+
+
+def _func_deco(cls, func_name):
+    @wraps(getattr(cls.__ducktype__, func_name))
+    def _wrapper(self, *args, **kwargs):
+        return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
+
+    return _wrapper
+
+
+def _generator_deco(cls, func_name):
+    @wraps(getattr(cls.__ducktype__, func_name))
+    def _wrapper(self, *args, **kwargs):
+        for ele in self:
+            yield from getattr(ele, func_name)(*args, **kwargs)
+
+    return _wrapper
+
+
+class Adapter(Sequence, ABC):
+    __ducktype__ = object
+    __ava__ = ()
+
+    def __init__(self, *args):
+        if not all(map(self._check, args)):
+            raise TypeError("Please check the input type.")
+        self._seq = tuple(args)
+
+    def __getitem__(self, key):
+        return self._seq[key]
+
+    def __len__(self):
+        return len(self._seq)
+
+    def __repr__(self):
+        return repr(self._seq)
+
+    @classmethod
+    def _check(cls, obj):
+        for attr in cls.__ava__:
+            try:
+                getattr(obj, attr)
+                # TODO: Check function signature
+            except AttributeError:
+                return False
+        return True
+
+
+def make_adapter(cls):
+    members = dict(getmembers(cls.__ducktype__))
+    for k in cls.__ava__:
+        if hasattr(cls, k):
+            continue
+        if k in members:
+            v = members[k]
+            if isgeneratorfunction(v):
+                setattr(cls, k, _generator_deco(cls, k))
+            elif isfunction(v):
+                setattr(cls, k, _func_deco(cls, k))
+            else:
+                setattr(cls, k, _AttrDesc(k))
+    return cls
+
+
+class GANAdapter(nn.Layer):
+    __ducktype__ = nn.Layer
+    __ava__ = ('state_dict', 'set_state_dict', 'train', 'eval')
+
+    def __init__(self, generators, discriminators):
+        super(GANAdapter, self).__init__()
+        self.generators = nn.LayerList(generators)
+        self.discriminators = nn.LayerList(discriminators)
+        self._m = [*generators, *discriminators]
+
+    def __len__(self):
+        return len(self._m)
+
+    def __getitem__(self, key):
+        return self._m[key]
+
+    def __contains__(self, m):
+        return m in self._m
+
+    def __repr__(self):
+        return repr(self._m)
+
+    @property
+    def generator(self):
+        return self.generators[0]
+
+    @property
+    def discriminator(self):
+        return self.discriminators[0]
+
+
+Adapter.register(GANAdapter)
+
+
+@make_adapter
+class OptimizerAdapter(Adapter):
+    __ducktype__ = paddle.optimizer.Optimizer
+    __ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr')
+
+    # Special dispatching rule
+    def set_state_dict(self, state_dicts):
+        for optim, state_dict in zip(self, state_dicts):
+            optim.set_state_dict(state_dict)

+ 4 - 0
paddlers/transforms/functions.py

@@ -638,3 +638,7 @@ def decode_seg_mask(mask_path):
     mask = np.asarray(Image.open(mask_path))
     mask = np.asarray(Image.open(mask_path))
     mask = mask.astype('int64')
     mask = mask.astype('int64')
     return mask
     return mask
+
+
+def calc_hr_shape(lr_shape, sr_factor):
+    return tuple(int(s * sr_factor) for s in lr_shape)

+ 99 - 13
paddlers/transforms/operators.py

@@ -35,7 +35,7 @@ from .functions import (
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
     vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
     vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
     resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
     resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
-    img_flip, img_simple_rotate, decode_seg_mask)
+    img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape)
 
 
 __all__ = [
 __all__ = [
     "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
     "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
@@ -44,7 +44,7 @@ __all__ = [
     "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
     "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
     "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
     "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
     "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
     "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
-    "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask"
+    "ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask"
 ]
 ]
 
 
 interp_dict = {
 interp_dict = {
@@ -154,6 +154,8 @@ class Transform(object):
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
+        if 'target' in sample:
+            sample['target'] = self.apply_im(sample['target'])
 
 
         return sample
         return sample
 
 
@@ -336,6 +338,14 @@ class DecodeImg(Transform):
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
             # TODO: check the shape of auxiliary masks
             # TODO: check the shape of auxiliary masks
 
 
+        if 'target' in sample:
+            if self.read_geo_info:
+                target, geo_info_dict = self.apply_im(sample['target'])
+                sample['target'] = target
+                sample['geo_info_dict_tar'] = geo_info_dict
+            else:
+                sample['target'] = self.apply_im(sample['target'])
+
         sample['im_shape'] = np.array(
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
@@ -457,6 +467,17 @@ class Resize(Transform):
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
                 sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
+        if 'target' in sample:
+            if 'sr_factor' in sample:
+                # For SR tasks
+                sample['target'] = self.apply_im(
+                    sample['target'], interp,
+                    calc_hr_shape(target_size, sample['sr_factor']))
+            else:
+                # For non-SR tasks
+                sample['target'] = self.apply_im(sample['target'], interp,
+                                                 target_size)
+
         sample['im_shape'] = np.asarray(
         sample['im_shape'] = np.asarray(
             sample['image'].shape[:2], dtype=np.float32)
             sample['image'].shape[:2], dtype=np.float32)
         if 'scale_factor' in sample:
         if 'scale_factor' in sample:
@@ -730,6 +751,9 @@ class RandomFlipOrRotate(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                                                     True)
                                                     True)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'], mode_id,
+                                                 True)
         elif p_m < self.probs[1]:
         elif p_m < self.probs[1]:
             mode_p = random.random()
             mode_p = random.random()
             mode_id = self.judge_probs_range(mode_p, self.probsr)
             mode_id = self.judge_probs_range(mode_p, self.probsr)
@@ -750,6 +774,9 @@ class RandomFlipOrRotate(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
                                                     False)
                                                     False)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'], mode_id,
+                                                 False)
 
 
         return sample
         return sample
 
 
@@ -809,6 +836,8 @@ class RandomHorizontalFlip(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                                                     im_w)
                                                     im_w)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -867,6 +896,8 @@ class RandomVerticalFlip(Transform):
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                 sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
                                                     im_w)
                                                     im_w)
+            if 'target' in sample:
+                sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -886,13 +917,16 @@ class Normalize(Transform):
             image(s). Defaults to [0, 0, 0, ].
             image(s). Defaults to [0, 0, 0, ].
         max_val (list[float] | tuple[float], optional): Max value of input image(s). 
         max_val (list[float] | tuple[float], optional): Max value of input image(s). 
             Defaults to [255., 255., 255.].
             Defaults to [255., 255., 255.].
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
     """
 
 
     def __init__(self,
     def __init__(self,
                  mean=[0.485, 0.456, 0.406],
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225],
                  std=[0.229, 0.224, 0.225],
                  min_val=None,
                  min_val=None,
-                 max_val=None):
+                 max_val=None,
+                 apply_to_tar=True):
         super(Normalize, self).__init__()
         super(Normalize, self).__init__()
         channel = len(mean)
         channel = len(mean)
         if min_val is None:
         if min_val is None:
@@ -914,6 +948,7 @@ class Normalize(Transform):
         self.std = std
         self.std = std
         self.min_val = min_val
         self.min_val = min_val
         self.max_val = max_val
         self.max_val = max_val
+        self.apply_to_tar = apply_to_tar
 
 
     def apply_im(self, image):
     def apply_im(self, image):
         image = image.astype(np.float32)
         image = image.astype(np.float32)
@@ -927,6 +962,8 @@ class Normalize(Transform):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
 
 
         return sample
         return sample
 
 
@@ -964,6 +1001,8 @@ class CenterCrop(Transform):
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
+        if 'target' in sample:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -1165,6 +1204,14 @@ class RandomCrop(Transform):
                         self.apply_mask, crop=crop_box),
                         self.apply_mask, crop=crop_box),
                         sample['aux_masks']))
                         sample['aux_masks']))
 
 
+            if 'target' in sample:
+                if 'sr_factor' in sample:
+                    sample['target'] = self.apply_im(
+                        sample['image'],
+                        calc_hr_shape(crop_box, sample['sr_factor']))
+                else:
+                    sample['target'] = self.apply_im(sample['image'], crop_box)
+
         if self.crop_size is not None:
         if self.crop_size is not None:
             sample = Resize(self.crop_size)(sample)
             sample = Resize(self.crop_size)(sample)
 
 
@@ -1266,6 +1313,7 @@ class Pad(Transform):
             pad_mode (int, optional): Pad mode. Currently only four modes are supported:
             pad_mode (int, optional): Pad mode. Currently only four modes are supported:
                 [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
                 [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
                 If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
                 If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
+            offsets (list[int]|None, optional): Padding offsets. Defaults to None.
             im_padding_value (list[float] | tuple[float]): RGB value of padded area. 
             im_padding_value (list[float] | tuple[float]): RGB value of padded area. 
                 Defaults to (127.5, 127.5, 127.5).
                 Defaults to (127.5, 127.5, 127.5).
             label_padding_value (int, optional): Filling value for the mask. 
             label_padding_value (int, optional): Filling value for the mask. 
@@ -1332,6 +1380,17 @@ class Pad(Transform):
                     expand_rle(segm, x, y, height, width, h, w))
                     expand_rle(segm, x, y, height, width, h, w))
         return expanded_segms
         return expanded_segms
 
 
+    def _get_offsets(self, im_h, im_w, h, w):
+        if self.pad_mode == -1:
+            offsets = self.offsets
+        elif self.pad_mode == 0:
+            offsets = [0, 0]
+        elif self.pad_mode == 1:
+            offsets = [(w - im_w) // 2, (h - im_h) // 2]
+        else:
+            offsets = [w - im_w, h - im_h]
+        return offsets
+
     def apply(self, sample):
     def apply(self, sample):
         im_h, im_w = sample['image'].shape[:2]
         im_h, im_w = sample['image'].shape[:2]
         if self.target_size:
         if self.target_size:
@@ -1349,14 +1408,7 @@ class Pad(Transform):
         if h == im_h and w == im_w:
         if h == im_h and w == im_w:
             return sample
             return sample
 
 
-        if self.pad_mode == -1:
-            offsets = self.offsets
-        elif self.pad_mode == 0:
-            offsets = [0, 0]
-        elif self.pad_mode == 1:
-            offsets = [(w - im_w) // 2, (h - im_h) // 2]
-        else:
-            offsets = [w - im_w, h - im_h]
+        offsets = self._get_offsets(im_h, im_w, h, w)
 
 
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         if 'image2' in sample:
         if 'image2' in sample:
@@ -1373,6 +1425,16 @@ class Pad(Transform):
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             sample['gt_poly'] = self.apply_segm(
             sample['gt_poly'] = self.apply_segm(
                 sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
                 sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
+        if 'target' in sample:
+            if 'sr_factor' in sample:
+                hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
+                hr_offsets = self._get_offsets(*sample['target'].shape[:2],
+                                               *hr_shape)
+                sample['target'] = self.apply_im(sample['target'], hr_offsets,
+                                                 hr_shape)
+            else:
+                sample['target'] = self.apply_im(sample['target'], offsets,
+                                                 (h, w))
         return sample
         return sample
 
 
 
 
@@ -1688,15 +1750,18 @@ class ReduceDim(Transform):
 
 
     Args: 
     Args: 
         joblib_path (str): Path of *.joblib file of PCA.
         joblib_path (str): Path of *.joblib file of PCA.
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
     """
 
 
-    def __init__(self, joblib_path):
+    def __init__(self, joblib_path, apply_to_tar=True):
         super(ReduceDim, self).__init__()
         super(ReduceDim, self).__init__()
         ext = joblib_path.split(".")[-1]
         ext = joblib_path.split(".")[-1]
         if ext != "joblib":
         if ext != "joblib":
             raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
             raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
                 ext))
                 ext))
         self.pca = load(joblib_path)
         self.pca = load(joblib_path)
+        self.apply_to_tar = apply_to_tar
 
 
     def apply_im(self, image):
     def apply_im(self, image):
         H, W, C = image.shape
         H, W, C = image.shape
@@ -1709,6 +1774,8 @@ class ReduceDim(Transform):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -1719,11 +1786,14 @@ class SelectBand(Transform):
     Args: 
     Args: 
         band_list (list, optional): Bands to select (band index starts from 1). 
         band_list (list, optional): Bands to select (band index starts from 1). 
             Defaults to [1, 2, 3].
             Defaults to [1, 2, 3].
+        apply_to_tar (bool, optional): Whether to apply transformation to the target
+            image. Defaults to True.
     """
     """
 
 
-    def __init__(self, band_list=[1, 2, 3]):
+    def __init__(self, band_list=[1, 2, 3], apply_to_tar=True):
         super(SelectBand, self).__init__()
         super(SelectBand, self).__init__()
         self.band_list = band_list
         self.band_list = band_list
+        self.appy_to_tar = apply_to_tar
 
 
     def apply_im(self, image):
     def apply_im(self, image):
         image = select_bands(image, self.band_list)
         image = select_bands(image, self.band_list)
@@ -1733,6 +1803,8 @@ class SelectBand(Transform):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
+        if 'target' in sample and self.apply_to_tar:
+            sample['target'] = self.apply_im(sample['target'])
         return sample
         return sample
 
 
 
 
@@ -1820,6 +1892,8 @@ class _Permute(Transform):
         sample['image'] = permute(sample['image'], False)
         sample['image'] = permute(sample['image'], False)
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = permute(sample['image2'], False)
             sample['image2'] = permute(sample['image2'], False)
+        if 'target' in sample:
+            sample['target'] = permute(sample['target'], False)
         return sample
         return sample
 
 
 
 
@@ -1915,3 +1989,15 @@ class ArrangeDetector(Arrange):
         if self.mode == 'eval' and 'gt_poly' in sample:
         if self.mode == 'eval' and 'gt_poly' in sample:
             del sample['gt_poly']
             del sample['gt_poly']
         return sample
         return sample
+
+
+class ArrangeRestorer(Arrange):
+    def apply(self, sample):
+        image = permute(sample['image'], False)
+        target = permute(sample['target'], False)
+        if self.mode == 'train':
+            return image, target
+        if self.mode == 'eval':
+            return image, target
+        if self.mode == 'test':
+            return image,

+ 4 - 3
tutorials/train/README.md

@@ -17,9 +17,10 @@
 |classification/hrnet.py | 场景分类 | HRNet |
 |classification/hrnet.py | 场景分类 | HRNet |
 |classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
 |classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
 |classification/resnet50_vd.py | 场景分类 | ResNet50-vd |
 |classification/resnet50_vd.py | 场景分类 | ResNet50-vd |
-|image_restoration/drn.py | 超分辨率 | DRN |
-|image_restoration/esrgan.py | 超分辨率 | ESRGAN |
-|image_restoration/lesrcnn.py | 超分辨率 | LESRCNN |
+|image_restoration/drn.py | 图像复原 | DRN |
+|image_restoration/esrgan.py | 图像复原 | ESRGAN |
+|image_restoration/lesrcnn.py | 图像复原 | LESRCNN |
+|image_restoration/rcan.py | 图像复原 | RCAN |
 |object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN |
 |object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN |
 |object_detection/ppyolo.py | 目标检测 | PP-YOLO |
 |object_detection/ppyolo.py | 目标检测 | PP-YOLO |
 |object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny |
 |object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny |

+ 1 - 1
tutorials/train/classification/hrnet.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     num_workers=0,
     shuffle=False)
     shuffle=False)
 
 
-# 使用默认参数构建HRNet模型
+# 构建HRNet模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.HRNet_W18_C(num_classes=len(train_dataset.labels))
 model = pdrs.tasks.clas.HRNet_W18_C(num_classes=len(train_dataset.labels))

+ 1 - 1
tutorials/train/classification/mobilenetv3.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     num_workers=0,
     shuffle=False)
     shuffle=False)
 
 
-# 使用默认参数构建MobileNetV3模型
+# 构建MobileNetV3模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.MobileNetV3_small_x1_0(
 model = pdrs.tasks.clas.MobileNetV3_small_x1_0(

+ 1 - 1
tutorials/train/classification/resnet50_vd.py

@@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset(
     num_workers=0,
     num_workers=0,
     shuffle=False)
     shuffle=False)
 
 
-# 使用默认参数构建ResNet50-vd模型
+# 构建ResNet50-vd模型
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py
 model = pdrs.tasks.clas.ResNet50_vd(num_classes=len(train_dataset.labels))
 model = pdrs.tasks.clas.ResNet50_vd(num_classes=len(train_dataset.labels))

+ 3 - 0
tutorials/train/image_restoration/data/.gitignore

@@ -0,0 +1,3 @@
+*.zip
+*.tar.gz
+rssr/

+ 86 - 0
tutorials/train/image_restoration/drn.py

@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# 图像复原模型DRN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/drn/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 使用默认参数构建DRN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.DRN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=1,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.01,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 80
tutorials/train/image_restoration/drn_train.py

@@ -1,80 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddle
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'lqx2', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 192,
-        'scale': 4,
-        'scale_list': True
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [1.0, 1.0, 1.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [1.0, 1.0, 1.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 4
-batch_size = 8
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-train_dict = train_dataset()
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-model = pdrs.tasks.res.DRNet(
-    n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2)
-
-model.train(
-    total_iters=100000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    lr_rate=0.0001,
-    log=10)

+ 86 - 0
tutorials/train/image_restoration/esrgan.py

@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# 图像复原模型ESRGAN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/esrgan/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 使用默认参数构建ESRGAN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.ESRGAN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=1,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.01,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 80
tutorials/train/image_restoration/esrgan_train.py

@@ -1,80 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 128,
-        'scale': 4
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 6
-batch_size = 32
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
-# 若loss_type = 'pixel' 只使用像素损失
-model = pdrs.tasks.res.ESRGANet(loss_type='pixel')
-
-model.train(
-    total_iters=1000000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    log=100,
-    lr_rate=0.0001,
-    periods=[250000, 250000, 250000, 250000],
-    restart_weights=[1, 1, 1, 1])

+ 86 - 0
tutorials/train/image_restoration/lesrcnn.py

@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# 图像复原模型LESRCNN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/lesrcnn/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 使用默认参数构建LESRCNN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.LESRCNN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=1,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.01,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 0 - 78
tutorials/train/image_restoration/lesrcnn_train.py

@@ -1,78 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath('../PaddleRS'))
-
-import paddlers as pdrs
-
-# 定义训练和验证时的transforms
-train_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'SRPairedRandomCrop',
-        'gt_patch_size': 192,
-        'scale': 4
-    }, {
-        'name': 'PairedRandomHorizontalFlip'
-    }, {
-        'name': 'PairedRandomVerticalFlip'
-    }, {
-        'name': 'PairedRandomTransposeHW'
-    }, {
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-test_transforms = pdrs.datasets.ComposeTrans(
-    input_keys=['lq', 'gt'],
-    output_keys=['lq', 'gt'],
-    pipelines=[{
-        'name': 'Transpose'
-    }, {
-        'name': 'Normalize',
-        'mean': [0.0, 0.0, 0.0],
-        'std': [255.0, 255.0, 255.0]
-    }])
-
-# 定义训练集
-train_gt_floder = r"../work/RSdata_for_SR/trian_HR"  # 高分辨率影像所在路径
-train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4"  # 低分辨率影像所在路径
-num_workers = 4
-batch_size = 16
-scale = 4
-train_dataset = pdrs.datasets.SRdataset(
-    mode='train',
-    gt_floder=train_gt_floder,
-    lq_floder=train_lq_floder,
-    transforms=train_transforms(),
-    scale=scale,
-    num_workers=num_workers,
-    batch_size=batch_size)
-
-# 定义测试集
-test_gt_floder = r"../work/RSdata_for_SR/test_HR"
-test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
-test_dataset = pdrs.datasets.SRdataset(
-    mode='test',
-    gt_floder=test_gt_floder,
-    lq_floder=test_lq_floder,
-    transforms=test_transforms(),
-    scale=scale)
-
-# 初始化模型,可以对网络结构的参数进行调整
-model = pdrs.tasks.res.LESRCNNet(scale=4, multi_scale=False, group=1)
-
-model.train(
-    total_iters=1000000,
-    train_dataset=train_dataset(),
-    test_dataset=test_dataset(),
-    output_dir='output_dir',
-    validate=5000,
-    snapshot=5000,
-    log=100,
-    lr_rate=0.0001,
-    periods=[250000, 250000, 250000, 250000],
-    restart_weights=[1, 1, 1, 1])

+ 86 - 0
tutorials/train/image_restoration/rcan.py

@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# 图像复原模型RCAN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rssr/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/rcan/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 将输入影像缩放到256x256大小
+    T.Resize(target_size=256),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    T.Resize(target_size=256),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
+    T.ArrangeRestorer('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 使用默认参数构建RCAN模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+model = pdrs.tasks.res.RCAN()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=1,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.01,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)