Browse Source

Add apply_im_only in dataset

Bobholamovic 2 years ago
parent
commit
d351f12c85

+ 5 - 0
paddlers/datasets/base.py

@@ -28,3 +28,8 @@ class BaseDataset(Dataset):
         self.transforms = deepcopy(transforms)
         self.num_workers = get_num_workers(num_workers)
         self.shuffle = shuffle
+
+    def __getitem__(self, idx):
+        sample = deepcopy(self.file_list[idx])
+        outputs = self.transforms(sample)
+        return outputs

+ 32 - 2
paddlers/datasets/cd_dataset.py

@@ -18,6 +18,7 @@ import os.path as osp
 
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
+from paddlers.transforms import decode_seg_mask
 
 
 class CDDataset(BaseDataset):
@@ -35,6 +36,7 @@ class CDDataset(BaseDataset):
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+        apply_im_only (bool, optional): 是否绕过对标签的数据增强和预处理。在模型验证和推理阶段一般指定此选项为True。默认为False。
         with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
         binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作。默认为False。
     """
@@ -46,6 +48,7 @@ class CDDataset(BaseDataset):
                  transforms=None,
                  num_workers='auto',
                  shuffle=False,
+                 apply_im_only=False,
                  with_seg_labels=False,
                  binarize_labels=False):
         super(CDDataset, self).__init__(data_dir, label_list, transforms,
@@ -58,6 +61,7 @@ class CDDataset(BaseDataset):
         self.file_list = list()
         self.labels = list()
         self.with_seg_labels = with_seg_labels
+        self.apply_im_only = apply_im_only
         if self.with_seg_labels:
             num_items = 5  # RGB1, RGB2, CD, Seg1, Seg2
         else:
@@ -127,9 +131,35 @@ class CDDataset(BaseDataset):
 
     def __getitem__(self, idx):
         sample = copy.deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
+
+        if self.apply_im_only:
+            has_mask, has_aux_masks = False, False
+            if 'mask' in sample:
+                has_mask = True
+                mask = decode_seg_mask(sample['mask'])
+                del sample['mask']
+            if 'aux_masks' in sample:
+                has_aux_masks = True
+                aux_masks = list(map(decode_seg_mask, sample['aux_masks']))
+                del sample['aux_masks']
+
+        sample = self.transforms.apply_transforms(sample)
+
+        if self.apply_im_only:
+            if has_mask:
+                sample['mask'] = mask
+            if has_aux_masks:
+                sample['aux_masks'] = aux_masks
+
         if self.binarize_labels:
-            outputs = outputs[:2] + tuple(map(self._binarize, outputs[2:]))
+            # Requires 'mask' to exist
+            sample['mask'] = self._binarize(sample['mask'])
+            if 'aux_masks' in sample:
+                sample['aux_masks'] = list(
+                    map(self._binarize, sample['aux_masks']))
+
+        outputs = self.transforms.arrange_outputs(sample)
+
         return outputs
 
     def __len__(self):

+ 0 - 6
paddlers/datasets/clas_dataset.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 import os.path as osp
-import copy
 
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
@@ -82,10 +81,5 @@ class ClasDataset(BaseDataset):
         logging.info("{} samples in file {}".format(
             len(self.file_list), file_list))
 
-    def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
-        return outputs
-
     def __len__(self):
         return len(self.file_list)

+ 17 - 2
paddlers/datasets/seg_dataset.py

@@ -17,6 +17,7 @@ import copy
 
 from .base import BaseDataset
 from paddlers.utils import logging, get_encoding, norm_path, is_pic
+from paddlers.transforms import decode_seg_mask
 
 
 class SegDataset(BaseDataset):
@@ -31,6 +32,7 @@ class SegDataset(BaseDataset):
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+        apply_im_only (bool, optional): 是否绕过对标签的数据增强和预处理。在模型验证和推理阶段一般指定此选项为True。默认为False。
     """
 
     def __init__(self,
@@ -39,13 +41,15 @@ class SegDataset(BaseDataset):
                  label_list=None,
                  transforms=None,
                  num_workers='auto',
-                 shuffle=False):
+                 shuffle=False,
+                 apply_im_only=False):
         super(SegDataset, self).__init__(data_dir, label_list, transforms,
                                          num_workers, shuffle)
         # TODO batch padding
         self.batch_transforms = None
         self.file_list = list()
         self.labels = list()
+        self.apply_im_only = apply_im_only
 
         # TODO:非None时,让用户跳转数据集分析生成label_list
         # 不要在此处分析label file
@@ -84,7 +88,18 @@ class SegDataset(BaseDataset):
 
     def __getitem__(self, idx):
         sample = copy.deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
+        if self.apply_im_only:
+            has_mask = False
+            if 'mask' in sample:
+                has_mask = True
+                mask = decode_seg_mask(sample['mask'])
+                del sample['mask']
+            sample = self.transforms.apply_transforms(sample)
+            if has_mask:
+                sample['mask'] = mask
+            outputs = self.transforms.arrange_outputs(sample)
+        else:
+            outputs = super().__getitem__(idx)
         return outputs
 
     def __len__(self):

+ 18 - 23
paddlers/transforms/__init__.py

@@ -15,6 +15,9 @@
 import copy
 import os.path as osp
 
+import numpy as np
+from PIL import Image
+
 from .operators import *
 from .batch_operators import BatchRandomResize, BatchRandomResizeByShort, _BatchPad
 from paddlers import transforms as T
@@ -29,6 +32,7 @@ def decode_image(im_path,
     Decode an image.
     
     Args:
+        im_path (str): Path of the image to decode.
         to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
         to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. 
@@ -37,7 +41,7 @@ def decode_image(im_path,
             SAR image, set this argument to True. Defaults to True.
     """
 
-    # Do a presence check. `osp.exists` assumes `im_path` is a path-like object.
+    # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
     if not osp.exists(im_path):
         raise ValueError(f"{im_path} does not exist!")
     decoder = T.DecodeImg(
@@ -51,27 +55,17 @@ def decode_image(im_path,
     return sample['image']
 
 
-def arrange_transforms(model_type, transforms, mode='train'):
-    # 给transforms添加arrange操作
-    if model_type == 'segmenter':
-        if mode == 'eval':
-            transforms.apply_im_only = True
-        else:
-            transforms.apply_im_only = False
-        arrange_transform = ArrangeSegmenter(mode)
-    elif model_type == 'changedetector':
-        if mode == 'eval':
-            transforms.apply_im_only = True
-        else:
-            transforms.apply_im_only = False
-        arrange_transform = ArrangeChangeDetector(mode)
-    elif model_type == 'classifier':
-        arrange_transform = ArrangeClassifier(mode)
-    elif model_type == 'detector':
-        arrange_transform = ArrangeDetector(mode)
-    else:
-        raise Exception("Unrecognized model type: {}".format(model_type))
-    transforms.arrange_outputs = arrange_transform
+def decode_seg_mask(mask_path):
+    """
+    Decode a segmentation mask image.
+    
+    Args:
+        mask_path (str): Path of the mask image to decode.
+    """
+
+    mask = np.asarray(Image.open(mask_path))
+    mask = mask.astype('int64')
+    return mask
 
 
 def build_transforms(transforms_info):
@@ -80,7 +74,8 @@ def build_transforms(transforms_info):
         op_name = list(op_info.keys())[0]
         op_attr = op_info[op_name]
         if not hasattr(T, op_name):
-            raise Exception("There's no transform named '{}'".format(op_name))
+            raise ValueError(
+                "There is no transform operator named '{}'.".format(op_name))
         transforms.append(getattr(T, op_name)(**op_attr))
     eval_transforms = T.Compose(transforms)
     return eval_transforms