Pārlūkot izejas kodu

[Re-request][Feature] Add RandomSwap and label binarization (#24)

Lin Manhui 3 gadi atpakaļ
vecāks
revīzija
29c904bc5b
2 mainītis faili ar 46 papildinājumiem un 14 dzēšanām
  1. 23 13
      paddlers/datasets/cd_dataset.py
  2. 23 1
      paddlers/transforms/operators.py

+ 23 - 13
paddlers/datasets/cd_dataset.py

@@ -26,12 +26,15 @@ class CDDataset(Dataset):
 
     Args:
         data_dir (str): 数据集所在的目录路径。
-        file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
+        file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路径)。当`with_seg_labels`为
+            False(默认设置)时,文件中每一行应依次包含第一时相影像、第二时相影像以及变化检测标签的路径;当`with_seg_labels`为True时,
+            文件中每一行应依次包含第一时相影像、第二时相影像、变化检测标签、第一时相建筑物标签以及第二时相建筑物标签的路径。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
         transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
+        binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作。默认为False。
     """
 
     def __init__(self,
@@ -41,7 +44,8 @@ class CDDataset(Dataset):
                  transforms=None,
                  num_workers='auto',
                  shuffle=False,
-                 with_seg_labels=False):
+                 with_seg_labels=False,
+                 binarize_labels=False):
         super(CDDataset, self).__init__()
 
         DELIMETER = ' '
@@ -55,9 +59,10 @@ class CDDataset(Dataset):
         self.labels = list()
         self.with_seg_labels = with_seg_labels
         if self.with_seg_labels:
-            num_items = 5   # 3+2
+            num_items = 5  # RGB1, RGB2, CD, Seg1, Seg2
         else:
-            num_items = 3
+            num_items = 3  # RGB1, RGB2, CD
+        self.binarize_labels = binarize_labels
 
         # TODO:非None时,让用户跳转数据集分析生成label_list
         # 不要在此处分析label file
@@ -66,15 +71,15 @@ class CDDataset(Dataset):
                 for line in f:
                     item = line.strip()
                     self.labels.append(item)
-                    
+
         with open(file_list, encoding=get_encoding(file_list)) as f:
             for line in f:
                 items = line.strip().split(DELIMETER)
 
                 if len(items) != num_items:
-                    raise Exception("Line[{}] in file_list[{}] has an incorrect number of file paths.".format(
-                        line.strip(), file_list
-                    ))
+                    raise Exception(
+                        "Line[{}] in file_list[{}] has an incorrect number of file paths.".
+                        format(line.strip(), file_list))
 
                 items = list(map(path_normalization, items))
                 if not all(map(is_pic, items)):
@@ -106,10 +111,11 @@ class CDDataset(Dataset):
                 item_dict = dict(
                     image_t1=full_path_im_t1,
                     image_t2=full_path_im_t2,
-                    mask=full_path_label
-                )
+                    mask=full_path_label)
                 if with_seg_labels:
-                    item_dict['aux_masks'] = [full_path_seg_label_t1, full_path_seg_label_t2]
+                    item_dict['aux_masks'] = [
+                        full_path_seg_label_t1, full_path_seg_label_t2
+                    ]
 
                 self.file_list.append(item_dict)
 
@@ -120,15 +126,19 @@ class CDDataset(Dataset):
     def __getitem__(self, idx):
         sample = copy.deepcopy(self.file_list[idx])
         outputs = self.transforms(sample)
-
+        if self.binarize_labels:
+            outputs = outputs[:2] + tuple(map(self._binarize, outputs[2:]))
         return outputs
 
     def __len__(self):
         return len(self.file_list)
 
+    def _binarize(self, mask, threshold=127):
+        return (mask > threshold).astype('int64')
+
 
 class MaskType(IntEnum):
     """Enumeration of the mask types used in the change detection task."""
     CD = 0
     SEG_T1 = 1
-    SEG_T2 = 2
+    SEG_T2 = 2

+ 23 - 1
paddlers/transforms/operators.py

@@ -39,7 +39,9 @@ __all__ = [
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
-    "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeChangeDetector", 
+    "RandomDistort", "RandomBlur", 
+    "RandomSwap",
+    "ArrangeSegmenter", "ArrangeChangeDetector", 
     "ArrangeClassifier", "ArrangeDetector"
 ]
 
@@ -1462,6 +1464,26 @@ class _Permute(Transform):
         if 'image2' in sample:
             sample['image2'] = permute(sample['image2'], False)
         return sample
+        
+
+class RandomSwap(Transform):
+    """
+    Randomly swap multi-temporal images.
+
+    Args:
+        prob (float, optional): Probability of swapping the input images. Default: 0.2.
+    """
+
+    def __init__(self, prob=0.2):
+        super(RandomSwap, self).__init__()
+        self.prob = prob
+
+    def apply(self, sample):
+        if 'image2' not in sample:
+            raise ValueError('image2 is not found in the sample.')
+        if random.random() < self.prob:
+            sample['image'], sample['image2'] = sample['image2'], sample['image']
+        return sample
 
 
 class ArrangeSegmenter(Transform):