Browse Source

Merge pull request #1 from Bobholamovic/refactor_data

[Refactor] All about paddlers.transforms.Compose
cc 2 years ago
parent
commit
1932543c37
48 changed files with 459 additions and 310 deletions
  1. 1 1
      README.md
  2. 2 0
      docs/apis/transforms.md
  3. 1 1
      docs/cases/csc_cd_cn.md
  4. 35 0
      paddlers/datasets/base.py
  5. 20 13
      paddlers/datasets/cd_dataset.py
  6. 10 18
      paddlers/datasets/clas_dataset.py
  7. 9 11
      paddlers/datasets/coco.py
  8. 11 17
      paddlers/datasets/seg_dataset.py
  9. 9 11
      paddlers/datasets/voc.py
  10. 3 3
      paddlers/deploy/predictor.py
  11. 14 9
      paddlers/tasks/base.py
  12. 14 7
      paddlers/tasks/change_detector.py
  13. 9 7
      paddlers/tasks/classifier.py
  14. 9 7
      paddlers/tasks/object_detector.py
  15. 14 7
      paddlers/tasks/segmenter.py
  16. 7 25
      paddlers/transforms/__init__.py
  17. 17 0
      paddlers/transforms/functions.py
  18. 123 143
      paddlers/transforms/operators.py
  19. 2 2
      paddlers/utils/__init__.py
  20. 1 1
      paddlers/utils/utils.py
  21. 1 1
      tests/data/data_utils.py
  22. 16 4
      tests/deploy/test_predictor.py
  23. 1 1
      tests/download_test_data.sh
  24. 2 2
      tests/test_tutorials.py
  25. 1 1
      tools/raster2geotiff.py
  26. 2 2
      tools/raster2vector.py
  27. 1 1
      tools/split.py
  28. 2 2
      tools/utils/raster.py
  29. 6 0
      tutorials/train/change_detection/bit.py
  30. 6 0
      tutorials/train/change_detection/cdnet.py
  31. 6 0
      tutorials/train/change_detection/dsamnet.py
  32. 6 0
      tutorials/train/change_detection/dsifn.py
  33. 6 0
      tutorials/train/change_detection/fc_ef.py
  34. 6 0
      tutorials/train/change_detection/fc_siam_conc.py
  35. 6 0
      tutorials/train/change_detection/fc_siam_diff.py
  36. 6 0
      tutorials/train/change_detection/snunet.py
  37. 6 0
      tutorials/train/change_detection/stanet.py
  38. 6 3
      tutorials/train/classification/condensenetv2_b_rs_mul.py
  39. 5 0
      tutorials/train/classification/hrnet.py
  40. 5 0
      tutorials/train/classification/mobilenetv3.py
  41. 5 0
      tutorials/train/classification/resnet50_vd.py
  42. 7 2
      tutorials/train/object_detection/faster_rcnn.py
  43. 7 2
      tutorials/train/object_detection/ppyolo.py
  44. 7 2
      tutorials/train/object_detection/ppyolotiny.py
  45. 7 2
      tutorials/train/object_detection/ppyolov2.py
  46. 7 2
      tutorials/train/object_detection/yolov3.py
  47. 6 0
      tutorials/train/semantic_segmentation/deeplabv3p.py
  48. 6 0
      tutorials/train/semantic_segmentation/unet.py

+ 1 - 1
README.md

@@ -8,7 +8,7 @@
 
 
   <!-- [![version](https://img.shields.io/github/release/PaddlePaddle/PaddleRS.svg)](https://github.com/PaddlePaddle/PaddleRS/releases) -->
   <!-- [![version](https://img.shields.io/github/release/PaddlePaddle/PaddleRS.svg)](https://github.com/PaddlePaddle/PaddleRS/releases) -->
   [![license](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
   [![license](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
-  [![build status](https://github.com/PaddlePaddle/PaddleRS/workflows/build_and_test.yaml/badge.svg?branch=develop)](https://github.com/PaddlePaddle/PaddleRS/actions)
+  [![build status](https://github.com/PaddlePaddle/PaddleRS/actions/workflows/build_and_test.yaml/badge.svg?branch=develop)](https://github.com/PaddlePaddle/PaddleRS/actions)
   ![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
   ![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
   ![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
   ![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
 </div>
 </div>

+ 2 - 0
docs/apis/transforms.md

@@ -36,10 +36,12 @@ from paddlers.datasets import CDDataset
 
 
 
 
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=512),
     T.Resize(target_size=512),
     T.RandomHorizontalFlip(),
     T.RandomHorizontalFlip(),
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 train_dataset = CDDataset(
 train_dataset = CDDataset(

+ 1 - 1
docs/cases/csc_cd_cn.md

@@ -365,7 +365,7 @@ class InferDataset(paddle.io.Dataset):
         names = []
         names = []
         for line in lines:
         for line in lines:
             items = line.strip().split(' ')
             items = line.strip().split(' ')
-            items = list(map(pdrs.utils.path_normalization, items))
+            items = list(map(pdrs.utils.norm_path, items))
             item_dict = {
             item_dict = {
                 'image_t1': osp.join(data_dir, items[0]),
                 'image_t1': osp.join(data_dir, items[0]),
                 'image_t2': osp.join(data_dir, items[1])
                 'image_t2': osp.join(data_dir, items[1])

+ 35 - 0
paddlers/datasets/base.py

@@ -0,0 +1,35 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+from paddle.io import Dataset
+
+from paddlers.utils import get_num_workers
+
+
+class BaseDataset(Dataset):
+    def __init__(self, data_dir, label_list, transforms, num_workers, shuffle):
+        super(BaseDataset, self).__init__()
+
+        self.data_dir = data_dir
+        self.label_list = label_list
+        self.transforms = deepcopy(transforms)
+        self.num_workers = get_num_workers(num_workers)
+        self.shuffle = shuffle
+
+    def __getitem__(self, idx):
+        sample = deepcopy(self.file_list[idx])
+        outputs = self.transforms(sample)
+        return outputs

+ 20 - 13
paddlers/datasets/cd_dataset.py

@@ -16,12 +16,11 @@ import copy
 from enum import IntEnum
 from enum import IntEnum
 import os.path as osp
 import os.path as osp
 
 
-from paddle.io import Dataset
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
-
-class CDDataset(Dataset):
+class CDDataset(BaseDataset):
     """
     """
     读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
     读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
 
 
@@ -31,8 +30,10 @@ class CDDataset(Dataset):
             False(默认设置)时,文件中每一行应依次包含第一时相影像、第二时相影像以及变化检测标签的路径;当`with_seg_labels`为True时,
             False(默认设置)时,文件中每一行应依次包含第一时相影像、第二时相影像以及变化检测标签的路径;当`with_seg_labels`为True时,
             文件中每一行应依次包含第一时相影像、第二时相影像、变化检测标签、第一时相建筑物标签以及第二时相建筑物标签的路径。
             文件中每一行应依次包含第一时相影像、第二时相影像、变化检测标签、第一时相建筑物标签以及第二时相建筑物标签的路径。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
-        transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
-        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
         with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
         binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作。默认为False。
         binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作。默认为False。
@@ -47,15 +48,13 @@ class CDDataset(Dataset):
                  shuffle=False,
                  shuffle=False,
                  with_seg_labels=False,
                  with_seg_labels=False,
                  binarize_labels=False):
                  binarize_labels=False):
-        super(CDDataset, self).__init__()
+        super(CDDataset, self).__init__(data_dir, label_list, transforms,
+                                        num_workers, shuffle)
 
 
         DELIMETER = ' '
         DELIMETER = ' '
 
 
-        self.transforms = copy.deepcopy(transforms)
         # TODO: batch padding
         # TODO: batch padding
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
         self.with_seg_labels = with_seg_labels
         self.with_seg_labels = with_seg_labels
@@ -82,7 +81,7 @@ class CDDataset(Dataset):
                         "Line[{}] in file_list[{}] has an incorrect number of file paths.".
                         "Line[{}] in file_list[{}] has an incorrect number of file paths.".
                         format(line.strip(), file_list))
                         format(line.strip(), file_list))
 
 
-                items = list(map(path_normalization, items))
+                items = list(map(norm_path, items))
 
 
                 full_path_im_t1 = osp.join(data_dir, items[0])
                 full_path_im_t1 = osp.join(data_dir, items[0])
                 full_path_im_t2 = osp.join(data_dir, items[1])
                 full_path_im_t2 = osp.join(data_dir, items[1])
@@ -128,9 +127,17 @@ class CDDataset(Dataset):
 
 
     def __getitem__(self, idx):
     def __getitem__(self, idx):
         sample = copy.deepcopy(self.file_list[idx])
         sample = copy.deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
+        sample = self.transforms.apply_transforms(sample)
+
         if self.binarize_labels:
         if self.binarize_labels:
-            outputs = outputs[:2] + tuple(map(self._binarize, outputs[2:]))
+            # Requires 'mask' to exist
+            sample['mask'] = self._binarize(sample['mask'])
+            if 'aux_masks' in sample:
+                sample['aux_masks'] = list(
+                    map(self._binarize, sample['aux_masks']))
+
+        outputs = self.transforms.arrange_outputs(sample)
+
         return outputs
         return outputs
 
 
     def __len__(self):
     def __len__(self):

+ 10 - 18
paddlers/datasets/clas_dataset.py

@@ -13,22 +13,22 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import os.path as osp
 import os.path as osp
-import copy
 
 
-from paddle.io import Dataset
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
-
-class ClasDataset(Dataset):
+class ClasDataset(BaseDataset):
     """读取图像分类任务数据集,并对样本进行相应的处理。
     """读取图像分类任务数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
         data_dir (str): 数据集所在的目录路径。
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注序号(文本内每行路径为相对data_dir的相对路)。
         file_list (str): 描述数据集图片文件和对应标注序号(文本内每行路径为相对data_dir的相对路)。
         label_list (str): 描述数据集包含的类别信息文件路径,文件格式为(类别 说明)。默认值为None。
         label_list (str): 描述数据集包含的类别信息文件路径,文件格式为(类别 说明)。默认值为None。
-        transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
-        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 
@@ -39,14 +39,11 @@ class ClasDataset(Dataset):
                  transforms=None,
                  transforms=None,
                  num_workers='auto',
                  num_workers='auto',
                  shuffle=False):
                  shuffle=False):
-        super(ClasDataset, self).__init__()
-        self.transforms = copy.deepcopy(transforms)
+        super(ClasDataset, self).__init__(data_dir, label_list, transforms,
+                                          num_workers, shuffle)
         # TODO batch padding
         # TODO batch padding
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.file_list = list()
         self.file_list = list()
-        self.label_list = label_list
         self.labels = list()
         self.labels = list()
 
 
         # TODO:非None时,让用户跳转数据集分析生成label_list
         # TODO:非None时,让用户跳转数据集分析生成label_list
@@ -64,7 +61,7 @@ class ClasDataset(Dataset):
                         "A space is defined as the delimiter to separate the image and label path, " \
                         "A space is defined as the delimiter to separate the image and label path, " \
                         "so the space cannot be in the image or label path, but the line[{}] of " \
                         "so the space cannot be in the image or label path, but the line[{}] of " \
                         " file_list[{}] has a space in the image or label path.".format(line, file_list))
                         " file_list[{}] has a space in the image or label path.".format(line, file_list))
-                items[0] = path_normalization(items[0])
+                items[0] = norm_path(items[0])
                 full_path_im = osp.join(data_dir, items[0])
                 full_path_im = osp.join(data_dir, items[0])
                 label = items[1]
                 label = items[1]
                 if not is_pic(full_path_im):
                 if not is_pic(full_path_im):
@@ -84,10 +81,5 @@ class ClasDataset(Dataset):
         logging.info("{} samples in file {}".format(
         logging.info("{} samples in file {}".format(
             len(self.file_list), file_list))
             len(self.file_list), file_list))
 
 
-    def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
-        return outputs
-
     def __len__(self):
     def __len__(self):
         return len(self.file_list)
         return len(self.file_list)

+ 9 - 11
paddlers/datasets/coco.py

@@ -20,14 +20,14 @@ import random
 from collections import OrderedDict
 from collections import OrderedDict
 
 
 import numpy as np
 import numpy as np
-from paddle.io import Dataset
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.tools import YOLOAnchorCluster
 from paddlers.tools import YOLOAnchorCluster
 
 
 
 
-class COCODetection(Dataset):
+class COCODetection(BaseDataset):
     """读取COCO格式的检测数据集,并对样本进行相应的处理。
     """读取COCO格式的检测数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
@@ -35,7 +35,7 @@ class COCODetection(Dataset):
         image_dir (str): 描述数据集图片文件路径。
         image_dir (str): 描述数据集图片文件路径。
         anno_path (str): COCO标注文件路径。
         anno_path (str): COCO标注文件路径。
         label_list (str): 描述数据集包含的类别信息文件路径。
         label_list (str): 描述数据集包含的类别信息文件路径。
-        transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
             一半。
@@ -60,10 +60,10 @@ class COCODetection(Dataset):
         import matplotlib
         import matplotlib
         matplotlib.use('Agg')
         matplotlib.use('Agg')
         from pycocotools.coco import COCO
         from pycocotools.coco import COCO
-        super(COCODetection, self).__init__()
-        self.data_dir = data_dir
+        super(COCODetection, self).__init__(data_dir, label_list, transforms,
+                                            num_workers, shuffle)
+
         self.data_fields = None
         self.data_fields = None
-        self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
         self.num_max_boxes = 50
 
 
         self.use_mix = False
         self.use_mix = False
@@ -76,8 +76,6 @@ class COCODetection(Dataset):
                     break
                     break
 
 
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.allow_empty = allow_empty
         self.allow_empty = allow_empty
         self.empty_ratio = empty_ratio
         self.empty_ratio = empty_ratio
         self.file_list = list()
         self.file_list = list()
@@ -104,8 +102,8 @@ class COCODetection(Dataset):
                 'name': k
                 'name': k
             })
             })
 
 
-        anno_path = path_normalization(os.path.join(self.data_dir, anno_path))
-        image_dir = path_normalization(os.path.join(self.data_dir, image_dir))
+        anno_path = norm_path(os.path.join(self.data_dir, anno_path))
+        image_dir = norm_path(os.path.join(self.data_dir, image_dir))
 
 
         assert anno_path.endswith('.json'), \
         assert anno_path.endswith('.json'), \
             'invalid coco annotation file: ' + anno_path
             'invalid coco annotation file: ' + anno_path

+ 11 - 17
paddlers/datasets/seg_dataset.py

@@ -15,20 +15,21 @@
 import os.path as osp
 import os.path as osp
 import copy
 import copy
 
 
-from paddle.io import Dataset
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
-
-class SegDataset(Dataset):
+class SegDataset(BaseDataset):
     """读取语义分割任务数据集,并对样本进行相应的处理。
     """读取语义分割任务数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
         data_dir (str): 数据集所在的目录路径。
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
-        transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
-        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 
@@ -39,12 +40,10 @@ class SegDataset(Dataset):
                  transforms=None,
                  transforms=None,
                  num_workers='auto',
                  num_workers='auto',
                  shuffle=False):
                  shuffle=False):
-        super(SegDataset, self).__init__()
-        self.transforms = copy.deepcopy(transforms)
+        super(SegDataset, self).__init__(data_dir, label_list, transforms,
+                                         num_workers, shuffle)
         # TODO batch padding
         # TODO batch padding
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
 
 
@@ -63,8 +62,8 @@ class SegDataset(Dataset):
                         "A space is defined as the delimiter to separate the image and label path, " \
                         "A space is defined as the delimiter to separate the image and label path, " \
                         "so the space cannot be in the image or label path, but the line[{}] of " \
                         "so the space cannot be in the image or label path, but the line[{}] of " \
                         " file_list[{}] has a space in the image or label path.".format(line, file_list))
                         " file_list[{}] has a space in the image or label path.".format(line, file_list))
-                items[0] = path_normalization(items[0])
-                items[1] = path_normalization(items[1])
+                items[0] = norm_path(items[0])
+                items[1] = norm_path(items[1])
                 full_path_im = osp.join(data_dir, items[0])
                 full_path_im = osp.join(data_dir, items[0])
                 full_path_label = osp.join(data_dir, items[1])
                 full_path_label = osp.join(data_dir, items[1])
                 if not is_pic(full_path_im) or not is_pic(full_path_label):
                 if not is_pic(full_path_im) or not is_pic(full_path_label):
@@ -83,10 +82,5 @@ class SegDataset(Dataset):
         logging.info("{} samples in file {}".format(
         logging.info("{} samples in file {}".format(
             len(self.file_list), file_list))
             len(self.file_list), file_list))
 
 
-    def __getitem__(self, idx):
-        sample = copy.deepcopy(self.file_list[idx])
-        outputs = self.transforms(sample)
-        return outputs
-
     def __len__(self):
     def __len__(self):
         return len(self.file_list)
         return len(self.file_list)

+ 9 - 11
paddlers/datasets/voc.py

@@ -22,21 +22,21 @@ from collections import OrderedDict
 import xml.etree.ElementTree as ET
 import xml.etree.ElementTree as ET
 
 
 import numpy as np
 import numpy as np
-from paddle.io import Dataset
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, norm_path, is_pic
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.tools import YOLOAnchorCluster
 from paddlers.tools import YOLOAnchorCluster
 
 
 
 
-class VOCDetection(Dataset):
+class VOCDetection(BaseDataset):
     """读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
     """读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
         data_dir (str): 数据集所在的目录路径。
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         label_list (str): 描述数据集包含的类别信息文件路径。
         label_list (str): 描述数据集包含的类别信息文件路径。
-        transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
             一半。
@@ -60,10 +60,10 @@ class VOCDetection(Dataset):
         import matplotlib
         import matplotlib
         matplotlib.use('Agg')
         matplotlib.use('Agg')
         from pycocotools.coco import COCO
         from pycocotools.coco import COCO
-        super(VOCDetection, self).__init__()
-        self.data_dir = data_dir
+        super(VOCDetection, self).__init__(data_dir, label_list, transforms,
+                                           num_workers, shuffle)
+
         self.data_fields = None
         self.data_fields = None
-        self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
         self.num_max_boxes = 50
 
 
         self.use_mix = False
         self.use_mix = False
@@ -76,8 +76,6 @@ class VOCDetection(Dataset):
                     break
                     break
 
 
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.allow_empty = allow_empty
         self.allow_empty = allow_empty
         self.empty_ratio = empty_ratio
         self.empty_ratio = empty_ratio
         self.file_list = list()
         self.file_list = list()
@@ -117,8 +115,8 @@ class VOCDetection(Dataset):
                 img_file, xml_file = [
                 img_file, xml_file = [
                     osp.join(data_dir, x) for x in line.strip().split()[:2]
                     osp.join(data_dir, x) for x in line.strip().split()[:2]
                 ]
                 ]
-                img_file = path_normalization(img_file)
-                xml_file = path_normalization(xml_file)
+                img_file = norm_path(img_file)
+                xml_file = norm_path(xml_file)
                 if not is_pic(img_file):
                 if not is_pic(img_file):
                     continue
                     continue
                 if not osp.isfile(xml_file):
                 if not osp.isfile(xml_file):

+ 3 - 3
paddlers/deploy/predictor.py

@@ -258,9 +258,9 @@ class Predictor(object):
             Args:
             Args:
                 img_file(list[str | tuple | np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
                 img_file(list[str | tuple | np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
                     object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict
                     object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict
-                    , a decoded image (a `np.ndarray`, which should be consistent with what you get from passing image path to
-                    `paddlers.transforms.decode_image()`), or a list of image paths or decoded images. For change detection tasks,
-                    `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
+                    , a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
+                    paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
+                    img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
                 topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
                 topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
                 transforms (paddlers.transforms.Compose | None, optional): Pipeline of data preprocessing. If None, load transforms
                 transforms (paddlers.transforms.Compose | None, optional): Pipeline of data preprocessing. If None, load transforms
                     from `model.yml`. Defaults to None.
                     from `model.yml`. Defaults to None.

+ 14 - 9
paddlers/tasks/base.py

@@ -30,7 +30,6 @@ from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 
 
 import paddlers
 import paddlers
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
-from paddlers.transforms import arrange_transforms
 from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
 from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
                             get_pretrain_weights, load_pretrain_weights,
                             get_pretrain_weights, load_pretrain_weights,
                             load_checkpoint, SmoothedValue, TrainingStats,
                             load_checkpoint, SmoothedValue, TrainingStats,
@@ -302,10 +301,7 @@ class BaseModel(metaclass=ModelMeta):
                    early_stop=False,
                    early_stop=False,
                    early_stop_patience=5,
                    early_stop_patience=5,
                    use_vdl=True):
                    use_vdl=True):
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=train_dataset.transforms,
-            mode='train')
+        self._check_transforms(train_dataset.transforms, 'train')
 
 
         if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
         if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
                 train_dataset.file_list):
@@ -488,10 +484,7 @@ class BaseModel(metaclass=ModelMeta):
 
 
         assert criterion in {'l1_norm', 'fpgm'}, \
         assert criterion in {'l1_norm', 'fpgm'}, \
             "Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
             "Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=dataset.transforms,
-            mode='eval')
+        self._check_transforms(dataset.transforms, 'eval')
         if self.model_type == 'detector':
         if self.model_type == 'detector':
             self.net.eval()
             self.net.eval()
         else:
         else:
@@ -670,3 +663,15 @@ class BaseModel(metaclass=ModelMeta):
         open(osp.join(save_dir, '.success'), 'w').close()
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("The model for the inference deployment is saved in {}.".
         logging.info("The model for the inference deployment is saved in {}.".
                      format(save_dir))
                      format(save_dir))
+
+    def _check_transforms(self, transforms, mode):
+        # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
+        if not isinstance(transforms, paddlers.transforms.Compose):
+            raise TypeError("`transforms` must be paddlers.transforms.Compose.")
+        arrange_obj = transforms.arrange
+        if not isinstance(arrange_obj, paddlers.transforms.operators.Arrange):
+            raise TypeError("`transforms.arrange` must be an Arrange object.")
+        if arrange_obj.mode != mode:
+            raise ValueError(
+                f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
+            )

+ 14 - 7
paddlers/tasks/change_detector.py

@@ -28,7 +28,6 @@ import paddlers
 import paddlers.custom_models.cd as cmcd
 import paddlers.custom_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 import paddlers.models.ppseg as paddleseg
 import paddlers.models.ppseg as paddleseg
-from paddlers.transforms import arrange_transforms
 from paddlers.transforms import Resize, decode_image
 from paddlers.transforms import Resize, decode_image
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
@@ -137,6 +136,11 @@ class BaseChangeDetector(BaseModel):
             else:
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[2]
             label = inputs[2]
+            if label.ndim == 3:
+                paddle.unsqueeze_(label, axis=1)
+            if label.ndim != 4:
+                raise ValueError("Expected label.ndim == 4 but got {}".format(
+                    label.ndim))
             origin_shape = [label.shape[-2:]]
             origin_shape = [label.shape[-2:]]
             pred = self._postprocess(
             pred = self._postprocess(
                 pred, origin_shape, transforms=inputs[3])[0]  # NCHW
                 pred, origin_shape, transforms=inputs[3])[0]  # NCHW
@@ -396,10 +400,7 @@ class BaseChangeDetector(BaseModel):
                  "category_F1-score": `F1 score`}.
                  "category_F1-score": `F1 score`}.
 
 
         """
         """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
 
         self.net.eval()
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
@@ -641,8 +642,7 @@ class BaseChangeDetector(BaseModel):
         print("GeoTiff saved in {}.".format(save_file))
         print("GeoTiff saved in {}.".format(save_file))
 
 
     def _preprocess(self, images, transforms, to_tensor=True):
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_im1, batch_im2 = list(), list()
         batch_im1, batch_im2 = list(), list()
         batch_ori_shape = list()
         batch_ori_shape = list()
         for im1, im2 in images:
         for im1, im2 in images:
@@ -786,6 +786,13 @@ class BaseChangeDetector(BaseModel):
             score_maps.append(score_map.squeeze())
             score_maps.append(score_map.squeeze())
         return label_maps, score_maps
         return label_maps, score_maps
 
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeChangeDetector):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeChangeDetector object.")
+
 
 
 class CDNet(BaseChangeDetector):
 class CDNet(BaseChangeDetector):
     def __init__(self,
     def __init__(self,

+ 9 - 7
paddlers/tasks/classifier.py

@@ -25,7 +25,6 @@ from paddle.static import InputSpec
 import paddlers.models.ppcls as paddleclas
 import paddlers.models.ppcls as paddleclas
 import paddlers.custom_models.cls as cmcls
 import paddlers.custom_models.cls as cmcls
 import paddlers
 import paddlers
-from paddlers.transforms import arrange_transforms
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils import get_single_card_bs, DisablePrint
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from .base import BaseModel
 from .base import BaseModel
@@ -358,10 +357,7 @@ class BaseClassifier(BaseModel):
                  "top5": `acc of top5`}.
                  "top5": `acc of top5`}.
 
 
         """
         """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
 
         self.net.eval()
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
@@ -460,8 +456,7 @@ class BaseClassifier(BaseModel):
         return prediction
         return prediction
 
 
     def _preprocess(self, images, transforms, to_tensor=True):
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_im = list()
         batch_im = list()
         batch_ori_shape = list()
         batch_ori_shape = list()
         for im in images:
         for im in images:
@@ -527,6 +522,13 @@ class BaseClassifier(BaseModel):
             batch_restore_list.append(restore_list)
             batch_restore_list.append(restore_list)
         return batch_restore_list
         return batch_restore_list
 
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeClassifier):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeClassifier object.")
+
 
 
 class ResNet50_vd(BaseClassifier):
 class ResNet50_vd(BaseClassifier):
     def __init__(self, num_classes=2, use_mixed_loss=False, **params):
     def __init__(self, num_classes=2, use_mixed_loss=False, **params):

+ 9 - 7
paddlers/tasks/object_detector.py

@@ -31,7 +31,6 @@ from paddlers.transforms import decode_image
 from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
 from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
     _BatchPad, _Gt2YoloTarget
     _BatchPad, _Gt2YoloTarget
-from paddlers.transforms import arrange_transforms
 from .base import BaseModel
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric
 from .utils.det_metrics import VOCMetric, COCOMetric
 from paddlers.models.ppdet.optimizer import ModelEMA
 from paddlers.models.ppdet.optimizer import ModelEMA
@@ -452,10 +451,7 @@ class BaseDetector(BaseModel):
                 }
                 }
         eval_dataset.batch_transforms = self._compose_batch_transform(
         eval_dataset.batch_transforms = self._compose_batch_transform(
             eval_dataset.transforms, mode='eval')
             eval_dataset.transforms, mode='eval')
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
 
         self.net.eval()
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
@@ -545,8 +541,7 @@ class BaseDetector(BaseModel):
         return prediction
         return prediction
 
 
     def _preprocess(self, images, transforms, to_tensor=True):
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_samples = list()
         batch_samples = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
@@ -630,6 +625,13 @@ class BaseDetector(BaseModel):
 
 
         return results
         return results
 
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeDetector):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeDetector object.")
+
 
 
 class PicoDet(BaseDetector):
 class PicoDet(BaseDetector):
     def __init__(self,
     def __init__(self,

+ 14 - 7
paddlers/tasks/segmenter.py

@@ -26,7 +26,6 @@ from paddle.static import InputSpec
 import paddlers.models.ppseg as paddleseg
 import paddlers.models.ppseg as paddleseg
 import paddlers.custom_models.seg as cmseg
 import paddlers.custom_models.seg as cmseg
 import paddlers
 import paddlers
-from paddlers.transforms import arrange_transforms
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils import get_single_card_bs, DisablePrint
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from .base import BaseModel
 from .base import BaseModel
@@ -136,6 +135,11 @@ class BaseSegmenter(BaseModel):
             else:
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[1]
             label = inputs[1]
+            if label.ndim == 3:
+                paddle.unsqueeze_(label, axis=1)
+            if label.ndim != 4:
+                raise ValueError("Expected label.ndim == 4 but got {}".format(
+                    label.ndim))
             origin_shape = [label.shape[-2:]]
             origin_shape = [label.shape[-2:]]
             pred = self._postprocess(
             pred = self._postprocess(
                 pred, origin_shape, transforms=inputs[2])[0]  # NCHW
                 pred, origin_shape, transforms=inputs[2])[0]  # NCHW
@@ -380,10 +384,7 @@ class BaseSegmenter(BaseModel):
                  "category_F1-score": `F1 score`}.
                  "category_F1-score": `F1 score`}.
 
 
         """
         """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
 
         self.net.eval()
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
@@ -606,8 +607,7 @@ class BaseSegmenter(BaseModel):
         print("GeoTiff saved in {}.".format(save_file))
         print("GeoTiff saved in {}.".format(save_file))
 
 
     def _preprocess(self, images, transforms, to_tensor=True):
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_im = list()
         batch_im = list()
         batch_ori_shape = list()
         batch_ori_shape = list()
         for im in images:
         for im in images:
@@ -746,6 +746,13 @@ class BaseSegmenter(BaseModel):
             score_maps.append(score_map.squeeze())
             score_maps.append(score_map.squeeze())
         return label_maps, score_maps
         return label_maps, score_maps
 
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeSegmenter):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeSegmenter object.")
+
 
 
 class UNet(BaseSegmenter):
 class UNet(BaseSegmenter):
     def __init__(self,
     def __init__(self,

+ 7 - 25
paddlers/transforms/__init__.py

@@ -29,15 +29,19 @@ def decode_image(im_path,
     Decode an image.
     Decode an image.
     
     
     Args:
     Args:
+        im_path (str): Path of the image to decode.
         to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
         to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
         to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
         to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. 
         decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. 
             Defaults to True.
             Defaults to True.
         decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a 
         decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a 
             SAR image, set this argument to True. Defaults to True.
             SAR image, set this argument to True. Defaults to True.
+
+    Returns:
+        np.ndarray: Decoded image.
     """
     """
 
 
-    # Do a presence check. `osp.exists` assumes `im_path` is a path-like object.
+    # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
     if not osp.exists(im_path):
     if not osp.exists(im_path):
         raise ValueError(f"{im_path} does not exist!")
         raise ValueError(f"{im_path} does not exist!")
     decoder = T.DecodeImg(
     decoder = T.DecodeImg(
@@ -51,36 +55,14 @@ def decode_image(im_path,
     return sample['image']
     return sample['image']
 
 
 
 
-def arrange_transforms(model_type, transforms, mode='train'):
-    # 给transforms添加arrange操作
-    if model_type == 'segmenter':
-        if mode == 'eval':
-            transforms.apply_im_only = True
-        else:
-            transforms.apply_im_only = False
-        arrange_transform = ArrangeSegmenter(mode)
-    elif model_type == 'changedetector':
-        if mode == 'eval':
-            transforms.apply_im_only = True
-        else:
-            transforms.apply_im_only = False
-        arrange_transform = ArrangeChangeDetector(mode)
-    elif model_type == 'classifier':
-        arrange_transform = ArrangeClassifier(mode)
-    elif model_type == 'detector':
-        arrange_transform = ArrangeDetector(mode)
-    else:
-        raise Exception("Unrecognized model type: {}".format(model_type))
-    transforms.arrange_outputs = arrange_transform
-
-
 def build_transforms(transforms_info):
 def build_transforms(transforms_info):
     transforms = list()
     transforms = list()
     for op_info in transforms_info:
     for op_info in transforms_info:
         op_name = list(op_info.keys())[0]
         op_name = list(op_info.keys())[0]
         op_attr = op_info[op_name]
         op_attr = op_info[op_name]
         if not hasattr(T, op_name):
         if not hasattr(T, op_name):
-            raise Exception("There's no transform named '{}'".format(op_name))
+            raise ValueError(
+                "There is no transform operator named '{}'.".format(op_name))
         transforms.append(getattr(T, op_name)(**op_attr))
         transforms.append(getattr(T, op_name)(**op_attr))
     eval_transforms = T.Compose(transforms)
     eval_transforms = T.Compose(transforms)
     return eval_transforms
     return eval_transforms

+ 17 - 0
paddlers/transforms/functions.py

@@ -21,6 +21,7 @@ from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
 from sklearn.linear_model import LinearRegression
 from sklearn.linear_model import LinearRegression
 from skimage import exposure
 from skimage import exposure
 from joblib import load
 from joblib import load
+from PIL import Image
 
 
 
 
 def normalize(im, mean, std, min_value=[0, 0, 0], max_value=[255, 255, 255]):
 def normalize(im, mean, std, min_value=[0, 0, 0], max_value=[255, 255, 255]):
@@ -623,3 +624,19 @@ def inv_pca(im, joblib_path):
     r_im = pca.inverse_transform(n_im)
     r_im = pca.inverse_transform(n_im)
     r_im = np.reshape(r_im, (H, W, -1))
     r_im = np.reshape(r_im, (H, W, -1))
     return r_im
     return r_im
+
+
+def decode_seg_mask(mask_path):
+    """
+    Decode a segmentation mask image.
+    
+    Args:
+        mask_path (str): Path of the mask image to decode.
+
+    Returns:
+        np.ndarray: Decoded mask image.
+    """
+
+    mask = np.asarray(Image.open(mask_path))
+    mask = mask.astype('int64')
+    return mask

+ 123 - 143
paddlers/transforms/operators.py

@@ -30,39 +30,21 @@ from PIL import Image
 from joblib import load
 from joblib import load
 
 
 import paddlers
 import paddlers
-from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
-    horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
-    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, dehaze, select_bands, \
-    to_intensity, to_uint8, img_flip, img_simple_rotate
+from .functions import (
+    normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
+    horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
+    vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
+    resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
+    img_flip, img_simple_rotate, decode_seg_mask)
 
 
 __all__ = [
 __all__ = [
-    "Compose",
-    "DecodeImg",
-    "Resize",
-    "RandomResize",
-    "ResizeByShort",
-    "RandomResizeByShort",
-    "ResizeByLong",
-    "RandomHorizontalFlip",
-    "RandomVerticalFlip",
-    "Normalize",
-    "CenterCrop",
-    "RandomCrop",
-    "RandomScaleAspect",
-    "RandomExpand",
-    "Pad",
-    "MixupImage",
-    "RandomDistort",
-    "RandomBlur",
-    "RandomSwap",
-    "Dehaze",
-    "ReduceDim",
-    "SelectBand",
-    "ArrangeSegmenter",
-    "ArrangeChangeDetector",
-    "ArrangeClassifier",
-    "ArrangeDetector",
-    "RandomFlipOrRotate",
+    "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
+    "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
+    "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
+    "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
+    "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
+    "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
+    "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask"
 ]
 ]
 
 
 interp_dict = {
 interp_dict = {
@@ -74,6 +56,71 @@ interp_dict = {
 }
 }
 
 
 
 
+class Compose(object):
+    """
+    Apply a series of data augmentation strategies to the input.
+    All input images should be in Height-Width-Channel ([H, W, C]) format.
+
+    Args:
+        transforms (list[paddlers.transforms.Transform]): List of data preprocess or
+            augmentation operators.
+
+    Raises:
+        TypeError: Invalid type of transforms.
+        ValueError: Invalid length of transforms.
+    """
+
+    def __init__(self, transforms):
+        super(Compose, self).__init__()
+        if not isinstance(transforms, list):
+            raise TypeError(
+                "Type of transforms is invalid. Must be a list, but received is {}."
+                .format(type(transforms)))
+        if len(transforms) < 1:
+            raise ValueError(
+                "Length of transforms must not be less than 1, but received is {}."
+                .format(len(transforms)))
+        transforms = copy.deepcopy(transforms)
+        self.arrange = self._pick_arrange(transforms)
+        self.transforms = transforms
+
+    def __call__(self, sample):
+        """
+        This is equivalent to sequentially calling compose_obj.apply_transforms() 
+            and compose_obj.arrange_outputs().
+        """
+
+        sample = self.apply_transforms(sample)
+        sample = self.arrange_outputs(sample)
+        return sample
+
+    def apply_transforms(self, sample):
+        for op in self.transforms:
+            # Skip batch transforms amd mixup
+            if isinstance(op, (paddlers.transforms.BatchRandomResize,
+                               paddlers.transforms.BatchRandomResizeByShort,
+                               MixupImage)):
+                continue
+            sample = op(sample)
+        return sample
+
+    def arrange_outputs(self, sample):
+        if self.arrange is not None:
+            sample = self.arrange(sample)
+        return sample
+
+    def _pick_arrange(self, transforms):
+        arrange = None
+        for idx, op in enumerate(transforms):
+            if isinstance(op, Arrange):
+                if idx != len(transforms) - 1:
+                    raise ValueError(
+                        "Arrange operator must be placed at the end of the list."
+                    )
+                arrange = transforms.pop(idx)
+        return arrange
+
+
 class Transform(object):
 class Transform(object):
     """
     """
     Parent class of all data augmentation operations
     Parent class of all data augmentation operations
@@ -178,14 +225,14 @@ class DecodeImg(Transform):
         elif ext == '.npy':
         elif ext == '.npy':
             return np.load(img_path)
             return np.load(img_path)
         else:
         else:
-            raise TypeError('Image format {} is not supported!'.format(ext))
+            raise TypeError("Image format {} is not supported!".format(ext))
 
 
     def apply_im(self, im_path):
     def apply_im(self, im_path):
         if isinstance(im_path, str):
         if isinstance(im_path, str):
             try:
             try:
                 image = self.read_img(im_path)
                 image = self.read_img(im_path)
             except:
             except:
-                raise ValueError('Cannot read the image file {}!'.format(
+                raise ValueError("Cannot read the image file {}!".format(
                     im_path))
                     im_path))
         else:
         else:
             image = im_path
             image = im_path
@@ -217,7 +264,9 @@ class DecodeImg(Transform):
         Returns:
         Returns:
             dict: Decoded sample.
             dict: Decoded sample.
         """
         """
+
         if 'image' in sample:
         if 'image' in sample:
+            sample['image_ori'] = copy.deepcopy(sample['image'])
             sample['image'] = self.apply_im(sample['image'])
             sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = self.apply_im(sample['image2'])
             sample['image2'] = self.apply_im(sample['image2'])
@@ -227,6 +276,7 @@ class DecodeImg(Transform):
             sample['image'] = self.apply_im(sample['image_t1'])
             sample['image'] = self.apply_im(sample['image_t1'])
             sample['image2'] = self.apply_im(sample['image_t2'])
             sample['image2'] = self.apply_im(sample['image_t2'])
         if 'mask' in sample:
         if 'mask' in sample:
+            sample['mask_ori'] = copy.deepcopy(sample['mask'])
             sample['mask'] = self.apply_mask(sample['mask'])
             sample['mask'] = self.apply_mask(sample['mask'])
             im_height, im_width, _ = sample['image'].shape
             im_height, im_width, _ = sample['image'].shape
             se_height, se_width = sample['mask'].shape
             se_height, se_width = sample['mask'].shape
@@ -234,6 +284,7 @@ class DecodeImg(Transform):
                 raise ValueError(
                 raise ValueError(
                     "The height or width of the image is not same as the mask.")
                     "The height or width of the image is not same as the mask.")
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
+            sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks_ori'])
             sample['aux_masks'] = list(
             sample['aux_masks'] = list(
                 map(self.apply_mask, sample['aux_masks']))
                 map(self.apply_mask, sample['aux_masks']))
             # TODO: check the shape of auxiliary masks
             # TODO: check the shape of auxiliary masks
@@ -244,61 +295,6 @@ class DecodeImg(Transform):
         return sample
         return sample
 
 
 
 
-class Compose(Transform):
-    """
-    Apply a series of data augmentation to the input.
-    All input images are in Height-Width-Channel ([H, W, C]) format.
-
-    Args:
-        transforms (list[paddlers.transforms.Transform]): List of data preprocess or augmentations.
-    Raises:
-        TypeError: Invalid type of transforms.
-        ValueError: Invalid length of transforms.
-    """
-
-    def __init__(self, transforms, to_uint8=True):
-        super(Compose, self).__init__()
-        if not isinstance(transforms, list):
-            raise TypeError(
-                'Type of transforms is invalid. Must be a list, but received is {}'
-                .format(type(transforms)))
-        if len(transforms) < 1:
-            raise ValueError(
-                'Length of transforms must not be less than 1, but received is {}'
-                .format(len(transforms)))
-        self.transforms = transforms
-        self.decode_image = DecodeImg(to_uint8=to_uint8)
-        self.arrange_outputs = None
-        self.apply_im_only = False
-
-    def __call__(self, sample):
-        if self.apply_im_only:
-            if 'mask' in sample:
-                mask_backup = copy.deepcopy(sample['mask'])
-                del sample['mask']
-            if 'aux_masks' in sample:
-                aux_masks = copy.deepcopy(sample['aux_masks'])
-
-        sample = self.decode_image(sample)
-
-        for op in self.transforms:
-            # skip batch transforms amd mixup
-            if isinstance(op, (paddlers.transforms.BatchRandomResize,
-                               paddlers.transforms.BatchRandomResizeByShort,
-                               MixupImage)):
-                continue
-            sample = op(sample)
-
-        if self.arrange_outputs is not None:
-            if self.apply_im_only:
-                sample['mask'] = mask_backup
-                if 'aux_masks' in locals():
-                    sample['aux_masks'] = aux_masks
-            sample = self.arrange_outputs(sample)
-
-        return sample
-
-
 class Resize(Transform):
 class Resize(Transform):
     """
     """
     Resize input.
     Resize input.
@@ -323,7 +319,7 @@ class Resize(Transform):
     def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
     def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
         super(Resize, self).__init__()
         super(Resize, self).__init__()
         if not (interp == "RANDOM" or interp in interp_dict):
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}.".format(
                 interp_dict.keys()))
                 interp_dict.keys()))
         if isinstance(target_size, int):
         if isinstance(target_size, int):
             target_size = (target_size, target_size)
             target_size = (target_size, target_size)
@@ -331,7 +327,7 @@ class Resize(Transform):
             if not (isinstance(target_size,
             if not (isinstance(target_size,
                                (list, tuple)) and len(target_size) == 2):
                                (list, tuple)) and len(target_size) == 2):
                 raise TypeError(
                 raise TypeError(
-                    "target_size should be an int or a list of length 2, but received {}".
+                    "`target_size` should be an int or a list of length 2, but received {}.".
                     format(target_size))
                     format(target_size))
         # (height, width)
         # (height, width)
         self.target_size = target_size
         self.target_size = target_size
@@ -443,11 +439,11 @@ class RandomResize(Transform):
     def __init__(self, target_sizes, interp='LINEAR'):
     def __init__(self, target_sizes, interp='LINEAR'):
         super(RandomResize, self).__init__()
         super(RandomResize, self).__init__()
         if not (interp == "RANDOM" or interp in interp_dict):
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}.".format(
                 interp_dict.keys()))
                 interp_dict.keys()))
         self.interp = interp
         self.interp = interp
         assert isinstance(target_sizes, list), \
         assert isinstance(target_sizes, list), \
-            "target_size must be a list."
+            "`target_size` must be a list."
         for i, item in enumerate(target_sizes):
         for i, item in enumerate(target_sizes):
             if isinstance(item, int):
             if isinstance(item, int):
                 target_sizes[i] = (item, item)
                 target_sizes[i] = (item, item)
@@ -478,7 +474,7 @@ class ResizeByShort(Transform):
 
 
     def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
     def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
         if not (interp == "RANDOM" or interp in interp_dict):
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}".format(
                 interp_dict.keys()))
                 interp_dict.keys()))
         super(ResizeByShort, self).__init__()
         super(ResizeByShort, self).__init__()
         self.short_size = short_size
         self.short_size = short_size
@@ -522,11 +518,11 @@ class RandomResizeByShort(Transform):
     def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
     def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
         super(RandomResizeByShort, self).__init__()
         super(RandomResizeByShort, self).__init__()
         if not (interp == "RANDOM" or interp in interp_dict):
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}".format(
                 interp_dict.keys()))
                 interp_dict.keys()))
         self.interp = interp
         self.interp = interp
         assert isinstance(short_sizes, list), \
         assert isinstance(short_sizes, list), \
-            "short_sizes must be a list."
+            "`short_sizes` must be a list."
 
 
         self.short_sizes = short_sizes
         self.short_sizes = short_sizes
         self.max_size = max_size
         self.max_size = max_size
@@ -574,6 +570,7 @@ class RandomFlipOrRotate(Transform):
 
 
         # 定义数据增强
         # 定义数据增强
         train_transforms = T.Compose([
         train_transforms = T.Compose([
+            T.DecodeImg(),
             T.RandomFlipOrRotate(
             T.RandomFlipOrRotate(
                 probs  = [0.3, 0.2]             # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
                 probs  = [0.3, 0.2]             # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
                 probsf = [0.3, 0.25, 0, 0, 0]   # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
                 probsf = [0.3, 0.25, 0, 0, 0]   # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
@@ -609,12 +606,12 @@ class RandomFlipOrRotate(Transform):
 
 
     def apply_bbox(self, bbox, mode_id, flip_mode=True):
     def apply_bbox(self, bbox, mode_id, flip_mode=True):
         raise TypeError(
         raise TypeError(
-            "Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
+            "Currently, RandomFlipOrRotate is not available for object detection tasks."
         )
         )
 
 
     def apply_segm(self, bbox, mode_id, flip_mode=True):
     def apply_segm(self, bbox, mode_id, flip_mode=True):
         raise TypeError(
         raise TypeError(
-            "Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
+            "Currently, RandomFlipOrRotate is not available for object detection tasks."
         )
         )
 
 
     def get_probs_range(self, probs):
     def get_probs_range(self, probs):
@@ -845,11 +842,11 @@ class Normalize(Transform):
         from functools import reduce
         from functools import reduce
         if reduce(lambda x, y: x * y, std) == 0:
         if reduce(lambda x, y: x * y, std) == 0:
             raise ValueError(
             raise ValueError(
-                'Std should not contain 0, but received is {}.'.format(std))
+                "`std` should not contain 0, but received is {}.".format(std))
         if reduce(lambda x, y: x * y,
         if reduce(lambda x, y: x * y,
                   [a - b for a, b in zip(max_val, min_val)]) == 0:
                   [a - b for a, b in zip(max_val, min_val)]) == 0:
             raise ValueError(
             raise ValueError(
-                '(max_val - min_val) should not contain 0, but received is {}.'.
+                "(`max_val` - `min_val`) should not contain 0, but received is {}.".
                 format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
                 format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
 
 
         self.mean = mean
         self.mean = mean
@@ -1153,11 +1150,11 @@ class RandomExpand(Transform):
                  im_padding_value=127.5,
                  im_padding_value=127.5,
                  label_padding_value=255):
                  label_padding_value=255):
         super(RandomExpand, self).__init__()
         super(RandomExpand, self).__init__()
-        assert upper_ratio > 1.01, "expand ratio must be larger than 1.01"
+        assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
         self.upper_ratio = upper_ratio
         self.upper_ratio = upper_ratio
         self.prob = prob
         self.prob = prob
         assert isinstance(im_padding_value, (Number, Sequence)), \
         assert isinstance(im_padding_value, (Number, Sequence)), \
-            "fill value must be either float or sequence"
+            "Value to fill must be either float or sequence."
         self.im_padding_value = im_padding_value
         self.im_padding_value = im_padding_value
         self.label_padding_value = label_padding_value
         self.label_padding_value = label_padding_value
 
 
@@ -1204,16 +1201,16 @@ class Pad(Transform):
         if isinstance(target_size, (list, tuple)):
         if isinstance(target_size, (list, tuple)):
             if len(target_size) != 2:
             if len(target_size) != 2:
                 raise ValueError(
                 raise ValueError(
-                    '`target_size` should include 2 elements, but it is {}'.
+                    "`target_size` should contain 2 elements, but it is {}.".
                     format(target_size))
                     format(target_size))
         if isinstance(target_size, int):
         if isinstance(target_size, int):
             target_size = [target_size] * 2
             target_size = [target_size] * 2
 
 
         assert pad_mode in [
         assert pad_mode in [
             -1, 0, 1, 2
             -1, 0, 1, 2
-        ], 'currently only supports four modes [-1, 0, 1, 2]'
+        ], "Currently only four modes are supported: [-1, 0, 1, 2]."
         if pad_mode == -1:
         if pad_mode == -1:
-            assert offsets, 'if pad_mode is -1, offsets should not be None'
+            assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
 
 
         self.target_size = target_size
         self.target_size = target_size
         self.size_divisor = size_divisor
         self.size_divisor = size_divisor
@@ -1314,9 +1311,9 @@ class MixupImage(Transform):
         """
         """
         super(MixupImage, self).__init__()
         super(MixupImage, self).__init__()
         if alpha <= 0.0:
         if alpha <= 0.0:
-            raise ValueError("alpha should be positive in {}".format(self))
+            raise ValueError("`alpha` should be positive in MixupImage.")
         if beta <= 0.0:
         if beta <= 0.0:
-            raise ValueError("beta should be positive in {}".format(self))
+            raise ValueError("`beta` should be positive in MixupImage.")
         self.alpha = alpha
         self.alpha = alpha
         self.beta = beta
         self.beta = beta
         self.mixup_epoch = mixup_epoch
         self.mixup_epoch = mixup_epoch
@@ -1753,55 +1750,56 @@ class RandomSwap(Transform):
 
 
     def apply(self, sample):
     def apply(self, sample):
         if 'image2' not in sample:
         if 'image2' not in sample:
-            raise ValueError('image2 is not found in the sample.')
+            raise ValueError("'image2' is not found in the sample.")
         if random.random() < self.prob:
         if random.random() < self.prob:
             sample['image'], sample['image2'] = sample['image2'], sample[
             sample['image'], sample['image2'] = sample['image2'], sample[
                 'image']
                 'image']
         return sample
         return sample
 
 
 
 
-class ArrangeSegmenter(Transform):
+class ReloadMask(Transform):
+    def apply(self, sample):
+        sample['mask'] = decode_seg_mask(sample['mask_ori'])
+        if 'aux_masks' in sample:
+            sample['aux_masks'] = list(
+                map(decode_seg_mask, sample['aux_masks_ori']))
+        return sample
+
+
+class Arrange(Transform):
     def __init__(self, mode):
     def __init__(self, mode):
-        super(ArrangeSegmenter, self).__init__()
+        super().__init__()
         if mode not in ['train', 'eval', 'test', 'quant']:
         if mode not in ['train', 'eval', 'test', 'quant']:
             raise ValueError(
             raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
+                "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
             )
             )
         self.mode = mode
         self.mode = mode
 
 
+
+class ArrangeSegmenter(Arrange):
     def apply(self, sample):
     def apply(self, sample):
         if 'mask' in sample:
         if 'mask' in sample:
             mask = sample['mask']
             mask = sample['mask']
+            mask = mask.astype('int64')
 
 
         image = permute(sample['image'], False)
         image = permute(sample['image'], False)
         if self.mode == 'train':
         if self.mode == 'train':
-            mask = mask.astype('int64')
             return image, mask
             return image, mask
         if self.mode == 'eval':
         if self.mode == 'eval':
-            mask = np.asarray(Image.open(mask))
-            mask = mask[np.newaxis, :, :].astype('int64')
             return image, mask
             return image, mask
         if self.mode == 'test':
         if self.mode == 'test':
             return image,
             return image,
 
 
 
 
-class ArrangeChangeDetector(Transform):
-    def __init__(self, mode):
-        super(ArrangeChangeDetector, self).__init__()
-        if mode not in ['train', 'eval', 'test', 'quant']:
-            raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
-            )
-        self.mode = mode
-
+class ArrangeChangeDetector(Arrange):
     def apply(self, sample):
     def apply(self, sample):
         if 'mask' in sample:
         if 'mask' in sample:
             mask = sample['mask']
             mask = sample['mask']
+            mask = mask.astype('int64')
 
 
         image_t1 = permute(sample['image'], False)
         image_t1 = permute(sample['image'], False)
         image_t2 = permute(sample['image2'], False)
         image_t2 = permute(sample['image2'], False)
         if self.mode == 'train':
         if self.mode == 'train':
-            mask = mask.astype('int64')
             masks = [mask]
             masks = [mask]
             if 'aux_masks' in sample:
             if 'aux_masks' in sample:
                 masks.extend(
                 masks.extend(
@@ -1810,22 +1808,12 @@ class ArrangeChangeDetector(Transform):
                 image_t1,
                 image_t1,
                 image_t2, ) + tuple(masks)
                 image_t2, ) + tuple(masks)
         if self.mode == 'eval':
         if self.mode == 'eval':
-            mask = np.asarray(Image.open(mask))
-            mask = mask[np.newaxis, :, :].astype('int64')
             return image_t1, image_t2, mask
             return image_t1, image_t2, mask
         if self.mode == 'test':
         if self.mode == 'test':
             return image_t1, image_t2,
             return image_t1, image_t2,
 
 
 
 
-class ArrangeClassifier(Transform):
-    def __init__(self, mode):
-        super(ArrangeClassifier, self).__init__()
-        if mode not in ['train', 'eval', 'test', 'quant']:
-            raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
-            )
-        self.mode = mode
-
+class ArrangeClassifier(Arrange):
     def apply(self, sample):
     def apply(self, sample):
         image = permute(sample['image'], False)
         image = permute(sample['image'], False)
         if self.mode in ['train', 'eval']:
         if self.mode in ['train', 'eval']:
@@ -1834,15 +1822,7 @@ class ArrangeClassifier(Transform):
             return image
             return image
 
 
 
 
-class ArrangeDetector(Transform):
-    def __init__(self, mode):
-        super(ArrangeDetector, self).__init__()
-        if mode not in ['train', 'eval', 'test', 'quant']:
-            raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
-            )
-        self.mode = mode
-
+class ArrangeDetector(Arrange):
     def apply(self, sample):
     def apply(self, sample):
         if self.mode == 'eval' and 'gt_poly' in sample:
         if self.mode == 'eval' and 'gt_poly' in sample:
             del sample['gt_poly']
             del sample['gt_poly']

+ 2 - 2
paddlers/utils/__init__.py

@@ -15,8 +15,8 @@
 from . import logging
 from . import logging
 from . import utils
 from . import utils
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
-                    EarlyStop, path_normalization, is_pic, MyEncoder,
-                    DisablePrint, Timer)
+                    EarlyStop, norm_path, is_pic, MyEncoder, DisablePrint,
+                    Timer)
 from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress
 from .download import download_and_decompress, decompress

+ 1 - 1
paddlers/utils/utils.py

@@ -69,7 +69,7 @@ def dict2str(dict_input):
     return out.strip(', ')
     return out.strip(', ')
 
 
 
 
-def path_normalization(path):
+def norm_path(path):
     win_sep = "\\"
     win_sep = "\\"
     other_sep = "/"
     other_sep = "/"
     if platform.system() == "Windows":
     if platform.system() == "Windows":

+ 1 - 1
tests/data/data_utils.py

@@ -325,7 +325,7 @@ class ConstrDetSample(ConstrSample):
 
 
 def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
 def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
     """
     """
-    Construct a list of dictionaries from file. Each dict in the list can be used as the input to `paddlers.transforms.Transform` objects.
+    Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
 
 
     Args:
     Args:
         file_list (str): Path of file_list.
         file_list (str): Path of file_list.

+ 16 - 4
tests/deploy/test_predictor.py

@@ -120,7 +120,10 @@ class TestCDPredictor(TestPredictor):
         t2_path = "data/ssmt/optical_t2.bmp"
         t2_path = "data/ssmt/optical_t2.bmp"
         single_input = (t1_path, t2_path)
         single_input = (t1_path, t2_path)
         num_inputs = 2
         num_inputs = 2
-        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+        transforms = pdrs.transforms.Compose([
+            pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
+            pdrs.transforms.ArrangeChangeDetector('test')
+        ])
 
 
         # Expected failure
         # Expected failure
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
@@ -184,7 +187,10 @@ class TestClasPredictor(TestPredictor):
     def check_predictor(self, predictor, trainer):
     def check_predictor(self, predictor, trainer):
         single_input = "data/ssmt/optical_t1.bmp"
         single_input = "data/ssmt/optical_t1.bmp"
         num_inputs = 2
         num_inputs = 2
-        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+        transforms = pdrs.transforms.Compose([
+            pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
+            pdrs.transforms.ArrangeClassifier('test')
+        ])
         labels = list(range(2))
         labels = list(range(2))
         trainer.labels = labels
         trainer.labels = labels
         predictor._model.labels = labels
         predictor._model.labels = labels
@@ -249,7 +255,10 @@ class TestDetPredictor(TestPredictor):
         # given that the network is (partially?) randomly initialized.
         # given that the network is (partially?) randomly initialized.
         single_input = "data/ssmt/optical_t1.bmp"
         single_input = "data/ssmt/optical_t1.bmp"
         num_inputs = 2
         num_inputs = 2
-        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+        transforms = pdrs.transforms.Compose([
+            pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
+            pdrs.transforms.ArrangeDetector('test')
+        ])
         labels = list(range(80))
         labels = list(range(80))
         trainer.labels = labels
         trainer.labels = labels
         predictor._model.labels = labels
         predictor._model.labels = labels
@@ -303,7 +312,10 @@ class TestSegPredictor(TestPredictor):
     def check_predictor(self, predictor, trainer):
     def check_predictor(self, predictor, trainer):
         single_input = "data/ssmt/optical_t1.bmp"
         single_input = "data/ssmt/optical_t1.bmp"
         num_inputs = 2
         num_inputs = 2
-        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+        transforms = pdrs.transforms.Compose([
+            pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
+            pdrs.transforms.ArrangeSegmenter('test')
+        ])
 
 
         # Single input (file path)
         # Single input (file path)
         input_ = single_input
         input_ = single_input

+ 1 - 1
tests/download_test_data.sh

@@ -4,7 +4,7 @@ function remove_dir_if_exist() {
     local dir="$1"
     local dir="$1"
     if [ -d "${dir}" ]; then
     if [ -d "${dir}" ]; then
         rm -rf "${dir}"
         rm -rf "${dir}"
-        echo "\033[0;31mDirectory ${dir} has been removed.\033[0m"
+        echo -e "\033[0;31mDirectory ${dir} has been removed.\033[0m"
     fi
     fi
 }
 }
 
 

+ 2 - 2
tests/test_tutorials.py

@@ -29,7 +29,7 @@ class TestTutorial(CpuCommonTest):
     @classmethod
     @classmethod
     def setUpClass(cls):
     def setUpClass(cls):
         cls._td = tempfile.TemporaryDirectory(dir='./')
         cls._td = tempfile.TemporaryDirectory(dir='./')
-        # Recursively copy the content of `cls.SUBDIR` to td.
+        # Recursively copy the content of cls.SUBDIR to td.
         # This is necessary for running scripts in td.
         # This is necessary for running scripts in td.
         cls._TSUBDIR = osp.join(cls._td.name, osp.basename(cls.SUBDIR))
         cls._TSUBDIR = osp.join(cls._td.name, osp.basename(cls.SUBDIR))
         shutil.copytree(cls.SUBDIR, cls._TSUBDIR)
         shutil.copytree(cls.SUBDIR, cls._TSUBDIR)
@@ -47,7 +47,7 @@ class TestTutorial(CpuCommonTest):
 
 
         def _test_tutorial(script_name):
         def _test_tutorial(script_name):
             def _test_tutorial_impl(self):
             def _test_tutorial_impl(self):
-                # Set working directory to `cls._TSUBDIR` such that the 
+                # Set working directory to cls._TSUBDIR such that the 
                 # files generated by the script will be automatically cleaned.
                 # files generated by the script will be automatically cleaned.
                 run_script(f"python {script_name}", wd=cls._TSUBDIR)
                 run_script(f"python {script_name}", wd=cls._TSUBDIR)
 
 

+ 1 - 1
tools/raster2geotiff.py

@@ -46,7 +46,7 @@ def convert_data(image_path, geojson_path):
             geo_points = geo["coordinates"][0][0]
             geo_points = geo["coordinates"][0][0]
         else:
         else:
             raise TypeError(
             raise TypeError(
-                "Geometry type must be `Polygon` or `MultiPolygon`, not {}.".
+                "Geometry type must be 'Polygon' or 'MultiPolygon', not {}.".
                 format(geo["type"]))
                 format(geo["type"]))
         xy_points = np.array([
         xy_points = np.array([
             _gt_convert(point[0], point[1], raster.geot) for point in geo_points
             _gt_convert(point[0], point[1], raster.geot) for point in geo_points

+ 2 - 2
tools/raster2vector.py

@@ -76,7 +76,7 @@ def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
     vec_ext = save_path.split(".")[-1].lower()
     vec_ext = save_path.split(".")[-1].lower()
     if vec_ext not in ["json", "geojson", "shp"]:
     if vec_ext not in ["json", "geojson", "shp"]:
         raise ValueError(
         raise ValueError(
-            "The ext of `save_path` must be `json/geojson` or `shp`, not {}.".
+            "The extension of `save_path` must be 'json/geojson' or 'shp', not {}.".
             format(vec_ext))
             format(vec_ext))
     ras_ext = srcimg_path.split(".")[-1].lower()
     ras_ext = srcimg_path.split(".")[-1].lower()
     if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
     if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
@@ -93,7 +93,7 @@ parser = argparse.ArgumentParser()
 parser.add_argument("--mask_path", type=str, required=True, \
 parser.add_argument("--mask_path", type=str, required=True, \
                     help="Path of mask data.")
                     help="Path of mask data.")
 parser.add_argument("--save_path", type=str, required=True, \
 parser.add_argument("--save_path", type=str, required=True, \
-                    help="Path to save the shape file (the file suffix is `.json/geojson` or `.shp`).")
+                    help="Path to save the shape file (the extension is .json/geojson or .shp).")
 parser.add_argument("--srcimg_path", type=str, default="", \
 parser.add_argument("--srcimg_path", type=str, default="", \
                     help="Path of original data with geoinfo. Default to empty.")
                     help="Path of original data with geoinfo. Default to empty.")
 parser.add_argument("--ignore_index", type=int, default=255, \
 parser.add_argument("--ignore_index", type=int, default=255, \

+ 1 - 1
tools/split.py

@@ -75,7 +75,7 @@ parser.add_argument("--mask_path", type=str, default=None, \
 parser.add_argument("--block_size", type=int, default=512, \
 parser.add_argument("--block_size", type=int, default=512, \
                     help="Size of image block. Default value is 512.")
                     help="Size of image block. Default value is 512.")
 parser.add_argument("--save_dir", type=str, default="dataset", \
 parser.add_argument("--save_dir", type=str, default="dataset", \
-                    help="Directory to save the results. Default value is `dataset`.")
+                    help="Directory to save the results. Default value is 'dataset'.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()

+ 2 - 2
tools/utils/raster.py

@@ -42,7 +42,7 @@ def _get_type(type_name: str) -> int:
     elif type_name == "complex64":
     elif type_name == "complex64":
         gdal_type = gdal.GDT_CFloat64
         gdal_type = gdal.GDT_CFloat64
     else:
     else:
-        raise TypeError("Non-suported data type `{}`.".format(type_name))
+        raise TypeError("Non-suported data type {}.".format(type_name))
     return gdal_type
     return gdal_type
 
 
 
 
@@ -76,7 +76,7 @@ class Raster:
                         # https://www.osgeo.cn/gdal/drivers/raster/index.html
                         # https://www.osgeo.cn/gdal/drivers/raster/index.html
                         self._src_data = gdal.Open(path)
                         self._src_data = gdal.Open(path)
                     except:
                     except:
-                        raise TypeError("Unsupported data format: `{}`".format(
+                        raise TypeError("Unsupported data format: {}".format(
                             self.ext_type))
                             self.ext_type))
             else:
             else:
                 raise ValueError("The path {0} not exists.".format(path))
                 raise ValueError("The path {0} not exists.".format(path))

+ 6 - 0
tutorials/train/change_detection/bit.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/cdnet.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/dsamnet.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/dsifn.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/fc_ef.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/fc_siam_conc.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/fc_siam_diff.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/snunet.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/change_detection/stanet.py

@@ -23,6 +23,8 @@ pdrs.utils.download_and_decompress(airchange_dataset, path=DATA_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 随机裁剪
     # 随机裁剪
     T.RandomCrop(
     T.RandomCrop(
         # 裁剪区域将被缩放到256x256
         # 裁剪区域将被缩放到256x256
@@ -36,12 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 3
tutorials/train/classification/condensenetv2_b_rs_mul.py

@@ -3,18 +3,21 @@ from paddlers import transforms as T
 
 
 # 定义训练和验证时的transforms
 # 定义训练和验证时的transforms
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     T.SelectBand([5, 10, 15, 20, 25]),  # for tet
     T.SelectBand([5, 10, 15, 20, 25]),  # for tet
     T.Resize(target_size=224),
     T.Resize(target_size=224),
     T.RandomHorizontalFlip(),
     T.RandomHorizontalFlip(),
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
-    T.SelectBand([5, 10, 15, 20, 25]),
-    T.Resize(target_size=224),
+    T.DecodeImg(), T.SelectBand([5, 10, 15, 20, 25]), T.Resize(target_size=224),
     T.Normalize(
     T.Normalize(
-        mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
+        mean=[0.5, 0.5, 0.5, 0.5, 0.5],
+        std=[0.5, 0.5, 0.5, 0.5, 0.5]), T.ArrangeClassifier('eval')
 ])
 ])
 
 
 # 定义训练和验证所用的数据集
 # 定义训练和验证所用的数据集

+ 5 - 0
tutorials/train/classification/hrnet.py

@@ -27,6 +27,8 @@ pdrs.utils.download_and_decompress(ucmerced_dataset, path=DOWNLOAD_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 将影像缩放到256x256大小
     # 将影像缩放到256x256大小
     T.Resize(target_size=256),
     T.Resize(target_size=256),
     # 以50%的概率实施随机水平翻转
     # 以50%的概率实施随机水平翻转
@@ -36,13 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=256),
     T.Resize(target_size=256),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 5 - 0
tutorials/train/classification/mobilenetv3.py

@@ -27,6 +27,8 @@ pdrs.utils.download_and_decompress(ucmerced_dataset, path=DOWNLOAD_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 将影像缩放到256x256大小
     # 将影像缩放到256x256大小
     T.Resize(target_size=256),
     T.Resize(target_size=256),
     # 以50%的概率实施随机水平翻转
     # 以50%的概率实施随机水平翻转
@@ -36,13 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=256),
     T.Resize(target_size=256),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 5 - 0
tutorials/train/classification/resnet50_vd.py

@@ -27,6 +27,8 @@ pdrs.utils.download_and_decompress(ucmerced_dataset, path=DOWNLOAD_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 将影像缩放到256x256大小
     # 将影像缩放到256x256大小
     T.Resize(target_size=256),
     T.Resize(target_size=256),
     # 以50%的概率实施随机水平翻转
     # 以50%的概率实施随机水平翻转
@@ -36,13 +38,16 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=256),
     T.Resize(target_size=256),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeClassifier('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 7 - 2
tutorials/train/object_detection/faster_rcnn.py

@@ -30,6 +30,8 @@ if not os.path.exists(DATA_DIR):
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 对输入影像施加随机色彩扰动
     # 对输入影像施加随机色彩扰动
     T.RandomDistort(),
     T.RandomDistort(),
     # 在影像边界进行随机padding
     # 在影像边界进行随机padding
@@ -44,16 +46,19 @@ train_transforms = T.Compose([
         interp='RANDOM'),
         interp='RANDOM'),
     # 影像归一化
     # 影像归一化
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 使用双三次插值将输入影像缩放到固定大小
     # 使用双三次插值将输入影像缩放到固定大小
     T.Resize(
     T.Resize(
         target_size=608, interp='CUBIC'),
         target_size=608, interp='CUBIC'),
     # 验证阶段与训练阶段的归一化方式必须相同
     # 验证阶段与训练阶段的归一化方式必须相同
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 7 - 2
tutorials/train/object_detection/ppyolo.py

@@ -31,6 +31,8 @@ if not os.path.exists(DATA_DIR):
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 对输入影像施加随机色彩扰动
     # 对输入影像施加随机色彩扰动
     T.RandomDistort(),
     T.RandomDistort(),
     # 在影像边界进行随机padding
     # 在影像边界进行随机padding
@@ -45,16 +47,19 @@ train_transforms = T.Compose([
         interp='RANDOM'),
         interp='RANDOM'),
     # 影像归一化
     # 影像归一化
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 使用双三次插值将输入影像缩放到固定大小
     # 使用双三次插值将输入影像缩放到固定大小
     T.Resize(
     T.Resize(
         target_size=608, interp='CUBIC'),
         target_size=608, interp='CUBIC'),
     # 验证阶段与训练阶段的归一化方式必须相同
     # 验证阶段与训练阶段的归一化方式必须相同
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 7 - 2
tutorials/train/object_detection/ppyolotiny.py

@@ -31,6 +31,8 @@ if not os.path.exists(DATA_DIR):
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 对输入影像施加随机色彩扰动
     # 对输入影像施加随机色彩扰动
     T.RandomDistort(),
     T.RandomDistort(),
     # 在影像边界进行随机padding
     # 在影像边界进行随机padding
@@ -45,16 +47,19 @@ train_transforms = T.Compose([
         interp='RANDOM'),
         interp='RANDOM'),
     # 影像归一化
     # 影像归一化
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 使用双三次插值将输入影像缩放到固定大小
     # 使用双三次插值将输入影像缩放到固定大小
     T.Resize(
     T.Resize(
         target_size=608, interp='CUBIC'),
         target_size=608, interp='CUBIC'),
     # 验证阶段与训练阶段的归一化方式必须相同
     # 验证阶段与训练阶段的归一化方式必须相同
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 7 - 2
tutorials/train/object_detection/ppyolov2.py

@@ -31,6 +31,8 @@ if not os.path.exists(DATA_DIR):
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 对输入影像施加随机色彩扰动
     # 对输入影像施加随机色彩扰动
     T.RandomDistort(),
     T.RandomDistort(),
     # 在影像边界进行随机padding
     # 在影像边界进行随机padding
@@ -45,16 +47,19 @@ train_transforms = T.Compose([
         interp='RANDOM'),
         interp='RANDOM'),
     # 影像归一化
     # 影像归一化
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 使用双三次插值将输入影像缩放到固定大小
     # 使用双三次插值将输入影像缩放到固定大小
     T.Resize(
     T.Resize(
         target_size=608, interp='CUBIC'),
         target_size=608, interp='CUBIC'),
     # 验证阶段与训练阶段的归一化方式必须相同
     # 验证阶段与训练阶段的归一化方式必须相同
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 7 - 2
tutorials/train/object_detection/yolov3.py

@@ -31,6 +31,8 @@ if not os.path.exists(DATA_DIR):
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 对输入影像施加随机色彩扰动
     # 对输入影像施加随机色彩扰动
     T.RandomDistort(),
     T.RandomDistort(),
     # 在影像边界进行随机padding
     # 在影像边界进行随机padding
@@ -45,16 +47,19 @@ train_transforms = T.Compose([
         interp='RANDOM'),
         interp='RANDOM'),
     # 影像归一化
     # 影像归一化
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     # 使用双三次插值将输入影像缩放到固定大小
     # 使用双三次插值将输入影像缩放到固定大小
     T.Resize(
     T.Resize(
         target_size=608, interp='CUBIC'),
         target_size=608, interp='CUBIC'),
     # 验证阶段与训练阶段的归一化方式必须相同
     # 验证阶段与训练阶段的归一化方式必须相同
     T.Normalize(
     T.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+    T.ArrangeDetector('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/semantic_segmentation/deeplabv3p.py

@@ -30,6 +30,8 @@ pdrs.utils.download_and_decompress(seg_dataset, path=DOWNLOAD_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 将影像缩放到512x512大小
     # 将影像缩放到512x512大小
     T.Resize(target_size=512),
     T.Resize(target_size=512),
     # 以50%的概率实施随机水平翻转
     # 以50%的概率实施随机水平翻转
@@ -37,13 +39,17 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ArrangeSegmenter('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=512),
     T.Resize(target_size=512),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ReloadMask(),
+    T.ArrangeSegmenter('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集

+ 6 - 0
tutorials/train/semantic_segmentation/unet.py

@@ -30,6 +30,8 @@ pdrs.utils.download_and_decompress(seg_dataset, path=DOWNLOAD_DIR)
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 # API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
 train_transforms = T.Compose([
 train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
     # 将影像缩放到512x512大小
     # 将影像缩放到512x512大小
     T.Resize(target_size=512),
     T.Resize(target_size=512),
     # 以50%的概率实施随机水平翻转
     # 以50%的概率实施随机水平翻转
@@ -37,13 +39,17 @@ train_transforms = T.Compose([
     # 将数据归一化到[-1,1]
     # 将数据归一化到[-1,1]
     T.Normalize(
     T.Normalize(
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ArrangeSegmenter('train')
 ])
 ])
 
 
 eval_transforms = T.Compose([
 eval_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=512),
     T.Resize(target_size=512),
     # 验证阶段与训练阶段的数据归一化方式必须相同
     # 验证阶段与训练阶段的数据归一化方式必须相同
     T.Normalize(
     T.Normalize(
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
         mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ReloadMask(),
+    T.ArrangeSegmenter('eval')
 ])
 ])
 
 
 # 分别构建训练和验证所用的数据集
 # 分别构建训练和验证所用的数据集