소스 검색

[Fix] Remediate Trainers, Datasets, and Transformation Operators (#104)

* Update paddleslim and scikit-learn version requirements

* Polish TIPC

* Ignore .path files

* Fix setup.py

* Update examples

* Fix trans_info bugs

* Update paddleslim version

* Store version separately

* Fix bugs
Lin Manhui 2 년 전
부모
커밋
ff928ecb80

+ 1 - 2
deploy/export/export_model.py

@@ -73,5 +73,4 @@ if __name__ == '__main__':
     model = load_model(args.model_dir)
     model = load_model(args.model_dir)
 
 
     # Do dynamic-to-static cast
     # Do dynamic-to-static cast
-    # XXX: Invoke a protected (single underscore) method outside of subclasses.
-    model._export_inference_model(args.save_dir, fixed_input_shape)
+    model.export_inference_model(args.save_dir, fixed_input_shape)

+ 1 - 0
examples/README.md

@@ -5,6 +5,7 @@ PaddleRS提供从科学研究到产业应用的丰富示例,希望帮助遥感
 ## 1 官方案例
 ## 1 官方案例
 
 
 - [PaddleRS科研实战:设计深度学习变化检测模型](./rs_research/)
 - [PaddleRS科研实战:设计深度学习变化检测模型](./rs_research/)
+- [基于PaddleRS的遥感图像小目标语义分割优化方法](./c2fnet/)
 
 
 ## 2 社区贡献案例
 ## 2 社区贡献案例
 
 

+ 1 - 0
paddlers/.version

@@ -0,0 +1 @@
+0.0.0.dev0

+ 4 - 1
paddlers/__init__.py

@@ -12,10 +12,13 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-__version__ = '0.0.0.dev0'
+import os
 
 
 from paddlers.utils.env import get_environ_info, init_parallel_env
 from paddlers.utils.env import get_environ_info, init_parallel_env
 from . import tasks, datasets, transforms, utils, tools, models, deploy
 from . import tasks, datasets, transforms, utils, tools, models, deploy
 
 
 init_parallel_env()
 init_parallel_env()
 env_info = get_environ_info()
 env_info = get_environ_info()
+
+with open(os.path.join(os.path.dirname(__file__), ".version"), 'r') as fv:
+    __version__ = fv.read().rstrip()

+ 16 - 3
paddlers/datasets/base.py

@@ -15,11 +15,15 @@
 from copy import deepcopy
 from copy import deepcopy
 
 
 from paddle.io import Dataset
 from paddle.io import Dataset
+from paddle.fluid.dataloader.collate import default_collate_fn
 
 
 from paddlers.utils import get_num_workers
 from paddlers.utils import get_num_workers
+from paddlers.transforms import construct_sample_from_dict
 
 
 
 
 class BaseDataset(Dataset):
 class BaseDataset(Dataset):
+    _collate_trans_info = False
+
     def __init__(self, data_dir, label_list, transforms, num_workers, shuffle):
     def __init__(self, data_dir, label_list, transforms, num_workers, shuffle):
         super(BaseDataset, self).__init__()
         super(BaseDataset, self).__init__()
 
 
@@ -30,6 +34,15 @@ class BaseDataset(Dataset):
         self.shuffle = shuffle
         self.shuffle = shuffle
 
 
     def __getitem__(self, idx):
     def __getitem__(self, idx):
-        sample = deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
-        return outputs
+        sample = construct_sample_from_dict(self.file_list[idx])
+        # `trans_info` will be used to store meta info about image shape
+        sample['trans_info'] = []
+        outputs, trans_info = self.transforms(sample)
+        return outputs, trans_info
+
+    def collate_fn(self, batch):
+        if self._collate_trans_info:
+            return default_collate_fn(
+                [s[0] for s in batch]), [s[1] for s in batch]
+        else:
+            return default_collate_fn([s[0] for s in batch])

+ 6 - 5
paddlers/datasets/cd_dataset.py

@@ -12,12 +12,12 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import copy
 from enum import IntEnum
 from enum import IntEnum
 import os.path as osp
 import os.path as osp
 
 
 from .base import BaseDataset
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
+from paddlers.transforms import construct_sample_from_dict
 
 
 
 
 class CDDataset(BaseDataset):
 class CDDataset(BaseDataset):
@@ -44,6 +44,8 @@ class CDDataset(BaseDataset):
             Defaults to False.
             Defaults to False.
     """
     """
 
 
+    _collate_trans_info = True
+
     def __init__(self,
     def __init__(self,
                  data_dir,
                  data_dir,
                  file_list,
                  file_list,
@@ -58,8 +60,6 @@ class CDDataset(BaseDataset):
 
 
         DELIMETER = ' '
         DELIMETER = ' '
 
 
-        # TODO: batch padding
-        self.batch_transforms = None
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
         self.with_seg_labels = with_seg_labels
         self.with_seg_labels = with_seg_labels
@@ -130,7 +130,8 @@ class CDDataset(BaseDataset):
             len(self.file_list), file_list))
             len(self.file_list), file_list))
 
 
     def __getitem__(self, idx):
     def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
+        sample = construct_sample_from_dict(self.file_list[idx])
+        sample['trans_info'] = []
         sample = self.transforms.apply_transforms(sample)
         sample = self.transforms.apply_transforms(sample)
 
 
         if self.binarize_labels:
         if self.binarize_labels:
@@ -142,7 +143,7 @@ class CDDataset(BaseDataset):
 
 
         outputs = self.transforms.arrange_outputs(sample)
         outputs = self.transforms.arrange_outputs(sample)
 
 
-        return outputs
+        return outputs, sample['trans_info']
 
 
     def __len__(self):
     def __len__(self):
         return len(self.file_list)
         return len(self.file_list)

+ 0 - 2
paddlers/datasets/clas_dataset.py

@@ -43,8 +43,6 @@ class ClasDataset(BaseDataset):
                  shuffle=False):
                  shuffle=False):
         super(ClasDataset, self).__init__(data_dir, label_list, transforms,
         super(ClasDataset, self).__init__(data_dir, label_list, transforms,
                                           num_workers, shuffle)
                                           num_workers, shuffle)
-        # TODO batch padding
-        self.batch_transforms = None
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
 
 

+ 6 - 6
paddlers/datasets/coco.py

@@ -23,7 +23,7 @@ import numpy as np
 
 
 from .base import BaseDataset
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
-from paddlers.transforms import DecodeImg, MixupImage
+from paddlers.transforms import DecodeImg, MixupImage, construct_sample_from_dict
 from paddlers.tools import YOLOAnchorCluster
 from paddlers.tools import YOLOAnchorCluster
 
 
 
 
@@ -78,7 +78,6 @@ class COCODetDataset(BaseDataset):
                     self.num_max_boxes *= 2
                     self.num_max_boxes *= 2
                     break
                     break
 
 
-        self.batch_transforms = None
         self.allow_empty = allow_empty
         self.allow_empty = allow_empty
         self.empty_ratio = empty_ratio
         self.empty_ratio = empty_ratio
         self.file_list = list()
         self.file_list = list()
@@ -243,7 +242,7 @@ class COCODetDataset(BaseDataset):
         self._epoch = 0
         self._epoch = 0
 
 
     def __getitem__(self, idx):
     def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
+        sample = construct_sample_from_dict(self.file_list[idx])
         if self.data_fields is not None:
         if self.data_fields is not None:
             sample = {k: sample[k] for k in self.data_fields}
             sample = {k: sample[k] for k in self.data_fields}
         if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
         if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
@@ -253,15 +252,16 @@ class COCODetDataset(BaseDataset):
                 mix_pos = (mix_idx + idx) % self.num_samples
                 mix_pos = (mix_idx + idx) % self.num_samples
             else:
             else:
                 mix_pos = 0
                 mix_pos = 0
-            sample_mix = copy.deepcopy(self.file_list[mix_pos])
+            sample_mix = construct_sample_from_dict(self.file_list[mix_pos])
             if self.data_fields is not None:
             if self.data_fields is not None:
                 sample_mix = {k: sample_mix[k] for k in self.data_fields}
                 sample_mix = {k: sample_mix[k] for k in self.data_fields}
             sample = self.mixup_op(sample=[
             sample = self.mixup_op(sample=[
                 DecodeImg(to_rgb=False)(sample),
                 DecodeImg(to_rgb=False)(sample),
                 DecodeImg(to_rgb=False)(sample_mix)
                 DecodeImg(to_rgb=False)(sample_mix)
             ])
             ])
-        sample = self.transforms(sample)
-        return sample
+        sample['trans_info'] = []
+        sample, trans_info = self.transforms(sample)
+        return sample, trans_info
 
 
     def __len__(self):
     def __len__(self):
         return self.num_samples
         return self.num_samples

+ 2 - 1
paddlers/datasets/res_dataset.py

@@ -36,6 +36,8 @@ class ResDataset(BaseDataset):
             restoration tasks. Defaults to None.
             restoration tasks. Defaults to None.
     """
     """
 
 
+    _collate_trans_info = True
+
     def __init__(self,
     def __init__(self,
                  data_dir,
                  data_dir,
                  file_list,
                  file_list,
@@ -45,7 +47,6 @@ class ResDataset(BaseDataset):
                  sr_factor=None):
                  sr_factor=None):
         super(ResDataset, self).__init__(data_dir, None, transforms,
         super(ResDataset, self).__init__(data_dir, None, transforms,
                                          num_workers, shuffle)
                                          num_workers, shuffle)
-        self.batch_transforms = None
         self.file_list = list()
         self.file_list = list()
 
 
         with open(file_list, encoding=get_encoding(file_list)) as f:
         with open(file_list, encoding=get_encoding(file_list)) as f:

+ 2 - 2
paddlers/datasets/seg_dataset.py

@@ -35,6 +35,8 @@ class SegDataset(BaseDataset):
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
     """
     """
 
 
+    _collate_trans_info = True
+
     def __init__(self,
     def __init__(self,
                  data_dir,
                  data_dir,
                  file_list,
                  file_list,
@@ -44,8 +46,6 @@ class SegDataset(BaseDataset):
                  shuffle=False):
                  shuffle=False):
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
                                          num_workers, shuffle)
                                          num_workers, shuffle)
-        # TODO: batch padding
-        self.batch_transforms = None
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
 
 

+ 16 - 33
paddlers/deploy/predictor.py

@@ -26,6 +26,8 @@ from paddlers.tasks import load_model
 from paddlers.utils import logging, Timer
 from paddlers.utils import logging, Timer
 from paddlers.tasks.utils.slider_predict import slider_predict
 from paddlers.tasks.utils.slider_predict import slider_predict
 
 
+# TODO: Refactor
+
 
 
 class Predictor(object):
 class Predictor(object):
     def __init__(self,
     def __init__(self,
@@ -148,44 +150,32 @@ class Predictor(object):
         return predictor
         return predictor
 
 
     def preprocess(self, images, transforms):
     def preprocess(self, images, transforms):
-        preprocessed_samples = self._model.preprocess(
+        preprocessed_samples, batch_trans_info = self._model.preprocess(
             images, transforms, to_tensor=False)
             images, transforms, to_tensor=False)
         if self.model_type == 'classifier':
         if self.model_type == 'classifier':
-            preprocessed_samples = {'image': preprocessed_samples[0]}
+            preprocessed_samples = {'image': preprocessed_samples}
         elif self.model_type == 'segmenter':
         elif self.model_type == 'segmenter':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'ori_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         elif self.model_type == 'detector':
         elif self.model_type == 'detector':
             pass
             pass
         elif self.model_type == 'change_detector':
         elif self.model_type == 'change_detector':
             preprocessed_samples = {
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
                 'image': preprocessed_samples[0],
-                'image2': preprocessed_samples[1],
-                'ori_shape': preprocessed_samples[2]
+                'image2': preprocessed_samples[1]
             }
             }
         elif self.model_type == 'restorer':
         elif self.model_type == 'restorer':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'tar_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         else:
         else:
             logging.error(
             logging.error(
                 "Invalid model type {}".format(self.model_type), exit=True)
                 "Invalid model type {}".format(self.model_type), exit=True)
-        return preprocessed_samples
-
-    def postprocess(self,
-                    net_outputs,
-                    topk=1,
-                    ori_shape=None,
-                    tar_shape=None,
-                    transforms=None):
+        return preprocessed_samples, batch_trans_info
+
+    def postprocess(self, net_outputs, batch_restore_list, topk=1):
         if self.model_type == 'classifier':
         if self.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
             if self._model.postprocess is None:
                 self._model.build_postprocess_from_labels(topk)
                 self._model.build_postprocess_from_labels(topk)
-            # XXX: Convert ndarray to tensor as self._model.postprocess requires
+            # XXX: Convert ndarray to tensor as `self._model.postprocess` requires
             assert len(net_outputs) == 1
             assert len(net_outputs) == 1
             net_outputs = paddle.to_tensor(net_outputs[0])
             net_outputs = paddle.to_tensor(net_outputs[0])
             outputs = self._model.postprocess(net_outputs)
             outputs = self._model.postprocess(net_outputs)
@@ -199,9 +189,7 @@ class Predictor(object):
             } for l, s, n in zip(class_ids, scores, label_names)]
             } for l, s, n in zip(class_ids, scores, label_names)]
         elif self.model_type in ('segmenter', 'change_detector'):
         elif self.model_type in ('segmenter', 'change_detector'):
             label_map, score_map = self._model.postprocess(
             label_map, score_map = self._model.postprocess(
-                net_outputs,
-                batch_origin_shape=ori_shape,
-                transforms=transforms.transforms)
+                net_outputs, batch_restore_list=batch_restore_list)
             preds = [{
             preds = [{
                 'label_map': l,
                 'label_map': l,
                 'score_map': s
                 'score_map': s
@@ -214,9 +202,7 @@ class Predictor(object):
             preds = self._model.postprocess(net_outputs)
             preds = self._model.postprocess(net_outputs)
         elif self.model_type == 'restorer':
         elif self.model_type == 'restorer':
             res_maps = self._model.postprocess(
             res_maps = self._model.postprocess(
-                net_outputs[0],
-                batch_tar_shape=tar_shape,
-                transforms=transforms.transforms)
+                net_outputs[0], batch_restore_list=batch_restore_list)
             preds = [{'res_map': res_map} for res_map in res_maps]
             preds = [{'res_map': res_map} for res_map in res_maps]
         else:
         else:
             logging.error(
             logging.error(
@@ -248,7 +234,8 @@ class Predictor(object):
 
 
     def _run(self, images, topk=1, transforms=None):
     def _run(self, images, topk=1, transforms=None):
         self.timer.preprocess_time_s.start()
         self.timer.preprocess_time_s.start()
-        preprocessed_input = self.preprocess(images, transforms)
+        preprocessed_input, batch_trans_info = self.preprocess(images,
+                                                               transforms)
         self.timer.preprocess_time_s.end(iter_num=len(images))
         self.timer.preprocess_time_s.end(iter_num=len(images))
 
 
         self.timer.inference_time_s.start()
         self.timer.inference_time_s.start()
@@ -257,11 +244,7 @@ class Predictor(object):
 
 
         self.timer.postprocess_time_s.start()
         self.timer.postprocess_time_s.start()
         results = self.postprocess(
         results = self.postprocess(
-            net_outputs,
-            topk,
-            ori_shape=preprocessed_input.get('ori_shape', None),
-            tar_shape=preprocessed_input.get('tar_shape', None),
-            transforms=transforms)
+            net_outputs, batch_restore_list=batch_trans_info, topk=topk)
         self.timer.postprocess_time_s.end(iter_num=len(images))
         self.timer.postprocess_time_s.end(iter_num=len(images))
 
 
         return results
         return results

+ 23 - 13
paddlers/tasks/base.py

@@ -61,6 +61,9 @@ class ModelMeta(type):
 
 
 
 
 class BaseModel(metaclass=ModelMeta):
 class BaseModel(metaclass=ModelMeta):
+
+    find_unused_parameters = False
+
     def __init__(self, model_type):
     def __init__(self, model_type):
         self.model_type = model_type
         self.model_type = model_type
         self.in_channels = None
         self.in_channels = None
@@ -98,6 +101,7 @@ class BaseModel(metaclass=ModelMeta):
                 if osp.exists(save_dir):
                 if osp.exists(save_dir):
                     os.remove(save_dir)
                     os.remove(save_dir)
                 os.makedirs(save_dir)
                 os.makedirs(save_dir)
+            # XXX: Hard-coding
             if self.model_type == 'classifier':
             if self.model_type == 'classifier':
                 pretrain_weights = get_pretrain_weights(
                 pretrain_weights = get_pretrain_weights(
                     pretrain_weights, self.model_name, save_dir)
                     pretrain_weights, self.model_name, save_dir)
@@ -214,10 +218,7 @@ class BaseModel(metaclass=ModelMeta):
         info = dict()
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
         info['pruning_ratios'] = self.pruning_ratios
-        pruner_inputs = self.pruner.inputs
-        if self.model_type == 'detector':
-            pruner_inputs = {k: v.tolist() for k, v in pruner_inputs[0].items()}
-        info['pruner_inputs'] = pruner_inputs
+        info['pruner_inputs'] = self.pruner.inputs
 
 
         return info
         return info
 
 
@@ -266,7 +267,11 @@ class BaseModel(metaclass=ModelMeta):
         open(osp.join(save_dir, '.success'), 'w').close()
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("Model saved in {}.".format(save_dir))
         logging.info("Model saved in {}.".format(save_dir))
 
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
         if dataset.num_samples < batch_size:
             raise ValueError(
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -291,7 +296,7 @@ class BaseModel(metaclass=ModelMeta):
         loader = DataLoader(
         loader = DataLoader(
             dataset,
             dataset,
             batch_sampler=batch_sampler,
             batch_sampler=batch_sampler,
-            collate_fn=dataset.batch_transforms,
+            collate_fn=dataset.collate_fn if collate_fn is None else collate_fn,
             num_workers=dataset.num_workers,
             num_workers=dataset.num_workers,
             return_list=True,
             return_list=True,
             use_shared_memory=use_shared_memory)
             use_shared_memory=use_shared_memory)
@@ -312,6 +317,7 @@ class BaseModel(metaclass=ModelMeta):
                    use_vdl=True):
                    use_vdl=True):
         self._check_transforms(train_dataset.transforms, 'train')
         self._check_transforms(train_dataset.transforms, 'train')
 
 
+        # XXX: Hard-coding
         if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
         if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
                 train_dataset.file_list):
             nranks = 1
             nranks = 1
@@ -319,17 +325,17 @@ class BaseModel(metaclass=ModelMeta):
             nranks = paddle.distributed.get_world_size()
             nranks = paddle.distributed.get_world_size()
         local_rank = paddle.distributed.get_rank()
         local_rank = paddle.distributed.get_rank()
         if nranks > 1:
         if nranks > 1:
-            find_unused_parameters = getattr(self, 'find_unused_parameters',
-                                             False)
             # Initialize parallel environment if not done.
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
                 ddp_net = to_data_parallel(
                 ddp_net = to_data_parallel(
-                    self.net, find_unused_parameters=find_unused_parameters)
+                    self.net,
+                    find_unused_parameters=self.find_unused_parameters)
             else:
             else:
                 ddp_net = to_data_parallel(
                 ddp_net = to_data_parallel(
-                    self.net, find_unused_parameters=find_unused_parameters)
+                    self.net,
+                    find_unused_parameters=self.find_unused_parameters)
 
 
         if use_vdl:
         if use_vdl:
             from visualdl import LogWriter
             from visualdl import LogWriter
@@ -488,12 +494,13 @@ class BaseModel(metaclass=ModelMeta):
         assert criterion in {'l1_norm', 'fpgm'}, \
         assert criterion in {'l1_norm', 'fpgm'}, \
             "Pruning criterion {} is not supported. Please choose from {'l1_norm', 'fpgm'}."
             "Pruning criterion {} is not supported. Please choose from {'l1_norm', 'fpgm'}."
         self._check_transforms(dataset.transforms, 'eval')
         self._check_transforms(dataset.transforms, 'eval')
+        # XXX: Hard-coding
         if self.model_type == 'detector':
         if self.model_type == 'detector':
             self.net.eval()
             self.net.eval()
         else:
         else:
             self.net.train()
             self.net.train()
         inputs = _pruner_template_input(
         inputs = _pruner_template_input(
-            sample=dataset[0], model_type=self.model_type)
+            sample=dataset[0][0], model_type=self.model_type)
         if criterion == 'l1_norm':
         if criterion == 'l1_norm':
             self.pruner = L1NormFilterPruner(self.net, inputs=inputs)
             self.pruner = L1NormFilterPruner(self.net, inputs=inputs)
         else:
         else:
@@ -618,7 +625,10 @@ class BaseModel(metaclass=ModelMeta):
     def _build_inference_net(self):
     def _build_inference_net(self):
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def _export_inference_model(self, save_dir, image_shape=None):
+    def _get_test_inputs(self, image_shape):
+        raise NotImplementedError
+
+    def export_inference_model(self, save_dir, image_shape=None):
         self.test_inputs = self._get_test_inputs(image_shape)
         self.test_inputs = self._get_test_inputs(image_shape)
         infer_net = self._build_inference_net()
         infer_net = self._build_inference_net()
 
 
@@ -696,4 +706,4 @@ class BaseModel(metaclass=ModelMeta):
         raise NotImplementedError
         raise NotImplementedError
 
 
     def postprocess(self, *args, **kwargs):
     def postprocess(self, *args, **kwargs):
-        raise NotImplementedError
+        raise NotImplementedError

+ 9 - 62
paddlers/tasks/change_detector.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import math
 import math
-import os
 import os.path as osp
 import os.path as osp
 from collections import OrderedDict
 from collections import OrderedDict
 from operator import attrgetter
 from operator import attrgetter
@@ -29,7 +28,7 @@ import paddlers.models.paddleseg as ppseg
 import paddlers.rs_models.cd as cmcd
 import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.models import seg_losses
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.utils import get_single_card_bs
 from paddlers.utils import get_single_card_bs
 from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
@@ -63,7 +62,6 @@ class BaseChangeDetector(BaseModel):
         if params.get('with_net', True):
         if params.get('with_net', True):
             params.pop('with_net', None)
             params.pop('with_net', None)
             self.net = self.build_net(**params)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
 
 
     def build_net(self, **params):
     def build_net(self, **params):
         # TODO: add other model
         # TODO: add other model
@@ -112,11 +110,11 @@ class BaseChangeDetector(BaseModel):
         ]
         ]
 
 
     def run(self, net, inputs, mode):
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         net_out = net(inputs[0], inputs[1])
         net_out = net(inputs[0], inputs[1])
         logit = net_out[0]
         logit = net_out[0]
         outputs = OrderedDict()
         outputs = OrderedDict()
         if mode == 'test':
         if mode == 'test':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
                 label_map_list, score_map_list = self.postprocess(
                     net_out, batch_restore_list)
                     net_out, batch_restore_list)
@@ -137,7 +135,6 @@ class BaseChangeDetector(BaseModel):
             outputs['score_map'] = score_map_list
             outputs['score_map'] = score_map_list
 
 
         if mode == 'eval':
         if mode == 'eval':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
             else:
@@ -560,10 +557,8 @@ class BaseChangeDetector(BaseModel):
             images = [img_file]
             images = [img_file]
         else:
         else:
             images = img_file
             images = img_file
-        batch_im1, batch_im2, batch_trans_info = self.preprocess(
-            images, transforms, self.model_type)
+        data = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
         self.net.eval()
-        data = (batch_im1, batch_im2, batch_trans_info)
         outputs = self.run(self.net, data, 'test')
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
         score_map_list = outputs['score_map']
@@ -631,10 +626,10 @@ class BaseChangeDetector(BaseModel):
                 im1 = decode_image(im1, read_raw=True)
                 im1 = decode_image(im1, read_raw=True)
                 im2 = decode_image(im2, read_raw=True)
                 im2 = decode_image(im2, read_raw=True)
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
-            sample = {'image': im1, 'image2': im2}
+            sample = construct_sample(image=im1, image2=im2)
             data = transforms(sample)
             data = transforms(sample)
-            im1, im2 = data[:2]
-            trans_info = data[-1]
+            im1, im2 = data[0][:2]
+            trans_info = data[1]
             batch_im1.append(im1)
             batch_im1.append(im1)
             batch_im2.append(im2)
             batch_im2.append(im2)
             batch_trans_info.append(trans_info)
             batch_trans_info.append(trans_info)
@@ -645,55 +640,7 @@ class BaseChangeDetector(BaseModel):
             batch_im1 = np.asarray(batch_im1)
             batch_im1 = np.asarray(batch_im1)
             batch_im2 = np.asarray(batch_im2)
             batch_im2 = np.asarray(batch_im2)
 
 
-        return batch_im1, batch_im2, batch_trans_info
-
-    @staticmethod
-    def get_transforms_shape_info(batch_ori_shape, transforms):
-        batch_restore_list = list()
-        for ori_shape in batch_ori_shape:
-            restore_list = list()
-            h, w = ori_shape[0], ori_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
+        return (batch_im1, batch_im2), batch_trans_info
 
 
     def postprocess(self, batch_pred, batch_restore_list):
     def postprocess(self, batch_pred, batch_restore_list):
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
@@ -717,7 +664,7 @@ class BaseChangeDetector(BaseModel):
                     x, y = item[2]
                     x, y = item[2]
                     pred = pred[:, :, y:y + h, x:x + w]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
                 else:
-                    pass
+                    raise RuntimeError
             results.append(pred)
             results.append(pred)
         return results
         return results
 
 
@@ -756,7 +703,7 @@ class BaseChangeDetector(BaseModel):
                         label_map = label_map[:, y:y + h, x:x + w, :]
                         label_map = label_map[:, y:y + h, x:x + w, :]
                         score_map = score_map[:, y:y + h, x:x + w, :]
                         score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
                 else:
-                    pass
+                    raise RuntimeError
             label_map = label_map.squeeze()
             label_map = label_map.squeeze()
             score_map = score_map.squeeze()
             score_map = score_map.squeeze()
             if not isinstance(label_map, np.ndarray):
             if not isinstance(label_map, np.ndarray):

+ 13 - 64
paddlers/tasks/classifier.py

@@ -12,26 +12,23 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import math
 import os.path as osp
 import os.path as osp
 from collections import OrderedDict
 from collections import OrderedDict
 from operator import itemgetter
 from operator import itemgetter
 
 
 import numpy as np
 import numpy as np
 import paddle
 import paddle
-import paddle.nn.functional as F
 from paddle.static import InputSpec
 from paddle.static import InputSpec
 
 
 import paddlers
 import paddlers
 import paddlers.models.ppcls as ppcls
 import paddlers.models.ppcls as ppcls
 import paddlers.rs_models.clas as cmcls
 import paddlers.rs_models.clas as cmcls
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
-from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.models.ppcls.metric import build_metrics
 from paddlers.models.ppcls.metric import build_metrics
 from paddlers.models import clas_losses
 from paddlers.models import clas_losses
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from .base import BaseModel
 from .base import BaseModel
 
 
 __all__ = ["ResNet50_vd", "MobileNetV3", "HRNet", "CondenseNetV2"]
 __all__ = ["ResNet50_vd", "MobileNetV3", "HRNet", "CondenseNetV2"]
@@ -64,7 +61,6 @@ class BaseClassifier(BaseModel):
         if params.get('with_net', True):
         if params.get('with_net', True):
             params.pop('with_net', None)
             params.pop('with_net', None)
             self.net = self.build_net(**params)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
 
 
     def build_net(self, **params):
     def build_net(self, **params):
         with paddle.utils.unique_name.guard():
         with paddle.utils.unique_name.guard():
@@ -459,10 +455,8 @@ class BaseClassifier(BaseModel):
             images = [img_file]
             images = [img_file]
         else:
         else:
             images = img_file
             images = img_file
-        batch_im, batch_origin_shape = self.preprocess(images, transforms,
-                                                       self.model_type)
+        data, _ = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
         self.net.eval()
-        data = (batch_im, batch_origin_shape, transforms.transforms)
 
 
         if self.postprocess is None:
         if self.postprocess is None:
             self.build_postprocess_from_labels()
             self.build_postprocess_from_labels()
@@ -488,69 +482,19 @@ class BaseClassifier(BaseModel):
     def preprocess(self, images, transforms, to_tensor=True):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         self._check_transforms(transforms, 'test')
         batch_im = list()
         batch_im = list()
-        batch_ori_shape = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
                 im = decode_image(im, read_raw=True)
-            ori_shape = im.shape[:2]
-            sample = {'image': im}
-            im = transforms(sample)
+            sample = construct_sample(image=im)
+            data = transforms(sample)
+            im = data[0][0]
             batch_im.append(im)
             batch_im.append(im)
-            batch_ori_shape.append(ori_shape)
         if to_tensor:
         if to_tensor:
             batch_im = paddle.to_tensor(batch_im)
             batch_im = paddle.to_tensor(batch_im)
         else:
         else:
             batch_im = np.asarray(batch_im)
             batch_im = np.asarray(batch_im)
 
 
-        return batch_im, batch_ori_shape
-
-    @staticmethod
-    def get_transforms_shape_info(batch_ori_shape, transforms):
-        batch_restore_list = list()
-        for ori_shape in batch_ori_shape:
-            restore_list = list()
-            h, w = ori_shape[0], ori_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
+        return batch_im, None
 
 
     def _check_transforms(self, transforms, mode):
     def _check_transforms(self, transforms, mode):
         super()._check_transforms(transforms, mode)
         super()._check_transforms(transforms, mode)
@@ -559,7 +503,11 @@ class BaseClassifier(BaseModel):
             raise TypeError(
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeClassifier object.")
                 "`transforms.arrange` must be an ArrangeClassifier object.")
 
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
         if dataset.num_samples < batch_size:
             raise ValueError(
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -571,7 +519,8 @@ class BaseClassifier(BaseModel):
                 batch_size=batch_size,
                 batch_size=batch_size,
                 shuffle=dataset.shuffle,
                 shuffle=dataset.shuffle,
                 drop_last=False,
                 drop_last=False,
-                collate_fn=dataset.batch_transforms,
+                collate_fn=dataset.collate_fn
+                if collate_fn is None else collate_fn,
                 num_workers=dataset.num_workers,
                 num_workers=dataset.num_workers,
                 return_list=True,
                 return_list=True,
                 use_shared_memory=False)
                 use_shared_memory=False)

+ 44 - 11
paddlers/tasks/object_detector.py

@@ -24,7 +24,7 @@ from paddle.static import InputSpec
 import paddlers
 import paddlers
 import paddlers.models.ppdet as ppdet
 import paddlers.models.ppdet as ppdet
 from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
 from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
-from paddlers.transforms import decode_image
+from paddlers.transforms import decode_image, construct_sample
 from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
 from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
     _BatchPad, _Gt2YoloTarget
     _BatchPad, _Gt2YoloTarget
@@ -38,6 +38,8 @@ __all__ = [
     "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
     "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
 ]
 ]
 
 
+# TODO: Prune and decoupling
+
 
 
 class BaseDetector(BaseModel):
 class BaseDetector(BaseModel):
     def __init__(self, model_name, num_classes=80, **params):
     def __init__(self, model_name, num_classes=80, **params):
@@ -307,6 +309,8 @@ class BaseDetector(BaseModel):
         self.num_max_boxes = train_dataset.num_max_boxes
         self.num_max_boxes = train_dataset.num_max_boxes
         train_dataset.batch_transforms = self._compose_batch_transform(
         train_dataset.batch_transforms = self._compose_batch_transform(
             train_dataset.transforms, mode='train')
             train_dataset.transforms, mode='train')
+        train_dataset.collate_fn = self._build_collate_fn(
+            train_dataset.batch_transforms)
 
 
         # Build optimizer if not defined
         # Build optimizer if not defined
         if optimizer is None:
         if optimizer is None:
@@ -372,6 +376,17 @@ class BaseDetector(BaseModel):
             early_stop_patience=early_stop_patience,
             early_stop_patience=early_stop_patience,
             use_vdl=use_vdl)
             use_vdl=use_vdl)
 
 
+    def _build_collate_fn(self, compose):
+        def _collate_fn(batch):
+            # We drop `trans_info` as it is not required in detection tasks
+            samples = [s[0] for s in batch]
+            return compose(samples)
+
+        return _collate_fn
+
+    def _compose_batch_transform(self, transforms, mode):
+        raise NotImplementedError
+
     def quant_aware_train(self,
     def quant_aware_train(self,
                           num_epochs,
                           num_epochs,
                           train_dataset,
                           train_dataset,
@@ -534,9 +549,13 @@ class BaseDetector(BaseModel):
 
 
         if nranks < 2 or local_rank == 0:
         if nranks < 2 or local_rank == 0:
             self.eval_data_loader = self.build_data_loader(
             self.eval_data_loader = self.build_data_loader(
-                eval_dataset, batch_size=batch_size, mode='eval')
+                eval_dataset,
+                batch_size=batch_size,
+                mode='eval',
+                collate_fn=self._build_collate_fn(
+                    eval_dataset.batch_transforms))
             is_bbox_normalized = False
             is_bbox_normalized = False
-            if eval_dataset.batch_transforms is not None:
+            if hasattr(eval_dataset, 'batch_transforms'):
                 is_bbox_normalized = any(
                 is_bbox_normalized = any(
                     isinstance(t, _NormalizeBox)
                     isinstance(t, _NormalizeBox)
                     for t in eval_dataset.batch_transforms.batch_transforms)
                     for t in eval_dataset.batch_transforms.batch_transforms)
@@ -604,7 +623,7 @@ class BaseDetector(BaseModel):
         else:
         else:
             images = img_file
             images = img_file
 
 
-        batch_samples = self.preprocess(images, transforms)
+        batch_samples, _ = self.preprocess(images, transforms)
         self.net.eval()
         self.net.eval()
         outputs = self.run(self.net, batch_samples, 'test')
         outputs = self.run(self.net, batch_samples, 'test')
         prediction = self.postprocess(outputs)
         prediction = self.postprocess(outputs)
@@ -619,16 +638,17 @@ class BaseDetector(BaseModel):
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
                 im = decode_image(im, read_raw=True)
-            sample = {'image': im}
+            sample = construct_sample(image=im)
             sample = transforms(sample)
             sample = transforms(sample)
-            batch_samples.append(sample)
+            data = sample[0]
+            batch_samples.append(data)
         batch_transforms = self._compose_batch_transform(transforms, 'test')
         batch_transforms = self._compose_batch_transform(transforms, 'test')
         batch_samples = batch_transforms(batch_samples)
         batch_samples = batch_transforms(batch_samples)
         if to_tensor:
         if to_tensor:
             for k in batch_samples:
             for k in batch_samples:
                 batch_samples[k] = paddle.to_tensor(batch_samples[k])
                 batch_samples[k] = paddle.to_tensor(batch_samples[k])
 
 
-        return batch_samples
+        return batch_samples, None
 
 
     def postprocess(self, batch_pred):
     def postprocess(self, batch_pred):
         infer_result = {}
         infer_result = {}
@@ -705,6 +725,14 @@ class BaseDetector(BaseModel):
             raise TypeError(
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeDetector object.")
                 "`transforms.arrange` must be an ArrangeDetector object.")
 
 
+    def get_pruning_info(self):
+        info = super().get_pruning_info()
+        info['pruner_inputs'] = {
+            k: v.tolist()
+            for k, v in info['pruner_inputs'][0].items()
+        }
+        return info
+
 
 
 class PicoDet(BaseDetector):
 class PicoDet(BaseDetector):
     def __init__(self,
     def __init__(self,
@@ -920,7 +948,11 @@ class PicoDet(BaseDetector):
             in_args['optimizer'] = optimizer
             in_args['optimizer'] = optimizer
         return in_args
         return in_args
 
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
         if dataset.num_samples < batch_size:
             raise ValueError(
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -932,13 +964,14 @@ class PicoDet(BaseDetector):
                 batch_size=batch_size,
                 batch_size=batch_size,
                 shuffle=dataset.shuffle,
                 shuffle=dataset.shuffle,
                 drop_last=False,
                 drop_last=False,
-                collate_fn=dataset.batch_transforms,
+                collate_fn=dataset.collate_fn
+                if collate_fn is None else collate_fn,
                 num_workers=dataset.num_workers,
                 num_workers=dataset.num_workers,
                 return_list=True,
                 return_list=True,
                 use_shared_memory=False)
                 use_shared_memory=False)
         else:
         else:
-            return super(BaseDetector, self).build_data_loader(dataset,
-                                                               batch_size, mode)
+            return super(BaseDetector, self).build_data_loader(
+                dataset, batch_size, mode, collate_fn)
 
 
 
 
 class YOLOv3(BaseDetector):
 class YOLOv3(BaseDetector):

+ 35 - 82
paddlers/tasks/restorer.py

@@ -28,7 +28,7 @@ import paddlers.models.ppgan.metrics as metrics
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import res_losses
 from paddlers.models import res_losses
 from paddlers.models.ppgan.modules.init import init_weights
 from paddlers.models.ppgan.modules.init import init_weights
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms.functions import calc_hr_shape
 from paddlers.transforms.functions import calc_hr_shape
 from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
@@ -58,7 +58,6 @@ class BaseRestorer(BaseModel):
         if params.get('with_net', True):
         if params.get('with_net', True):
             params.pop('with_net', None)
             params.pop('with_net', None)
             self.net = self.build_net(**params)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
         if min_max is None:
         if min_max is None:
             self.min_max = self.MIN_MAX
             self.min_max = self.MIN_MAX
 
 
@@ -116,14 +115,13 @@ class BaseRestorer(BaseModel):
         return input_spec
         return input_spec
 
 
     def run(self, net, inputs, mode):
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         outputs = OrderedDict()
         outputs = OrderedDict()
 
 
         if mode == 'test':
         if mode == 'test':
-            tar_shape = inputs[1]
             if self.status == 'Infer':
             if self.status == 'Infer':
                 net_out = net(inputs[0])
                 net_out = net(inputs[0])
-                res_map_list = self.postprocess(
-                    net_out, tar_shape, transforms=inputs[2])
+                res_map_list = self.postprocess(net_out, batch_restore_list)
             else:
             else:
                 if isinstance(net, GANAdapter):
                 if isinstance(net, GANAdapter):
                     net_out = net.generator(inputs[0])
                     net_out = net.generator(inputs[0])
@@ -131,8 +129,7 @@ class BaseRestorer(BaseModel):
                     net_out = net(inputs[0])
                     net_out = net(inputs[0])
                 if self.TEST_OUT_KEY is not None:
                 if self.TEST_OUT_KEY is not None:
                     net_out = net_out[self.TEST_OUT_KEY]
                     net_out = net_out[self.TEST_OUT_KEY]
-                pred = self.postprocess(
-                    net_out, tar_shape, transforms=inputs[2])
+                pred = self.postprocess(net_out, batch_restore_list)
                 res_map_list = []
                 res_map_list = []
                 for res_map in pred:
                 for res_map in pred:
                     res_map = self._tensor_to_images(res_map)
                     res_map = self._tensor_to_images(res_map)
@@ -147,9 +144,7 @@ class BaseRestorer(BaseModel):
             if self.TEST_OUT_KEY is not None:
             if self.TEST_OUT_KEY is not None:
                 net_out = net_out[self.TEST_OUT_KEY]
                 net_out = net_out[self.TEST_OUT_KEY]
             tar = inputs[1]
             tar = inputs[1]
-            tar_shape = [tar.shape[-2:]]
-            pred = self.postprocess(
-                net_out, tar_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self.postprocess(net_out, batch_restore_list)[0]  # NCHW
             pred = self._tensor_to_images(pred)
             pred = self._tensor_to_images(pred)
             outputs['pred'] = pred
             outputs['pred'] = pred
             tar = self._tensor_to_images(tar)
             tar = self._tensor_to_images(tar)
@@ -424,7 +419,6 @@ class BaseRestorer(BaseModel):
                     eval_dataset.num_samples, eval_dataset.num_samples))
                     eval_dataset.num_samples, eval_dataset.num_samples))
             with paddle.no_grad():
             with paddle.no_grad():
                 for step, data in enumerate(self.eval_data_loader):
                 for step, data in enumerate(self.eval_data_loader):
-                    data.append(eval_dataset.transforms.transforms)
                     outputs = self.run(self.net, data, 'eval')
                     outputs = self.run(self.net, data, 'eval')
                     psnr.update(outputs['pred'], outputs['tar'])
                     psnr.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
@@ -472,10 +466,8 @@ class BaseRestorer(BaseModel):
             images = [img_file]
             images = [img_file]
         else:
         else:
             images = img_file
             images = img_file
-        batch_im, batch_tar_shape = self.preprocess(images, transforms,
-                                                    self.model_type)
+        data = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
         self.net.eval()
-        data = (batch_im, batch_tar_shape, transforms.transforms)
         outputs = self.run(self.net, data, 'test')
         outputs = self.run(self.net, data, 'test')
         res_map_list = outputs['res_map']
         res_map_list = outputs['res_map']
         if isinstance(img_file, list):
         if isinstance(img_file, list):
@@ -487,79 +479,24 @@ class BaseRestorer(BaseModel):
     def preprocess(self, images, transforms, to_tensor=True):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         self._check_transforms(transforms, 'test')
         batch_im = list()
         batch_im = list()
-        batch_tar_shape = list()
+        batch_trans_info = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
                 im = decode_image(im, read_raw=True)
-            ori_shape = im.shape[:2]
-            sample = {'image': im}
-            im = transforms(sample)[0]
+            sample = construct_sample(image=im)
+            data = transforms(sample)
+            im = data[0][0]
+            trans_info = data[1]
             batch_im.append(im)
             batch_im.append(im)
-            batch_tar_shape.append(self._get_target_shape(ori_shape))
+            batch_trans_info.append(trans_info)
         if to_tensor:
         if to_tensor:
             batch_im = paddle.to_tensor(batch_im)
             batch_im = paddle.to_tensor(batch_im)
         else:
         else:
             batch_im = np.asarray(batch_im)
             batch_im = np.asarray(batch_im)
 
 
-        return batch_im, batch_tar_shape
+        return (batch_im, ), batch_trans_info
 
 
-    def _get_target_shape(self, ori_shape):
-        if self.sr_factor is None:
-            return ori_shape
-        else:
-            return calc_hr_shape(ori_shape, self.sr_factor)
-
-    @staticmethod
-    def get_transforms_shape_info(batch_tar_shape, transforms):
-        batch_restore_list = list()
-        for tar_shape in batch_tar_shape:
-            restore_list = list()
-            h, w = tar_shape[0], tar_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
-
-    def postprocess(self, batch_pred, batch_tar_shape, transforms):
-        batch_restore_list = BaseRestorer.get_transforms_shape_info(
-            batch_tar_shape, transforms)
+    def postprocess(self, batch_pred, batch_restore_list):
         if self.status == 'Infer':
         if self.status == 'Infer':
             return self._infer_postprocess(
             return self._infer_postprocess(
                 batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
                 batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
@@ -572,11 +509,15 @@ class BaseRestorer(BaseModel):
             pred = paddle.unsqueeze(pred, axis=0)
             pred = paddle.unsqueeze(pred, axis=0)
             for item in restore_list[::-1]:
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
                 h, w = item[1][0], item[1][1]
+                if self.sr_factor:
+                    h, w = calc_hr_shape((h, w), self.sr_factor)
                 if item[0] == 'resize':
                 if item[0] == 'resize':
                     pred = F.interpolate(
                     pred = F.interpolate(
                         pred, (h, w), mode=mode, data_format='NCHW')
                         pred, (h, w), mode=mode, data_format='NCHW')
                 elif item[0] == 'padding':
                 elif item[0] == 'padding':
                     x, y = item[2]
                     x, y = item[2]
+                    if self.sr_factor:
+                        x, y = calc_hr_shape((x, y), self.sr_factor)
                     pred = pred[:, :, y:y + h, x:x + w]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
                 else:
                     pass
                     pass
@@ -590,6 +531,8 @@ class BaseRestorer(BaseModel):
                 res_map = paddle.unsqueeze(res_map, axis=0)
                 res_map = paddle.unsqueeze(res_map, axis=0)
             for item in restore_list[::-1]:
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
                 h, w = item[1][0], item[1][1]
+                if self.sr_factor:
+                    h, w = calc_hr_shape((h, w), self.sr_factor)
                 if item[0] == 'resize':
                 if item[0] == 'resize':
                     if isinstance(res_map, np.ndarray):
                     if isinstance(res_map, np.ndarray):
                         res_map = cv2.resize(
                         res_map = cv2.resize(
@@ -601,6 +544,8 @@ class BaseRestorer(BaseModel):
                             data_format='NHWC')
                             data_format='NHWC')
                 elif item[0] == 'padding':
                 elif item[0] == 'padding':
                     x, y = item[2]
                     x, y = item[2]
+                    if self.sr_factor:
+                        x, y = calc_hr_shape((x, y), self.sr_factor)
                     if isinstance(res_map, np.ndarray):
                     if isinstance(res_map, np.ndarray):
                         res_map = res_map[y:y + h, x:x + w]
                         res_map = res_map[y:y + h, x:x + w]
                     else:
                     else:
@@ -621,7 +566,11 @@ class BaseRestorer(BaseModel):
             raise TypeError(
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeRestorer object.")
                 "`transforms.arrange` must be an ArrangeRestorer object.")
 
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
+    def build_data_loader(self,
+                          dataset,
+                          batch_size,
+                          mode='train',
+                          collate_fn=None):
         if dataset.num_samples < batch_size:
         if dataset.num_samples < batch_size:
             raise ValueError(
             raise ValueError(
                 'The volume of dataset({}) must be larger than batch size({}).'
                 'The volume of dataset({}) must be larger than batch size({}).'
@@ -633,7 +582,8 @@ class BaseRestorer(BaseModel):
                 batch_size=batch_size,
                 batch_size=batch_size,
                 shuffle=dataset.shuffle,
                 shuffle=dataset.shuffle,
                 drop_last=False,
                 drop_last=False,
-                collate_fn=dataset.batch_transforms,
+                collate_fn=dataset.collate_fn
+                if collate_fn is None else collate_fn,
                 num_workers=dataset.num_workers,
                 num_workers=dataset.num_workers,
                 return_list=True,
                 return_list=True,
                 use_shared_memory=False)
                 use_shared_memory=False)
@@ -758,7 +708,7 @@ class DRN(BaseRestorer):
 
 
     def train_step(self, step, data, net):
     def train_step(self, step, data, net):
         outputs = self.run_gan(
         outputs = self.run_gan(
-            net, data, mode='train', gan_mode='forward_primary')
+            net, data[0], mode='train', gan_mode='forward_primary')
         outputs.update(
         outputs.update(
             self.run_gan(
             self.run_gan(
                 net, (outputs['sr'], outputs['lr']),
                 net, (outputs['sr'], outputs['lr']),
@@ -800,6 +750,9 @@ class LESRCNN(BaseRestorer):
 
 
 
 
 class ESRGAN(BaseRestorer):
 class ESRGAN(BaseRestorer):
+
+    find_unused_parameters = True
+
     def __init__(self,
     def __init__(self,
                  losses=None,
                  losses=None,
                  sr_factor=4,
                  sr_factor=4,
@@ -915,14 +868,14 @@ class ESRGAN(BaseRestorer):
             optim_g, optim_d = self.optimizer
             optim_g, optim_d = self.optimizer
 
 
             outputs = self.run_gan(
             outputs = self.run_gan(
-                net, data, mode='train', gan_mode='forward_g')
+                net, data[0], mode='train', gan_mode='forward_g')
             optim_g.clear_grad()
             optim_g.clear_grad()
             (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
             (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
             optim_g.step()
             optim_g.step()
 
 
             outputs.update(
             outputs.update(
                 self.run_gan(
                 self.run_gan(
-                    net, (outputs['g_pred'], data[1]),
+                    net, (outputs['g_pred'], data[0][1]),
                     mode='train',
                     mode='train',
                     gan_mode='forward_d'))
                     gan_mode='forward_d'))
             optim_d.clear_grad()
             optim_d.clear_grad()

+ 10 - 15
paddlers/tasks/segmenter.py

@@ -27,7 +27,7 @@ import paddlers.models.paddleseg as ppseg
 import paddlers.rs_models.seg as cmseg
 import paddlers.rs_models.seg as cmseg
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.models import seg_losses
-from paddlers.transforms import Resize, decode_image
+from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
@@ -64,7 +64,6 @@ class BaseSegmenter(BaseModel):
         if params.get('with_net', True):
         if params.get('with_net', True):
             params.pop('with_net', None)
             params.pop('with_net', None)
             self.net = self.build_net(**params)
             self.net = self.build_net(**params)
-        self.find_unused_parameters = True
 
 
     def build_net(self, **params):
     def build_net(self, **params):
         # TODO: when using paddle.utils.unique_name.guard,
         # TODO: when using paddle.utils.unique_name.guard,
@@ -114,11 +113,11 @@ class BaseSegmenter(BaseModel):
         return input_spec
         return input_spec
 
 
     def run(self, net, inputs, mode):
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         net_out = net(inputs[0])
         net_out = net(inputs[0])
         logit = net_out[0]
         logit = net_out[0]
         outputs = OrderedDict()
         outputs = OrderedDict()
         if mode == 'test':
         if mode == 'test':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
                 label_map_list, score_map_list = self.postprocess(
                     net_out, batch_restore_list)
                     net_out, batch_restore_list)
@@ -139,7 +138,6 @@ class BaseSegmenter(BaseModel):
             outputs['score_map'] = score_map_list
             outputs['score_map'] = score_map_list
 
 
         if mode == 'eval':
         if mode == 'eval':
-            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
             else:
@@ -526,10 +524,8 @@ class BaseSegmenter(BaseModel):
             images = [img_file]
             images = [img_file]
         else:
         else:
             images = img_file
             images = img_file
-        batch_im, batch_trans_info = self.preprocess(images, transforms,
-                                                     self.model_type)
+        data = self.preprocess(images, transforms, self.model_type)
         self.net.eval()
         self.net.eval()
-        data = (batch_im, batch_trans_info)
         outputs = self.run(self.net, data, 'test')
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
         score_map_list = outputs['score_map']
@@ -595,10 +591,10 @@ class BaseSegmenter(BaseModel):
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
                 im = decode_image(im, read_raw=True)
-            sample = {'image': im}
+            sample = construct_sample(image=im)
             data = transforms(sample)
             data = transforms(sample)
-            im = data[0]
-            trans_info = data[-1]
+            im = data[0][0]
+            trans_info = data[1]
             batch_im.append(im)
             batch_im.append(im)
             batch_trans_info.append(trans_info)
             batch_trans_info.append(trans_info)
         if to_tensor:
         if to_tensor:
@@ -606,7 +602,7 @@ class BaseSegmenter(BaseModel):
         else:
         else:
             batch_im = np.asarray(batch_im)
             batch_im = np.asarray(batch_im)
 
 
-        return batch_im, batch_trans_info
+        return (batch_im, ), batch_trans_info
 
 
     def postprocess(self, batch_pred, batch_restore_list):
     def postprocess(self, batch_pred, batch_restore_list):
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
@@ -630,7 +626,7 @@ class BaseSegmenter(BaseModel):
                     x, y = item[2]
                     x, y = item[2]
                     pred = pred[:, :, y:y + h, x:x + w]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
                 else:
-                    pass
+                    raise RuntimeError
             results.append(pred)
             results.append(pred)
         return results
         return results
 
 
@@ -669,7 +665,7 @@ class BaseSegmenter(BaseModel):
                         label_map = label_map[:, y:y + h, x:x + w, :]
                         label_map = label_map[:, y:y + h, x:x + w, :]
                         score_map = score_map[:, y:y + h, x:x + w, :]
                         score_map = score_map[:, y:y + h, x:x + w, :]
                 else:
                 else:
-                    pass
+                    raise RuntimeError
             label_map = label_map.squeeze()
             label_map = label_map.squeeze()
             score_map = score_map.squeeze()
             score_map = score_map.squeeze()
             if not isinstance(label_map, np.ndarray):
             if not isinstance(label_map, np.ndarray):
@@ -921,13 +917,13 @@ class C2FNet(BaseSegmenter):
             **params)
             **params)
 
 
     def run(self, net, inputs, mode):
     def run(self, net, inputs, mode):
+        inputs, batch_restore_list = inputs
         with paddle.no_grad():
         with paddle.no_grad():
             pre_coarse = self.coarse_model(inputs[0])
             pre_coarse = self.coarse_model(inputs[0])
             pre_coarse = pre_coarse[0]
             pre_coarse = pre_coarse[0]
             heatmaps = pre_coarse
             heatmaps = pre_coarse
 
 
         if mode == 'test':
         if mode == 'test':
-            batch_restore_list = inputs[-1]
             net_out = net(inputs[0], heatmaps)
             net_out = net(inputs[0], heatmaps)
             logit = net_out[0]
             logit = net_out[0]
             outputs = OrderedDict()
             outputs = OrderedDict()
@@ -952,7 +948,6 @@ class C2FNet(BaseSegmenter):
             outputs['score_map'] = score_map_list
             outputs['score_map'] = score_map_list
 
 
         if mode == 'eval':
         if mode == 'eval':
-            batch_restore_list = inputs[-1]
             net_out = net(inputs[0], heatmaps)
             net_out = net(inputs[0], heatmaps)
             logit = net_out[0]
             logit = net_out[0]
             outputs = OrderedDict()
             outputs = OrderedDict()

+ 11 - 7
paddlers/transforms/batch_operators.py

@@ -27,7 +27,11 @@ from .box_utils import jaccard_overlap
 from paddlers.utils import logging
 from paddlers.utils import logging
 
 
 
 
-class BatchCompose(Transform):
+class BatchTransform(Transform):
+    is_batch_transform = True
+
+
+class BatchCompose(BatchTransform):
     def __init__(self, batch_transforms=None, collate_batch=True):
     def __init__(self, batch_transforms=None, collate_batch=True):
         super(BatchCompose, self).__init__()
         super(BatchCompose, self).__init__()
         self.batch_transforms = batch_transforms
         self.batch_transforms = batch_transforms
@@ -40,14 +44,14 @@ class BatchCompose(Transform):
                     samples = op(samples)
                     samples = op(samples)
                 except Exception as e:
                 except Exception as e:
                     stack_info = traceback.format_exc()
                     stack_info = traceback.format_exc()
-                    logging.warning("fail to map batch transform [{}] "
+                    logging.warning("Fail to map batch transform [{}] "
                                     "with error: {} and stack:\n{}".format(
                                     "with error: {} and stack:\n{}".format(
                                         op, e, str(stack_info)))
                                         op, e, str(stack_info)))
                     raise e
                     raise e
 
 
         samples = _Permute()(samples)
         samples = _Permute()(samples)
 
 
-        extra_key = ['h', 'w', 'flipped']
+        extra_key = ['h', 'w', 'flipped', 'trans_info']
         for k in extra_key:
         for k in extra_key:
             for sample in samples:
             for sample in samples:
                 if k in sample:
                 if k in sample:
@@ -70,7 +74,7 @@ class BatchCompose(Transform):
         return batch_data
         return batch_data
 
 
 
 
-class BatchRandomResize(Transform):
+class BatchRandomResize(BatchTransform):
     """
     """
     Resize a batch of inputs to random sizes.
     Resize a batch of inputs to random sizes.
 
 
@@ -111,7 +115,7 @@ class BatchRandomResize(Transform):
         return samples
         return samples
 
 
 
 
-class BatchRandomResizeByShort(Transform):
+class BatchRandomResizeByShort(BatchTransform):
     """
     """
     Resize a batch of inputs to random sizes while keeping the aspect ratio.
     Resize a batch of inputs to random sizes while keeping the aspect ratio.
 
 
@@ -157,7 +161,7 @@ class BatchRandomResizeByShort(Transform):
         return samples
         return samples
 
 
 
 
-class _BatchPad(Transform):
+class _BatchPad(BatchTransform):
     def __init__(self, pad_to_stride=0):
     def __init__(self, pad_to_stride=0):
         super(_BatchPad, self).__init__()
         super(_BatchPad, self).__init__()
         self.pad_to_stride = pad_to_stride
         self.pad_to_stride = pad_to_stride
@@ -182,7 +186,7 @@ class _BatchPad(Transform):
         return samples
         return samples
 
 
 
 
-class _Gt2YoloTarget(Transform):
+class _Gt2YoloTarget(BatchTransform):
     """
     """
     Generate YOLOv3 targets by groud truth data, this operator is only used in
     Generate YOLOv3 targets by groud truth data, this operator is only used in
         fine grained YOLOv3 loss mode.
         fine grained YOLOv3 loss mode.

+ 29 - 18
paddlers/transforms/operators.py

@@ -18,6 +18,7 @@ import random
 from numbers import Number
 from numbers import Number
 from functools import partial
 from functools import partial
 from operator import methodcaller
 from operator import methodcaller
+from collections import OrderedDict
 from collections.abc import Sequence
 from collections.abc import Sequence
 
 
 import numpy as np
 import numpy as np
@@ -32,6 +33,8 @@ import paddlers.transforms.indices as indices
 import paddlers.transforms.satellites as satellites
 import paddlers.transforms.satellites as satellites
 
 
 __all__ = [
 __all__ = [
+    "construct_sample",
+    "construct_sample_from_dict",
     "Compose",
     "Compose",
     "DecodeImg",
     "DecodeImg",
     "Resize",
     "Resize",
@@ -74,6 +77,19 @@ interp_dict = {
 }
 }
 
 
 
 
+def construct_sample(**kwargs):
+    sample = OrderedDict()
+    for k, v in kwargs.items():
+        sample[k] = v
+    if 'trans_info' not in sample:
+        sample['trans_info'] = []
+    return sample
+
+
+def construct_sample_from_dict(dict_like_obj):
+    return construct_sample(**dict_like_obj)
+
+
 class Compose(object):
 class Compose(object):
     """
     """
     Apply a series of data augmentation strategies to the input.
     Apply a series of data augmentation strategies to the input.
@@ -107,17 +123,17 @@ class Compose(object):
         This is equivalent to sequentially calling compose_obj.apply_transforms() 
         This is equivalent to sequentially calling compose_obj.apply_transforms() 
             and compose_obj.arrange_outputs().
             and compose_obj.arrange_outputs().
         """
         """
-
+        if 'trans_info' not in sample:
+            sample['trans_info'] = []
         sample = self.apply_transforms(sample)
         sample = self.apply_transforms(sample)
+        trans_info = sample['trans_info']
         sample = self.arrange_outputs(sample)
         sample = self.arrange_outputs(sample)
-        return sample
+        return sample, trans_info
 
 
     def apply_transforms(self, sample):
     def apply_transforms(self, sample):
         for op in self.transforms:
         for op in self.transforms:
-            # Skip batch transforms amd mixup
-            if isinstance(op, (paddlers.transforms.BatchRandomResize,
-                               paddlers.transforms.BatchRandomResizeByShort,
-                               MixupImage)):
+            # Skip batch transforms
+            if getattr(op, 'is_batch_transform', False):
                 continue
                 continue
             sample = op(sample)
             sample = op(sample)
         return sample
         return sample
@@ -373,11 +389,6 @@ class DecodeImg(Transform):
             else:
             else:
                 sample['target'] = self.apply_im(sample['target'])
                 sample['target'] = self.apply_im(sample['target'])
 
 
-        # the `trans_info` will save the process of image shape,
-        # and will be used in evaluation and prediction.
-        if 'trans_info' not in sample:
-            sample['trans_info'] = []
-
         sample['im_shape'] = np.array(
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
@@ -1474,6 +1485,8 @@ class Pad(Transform):
 
 
 
 
 class MixupImage(Transform):
 class MixupImage(Transform):
+    is_batch_transform = True
+
     def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
     def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
         """
         """
         Mixup two images and their gt_bbbox/gt_score.
         Mixup two images and their gt_bbbox/gt_score.
@@ -2073,13 +2086,12 @@ class ArrangeSegmenter(Arrange):
             mask = sample['mask']
             mask = sample['mask']
             mask = mask.astype('int64')
             mask = mask.astype('int64')
         image = F.permute(sample['image'], False)
         image = F.permute(sample['image'], False)
-        trans_info = sample['trans_info']
         if self.mode == 'train':
         if self.mode == 'train':
             return image, mask
             return image, mask
         if self.mode == 'eval':
         if self.mode == 'eval':
-            return image, mask, trans_info
+            return image, mask
         if self.mode == 'test':
         if self.mode == 'test':
-            return image, trans_info,
+            return image,
 
 
 
 
 class ArrangeChangeDetector(Arrange):
 class ArrangeChangeDetector(Arrange):
@@ -2089,7 +2101,6 @@ class ArrangeChangeDetector(Arrange):
             mask = mask.astype('int64')
             mask = mask.astype('int64')
         image_t1 = F.permute(sample['image'], False)
         image_t1 = F.permute(sample['image'], False)
         image_t2 = F.permute(sample['image2'], False)
         image_t2 = F.permute(sample['image2'], False)
-        trans_info = sample['trans_info']
         if self.mode == 'train':
         if self.mode == 'train':
             masks = [mask]
             masks = [mask]
             if 'aux_masks' in sample:
             if 'aux_masks' in sample:
@@ -2099,9 +2110,9 @@ class ArrangeChangeDetector(Arrange):
                 image_t1,
                 image_t1,
                 image_t2, ) + tuple(masks)
                 image_t2, ) + tuple(masks)
         if self.mode == 'eval':
         if self.mode == 'eval':
-            return image_t1, image_t2, mask, trans_info
+            return image_t1, image_t2, mask
         if self.mode == 'test':
         if self.mode == 'test':
-            return image_t1, image_t2, trans_info
+            return image_t1, image_t2
 
 
 
 
 class ArrangeClassifier(Arrange):
 class ArrangeClassifier(Arrange):
@@ -2110,7 +2121,7 @@ class ArrangeClassifier(Arrange):
         if self.mode in ['train', 'eval']:
         if self.mode in ['train', 'eval']:
             return image, sample['label']
             return image, sample['label']
         else:
         else:
-            return image
+            return image,
 
 
 
 
 class ArrangeDetector(Arrange):
 class ArrangeDetector(Arrange):

+ 0 - 1
paddlers/utils/download.py

@@ -22,7 +22,6 @@ import hashlib
 import tarfile
 import tarfile
 import zipfile
 import zipfile
 
 
-import filelock
 import paddle
 import paddle
 
 
 from . import logging
 from . import logging

+ 1 - 1
paddlers/utils/postprocs/__init__.py

@@ -23,5 +23,5 @@ try:
     from .crf import conditional_random_field
     from .crf import conditional_random_field
 except ImportError:
 except ImportError:
     print(
     print(
-        "Can not use `conditional_random_field`. Please install pydensecrf first!"
+        "Can not use `conditional_random_field`. Please install pydensecrf first."
     )
     )

+ 4 - 2
requirements.txt

@@ -15,12 +15,14 @@ opencv-contrib-python >= 4.3.0
 openpyxl
 openpyxl
 # paddlepaddle >= 2.2.0
 # paddlepaddle >= 2.2.0
 # paddlepaddle-gpu >= 2.2.0
 # paddlepaddle-gpu >= 2.2.0
-paddleslim >= 2.2.1,< 2.3.5
+paddleslim >= 2.2.1, < 2.3.5
 pandas
 pandas
+protobuf >= 3.1.0, <= 3.20.0
 pycocotools
 pycocotools
 # pydensecrf
 # pydensecrf
-scikit-learn == 0.23.2
+scikit-learn
 scikit-image >= 0.14.0
 scikit-image >= 0.14.0
 scipy
 scipy
 shapely
 shapely
+spyndex
 visualdl >= 2.1.1
 visualdl >= 2.1.1

+ 31 - 27
setup.py

@@ -13,34 +13,38 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import setuptools
 import setuptools
-import paddlers
 
 
-DESCRIPTION = "Awesome Remote Sensing Toolkit based on PaddlePaddle"
+if __name__ == '__main__':
+    DESCRIPTION = "Awesome Remote Sensing Toolkit based on PaddlePaddle"
 
 
-with open("README.md", "r", encoding='utf8') as fh:
-    LONG_DESCRIPTION = fh.read()
+    with open("README_EN.md", 'r', encoding='utf8') as fh:
+        LONG_DESCRIPTION = fh.read()
 
 
-with open("requirements.txt") as fin:
-    REQUIRED_PACKAGES = fin.read()
+    with open("requirements.txt", 'r') as fin:
+        REQUIRED_PACKAGES = fin.read()
 
 
-setuptools.setup(
-    name="paddlers",
-    version=paddlers.__version__.replace('-', ''),
-    author='PaddleRS Authors',
-    author_email="",
-    description=DESCRIPTION,
-    long_description=LONG_DESCRIPTION,
-    long_description_content_type="text/plain",
-    url="https://github.com/PaddlePaddle/PaddleRS",
-    packages=setuptools.find_packages(include=['paddlers', 'paddlers.*']),
-    python_requires='>=3.7',
-    setup_requires=['cython', 'numpy'],
-    install_requires=REQUIRED_PACKAGES,
-    classifiers=[
-        "Programming Language :: Python :: 3.7",
-        "Programming Language :: Python :: 3.8",
-        "Programming Language :: Python :: 3.9",
-        "License :: OSI Approved :: Apache Software License",
-        "Operating System :: OS Independent",
-    ],
-    license='Apache 2.0', )
+    with open("paddlers/.version", 'r') as fv:
+        VERSION = fv.read().rstrip()
+
+    setuptools.setup(
+        name="paddlers",
+        version=VERSION.replace('-', ''),
+        author='PaddleRS Authors',
+        author_email="",
+        description=DESCRIPTION,
+        long_description=LONG_DESCRIPTION,
+        long_description_content_type="text/plain",
+        url="https://github.com/PaddlePaddle/PaddleRS",
+        packages=setuptools.find_packages(include=['paddlers', 'paddlers.*']) +
+        setuptools.find_namespace_packages(include=['paddlers', 'paddlers.*']),
+        python_requires='>=3.7',
+        setup_requires=['cython', 'numpy'],
+        install_requires=REQUIRED_PACKAGES,
+        classifiers=[
+            "Programming Language :: Python :: 3.7",
+            "Programming Language :: Python :: 3.8",
+            "Programming Language :: Python :: 3.9",
+            "License :: OSI Approved :: Apache Software License",
+            "Operating System :: OS Independent",
+        ],
+        license='Apache 2.0', )

+ 30 - 31
test_tipc/common_func.sh

@@ -1,35 +1,35 @@
 #!/bin/bash
 #!/bin/bash
 
 
 function func_parser_key() {
 function func_parser_key() {
-    strs=$1
-    IFS=":"
-    array=(${strs})
-    tmp=${array[0]}
+    local strs=$1
+    local IFS=':'
+    local array=(${strs})
+    local tmp=${array[0]}
     echo ${tmp}
     echo ${tmp}
 }
 }
 
 
 function func_parser_value() {
 function func_parser_value() {
-    strs=$1
-    IFS=":"
-    array=(${strs})
-    tmp=${array[1]}
+    local strs=$1
+    local IFS=':'
+    local array=(${strs})
+    local tmp=${array[1]}
     echo ${tmp}
     echo ${tmp}
 }
 }
 
 
 function func_parser_value_lite() {
 function func_parser_value_lite() {
-    strs=$1
-    IFS=$2
-    array=(${strs})
-    tmp=${array[1]}
+    local strs=$1
+    local IFS=$2
+    local array=(${strs})
+    local tmp=${array[1]}
     echo ${tmp}
     echo ${tmp}
 }
 }
 
 
 function func_set_params() {
 function func_set_params() {
-    key=$1
-    value=$2
-    if [ ${key}x = "null"x ];then
+    local key=$1
+    local value=$2
+    if [ ${key}x = 'null'x ];then
         echo " "
         echo " "
-    elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
+    elif [[ ${value} = 'null' ]] || [[ ${value} = ' ' ]] || [ ${#value} -le 0 ];then
         echo " "
         echo " "
     else 
     else 
         echo "${key}=${value}"
         echo "${key}=${value}"
@@ -37,21 +37,20 @@ function func_set_params() {
 }
 }
 
 
 function func_parser_params() {
 function func_parser_params() {
-    strs=$1
-    IFS=":"
-    array=(${strs})
-    key=${array[0]}
-    tmp=${array[1]}
-    IFS="|"
-    res=""
+    local strs=$1
+    local IFS=':'
+    local array=(${strs})
+    local key=${array[0]}
+    local tmp=${array[1]}
+    local IFS='|'
+    local res=''
     for _params in ${tmp[*]}; do
     for _params in ${tmp[*]}; do
-        IFS="="
-        array=(${_params})
-        mode=${array[0]}
-        value=${array[1]}
+        local IFS='='
+        local array=(${_params})
+        local mode=${array[0]}
+        local value=${array[1]}
         if [[ ${mode} = ${MODE} ]]; then
         if [[ ${mode} = ${MODE} ]]; then
-            IFS="|"
-            #echo $(func_set_params "${mode}" "${value}")
+            local IFS='|'
             echo $value
             echo $value
             break
             break
         fi
         fi
@@ -112,14 +111,14 @@ function add_suffix() {
 
 
 function parse_first_value() {
 function parse_first_value() {
     local key_values=$1
     local key_values=$1
-    local IFS=":"
+    local IFS=':'
     local arr=(${key_values})
     local arr=(${key_values})
     echo ${arr[1]}
     echo ${arr[1]}
 }
 }
 
 
 function parse_second_value() {
 function parse_second_value() {
     local key_values=$1
     local key_values=$1
-    local IFS=":"
+    local IFS=':'
     local arr=(${key_values})
     local arr=(${key_values})
     echo ${arr[2]}
     echo ${arr[2]}
 }
 }

+ 2 - 0
test_tipc/configs/seg/_base_/rsseg.yaml

@@ -51,6 +51,8 @@ transforms:
           args:
           args:
             mean: [0.5, 0.5, 0.5]
             mean: [0.5, 0.5, 0.5]
             std: [0.5, 0.5, 0.5]
             std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ReloadMask
         - !Node
         - !Node
           type: ArrangeSegmenter
           type: ArrangeSegmenter
           args: ['eval']
           args: ['eval']

+ 1 - 1
test_tipc/configs/seg/factseg/train_infer_python.txt

@@ -1,7 +1,7 @@
 ===========================train_params===========================
 ===========================train_params===========================
 model_name:seg:factseg
 model_name:seg:factseg
 python:python
 python:python
-gpu_list:0
+gpu_list:0|0,1
 use_gpu:null|null
 use_gpu:null|null
 --precision:null
 --precision:null
 --num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
 --num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20

+ 16 - 37
test_tipc/infer.py

@@ -143,45 +143,32 @@ class TIPCPredictor(object):
         return config
         return config
 
 
     def preprocess(self, images, transforms):
     def preprocess(self, images, transforms):
-        preprocessed_samples = self._model.preprocess(
+        preprocessed_samples, batch_trans_info = self._model.preprocess(
             images, transforms, to_tensor=False)
             images, transforms, to_tensor=False)
         if self._model.model_type == 'classifier':
         if self._model.model_type == 'classifier':
-            preprocessed_samples = {'image': preprocessed_samples[0]}
+            preprocessed_samples = {'image': preprocessed_samples}
         elif self._model.model_type == 'segmenter':
         elif self._model.model_type == 'segmenter':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'ori_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         elif self._model.model_type == 'detector':
         elif self._model.model_type == 'detector':
             pass
             pass
         elif self._model.model_type == 'change_detector':
         elif self._model.model_type == 'change_detector':
             preprocessed_samples = {
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
                 'image': preprocessed_samples[0],
-                'image2': preprocessed_samples[1],
-                'ori_shape': preprocessed_samples[2]
+                'image2': preprocessed_samples[1]
             }
             }
         elif self._model.model_type == 'restorer':
         elif self._model.model_type == 'restorer':
-            preprocessed_samples = {
-                'image': preprocessed_samples[0],
-                'tar_shape': preprocessed_samples[1]
-            }
+            preprocessed_samples = {'image': preprocessed_samples[0]}
         else:
         else:
             logging.error(
             logging.error(
-                "Invalid model type {}".format(self._model.model_type),
-                exit=True)
-        return preprocessed_samples
-
-    def postprocess(self,
-                    net_outputs,
-                    topk=1,
-                    ori_shape=None,
-                    tar_shape=None,
-                    transforms=None):
+                "Invalid model type {}".format(self.model_type), exit=True)
+        return preprocessed_samples, batch_trans_info
+
+    def postprocess(self, net_outputs, batch_restore_list, topk=1):
         if self._model.model_type == 'classifier':
         if self._model.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
             if self._model.postprocess is None:
                 self._model.build_postprocess_from_labels(topk)
                 self._model.build_postprocess_from_labels(topk)
-            # XXX: Convert ndarray to tensor as self._model.postprocess requires
+            # XXX: Convert ndarray to tensor as `self._model.postprocess` requires
             assert len(net_outputs) == 1
             assert len(net_outputs) == 1
             net_outputs = paddle.to_tensor(net_outputs[0])
             net_outputs = paddle.to_tensor(net_outputs[0])
             outputs = self._model.postprocess(net_outputs)
             outputs = self._model.postprocess(net_outputs)
@@ -195,9 +182,7 @@ class TIPCPredictor(object):
             } for l, s, n in zip(class_ids, scores, label_names)]
             } for l, s, n in zip(class_ids, scores, label_names)]
         elif self._model.model_type in ('segmenter', 'change_detector'):
         elif self._model.model_type in ('segmenter', 'change_detector'):
             label_map, score_map = self._model.postprocess(
             label_map, score_map = self._model.postprocess(
-                net_outputs,
-                batch_origin_shape=ori_shape,
-                transforms=transforms.transforms)
+                net_outputs, batch_restore_list=batch_restore_list)
             preds = [{
             preds = [{
                 'label_map': l,
                 'label_map': l,
                 'score_map': s
                 'score_map': s
@@ -210,14 +195,11 @@ class TIPCPredictor(object):
             preds = self._model.postprocess(net_outputs)
             preds = self._model.postprocess(net_outputs)
         elif self._model.model_type == 'restorer':
         elif self._model.model_type == 'restorer':
             res_maps = self._model.postprocess(
             res_maps = self._model.postprocess(
-                net_outputs[0],
-                batch_tar_shape=tar_shape,
-                transforms=transforms.transforms)
+                net_outputs[0], batch_restore_list=batch_restore_list)
             preds = [{'res_map': res_map} for res_map in res_maps]
             preds = [{'res_map': res_map} for res_map in res_maps]
         else:
         else:
             logging.error(
             logging.error(
-                "Invalid model type {}.".format(self._model.model_type),
-                exit=True)
+                "Invalid model type {}.".format(self.model_type), exit=True)
 
 
         return preds
         return preds
 
 
@@ -225,7 +207,8 @@ class TIPCPredictor(object):
         if self.benchmark and time_it:
         if self.benchmark and time_it:
             self.autolog.times.start()
             self.autolog.times.start()
 
 
-        preprocessed_input = self.preprocess(images, transforms)
+        preprocessed_input, batch_trans_info = self.preprocess(images,
+                                                               transforms)
 
 
         input_names = self.predictor.get_input_names()
         input_names = self.predictor.get_input_names()
         for name in input_names:
         for name in input_names:
@@ -247,11 +230,7 @@ class TIPCPredictor(object):
             self.autolog.times.stamp()
             self.autolog.times.stamp()
 
 
         res = self.postprocess(
         res = self.postprocess(
-            net_outputs,
-            topk,
-            ori_shape=preprocessed_input.get('ori_shape', None),
-            tar_shape=preprocessed_input.get('tar_shape', None),
-            transforms=transforms)
+            net_outputs, batch_restore_list=batch_trans_info, topk=topk)
 
 
         if self.benchmark and time_it:
         if self.benchmark and time_it:
             self.autolog.times.end(stamp=True)
             self.autolog.times.end(stamp=True)

+ 16 - 16
tests/data/data_utils.py

@@ -20,6 +20,8 @@ from functools import partial, wraps
 
 
 import numpy as np
 import numpy as np
 
 
+from paddlers.transforms import construct_sample
+
 __all__ = ['build_input_from_file']
 __all__ = ['build_input_from_file']
 
 
 
 
@@ -78,19 +80,17 @@ class ConstrSample(object):
 
 
 class ConstrSegSample(ConstrSample):
 class ConstrSegSample(ConstrSample):
     def __call__(self, im_path, mask_path):
     def __call__(self, im_path, mask_path):
-        return {
-            'image': self.get_full_path(im_path),
-            'mask': self.get_full_path(mask_path)
-        }
+        return construct_sample(
+            image=self.get_full_path(im_path),
+            mask=self.get_full_path(mask_path))
 
 
 
 
 class ConstrCdSample(ConstrSample):
 class ConstrCdSample(ConstrSample):
     def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
     def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
-        sample = {
-            'image_t1': self.get_full_path(im1_path),
-            'image_t2': self.get_full_path(im2_path),
-            'mask': self.get_full_path(mask_path)
-        }
+        sample = construct_sample(
+            image_t1=self.get_full_path(im1_path),
+            image_t2=self.get_full_path(im2_path),
+            mask=self.get_full_path(mask_path))
         if len(aux_mask_paths) > 0:
         if len(aux_mask_paths) > 0:
             sample['aux_masks'] = [
             sample['aux_masks'] = [
                 self.get_full_path(p) for p in aux_mask_paths
                 self.get_full_path(p) for p in aux_mask_paths
@@ -100,7 +100,8 @@ class ConstrCdSample(ConstrSample):
 
 
 class ConstrClasSample(ConstrSample):
 class ConstrClasSample(ConstrSample):
     def __call__(self, im_path, label):
     def __call__(self, im_path, label):
-        return {'image': self.get_full_path(im_path), 'label': int(label)}
+        return construct_sample(
+            image=self.get_full_path(im_path), label=int(label))
 
 
 
 
 class ConstrDetSample(ConstrSample):
 class ConstrDetSample(ConstrSample):
@@ -234,7 +235,7 @@ class ConstrDetSample(ConstrSample):
         }
         }
 
 
         self.ct += 1
         self.ct += 1
-        return {'image': im_path, ** im_info, ** label_info}
+        return construct_sample(image=im_path, **im_info, **label_info)
 
 
     @silent
     @silent
     def _parse_coco_files(self, im_dir, ann_path):
     def _parse_coco_files(self, im_dir, ann_path):
@@ -303,7 +304,7 @@ class ConstrDetSample(ConstrSample):
                 'difficult': np.array(difficults),
                 'difficult': np.array(difficults),
             }
             }
 
 
-            samples.append({ ** im_info, ** label_info})
+            samples.append(construct_sample(**im_info, **label_info))
 
 
         return samples
         return samples
 
 
@@ -314,10 +315,9 @@ class ConstrResSample(ConstrSample):
         self.sr_factor = sr_factor
         self.sr_factor = sr_factor
 
 
     def __call__(self, src_path, tar_path):
     def __call__(self, src_path, tar_path):
-        sample = {
-            'image': self.get_full_path(src_path),
-            'target': self.get_full_path(tar_path)
-        }
+        sample = construct_sample(
+            image=self.get_full_path(src_path),
+            target=self.get_full_path(tar_path))
         if self.sr_factor is not None:
         if self.sr_factor is not None:
             sample['sr_factor'] = self.sr_factor
             sample['sr_factor'] = self.sr_factor
         return sample
         return sample

+ 1 - 1
tests/postpros/test_postpros.py

@@ -16,8 +16,8 @@ import copy
 from PIL import Image
 from PIL import Image
 
 
 import numpy as np
 import numpy as np
-
 import paddle
 import paddle
+
 import paddlers.utils.postprocs as P
 import paddlers.utils.postprocs as P
 from testing_utils import CpuCommonTest
 from testing_utils import CpuCommonTest
 
 

+ 1 - 0
tutorials/train/change_detection/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.zip
 *.tar.gz
 *.tar.gz
 airchange/
 airchange/

+ 1 - 0
tutorials/train/classification/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.zip
 *.tar.gz
 *.tar.gz
 ucmerced/
 ucmerced/

+ 1 - 0
tutorials/train/image_restoration/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.zip
 *.tar.gz
 *.tar.gz
 rssr/
 rssr/

+ 1 - 0
tutorials/train/object_detection/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.zip
 *.tar.gz
 *.tar.gz
 sarship/
 sarship/

+ 1 - 0
tutorials/train/semantic_segmentation/data/.gitignore

@@ -1,3 +1,4 @@
+*.path
 *.zip
 *.zip
 *.tar.gz
 *.tar.gz
 rsseg/
 rsseg/