Przeglądaj źródła

[Fix] Remediate Trainers, Datasets, and Transformation Operators (#104)

* Update paddleslim and scikit-learn version requirements

* Polish TIPC

* Ignore .path files

* Fix setup.py

* Update examples

* Fix trans_info bugs

* Update paddleslim version

* Store version separately

* Fix bugs
Lin Manhui 2 lat temu
rodzic
commit
ff928ecb80

+ 1 - 2
deploy/export/export_model.py

@@ -73,5 +73,4 @@ if __name__ == '__main__':
     model = load_model(args.model_dir)
 
     # Do dynamic-to-static cast
-    # XXX: Invoke a protected (single underscore) method outside of subclasses.
-    model._export_inference_model(args.save_dir, fixed_input_shape)
+    model.export_inference_model(args.save_dir, fixed_input_shape)

+ 1 - 0
examples/README.md

@@ -5,6 +5,7 @@ PaddleRS提供从科学研究到产业应用的丰富示例,希望帮助遥感
 ## 1 官方案例
 
 - [PaddleRS科研实战:设计深度学习变化检测模型](./rs_research/)
+- [基于PaddleRS的遥感图像小目标语义分割优化方法](./c2fnet/)
 
 ## 2 社区贡献案例
 

+ 1 - 0
paddlers/.version

@@ -0,0 +1 @@
+0.0.0.dev0

+ 4 - 1
paddlers/__init__.py

@@ -12,10 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-__version__ = '0.0.0.dev0'
+import os
 
 from paddlers.utils.env import get_environ_info, init_parallel_env
 from . import tasks, datasets, transforms, utils, tools, models, deploy
 
 init_parallel_env()
 env_info = get_environ_info()
+
+with open(os.path.join(os.path.dirname(__file__), ".version"), 'r') as fv:
+    __version__ = fv.read().rstrip()

+ 16 - 3
paddlers/datasets/base.py

@@ -15,11 +15,15 @@
 from copy import deepcopy
 
 from paddle.io import Dataset
+from paddle.fluid.dataloader.collate import default_collate_fn
 
 from paddlers.utils import get_num_workers
+from paddlers.transforms import construct_sample_from_dict
 
 
 class BaseDataset(Dataset):
+    _collate_trans_info = False
+
     def __init__(self, data_dir, label_list, transforms, num_workers, shuffle):
         super(BaseDataset, self).__init__()
 
@@ -30,6 +34,15 @@ class BaseDataset(Dataset):
         self.shuffle = shuffle
 
     def __getitem__(self, idx):
-        sample = deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
-        return outputs
+        sample = construct_sample_from_dict(self.file_list[idx])
+        # `trans_info` will be used to store meta info about image shape
+        sample['trans_info'] = []
+        outputs, trans_info = self.transforms(sample)
+        return outputs, trans_info
+
+    def collate_fn(self, batch):
+        if self._collate_trans_info:
+            return default_collate_fn(
+                [s[0] for s in batch]), [s[1] for s in batch]
+        else:
+            return default_collate_fn([s[0] for s in batch])

+ 6 - 5
paddlers/datasets/cd_dataset.py

@@ -12,12 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
 from enum import IntEnum
 import os.path as osp
 
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
+from paddlers.transforms import construct_sample_from_dict
 
 
 class CDDataset(BaseDataset):
@@ -44,6 +44,8 @@ class CDDataset(BaseDataset):
             Defaults to False.
     """
 
+    _collate_trans_info = True
+
     def __init__(self,
                  data_dir,
                  file_list,
@@ -58,8 +60,6 @@ class CDDataset(BaseDataset):
 
         DELIMETER = ' '
 
-        # TODO: batch padding
-        self.batch_transforms = None
         self.file_list = list()
         self.labels = list()
         self.with_seg_labels = with_seg_labels
@@ -130,7 +130,8 @@ class CDDataset(BaseDataset):
             len(self.file_list), file_list))
 
     def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
+        sample = construct_sample_from_dict(self.file_list[idx])
+        sample['trans_info'] = []
         sample = self.transforms.apply_transforms(sample)
 
         if self.binarize_labels:
@@ -142,7 +143,7 @@ class CDDataset(BaseDataset):
 
         outputs = self.transforms.arrange_outputs(sample)
 
-        return outputs
+        return outputs, sample['trans_info']
 
     def __len__(self):
         return len(self.file_list)

+ 0 - 2
paddlers/datasets/clas_dataset.py

@@ -43,8 +43,6 @@ class ClasDataset(BaseDataset):
                  shuffle=False):
         super(ClasDataset, self).__init__(data_dir, label_list, transforms,
                                           num_workers, shuffle)
-        # TODO batch padding
-        self.batch_transforms = None
         self.file_list = list()
         self.labels = list()
 

+ 6 - 6
paddlers/datasets/coco.py

@@ -23,7 +23,7 @@ import numpy as np
 
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
-from paddlers.transforms import DecodeImg, MixupImage
+from paddlers.transforms import DecodeImg, MixupImage, construct_sample_from_dict
 from paddlers.tools import YOLOAnchorCluster
 
 
@@ -78,7 +78,6 @@ class COCODetDataset(BaseDataset):
                     self.num_max_boxes *= 2
                     break
 
-        self.batch_transforms = None
         self.allow_empty = allow_empty
         self.empty_ratio = empty_ratio
         self.file_list = list()
@@ -243,7 +242,7 @@ class COCODetDataset(BaseDataset):
         self._epoch = 0
 
     def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
+        sample = construct_sample_from_dict(self.file_list[idx])
         if self.data_fields is not None:
             sample = {k: sample[k] for k in self.data_fields}
         if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
@@ -253,15 +252,16 @@ class COCODetDataset(BaseDataset):
                 mix_pos = (mix_idx + idx) % self.num_samples
             else:
                 mix_pos = 0
-            sample_mix = copy.deepcopy(self.file_list[mix_pos])
+            sample_mix = construct_sample_from_dict(self.file_list[mix_pos])
             if self.data_fields is not None:
                 sample_mix = {k: sample_mix[k] for k in self.data_fields}
             sample = self.mixup_op(sample=[
                 DecodeImg(to_rgb=False)(sample),
                 DecodeImg(to_rgb=False)(sample_mix)
             ])
-        sample = self.transforms(sample)
-        return sample
+        sample['trans_info'] = []
+        sample, trans_info = self.transforms(sample)
+        return sample, trans_info
 
     def __len__(self):
         return self.num_samples

+ 2 - 1
paddlers/datasets/res_dataset.py

@@ -36,6 +36,8 @@ class ResDataset(BaseDataset):
             restoration tasks. Defaults to None.
     """
 
+    _collate_trans_info = True
+
     def __init__(self,
                  data_dir,
                  file_list,
@@ -45,7 +47,6 @@ class ResDataset(BaseDataset):
                  sr_factor=None):
         super(ResDataset, self).__init__(data_dir, None, transforms,
                                          num_workers, shuffle)
-        self.batch_transforms = None
         self.file_list = list()
 
         with open(file_list, encoding=get_encoding(file_list)) as f:

+ 2 - 2
paddlers/datasets/seg_dataset.py

@@ -35,6 +35,8 @@ class SegDataset(BaseDataset):
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
     """
 
+    _collate_trans_info = True
+
     def __init__(self,
                  data_dir,
                  file_list,
@@ -44,8 +46,6 @@ class SegDataset(BaseDataset):
                  shuffle=False):
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
                                          num_workers, shuffle)
-        # TODO: batch padding
-        self.batch_transforms = None
         self.file_list = list()
         self.labels = list()
 

+ 16 - 33
paddlers/deploy/predictor.py

@@ -26,6 +26,8 @@ from paddlers.tasks import load_model
 from paddlers.utils import logging, Timer
 from paddlers.tasks.utils.slider_predict import slider_predict
 
+# TODO: Refactor
+
 
 class Predictor(object):
     def __init__(self,
@@ -148,44 +150,32 @@ class Predictor(object):
         return predictor
 
     def preprocess(self, images, transforms):
-        preprocessed_samples = self._model.preprocess(
+        preprocessed_samples, batch_trans_info = self._model.preprocess(
             images, transforms, to_tensor=False)
         if self.model_type == 'classifier':
-            preprocessed_samples = {'image': preprocessed_samples[0]}
+            preprocessed_samples = {'image': preprocessed_samples}
         elif self.model_type == 'segmenter':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'ori_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         elif self.model_type == 'detector':
             pass
         elif self.model_type == 'change_detector':
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
-                'image2': preprocessed_samples[1],
-                'ori_shape': preprocessed_samples[2]
+                'image2': preprocessed_samples[1]
             }
         elif self.model_type == 'restorer':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'tar_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         else:
             logging.error(
                 "Invalid model type {}".format(self.model_type), exit=True)
-        return preprocessed_samples
-
-    def postprocess(self,
-                    net_outputs,
-                    topk=1,
-                    ori_shape=None,
-                    tar_shape=None,
-                    transforms=None):
+        return preprocessed_samples, batch_trans_info
+
+    def postprocess(self, net_outputs, batch_restore_list, topk=1):
         if self.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
                 self._model.build_postprocess_from_labels(topk)
-            # XXX: Convert ndarray to tensor as self._model.postprocess requires
+            # XXX: Convert ndarray to tensor as `self._model.postprocess` requires
             assert len(net_outputs) == 1
             net_outputs = paddle.to_tensor(net_outputs[0])
             outputs = self._model.postprocess(net_outputs)
@@ -199,9 +189,7 @@ class Predictor(object):
             } for l, s, n in zip(class_ids, scores, label_names)]
         elif self.model_type in ('segmenter', 'change_detector'):
             label_map, score_map = self._model.postprocess(
-                net_outputs,
-                batch_origin_shape=ori_shape,
-                transforms=transforms.transforms)
+                net_outputs, batch_restore_list=batch_restore_list)
             preds = [{
                 'label_map': l,
                 'score_map': s
@@ -214,9 +202,7 @@ class Predictor(object):
             preds = self._model.postprocess(net_outputs)
         elif self.model_type == 'restorer':
             res_maps = self._model.postprocess(
-                net_outputs[0],
-                batch_tar_shape=tar_shape,
-                transforms=transforms.transforms)
+                net_outputs[0], batch_restore_list=batch_restore_list)
             preds = [{'res_map': res_map} for res_map in res_maps]
         else:
             logging.error(
@@ -248,7 +234,8 @@ class Predictor(object):
 
     def _run(self, images, topk=1, transforms=None):
         self.timer.preprocess_time_s.start()
-        preprocessed_input = self.preprocess(images, transforms)
+        preprocessed_input, batch_trans_info = self.preprocess(images,
+                                                               transforms)
         self.timer.preprocess_time_s.end(iter_num=len(images))
 
         self.timer.inference_time_s.start()
@@ -257,11 +244,7 @@ class Predictor(object):
 
         self.timer.postprocess_time_s.start()
         results = self.postprocess(
-            net_outputs,
-            topk,
-            ori_shape=preprocessed_input.get('ori_shape', None),
-            tar_shape=preprocessed_input.get('tar_shape', None),
-            transforms=transforms)
+            net_outputs, batch_restore_list=batch_trans_info, topk=topk)
         self.timer.postprocess_time_s.end(iter_num=len(images))
 
         return results

+ 23 - 13
paddlers/tasks/base.py

@@ -61,6 +61,9 @@ class ModelMeta(type):
 
 
 class BaseModel(metaclass=ModelMeta):
+
+    find_unused_parameters = False
+
     def __init__(self, model_type):
         self.model_type = model_type
         self.in_channels = None
@@ -98,6 +101,7 @@ class BaseModel(metaclass=ModelMeta):
                 if osp.exists(save_dir):
                     os.remove(save_dir)
                 os.makedirs(save_dir)
+            # XXX: Hard-coding
             if self.model_type == 'classifier':
                 pretrain_weights = get_pretrain_weights(
                     pretrain_weights, self.model_name, save_dir)
@@ -214,10 +218,7 @@ class BaseModel(metaclass=ModelMeta):
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
-        pruner_inputs = self.pruner.inputs
-        if self.model_type == 'detector':
-            pruner_inputs = {k: v.tolist() for k, v in pruner_inputs[0].items()}
-        info['pruner_inputs'] = pruner_inputs
+        info['pruner_inputs'] = self.pruner.inputs
 
         return info
 
@@ -266,7 +267,11 @@ class BaseModel(metaclass=ModelMeta):
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("Model saved in {}.".format(save_dir))
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -291,7 +296,7 @@ class BaseModel(metaclass=ModelMeta):
         loader = DataLoader(
             dataset,
             batch_sampler=batch_sampler,
-            collate_fn=dataset.batch_transforms,
+            collate_fn=dataset.collate_fn if collate_fn is None else collate_fn,
             num_workers=dataset.num_workers,
             return_list=True,
             use_shared_memory=use_shared_memory)
@@ -312,6 +317,7 @@ class BaseModel(metaclass=ModelMeta):
                    use_vdl=True):
         self._check_transforms(train_dataset.transforms, 'train')
 
+        # XXX: Hard-coding
         if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
             nranks = 1
@@ -319,17 +325,17 @@ class BaseModel(metaclass=ModelMeta):
             nranks = paddle.distributed.get_world_size()
         local_rank = paddle.distributed.get_rank()
         if nranks > 1:
-            find_unused_parameters = getattr(self, 'find_unused_parameters',
-                                             False)
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
                 paddle.distributed.init_parallel_env()
                 ddp_net = to_data_parallel(
-                    self.net, find_unused_parameters=find_unused_parameters)
+                    self.net,
+                    find_unused_parameters=self.find_unused_parameters)
             else:
                 ddp_net = to_data_parallel(
-                    self.net, find_unused_parameters=find_unused_parameters)
+                    self.net,
+                    find_unused_parameters=self.find_unused_parameters)
 
         if use_vdl:
             from visualdl import LogWriter
@@ -488,12 +494,13 @@ class BaseModel(metaclass=ModelMeta):
         assert criterion in {'l1_norm', 'fpgm'}, \
             "Pruning criterion {} is not supported. Please choose from {'l1_norm', 'fpgm'}."
         self._check_transforms(dataset.transforms, 'eval')
+        # XXX: Hard-coding
         if self.model_type == 'detector':
             self.net.eval()
         else:
             self.net.train()
         inputs = _pruner_template_input(
-            sample=dataset[0], model_type=self.model_type)
+            sample=dataset[0][0], model_type=self.model_type)
         if criterion == 'l1_norm':
             self.pruner = L1NormFilterPruner(self.net, inputs=inputs)
         else:
@@ -618,7 +625,10 @@ class BaseModel(metaclass=ModelMeta):
     def _build_inference_net(self):
         raise NotImplementedError
 
-    def _export_inference_model(self, save_dir, image_shape=None):
+    def _get_test_inputs(self, image_shape):
+        raise NotImplementedError
+
+    def export_inference_model(self, save_dir, image_shape=None):
         self.test_inputs = self._get_test_inputs(image_shape)
         infer_net = self._build_inference_net()
 
@@ -696,4 +706,4 @@ class BaseModel(metaclass=ModelMeta):
         raise NotImplementedError
 
     def postprocess(self, *args, **kwargs):
-        raise NotImplementedError
+        raise NotImplementedError

+ 9 - 62
paddlers/tasks/change_detector.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 import math
-import os
 import os.path as osp
 from collections import OrderedDict
 from operator import attrgetter
@@ -29,7 +28,7 @@ import paddlers.models.paddleseg as ppseg
 import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.utils import get_single_card_bs
 from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
@@ -63,7 +62,6 @@ class BaseChangeDetector(BaseModel):
         if params.get('with_net', True):
             params.pop('with_net', None)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
 
     def build_net(self, **params):
         # TODO: add other model
@@ -112,11 +110,11 @@ class BaseChangeDetector(BaseModel):
         ]
 
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         net_out = net(inputs[0], inputs[1])
         logit = net_out[0]
         outputs = OrderedDict()
         if mode == 'test':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
                     net_out, batch_restore_list)
@@ -137,7 +135,6 @@ class BaseChangeDetector(BaseModel):
             outputs['score_map'] = score_map_list
 
         if mode == 'eval':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
@@ -560,10 +557,8 @@ class BaseChangeDetector(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        batch_im1, batch_im2, batch_trans_info = self.preprocess(
-            images, transforms, self.model_type)
+        data = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
-        data = (batch_im1, batch_im2, batch_trans_info)
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
@@ -631,10 +626,10 @@ class BaseChangeDetector(BaseModel):
                 im1 = decode_image(im1, read_raw=True)
                 im2 = decode_image(im2, read_raw=True)
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
-            sample = {'image': im1, 'image2': im2}
+            sample = construct_sample(image=im1, image2=im2)
             data = transforms(sample)
-            im1, im2 = data[:2]
-            trans_info = data[-1]
+            im1, im2 = data[0][:2]
+            trans_info = data[1]
             batch_im1.append(im1)
             batch_im2.append(im2)
             batch_trans_info.append(trans_info)
@@ -645,55 +640,7 @@ class BaseChangeDetector(BaseModel):
             batch_im1 = np.asarray(batch_im1)
             batch_im2 = np.asarray(batch_im2)
 
-        return batch_im1, batch_im2, batch_trans_info
-
-    @staticmethod
-    def get_transforms_shape_info(batch_ori_shape, transforms):
-        batch_restore_list = list()
-        for ori_shape in batch_ori_shape:
-            restore_list = list()
-            h, w = ori_shape[0], ori_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
+        return (batch_im1, batch_im2), batch_trans_info
 
     def postprocess(self, batch_pred, batch_restore_list):
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
@@ -717,7 +664,7 @@ class BaseChangeDetector(BaseModel):
                     x, y = item[2]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
-                    pass
+                    raise RuntimeError
             results.append(pred)
         return results
 
@@ -756,7 +703,7 @@ class BaseChangeDetector(BaseModel):
                         label_map = label_map[:, y:y + h, x:x + w, :]
                         score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
-                    pass
+                    raise RuntimeError
             label_map = label_map.squeeze()
             score_map = score_map.squeeze()
             if not isinstance(label_map, np.ndarray):

+ 13 - 64
paddlers/tasks/classifier.py

@@ -12,26 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import math
 import os.path as osp
 from collections import OrderedDict
 from operator import itemgetter
 
 import numpy as np
 import paddle
-import paddle.nn.functional as F
 from paddle.static import InputSpec
 
 import paddlers
 import paddlers.models.ppcls as ppcls
 import paddlers.rs_models.clas as cmcls
 import paddlers.utils.logging as logging
-from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.models.ppcls.metric import build_metrics
 from paddlers.models import clas_losses
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from .base import BaseModel
 
 __all__ = ["ResNet50_vd", "MobileNetV3", "HRNet", "CondenseNetV2"]
@@ -64,7 +61,6 @@ class BaseClassifier(BaseModel):
         if params.get('with_net', True):
             params.pop('with_net', None)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
 
     def build_net(self, **params):
         with paddle.utils.unique_name.guard():
@@ -459,10 +455,8 @@ class BaseClassifier(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        batch_im, batch_origin_shape = self.preprocess(images, transforms,
-                                                       self.model_type)
+        data, _ = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
-        data = (batch_im, batch_origin_shape, transforms.transforms)
 
         if self.postprocess is None:
             self.build_postprocess_from_labels()
@@ -488,69 +482,19 @@ class BaseClassifier(BaseModel):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         batch_im = list()
-        batch_ori_shape = list()
         for im in images:
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
-            ori_shape = im.shape[:2]
-            sample = {'image': im}
-            im = transforms(sample)
+            sample = construct_sample(image=im)
+            data = transforms(sample)
+            im = data[0][0]
             batch_im.append(im)
-            batch_ori_shape.append(ori_shape)
         if to_tensor:
             batch_im = paddle.to_tensor(batch_im)
         else:
             batch_im = np.asarray(batch_im)
 
-        return batch_im, batch_ori_shape
-
-    @staticmethod
-    def get_transforms_shape_info(batch_ori_shape, transforms):
-        batch_restore_list = list()
-        for ori_shape in batch_ori_shape:
-            restore_list = list()
-            h, w = ori_shape[0], ori_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
+        return batch_im, None
 
     def _check_transforms(self, transforms, mode):
         super()._check_transforms(transforms, mode)
@@ -559,7 +503,11 @@ class BaseClassifier(BaseModel):
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeClassifier object.")
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -571,7 +519,8 @@ class BaseClassifier(BaseModel):
                 batch_size=batch_size,
                 shuffle=dataset.shuffle,
                 drop_last=False,
-                collate_fn=dataset.batch_transforms,
+                collate_fn=dataset.collate_fn
+                if collate_fn is None else collate_fn,
                 num_workers=dataset.num_workers,
                 return_list=True,
                 use_shared_memory=False)

+ 44 - 11
paddlers/tasks/object_detector.py

@@ -24,7 +24,7 @@ from paddle.static import InputSpec
 import paddlers
 import paddlers.models.ppdet as ppdet
 from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
-from paddlers.transforms import decode_image
+from paddlers.transforms import decode_image, construct_sample
 from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
     _BatchPad, _Gt2YoloTarget
@@ -38,6 +38,8 @@ __all__ = [
     "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
 ]
 
+# TODO: Prune and decoupling
+
 
 class BaseDetector(BaseModel):
     def __init__(self, model_name, num_classes=80, **params):
@@ -307,6 +309,8 @@ class BaseDetector(BaseModel):
         self.num_max_boxes = train_dataset.num_max_boxes
         train_dataset.batch_transforms = self._compose_batch_transform(
             train_dataset.transforms, mode='train')
+        train_dataset.collate_fn = self._build_collate_fn(
+            train_dataset.batch_transforms)
 
         # Build optimizer if not defined
         if optimizer is None:
@@ -372,6 +376,17 @@ class BaseDetector(BaseModel):
             early_stop_patience=early_stop_patience,
             use_vdl=use_vdl)
 
+    def _build_collate_fn(self, compose):
+        def _collate_fn(batch):
+            # We drop `trans_info` as it is not required in detection tasks
+            samples = [s[0] for s in batch]
+            return compose(samples)
+
+        return _collate_fn
+
+    def _compose_batch_transform(self, transforms, mode):
+        raise NotImplementedError
+
     def quant_aware_train(self,
                           num_epochs,
                           train_dataset,
@@ -534,9 +549,13 @@ class BaseDetector(BaseModel):
 
         if nranks < 2 or local_rank == 0:
             self.eval_data_loader = self.build_data_loader(
-                eval_dataset, batch_size=batch_size, mode='eval')
+                eval_dataset,
+                batch_size=batch_size,
+                mode='eval',
+                collate_fn=self._build_collate_fn(
+                    eval_dataset.batch_transforms))
             is_bbox_normalized = False
-            if eval_dataset.batch_transforms is not None:
+            if hasattr(eval_dataset, 'batch_transforms'):
                 is_bbox_normalized = any(
                     isinstance(t, _NormalizeBox)
                     for t in eval_dataset.batch_transforms.batch_transforms)
@@ -604,7 +623,7 @@ class BaseDetector(BaseModel):
         else:
             images = img_file
 
-        batch_samples = self.preprocess(images, transforms)
+        batch_samples, _ = self.preprocess(images, transforms)
         self.net.eval()
         outputs = self.run(self.net, batch_samples, 'test')
         prediction = self.postprocess(outputs)
@@ -619,16 +638,17 @@ class BaseDetector(BaseModel):
         for im in images:
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
-            sample = {'image': im}
+            sample = construct_sample(image=im)
             sample = transforms(sample)
-            batch_samples.append(sample)
+            data = sample[0]
+            batch_samples.append(data)
         batch_transforms = self._compose_batch_transform(transforms, 'test')
         batch_samples = batch_transforms(batch_samples)
         if to_tensor:
             for k in batch_samples:
                 batch_samples[k] = paddle.to_tensor(batch_samples[k])
 
-        return batch_samples
+        return batch_samples, None
 
     def postprocess(self, batch_pred):
         infer_result = {}
@@ -705,6 +725,14 @@ class BaseDetector(BaseModel):
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeDetector object.")
 
+    def get_pruning_info(self):
+        info = super().get_pruning_info()
+        info['pruner_inputs'] = {
+            k: v.tolist()
+            for k, v in info['pruner_inputs'][0].items()
+        }
+        return info
+
 
 class PicoDet(BaseDetector):
     def __init__(self,
@@ -920,7 +948,11 @@ class PicoDet(BaseDetector):
             in_args['optimizer'] = optimizer
         return in_args
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -932,13 +964,14 @@ class PicoDet(BaseDetector):
                 batch_size=batch_size,
                 shuffle=dataset.shuffle,
                 drop_last=False,
-                collate_fn=dataset.batch_transforms,
+                collate_fn=dataset.collate_fn
+                if collate_fn is None else collate_fn,
                 num_workers=dataset.num_workers,
                 return_list=True,
                 use_shared_memory=False)
         else:
-            return super(BaseDetector, self).build_data_loader(dataset,
-                                                               batch_size, mode)
+            return super(BaseDetector, self).build_data_loader(
+                dataset, batch_size, mode, collate_fn)
 
 
 class YOLOv3(BaseDetector):

+ 35 - 82
paddlers/tasks/restorer.py

@@ -28,7 +28,7 @@ import paddlers.models.ppgan.metrics as metrics
 import paddlers.utils.logging as logging
 from paddlers.models import res_losses
 from paddlers.models.ppgan.modules.init import init_weights
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms.functions import calc_hr_shape
 from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from .base import BaseModel
@@ -58,7 +58,6 @@ class BaseRestorer(BaseModel):
         if params.get('with_net', True):
             params.pop('with_net', None)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
         if min_max is None:
             self.min_max = self.MIN_MAX
 
@@ -116,14 +115,13 @@ class BaseRestorer(BaseModel):
         return input_spec
 
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         outputs = OrderedDict()
 
         if mode == 'test':
-            tar_shape = inputs[1]
             if self.status == 'Infer':
                 net_out = net(inputs[0])
-                res_map_list = self.postprocess(
-                    net_out, tar_shape, transforms=inputs[2])
+                res_map_list = self.postprocess(net_out, batch_restore_list)
             else:
                 if isinstance(net, GANAdapter):
                     net_out = net.generator(inputs[0])
@@ -131,8 +129,7 @@ class BaseRestorer(BaseModel):
                     net_out = net(inputs[0])
                 if self.TEST_OUT_KEY is not None:
                     net_out = net_out[self.TEST_OUT_KEY]
-                pred = self.postprocess(
-                    net_out, tar_shape, transforms=inputs[2])
+                pred = self.postprocess(net_out, batch_restore_list)
                 res_map_list = []
                 for res_map in pred:
                     res_map = self._tensor_to_images(res_map)
@@ -147,9 +144,7 @@ class BaseRestorer(BaseModel):
             if self.TEST_OUT_KEY is not None:
                 net_out = net_out[self.TEST_OUT_KEY]
             tar = inputs[1]
-            tar_shape = [tar.shape[-2:]]
-            pred = self.postprocess(
-                net_out, tar_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self.postprocess(net_out, batch_restore_list)[0]  # NCHW
             pred = self._tensor_to_images(pred)
             outputs['pred'] = pred
             tar = self._tensor_to_images(tar)
@@ -424,7 +419,6 @@ class BaseRestorer(BaseModel):
                     eval_dataset.num_samples, eval_dataset.num_samples))
             with paddle.no_grad():
                 for step, data in enumerate(self.eval_data_loader):
-                    data.append(eval_dataset.transforms.transforms)
                     outputs = self.run(self.net, data, 'eval')
                     psnr.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
@@ -472,10 +466,8 @@ class BaseRestorer(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        batch_im, batch_tar_shape = self.preprocess(images, transforms,
-                                                    self.model_type)
+        data = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
-        data = (batch_im, batch_tar_shape, transforms.transforms)
         outputs = self.run(self.net, data, 'test')
         res_map_list = outputs['res_map']
         if isinstance(img_file, list):
@@ -487,79 +479,24 @@ class BaseRestorer(BaseModel):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         batch_im = list()
-        batch_tar_shape = list()
+        batch_trans_info = list()
         for im in images:
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
-            ori_shape = im.shape[:2]
-            sample = {'image': im}
-            im = transforms(sample)[0]
+            sample = construct_sample(image=im)
+            data = transforms(sample)
+            im = data[0][0]
+            trans_info = data[1]
             batch_im.append(im)
-            batch_tar_shape.append(self._get_target_shape(ori_shape))
+            batch_trans_info.append(trans_info)
         if to_tensor:
             batch_im = paddle.to_tensor(batch_im)
         else:
             batch_im = np.asarray(batch_im)
 
-        return batch_im, batch_tar_shape
+        return (batch_im, ), batch_trans_info
 
-    def _get_target_shape(self, ori_shape):
-        if self.sr_factor is None:
-            return ori_shape
-        else:
-            return calc_hr_shape(ori_shape, self.sr_factor)
-
-    @staticmethod
-    def get_transforms_shape_info(batch_tar_shape, transforms):
-        batch_restore_list = list()
-        for tar_shape in batch_tar_shape:
-            restore_list = list()
-            h, w = tar_shape[0], tar_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
-
-    def postprocess(self, batch_pred, batch_tar_shape, transforms):
-        batch_restore_list = BaseRestorer.get_transforms_shape_info(
-            batch_tar_shape, transforms)
+    def postprocess(self, batch_pred, batch_restore_list):
         if self.status == 'Infer':
             return self._infer_postprocess(
                 batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
@@ -572,11 +509,15 @@ class BaseRestorer(BaseModel):
             pred = paddle.unsqueeze(pred, axis=0)
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
+                if self.sr_factor:
+                    h, w = calc_hr_shape((h, w), self.sr_factor)
                 if item[0] == 'resize':
                     pred = F.interpolate(
                         pred, (h, w), mode=mode, data_format='NCHW')
                 elif item[0] == 'padding':
                     x, y = item[2]
+                    if self.sr_factor:
+                        x, y = calc_hr_shape((x, y), self.sr_factor)
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
                     pass
@@ -590,6 +531,8 @@ class BaseRestorer(BaseModel):
                 res_map = paddle.unsqueeze(res_map, axis=0)
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
+                if self.sr_factor:
+                    h, w = calc_hr_shape((h, w), self.sr_factor)
                 if item[0] == 'resize':
                     if isinstance(res_map, np.ndarray):
                         res_map = cv2.resize(
@@ -601,6 +544,8 @@ class BaseRestorer(BaseModel):
                             data_format='NHWC')
                 elif item[0] == 'padding':
                     x, y = item[2]
+                    if self.sr_factor:
+                        x, y = calc_hr_shape((x, y), self.sr_factor)
                     if isinstance(res_map, np.ndarray):
                         res_map = res_map[y:y + h, x:x + w]
                     else:
@@ -621,7 +566,11 @@ class BaseRestorer(BaseModel):
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeRestorer object.")
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -633,7 +582,8 @@ class BaseRestorer(BaseModel):
                 batch_size=batch_size,
                 shuffle=dataset.shuffle,
                 drop_last=False,
-                collate_fn=dataset.batch_transforms,
+                collate_fn=dataset.collate_fn
+                if collate_fn is None else collate_fn,
                 num_workers=dataset.num_workers,
                 return_list=True,
                 use_shared_memory=False)
@@ -758,7 +708,7 @@ class DRN(BaseRestorer):
 
     def train_step(self, step, data, net):
         outputs = self.run_gan(
-            net, data, mode='train', gan_mode='forward_primary')
+            net, data[0], mode='train', gan_mode='forward_primary')
         outputs.update(
             self.run_gan(
                 net, (outputs['sr'], outputs['lr']),
@@ -800,6 +750,9 @@ class LESRCNN(BaseRestorer):
 
 
 class ESRGAN(BaseRestorer):
+
+    find_unused_parameters = True
+
     def __init__(self,
                  losses=None,
                  sr_factor=4,
@@ -915,14 +868,14 @@ class ESRGAN(BaseRestorer):
             optim_g, optim_d = self.optimizer
 
             outputs = self.run_gan(
-                net, data, mode='train', gan_mode='forward_g')
+                net, data[0], mode='train', gan_mode='forward_g')
             optim_g.clear_grad()
             (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
             optim_g.step()
 
             outputs.update(
                 self.run_gan(
-                    net, (outputs['g_pred'], data[1]),
+                    net, (outputs['g_pred'], data[0][1]),
                     mode='train',
                     gan_mode='forward_d'))
             optim_d.clear_grad()

+ 10 - 15
paddlers/tasks/segmenter.py

@@ -27,7 +27,7 @@ import paddlers.models.paddleseg as ppseg
 import paddlers.rs_models.seg as cmseg
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
@@ -64,7 +64,6 @@ class BaseSegmenter(BaseModel):
         if params.get('with_net', True):
             params.pop('with_net', None)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
 
     def build_net(self, **params):
         # TODO: when using paddle.utils.unique_name.guard,
@@ -114,11 +113,11 @@ class BaseSegmenter(BaseModel):
         return input_spec
 
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         net_out = net(inputs[0])
         logit = net_out[0]
         outputs = OrderedDict()
         if mode == 'test':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
                     net_out, batch_restore_list)
@@ -139,7 +138,6 @@ class BaseSegmenter(BaseModel):
             outputs['score_map'] = score_map_list
 
         if mode == 'eval':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
@@ -526,10 +524,8 @@ class BaseSegmenter(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        batch_im, batch_trans_info = self.preprocess(images, transforms,
-                                                     self.model_type)
+        data = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
-        data = (batch_im, batch_trans_info)
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
@@ -595,10 +591,10 @@ class BaseSegmenter(BaseModel):
         for im in images:
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
-            sample = {'image': im}
+            sample = construct_sample(image=im)
             data = transforms(sample)
-            im = data[0]
-            trans_info = data[-1]
+            im = data[0][0]
+            trans_info = data[1]
             batch_im.append(im)
             batch_trans_info.append(trans_info)
         if to_tensor:
@@ -606,7 +602,7 @@ class BaseSegmenter(BaseModel):
         else:
             batch_im = np.asarray(batch_im)
 
-        return batch_im, batch_trans_info
+        return (batch_im, ), batch_trans_info
 
     def postprocess(self, batch_pred, batch_restore_list):
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
@@ -630,7 +626,7 @@ class BaseSegmenter(BaseModel):
                     x, y = item[2]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
-                    pass
+                    raise RuntimeError
             results.append(pred)
         return results
 
@@ -669,7 +665,7 @@ class BaseSegmenter(BaseModel):
                         label_map = label_map[:, y:y + h, x:x + w, :]
                         score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
-                    pass
+                    raise RuntimeError
             label_map = label_map.squeeze()
             score_map = score_map.squeeze()
             if not isinstance(label_map, np.ndarray):
@@ -921,13 +917,13 @@ class C2FNet(BaseSegmenter):
             **params)
 
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         with paddle.no_grad():
             pre_coarse = self.coarse_model(inputs[0])
             pre_coarse = pre_coarse[0]
             heatmaps = pre_coarse
 
         if mode == 'test':
-            batch_restore_list = inputs[-1]
             net_out = net(inputs[0], heatmaps)
             logit = net_out[0]
             outputs = OrderedDict()
@@ -952,7 +948,6 @@ class C2FNet(BaseSegmenter):
             outputs['score_map'] = score_map_list
 
         if mode == 'eval':
-            batch_restore_list = inputs[-1]
             net_out = net(inputs[0], heatmaps)
             logit = net_out[0]
             outputs = OrderedDict()

+ 11 - 7
paddlers/transforms/batch_operators.py

@@ -27,7 +27,11 @@ from .box_utils import jaccard_overlap
 from paddlers.utils import logging
 
 
-class BatchCompose(Transform):
+class BatchTransform(Transform):
+    is_batch_transform = True
+
+
+class BatchCompose(BatchTransform):
     def __init__(self, batch_transforms=None, collate_batch=True):
         super(BatchCompose, self).__init__()
         self.batch_transforms = batch_transforms
@@ -40,14 +44,14 @@ class BatchCompose(Transform):
                     samples = op(samples)
                 except Exception as e:
                     stack_info = traceback.format_exc()
-                    logging.warning("fail to map batch transform [{}] "
+                    logging.warning("Fail to map batch transform [{}] "
                                     "with error: {} and stack:\n{}".format(
                                         op, e, str(stack_info)))
                     raise e
 
         samples = _Permute()(samples)
 
-        extra_key = ['h', 'w', 'flipped']
+        extra_key = ['h', 'w', 'flipped', 'trans_info']
         for k in extra_key:
             for sample in samples:
                 if k in sample:
@@ -70,7 +74,7 @@ class BatchCompose(Transform):
         return batch_data
 
 
-class BatchRandomResize(Transform):
+class BatchRandomResize(BatchTransform):
     """
     Resize a batch of inputs to random sizes.
 
@@ -111,7 +115,7 @@ class BatchRandomResize(Transform):
         return samples
 
 
-class BatchRandomResizeByShort(Transform):
+class BatchRandomResizeByShort(BatchTransform):
     """
     Resize a batch of inputs to random sizes while keeping the aspect ratio.
 
@@ -157,7 +161,7 @@ class BatchRandomResizeByShort(Transform):
         return samples
 
 
-class _BatchPad(Transform):
+class _BatchPad(BatchTransform):
     def __init__(self, pad_to_stride=0):
         super(_BatchPad, self).__init__()
         self.pad_to_stride = pad_to_stride
@@ -182,7 +186,7 @@ class _BatchPad(Transform):
         return samples
 
 
-class _Gt2YoloTarget(Transform):
+class _Gt2YoloTarget(BatchTransform):
     """
     Generate YOLOv3 targets by groud truth data, this operator is only used in
         fine grained YOLOv3 loss mode.

+ 29 - 18
paddlers/transforms/operators.py

@@ -18,6 +18,7 @@ import random
 from numbers import Number
 from functools import partial
 from operator import methodcaller
+from collections import OrderedDict
 from collections.abc import Sequence
 
 import numpy as np
@@ -32,6 +33,8 @@ import paddlers.transforms.indices as indices
 import paddlers.transforms.satellites as satellites
 
 __all__ = [
+    "construct_sample",
+    "construct_sample_from_dict",
     "Compose",
     "DecodeImg",
     "Resize",
@@ -74,6 +77,19 @@ interp_dict = {
 }
 
 
+def construct_sample(**kwargs):
+    sample = OrderedDict()
+    for k, v in kwargs.items():
+        sample[k] = v
+    if 'trans_info' not in sample:
+        sample['trans_info'] = []
+    return sample
+
+
+def construct_sample_from_dict(dict_like_obj):
+    return construct_sample(**dict_like_obj)
+
+
 class Compose(object):
     """
     Apply a series of data augmentation strategies to the input.
@@ -107,17 +123,17 @@ class Compose(object):
         This is equivalent to sequentially calling compose_obj.apply_transforms() 
             and compose_obj.arrange_outputs().
         """
-
+        if 'trans_info' not in sample:
+            sample['trans_info'] = []
         sample = self.apply_transforms(sample)
+        trans_info = sample['trans_info']
         sample = self.arrange_outputs(sample)
-        return sample
+        return sample, trans_info
 
     def apply_transforms(self, sample):
         for op in self.transforms:
-            # Skip batch transforms amd mixup
-            if isinstance(op, (paddlers.transforms.BatchRandomResize,
-                               paddlers.transforms.BatchRandomResizeByShort,
-                               MixupImage)):
+            # Skip batch transforms
+            if getattr(op, 'is_batch_transform', False):
                 continue
             sample = op(sample)
         return sample
@@ -373,11 +389,6 @@ class DecodeImg(Transform):
             else:
                 sample['target'] = self.apply_im(sample['target'])
 
-        # the `trans_info` will save the process of image shape,
-        # and will be used in evaluation and prediction.
-        if 'trans_info' not in sample:
-            sample['trans_info'] = []
-
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
@@ -1474,6 +1485,8 @@ class Pad(Transform):
 
 
 class MixupImage(Transform):
+    is_batch_transform = True
+
     def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
         """
         Mixup two images and their gt_bbbox/gt_score.
@@ -2073,13 +2086,12 @@ class ArrangeSegmenter(Arrange):
             mask = sample['mask']
             mask = mask.astype('int64')
         image = F.permute(sample['image'], False)
-        trans_info = sample['trans_info']
         if self.mode == 'train':
             return image, mask
         if self.mode == 'eval':
-            return image, mask, trans_info
+            return image, mask
         if self.mode == 'test':
-            return image, trans_info,
+            return image,
 
 
 class ArrangeChangeDetector(Arrange):
@@ -2089,7 +2101,6 @@ class ArrangeChangeDetector(Arrange):
             mask = mask.astype('int64')
         image_t1 = F.permute(sample['image'], False)
         image_t2 = F.permute(sample['image2'], False)
-        trans_info = sample['trans_info']
         if self.mode == 'train':
             masks = [mask]
             if 'aux_masks' in sample:
@@ -2099,9 +2110,9 @@ class ArrangeChangeDetector(Arrange):
                 image_t1,
                 image_t2, ) + tuple(masks)
         if self.mode == 'eval':
-            return image_t1, image_t2, mask, trans_info
+            return image_t1, image_t2, mask
         if self.mode == 'test':
-            return image_t1, image_t2, trans_info
+            return image_t1, image_t2
 
 
 class ArrangeClassifier(Arrange):
@@ -2110,7 +2121,7 @@ class ArrangeClassifier(Arrange):
         if self.mode in ['train', 'eval']:
             return image, sample['label']
         else:
-            return image
+            return image,
 
 
 class ArrangeDetector(Arrange):

+ 0 - 1
paddlers/utils/download.py

@@ -22,7 +22,6 @@ import hashlib
 import tarfile
 import zipfile
 
-import filelock
 import paddle
 
 from . import logging

+ 1 - 1
paddlers/utils/postprocs/__init__.py

@@ -23,5 +23,5 @@ try:
     from .crf import conditional_random_field
 except ImportError:
     print(
-        "Can not use `conditional_random_field`. Please install pydensecrf first!"
+        "Can not use `conditional_random_field`. Please install pydensecrf first."
     )

+ 4 - 2
requirements.txt

@@ -15,12 +15,14 @@ opencv-contrib-python >= 4.3.0
 openpyxl
 # paddlepaddle >= 2.2.0
 # paddlepaddle-gpu >= 2.2.0
-paddleslim >= 2.2.1,< 2.3.5
+paddleslim >= 2.2.1, < 2.3.5
 pandas
+protobuf >= 3.1.0, <= 3.20.0
 pycocotools
 # pydensecrf
-scikit-learn == 0.23.2
+scikit-learn
 scikit-image >= 0.14.0
 scipy
 shapely
+spyndex
 visualdl >= 2.1.1

+ 31 - 27
setup.py

@@ -13,34 +13,38 @@
 # limitations under the License.
 
 import setuptools
-import paddlers
 
-DESCRIPTION = "Awesome Remote Sensing Toolkit based on PaddlePaddle"
+if __name__ == '__main__':
+    DESCRIPTION = "Awesome Remote Sensing Toolkit based on PaddlePaddle"
 
-with open("README.md", "r", encoding='utf8') as fh:
-    LONG_DESCRIPTION = fh.read()
+    with open("README_EN.md", 'r', encoding='utf8') as fh:
+        LONG_DESCRIPTION = fh.read()
 
-with open("requirements.txt") as fin:
-    REQUIRED_PACKAGES = fin.read()
+    with open("requirements.txt", 'r') as fin:
+        REQUIRED_PACKAGES = fin.read()
 
-setuptools.setup(
-    name="paddlers",
-    version=paddlers.__version__.replace('-', ''),
-    author='PaddleRS Authors',
-    author_email="",
-    description=DESCRIPTION,
-    long_description=LONG_DESCRIPTION,
-    long_description_content_type="text/plain",
-    url="https://github.com/PaddlePaddle/PaddleRS",
-    packages=setuptools.find_packages(include=['paddlers', 'paddlers.*']),
-    python_requires='>=3.7',
-    setup_requires=['cython', 'numpy'],
-    install_requires=REQUIRED_PACKAGES,
-    classifiers=[
-        "Programming Language :: Python :: 3.7",
-        "Programming Language :: Python :: 3.8",
-        "Programming Language :: Python :: 3.9",
-        "License :: OSI Approved :: Apache Software License",
-        "Operating System :: OS Independent",
-    ],
-    license='Apache 2.0', )
+    with open("paddlers/.version", 'r') as fv:
+        VERSION = fv.read().rstrip()
+
+    setuptools.setup(
+        name="paddlers",
+        version=VERSION.replace('-', ''),
+        author='PaddleRS Authors',
+        author_email="",
+        description=DESCRIPTION,
+        long_description=LONG_DESCRIPTION,
+        long_description_content_type="text/plain",
+        url="https://github.com/PaddlePaddle/PaddleRS",
+        packages=setuptools.find_packages(include=['paddlers', 'paddlers.*']) +
+        setuptools.find_namespace_packages(include=['paddlers', 'paddlers.*']),
+        python_requires='>=3.7',
+        setup_requires=['cython', 'numpy'],
+        install_requires=REQUIRED_PACKAGES,
+        classifiers=[
+            "Programming Language :: Python :: 3.7",
+            "Programming Language :: Python :: 3.8",
+            "Programming Language :: Python :: 3.9",
+            "License :: OSI Approved :: Apache Software License",
+            "Operating System :: OS Independent",
+        ],
+        license='Apache 2.0', )

+ 30 - 31
test_tipc/common_func.sh

@@ -1,35 +1,35 @@
 #!/bin/bash
 
 function func_parser_key() {
-    strs=$1
-    IFS=":"
-    array=(${strs})
-    tmp=${array[0]}
+    local strs=$1
+    local IFS=':'
+    local array=(${strs})
+    local tmp=${array[0]}
     echo ${tmp}
 }
 
 function func_parser_value() {
-    strs=$1
-    IFS=":"
-    array=(${strs})
-    tmp=${array[1]}
+    local strs=$1
+    local IFS=':'
+    local array=(${strs})
+    local tmp=${array[1]}
     echo ${tmp}
 }
 
 function func_parser_value_lite() {
-    strs=$1
-    IFS=$2
-    array=(${strs})
-    tmp=${array[1]}
+    local strs=$1
+    local IFS=$2
+    local array=(${strs})
+    local tmp=${array[1]}
     echo ${tmp}
 }
 
 function func_set_params() {
-    key=$1
-    value=$2
-    if [ ${key}x = "null"x ];then
+    local key=$1
+    local value=$2
+    if [ ${key}x = 'null'x ];then
         echo " "
-    elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
+    elif [[ ${value} = 'null' ]] || [[ ${value} = ' ' ]] || [ ${#value} -le 0 ];then
         echo " "
     else 
         echo "${key}=${value}"
@@ -37,21 +37,20 @@ function func_set_params() {
 }
 
 function func_parser_params() {
-    strs=$1
-    IFS=":"
-    array=(${strs})
-    key=${array[0]}
-    tmp=${array[1]}
-    IFS="|"
-    res=""
+    local strs=$1
+    local IFS=':'
+    local array=(${strs})
+    local key=${array[0]}
+    local tmp=${array[1]}
+    local IFS='|'
+    local res=''
     for _params in ${tmp[*]}; do
-        IFS="="
-        array=(${_params})
-        mode=${array[0]}
-        value=${array[1]}
+        local IFS='='
+        local array=(${_params})
+        local mode=${array[0]}
+        local value=${array[1]}
         if [[ ${mode} = ${MODE} ]]; then
-            IFS="|"
-            #echo $(func_set_params "${mode}" "${value}")
+            local IFS='|'
             echo $value
             break
         fi
@@ -112,14 +111,14 @@ function add_suffix() {
 
 function parse_first_value() {
     local key_values=$1
-    local IFS=":"
+    local IFS=':'
     local arr=(${key_values})
     echo ${arr[1]}
 }
 
 function parse_second_value() {
     local key_values=$1
-    local IFS=":"
+    local IFS=':'
     local arr=(${key_values})
     echo ${arr[2]}
 }

+ 2 - 0
test_tipc/configs/seg/_base_/rsseg.yaml

@@ -51,6 +51,8 @@ transforms:
           args:
             mean: [0.5, 0.5, 0.5]
             std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ReloadMask
         - !Node
           type: ArrangeSegmenter
           args: ['eval']

+ 1 - 1
test_tipc/configs/seg/factseg/train_infer_python.txt

@@ -1,7 +1,7 @@
 ===========================train_params===========================
 model_name:seg:factseg
 python:python
-gpu_list:0
+gpu_list:0|0,1
 use_gpu:null|null
 --precision:null
 --num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20

+ 16 - 37
test_tipc/infer.py

@@ -143,45 +143,32 @@ class TIPCPredictor(object):
         return config
 
     def preprocess(self, images, transforms):
-        preprocessed_samples = self._model.preprocess(
+        preprocessed_samples, batch_trans_info = self._model.preprocess(
             images, transforms, to_tensor=False)
         if self._model.model_type == 'classifier':
-            preprocessed_samples = {'image': preprocessed_samples[0]}
+            preprocessed_samples = {'image': preprocessed_samples}
         elif self._model.model_type == 'segmenter':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'ori_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         elif self._model.model_type == 'detector':
             pass
         elif self._model.model_type == 'change_detector':
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
-                'image2': preprocessed_samples[1],
-                'ori_shape': preprocessed_samples[2]
+                'image2': preprocessed_samples[1]
             }
         elif self._model.model_type == 'restorer':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'tar_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         else:
             logging.error(
-                "Invalid model type {}".format(self._model.model_type),
-                exit=True)
-        return preprocessed_samples
-
-    def postprocess(self,
-                    net_outputs,
-                    topk=1,
-                    ori_shape=None,
-                    tar_shape=None,
-                    transforms=None):
+                "Invalid model type {}".format(self.model_type), exit=True)
+        return preprocessed_samples, batch_trans_info
+
+    def postprocess(self, net_outputs, batch_restore_list, topk=1):
         if self._model.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
                 self._model.build_postprocess_from_labels(topk)
-            # XXX: Convert ndarray to tensor as self._model.postprocess requires
+            # XXX: Convert ndarray to tensor as `self._model.postprocess` requires
             assert len(net_outputs) == 1
             net_outputs = paddle.to_tensor(net_outputs[0])
             outputs = self._model.postprocess(net_outputs)
@@ -195,9 +182,7 @@ class TIPCPredictor(object):
             } for l, s, n in zip(class_ids, scores, label_names)]
         elif self._model.model_type in ('segmenter', 'change_detector'):
             label_map, score_map = self._model.postprocess(
-                net_outputs,
-                batch_origin_shape=ori_shape,
-                transforms=transforms.transforms)
+                net_outputs, batch_restore_list=batch_restore_list)
             preds = [{
                 'label_map': l,
                 'score_map': s
@@ -210,14 +195,11 @@ class TIPCPredictor(object):
             preds = self._model.postprocess(net_outputs)
         elif self._model.model_type == 'restorer':
             res_maps = self._model.postprocess(
-                net_outputs[0],
-                batch_tar_shape=tar_shape,
-                transforms=transforms.transforms)
+                net_outputs[0], batch_restore_list=batch_restore_list)
             preds = [{'res_map': res_map} for res_map in res_maps]
         else:
             logging.error(
-                "Invalid model type {}.".format(self._model.model_type),
-                exit=True)
+                "Invalid model type {}.".format(self.model_type), exit=True)
 
         return preds
 
@@ -225,7 +207,8 @@ class TIPCPredictor(object):
         if self.benchmark and time_it:
             self.autolog.times.start()
 
-        preprocessed_input = self.preprocess(images, transforms)
+        preprocessed_input, batch_trans_info = self.preprocess(images,
+                                                               transforms)
 
         input_names = self.predictor.get_input_names()
         for name in input_names:
@@ -247,11 +230,7 @@ class TIPCPredictor(object):
             self.autolog.times.stamp()
 
         res = self.postprocess(
-            net_outputs,
-            topk,
-            ori_shape=preprocessed_input.get('ori_shape', None),
-            tar_shape=preprocessed_input.get('tar_shape', None),
-            transforms=transforms)
+            net_outputs, batch_restore_list=batch_trans_info, topk=topk)
 
         if self.benchmark and time_it:
             self.autolog.times.end(stamp=True)

+ 16 - 16
tests/data/data_utils.py

@@ -20,6 +20,8 @@ from functools import partial, wraps
 
 import numpy as np
 
+from paddlers.transforms import construct_sample
+
 __all__ = ['build_input_from_file']
 
 
@@ -78,19 +80,17 @@ class ConstrSample(object):
 
 class ConstrSegSample(ConstrSample):
     def __call__(self, im_path, mask_path):
-        return {
-            'image': self.get_full_path(im_path),
-            'mask': self.get_full_path(mask_path)
-        }
+        return construct_sample(
+            image=self.get_full_path(im_path),
+            mask=self.get_full_path(mask_path))
 
 
 class ConstrCdSample(ConstrSample):
     def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
-        sample = {
-            'image_t1': self.get_full_path(im1_path),
-            'image_t2': self.get_full_path(im2_path),
-            'mask': self.get_full_path(mask_path)
-        }
+        sample = construct_sample(
+            image_t1=self.get_full_path(im1_path),
+            image_t2=self.get_full_path(im2_path),
+            mask=self.get_full_path(mask_path))
         if len(aux_mask_paths) > 0:
             sample['aux_masks'] = [
                 self.get_full_path(p) for p in aux_mask_paths
@@ -100,7 +100,8 @@ class ConstrCdSample(ConstrSample):
 
 class ConstrClasSample(ConstrSample):
     def __call__(self, im_path, label):
-        return {'image': self.get_full_path(im_path), 'label': int(label)}
+        return construct_sample(
+            image=self.get_full_path(im_path), label=int(label))
 
 
 class ConstrDetSample(ConstrSample):
@@ -234,7 +235,7 @@ class ConstrDetSample(ConstrSample):
         }
 
         self.ct += 1
-        return {'image': im_path, ** im_info, ** label_info}
+        return construct_sample(image=im_path, **im_info, **label_info)
 
     @silent
     def _parse_coco_files(self, im_dir, ann_path):
@@ -303,7 +304,7 @@ class ConstrDetSample(ConstrSample):
                 'difficult': np.array(difficults),
             }
 
-            samples.append({ ** im_info, ** label_info})
+            samples.append(construct_sample(**im_info, **label_info))
 
         return samples
 
@@ -314,10 +315,9 @@ class ConstrResSample(ConstrSample):
         self.sr_factor = sr_factor
 
     def __call__(self, src_path, tar_path):
-        sample = {
-            'image': self.get_full_path(src_path),
-            'target': self.get_full_path(tar_path)
-        }
+        sample = construct_sample(
+            image=self.get_full_path(src_path),
+            target=self.get_full_path(tar_path))
         if self.sr_factor is not None:
             sample['sr_factor'] = self.sr_factor
         return sample

+ 1 - 1
tests/postpros/test_postpros.py

@@ -16,8 +16,8 @@ import copy
 from PIL import Image
 
 import numpy as np
-
 import paddle
+
 import paddlers.utils.postprocs as P
 from testing_utils import CpuCommonTest
 

+ 1 - 0
tutorials/train/change_detection/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.tar.gz
 airchange/

+ 1 - 0
tutorials/train/classification/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.tar.gz
 ucmerced/

+ 1 - 0
tutorials/train/image_restoration/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.tar.gz
 rssr/

+ 1 - 0
tutorials/train/object_detection/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.tar.gz
 sarship/

+ 1 - 0
tutorials/train/semantic_segmentation/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.tar.gz
 rsseg/