瀏覽代碼

[Refactor] Refactor segmentation's output based on trans_info (#97)

* [Refactor] Refactor segmentation's output based on trans_info

* [Refactor] Refactor change detection's output based on trans_info

* [Fix] Update restorer's model_name

* [Fix] Restore restorer's trans_info

* [Fix] Restore detection's trans_info

* [Fix] Fix trans_info

* [Fix] Update about dataset

* [Fix] Update trans_info
Yizhou Chen 2 年之前
父節點
當前提交
359b59769c

+ 1 - 1
paddlers/datasets/cd_dataset.py

@@ -35,7 +35,7 @@ class CDDataset(BaseDataset):
         label_list (str|None, optional): Path of the file that contains the category names. Defaults to None.
         num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
             the number of workers will be automatically determined according to the number of CPU cores: If 
-            there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half 
+            there are more than 16 cores, 8 workers will be used. Otherwise, the number of workers will be half 
             the number of CPU cores. Defaults: 'auto'.
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
         with_seg_labels (bool, optional): Set `with_seg_labels` to True if the datasets provides segmentation 

+ 1 - 1
paddlers/datasets/clas_dataset.py

@@ -29,7 +29,7 @@ class ClasDataset(BaseDataset):
         label_list (str|None, optional): Path of the file that contains the category names. Defaults to None.
         num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
             the number of workers will be automatically determined according to the number of CPU cores: If 
-            there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half 
+            there are more than 16 cores, 8 workers will be used. Otherwise, the number of workers will be half 
             the number of CPU cores. Defaults: 'auto'.
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
     """

+ 1 - 1
paddlers/datasets/coco.py

@@ -39,7 +39,7 @@ class COCODetDataset(BaseDataset):
         label_list (str|None, optional): Path of the file that contains the category names. Defaults to None.
         num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
             the number of workers will be automatically determined according to the number of CPU cores: If 
-            there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half 
+            there are more than 16 cores, 8 workers will be used. Otherwise, the number of workers will be half 
             the number of CPU cores. Defaults: 'auto'.
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
         allow_empty (bool, optional): Whether to add negative samples. Defaults to False.

+ 1 - 1
paddlers/datasets/res_dataset.py

@@ -29,7 +29,7 @@ class ResDataset(BaseDataset):
         transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
         num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
             the number of workers will be automatically determined according to the number of CPU cores: If 
-            there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half 
+            there are more than 16 cores, 8 workers will be used. Otherwise, the number of workers will be half 
             the number of CPU cores. Defaults: 'auto'.
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
         sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image 

+ 1 - 1
paddlers/datasets/seg_dataset.py

@@ -30,7 +30,7 @@ class SegDataset(BaseDataset):
         label_list (str|None, optional): Path of the file that contains the category names. Defaults to None.
         num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
             the number of workers will be automatically determined according to the number of CPU cores: If 
-            there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half 
+            there are more than 16 cores, 8 workers will be used. Otherwise, the number of workers will be half 
             the number of CPU cores. Defaults: 'auto'.
         shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
     """

+ 14 - 18
paddlers/tasks/change_detector.py

@@ -116,13 +116,12 @@ class BaseChangeDetector(BaseModel):
         logit = net_out[0]
         outputs = OrderedDict()
         if mode == 'test':
-            origin_shape = inputs[2]
+            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
-                    net_out, origin_shape, transforms=inputs[3])
+                    net_out, batch_restore_list)
             else:
-                logit_list = self.postprocess(
-                    logit, origin_shape, transforms=inputs[3])
+                logit_list = self.postprocess(logit, batch_restore_list)
                 label_map_list = []
                 score_map_list = []
                 for logit in logit_list:
@@ -138,6 +137,7 @@ class BaseChangeDetector(BaseModel):
             outputs['score_map'] = score_map_list
 
         if mode == 'eval':
+            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
@@ -148,9 +148,7 @@ class BaseChangeDetector(BaseModel):
             if label.ndim != 4:
                 raise ValueError("Expected label.ndim == 4 but got {}".format(
                     label.ndim))
-            origin_shape = [label.shape[-2:]]
-            pred = self.postprocess(
-                pred, origin_shape, transforms=inputs[3])[0]  # NCHW
+            pred = self.postprocess(pred, batch_restore_list)[0]  # NCHW
             intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
                 pred, label, self.num_classes)
             outputs['intersect_area'] = intersect_area
@@ -464,7 +462,6 @@ class BaseChangeDetector(BaseModel):
                 math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
         with paddle.no_grad():
             for step, data in enumerate(self.eval_data_loader):
-                data.append(eval_dataset.transforms.transforms)
                 outputs = self.run(self.net, data, 'eval')
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
@@ -563,10 +560,10 @@ class BaseChangeDetector(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        batch_im1, batch_im2, batch_origin_shape = self.preprocess(
+        batch_im1, batch_im2, batch_trans_info = self.preprocess(
             images, transforms, self.model_type)
         self.net.eval()
-        data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms)
+        data = (batch_im1, batch_im2, batch_trans_info)
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
@@ -628,18 +625,19 @@ class BaseChangeDetector(BaseModel):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         batch_im1, batch_im2 = list(), list()
-        batch_ori_shape = list()
+        batch_trans_info = list()
         for im1, im2 in images:
             if isinstance(im1, str) or isinstance(im2, str):
                 im1 = decode_image(im1, read_raw=True)
                 im2 = decode_image(im2, read_raw=True)
-            ori_shape = im1.shape[:2]
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             sample = {'image': im1, 'image2': im2}
-            im1, im2 = transforms(sample)[:2]
+            data = transforms(sample)
+            im1, im2 = data[:2]
+            trans_info = data[-1]
             batch_im1.append(im1)
             batch_im2.append(im2)
-            batch_ori_shape.append(ori_shape)
+            batch_trans_info.append(trans_info)
         if to_tensor:
             batch_im1 = paddle.to_tensor(batch_im1)
             batch_im2 = paddle.to_tensor(batch_im2)
@@ -647,7 +645,7 @@ class BaseChangeDetector(BaseModel):
             batch_im1 = np.asarray(batch_im1)
             batch_im2 = np.asarray(batch_im2)
 
-        return batch_im1, batch_im2, batch_ori_shape
+        return batch_im1, batch_im2, batch_trans_info
 
     @staticmethod
     def get_transforms_shape_info(batch_ori_shape, transforms):
@@ -697,9 +695,7 @@ class BaseChangeDetector(BaseModel):
             batch_restore_list.append(restore_list)
         return batch_restore_list
 
-    def postprocess(self, batch_pred, batch_origin_shape, transforms):
-        batch_restore_list = BaseChangeDetector.get_transforms_shape_info(
-            batch_origin_shape, transforms)
+    def postprocess(self, batch_pred, batch_restore_list):
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
             return self._infer_postprocess(
                 batch_label_map=batch_pred[0],

+ 1 - 1
paddlers/tasks/restorer.py

@@ -67,7 +67,7 @@ class BaseRestorer(BaseModel):
         # Currently, only use models from cmres.
         if not hasattr(cmres, self.model_name):
             raise ValueError("ERROR: There is no model named {}.".format(
-                model_name))
+                self.model_name))
         net = dict(**cmres.__dict__)[self.model_name](**params)
         return net
 

+ 21 - 75
paddlers/tasks/segmenter.py

@@ -118,13 +118,12 @@ class BaseSegmenter(BaseModel):
         logit = net_out[0]
         outputs = OrderedDict()
         if mode == 'test':
-            origin_shape = inputs[1]
+            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
-                    net_out, origin_shape, transforms=inputs[2])
+                    net_out, batch_restore_list)
             else:
-                logit_list = self.postprocess(
-                    logit, origin_shape, transforms=inputs[2])
+                logit_list = self.postprocess(logit, batch_restore_list)
                 label_map_list = []
                 score_map_list = []
                 for logit in logit_list:
@@ -140,6 +139,7 @@ class BaseSegmenter(BaseModel):
             outputs['score_map'] = score_map_list
 
         if mode == 'eval':
+            batch_restore_list = inputs[-1]
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
@@ -150,9 +150,7 @@ class BaseSegmenter(BaseModel):
             if label.ndim != 4:
                 raise ValueError("Expected label.ndim == 4 but got {}".format(
                     label.ndim))
-            origin_shape = [label.shape[-2:]]
-            pred = self.postprocess(
-                pred, origin_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self.postprocess(pred, batch_restore_list)[0]  # NCHW
             intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
                 pred, label, self.num_classes)
             outputs['intersect_area'] = intersect_area
@@ -441,7 +439,6 @@ class BaseSegmenter(BaseModel):
                 math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
         with paddle.no_grad():
             for step, data in enumerate(self.eval_data_loader):
-                data.append(eval_dataset.transforms.transforms)
                 outputs = self.run(self.net, data, 'eval')
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
@@ -529,10 +526,10 @@ class BaseSegmenter(BaseModel):
             images = [img_file]
         else:
             images = img_file
-        batch_im, batch_origin_shape = self.preprocess(images, transforms,
-                                                       self.model_type)
+        batch_im, batch_trans_info = self.preprocess(images, transforms,
+                                                     self.model_type)
         self.net.eval()
-        data = (batch_im, batch_origin_shape, transforms.transforms)
+        data = (batch_im, batch_trans_info)
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
@@ -594,75 +591,24 @@ class BaseSegmenter(BaseModel):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         batch_im = list()
-        batch_ori_shape = list()
+        batch_trans_info = list()
         for im in images:
             if isinstance(im, str):
                 im = decode_image(im, read_raw=True)
-            ori_shape = im.shape[:2]
             sample = {'image': im}
-            im = transforms(sample)[0]
+            data = transforms(sample)
+            im = data[0]
+            trans_info = data[-1]
             batch_im.append(im)
-            batch_ori_shape.append(ori_shape)
+            batch_trans_info.append(trans_info)
         if to_tensor:
             batch_im = paddle.to_tensor(batch_im)
         else:
             batch_im = np.asarray(batch_im)
 
-        return batch_im, batch_ori_shape
-
-    @staticmethod
-    def get_transforms_shape_info(batch_ori_shape, transforms):
-        # TODO: Store transform meta info when applying transforms
-        # and not here
-        batch_restore_list = list()
-        for ori_shape in batch_ori_shape:
-            restore_list = list()
-            h, w = ori_shape[0], ori_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Pad':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
+        return batch_im, batch_trans_info
 
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
-
-    def postprocess(self, batch_pred, batch_origin_shape, transforms):
-        batch_restore_list = BaseSegmenter.get_transforms_shape_info(
-            batch_origin_shape, transforms)
+    def postprocess(self, batch_pred, batch_restore_list):
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
             return self._infer_postprocess(
                 batch_label_map=batch_pred[0],
@@ -979,17 +925,18 @@ class C2FNet(BaseSegmenter):
             pre_coarse = self.coarse_model(inputs[0])
             pre_coarse = pre_coarse[0]
             heatmaps = pre_coarse
+
         if mode == 'test':
+            batch_restore_list = inputs[-1]
             net_out = net(inputs[0], heatmaps)
             logit = net_out[0]
             outputs = OrderedDict()
             origin_shape = inputs[1]
             if self.status == 'Infer':
                 label_map_list, score_map_list = self.postprocess(
-                    net_out, origin_shape, transforms=inputs[2])
+                    net_out, batch_restore_list)
             else:
-                logit_list = self.postprocess(
-                    logit, origin_shape, transforms=inputs[2])
+                logit_list = self.postprocess(logit, batch_restore_list)
                 label_map_list = []
                 score_map_list = []
                 for logit in logit_list:
@@ -1005,6 +952,7 @@ class C2FNet(BaseSegmenter):
             outputs['score_map'] = score_map_list
 
         if mode == 'eval':
+            batch_restore_list = inputs[-1]
             net_out = net(inputs[0], heatmaps)
             logit = net_out[0]
             outputs = OrderedDict()
@@ -1018,9 +966,7 @@ class C2FNet(BaseSegmenter):
             if label.ndim != 4:
                 raise ValueError("Expected label.ndim == 4 but got {}".format(
                     label.ndim))
-            origin_shape = [label.shape[-2:]]
-            pred = self.postprocess(
-                pred, origin_shape, transforms=inputs[2])[0]  # NCHW
+            pred = self.postprocess(pred, batch_restore_list)[0]  # NCHW
             intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
                 pred, label, self.num_classes)
             outputs['intersect_area'] = intersect_area

+ 15 - 7
paddlers/transforms/operators.py

@@ -373,6 +373,11 @@ class DecodeImg(Transform):
             else:
                 sample['target'] = self.apply_im(sample['target'])
 
+        # the `trans_info` will save the process of image shape,
+        # and will be used in evaluation and prediction.
+        if 'trans_info' not in sample:
+            sample['trans_info'] = []
+
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
@@ -459,6 +464,7 @@ class Resize(Transform):
         return resized_segms
 
     def apply(self, sample):
+        sample['trans_info'].append(('resize', sample['image'].shape[0:2]))
         if self.interp == "RANDOM":
             interp = random.choice(list(interp_dict.values()))
         else:
@@ -677,7 +683,7 @@ class RandomFlipOrRotate(Transform):
         train_transforms = T.Compose([
             T.DecodeImg(),
             T.RandomFlipOrRotate(
-                probs  = [0.3, 0.2]             # p=0.3 to flip the image,p=0.2 to rotate the image,p=0.5 to keep the image unchanged.
+                probs  = [0.3, 0.2]             # p=0.3 to flip the image, p=0.2 to rotate the image, p=0.5 to keep the image unchanged.
                 probsf = [0.3, 0.25, 0, 0, 0]   # p=0.3 and p=0.25 to perform horizontal and vertical flipping; probility of no-flipping is 0.45.
                 probsr = [0, 0.65, 0]),         # p=0.65 to rotate the image by 180°; probility of no-rotation is 0.35.
             T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
@@ -1437,6 +1443,8 @@ class Pad(Transform):
 
         offsets = self._get_offsets(im_h, im_w, h, w)
 
+        sample['trans_info'].append(
+            ('padding', sample['image'].shape[0:2], offsets))
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
@@ -2064,14 +2072,14 @@ class ArrangeSegmenter(Arrange):
         if 'mask' in sample:
             mask = sample['mask']
             mask = mask.astype('int64')
-
         image = F.permute(sample['image'], False)
+        trans_info = sample['trans_info']
         if self.mode == 'train':
             return image, mask
         if self.mode == 'eval':
-            return image, mask
+            return image, mask, trans_info
         if self.mode == 'test':
-            return image,
+            return image, trans_info,
 
 
 class ArrangeChangeDetector(Arrange):
@@ -2079,9 +2087,9 @@ class ArrangeChangeDetector(Arrange):
         if 'mask' in sample:
             mask = sample['mask']
             mask = mask.astype('int64')
-
         image_t1 = F.permute(sample['image'], False)
         image_t2 = F.permute(sample['image2'], False)
+        trans_info = sample['trans_info']
         if self.mode == 'train':
             masks = [mask]
             if 'aux_masks' in sample:
@@ -2091,9 +2099,9 @@ class ArrangeChangeDetector(Arrange):
                 image_t1,
                 image_t2, ) + tuple(masks)
         if self.mode == 'eval':
-            return image_t1, image_t2, mask
+            return image_t1, image_t2, mask, trans_info
         if self.mode == 'test':
-            return image_t1, image_t2,
+            return image_t1, image_t2, trans_info
 
 
 class ArrangeClassifier(Arrange):