Procházet zdrojové kódy

Cherry-pick commits from refactor_data

Bobholamovic před 2 roky
rodič
revize
9c382d82bb

+ 30 - 0
paddlers/datasets/base.py

@@ -0,0 +1,30 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+from paddle.io import Dataset
+
+from paddlers.utils import get_num_workers
+
+
+class BaseDataset(Dataset):
+    def __init__(self, data_dir, label_list, transforms, num_workers, shuffle):
+        super(BaseDataset, self).__init__()
+
+        self.data_dir = data_dir
+        self.label_list = label_list
+        self.transforms = deepcopy(transforms)
+        self.num_workers = get_num_workers(num_workers)
+        self.shuffle = shuffle

+ 9 - 10
paddlers/datasets/cd_dataset.py

@@ -16,12 +16,11 @@ import copy
 from enum import IntEnum
 from enum import IntEnum
 import os.path as osp
 import os.path as osp
 
 
-from paddle.io import Dataset
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, path_normalization, is_pic
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
-
-class CDDataset(Dataset):
+class CDDataset(BaseDataset):
     """
     """
     读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
     读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
 
 
@@ -31,8 +30,10 @@ class CDDataset(Dataset):
             False(默认设置)时,文件中每一行应依次包含第一时相影像、第二时相影像以及变化检测标签的路径;当`with_seg_labels`为True时,
             False(默认设置)时,文件中每一行应依次包含第一时相影像、第二时相影像以及变化检测标签的路径;当`with_seg_labels`为True时,
             文件中每一行应依次包含第一时相影像、第二时相影像、变化检测标签、第一时相建筑物标签以及第二时相建筑物标签的路径。
             文件中每一行应依次包含第一时相影像、第二时相影像、变化检测标签、第一时相建筑物标签以及第二时相建筑物标签的路径。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
-        transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
-        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
         with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
         binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作。默认为False。
         binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作。默认为False。
@@ -47,15 +48,13 @@ class CDDataset(Dataset):
                  shuffle=False,
                  shuffle=False,
                  with_seg_labels=False,
                  with_seg_labels=False,
                  binarize_labels=False):
                  binarize_labels=False):
-        super(CDDataset, self).__init__()
+        super(CDDataset, self).__init__(data_dir, label_list, transforms,
+                                        num_workers, shuffle)
 
 
         DELIMETER = ' '
         DELIMETER = ' '
 
 
-        self.transforms = copy.deepcopy(transforms)
         # TODO: batch padding
         # TODO: batch padding
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
         self.with_seg_labels = with_seg_labels
         self.with_seg_labels = with_seg_labels

+ 9 - 11
paddlers/datasets/clas_dataset.py

@@ -15,20 +15,21 @@
 import os.path as osp
 import os.path as osp
 import copy
 import copy
 
 
-from paddle.io import Dataset
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, path_normalization, is_pic
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
-
-class ClasDataset(Dataset):
+class ClasDataset(BaseDataset):
     """读取图像分类任务数据集,并对样本进行相应的处理。
     """读取图像分类任务数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
         data_dir (str): 数据集所在的目录路径。
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注序号(文本内每行路径为相对data_dir的相对路)。
         file_list (str): 描述数据集图片文件和对应标注序号(文本内每行路径为相对data_dir的相对路)。
         label_list (str): 描述数据集包含的类别信息文件路径,文件格式为(类别 说明)。默认值为None。
         label_list (str): 描述数据集包含的类别信息文件路径,文件格式为(类别 说明)。默认值为None。
-        transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
-        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 
@@ -39,14 +40,11 @@ class ClasDataset(Dataset):
                  transforms=None,
                  transforms=None,
                  num_workers='auto',
                  num_workers='auto',
                  shuffle=False):
                  shuffle=False):
-        super(ClasDataset, self).__init__()
-        self.transforms = copy.deepcopy(transforms)
+        super(ClasDataset, self).__init__(data_dir, label_list, transforms,
+                                          num_workers, shuffle)
         # TODO batch padding
         # TODO batch padding
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.file_list = list()
         self.file_list = list()
-        self.label_list = label_list
         self.labels = list()
         self.labels = list()
 
 
         # TODO:非None时,让用户跳转数据集分析生成label_list
         # TODO:非None时,让用户跳转数据集分析生成label_list

+ 7 - 9
paddlers/datasets/coco.py

@@ -20,14 +20,14 @@ import random
 from collections import OrderedDict
 from collections import OrderedDict
 
 
 import numpy as np
 import numpy as np
-from paddle.io import Dataset
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, path_normalization, is_pic
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.tools import YOLOAnchorCluster
 from paddlers.tools import YOLOAnchorCluster
 
 
 
 
-class COCODetection(Dataset):
+class COCODetection(BaseDataset):
     """读取COCO格式的检测数据集,并对样本进行相应的处理。
     """读取COCO格式的检测数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
@@ -35,7 +35,7 @@ class COCODetection(Dataset):
         image_dir (str): 描述数据集图片文件路径。
         image_dir (str): 描述数据集图片文件路径。
         anno_path (str): COCO标注文件路径。
         anno_path (str): COCO标注文件路径。
         label_list (str): 描述数据集包含的类别信息文件路径。
         label_list (str): 描述数据集包含的类别信息文件路径。
-        transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
             一半。
@@ -60,10 +60,10 @@ class COCODetection(Dataset):
         import matplotlib
         import matplotlib
         matplotlib.use('Agg')
         matplotlib.use('Agg')
         from pycocotools.coco import COCO
         from pycocotools.coco import COCO
-        super(COCODetection, self).__init__()
-        self.data_dir = data_dir
+        super(COCODetection, self).__init__(data_dir, label_list, transforms,
+                                            num_workers, shuffle)
+
         self.data_fields = None
         self.data_fields = None
-        self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
         self.num_max_boxes = 50
 
 
         self.use_mix = False
         self.use_mix = False
@@ -76,8 +76,6 @@ class COCODetection(Dataset):
                     break
                     break
 
 
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.allow_empty = allow_empty
         self.allow_empty = allow_empty
         self.empty_ratio = empty_ratio
         self.empty_ratio = empty_ratio
         self.file_list = list()
         self.file_list = list()

+ 9 - 10
paddlers/datasets/seg_dataset.py

@@ -15,20 +15,21 @@
 import os.path as osp
 import os.path as osp
 import copy
 import copy
 
 
-from paddle.io import Dataset
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, path_normalization, is_pic
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
-
-class SegDataset(Dataset):
+class SegDataset(BaseDataset):
     """读取语义分割任务数据集,并对样本进行相应的处理。
     """读取语义分割任务数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
         data_dir (str): 数据集所在的目录路径。
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
         label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
-        transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
-        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 
@@ -39,12 +40,10 @@ class SegDataset(Dataset):
                  transforms=None,
                  transforms=None,
                  num_workers='auto',
                  num_workers='auto',
                  shuffle=False):
                  shuffle=False):
-        super(SegDataset, self).__init__()
-        self.transforms = copy.deepcopy(transforms)
+        super(SegDataset, self).__init__(data_dir, label_list, transforms,
+                                         num_workers, shuffle)
         # TODO batch padding
         # TODO batch padding
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.file_list = list()
         self.file_list = list()
         self.labels = list()
         self.labels = list()
 
 

+ 7 - 9
paddlers/datasets/voc.py

@@ -22,21 +22,21 @@ from collections import OrderedDict
 import xml.etree.ElementTree as ET
 import xml.etree.ElementTree as ET
 
 
 import numpy as np
 import numpy as np
-from paddle.io import Dataset
 
 
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
+from .base import BaseDataset
+from paddlers.utils import logging, get_encoding, path_normalization, is_pic
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.transforms import DecodeImg, MixupImage
 from paddlers.tools import YOLOAnchorCluster
 from paddlers.tools import YOLOAnchorCluster
 
 
 
 
-class VOCDetection(Dataset):
+class VOCDetection(BaseDataset):
     """读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
     """读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
 
 
     Args:
     Args:
         data_dir (str): 数据集所在的目录路径。
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
         label_list (str): 描述数据集包含的类别信息文件路径。
         label_list (str): 描述数据集包含的类别信息文件路径。
-        transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子。
+        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
             一半。
@@ -60,10 +60,10 @@ class VOCDetection(Dataset):
         import matplotlib
         import matplotlib
         matplotlib.use('Agg')
         matplotlib.use('Agg')
         from pycocotools.coco import COCO
         from pycocotools.coco import COCO
-        super(VOCDetection, self).__init__()
-        self.data_dir = data_dir
+        super(VOCDetection, self).__init__(data_dir, label_list, transforms,
+                                           num_workers, shuffle)
+
         self.data_fields = None
         self.data_fields = None
-        self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
         self.num_max_boxes = 50
 
 
         self.use_mix = False
         self.use_mix = False
@@ -76,8 +76,6 @@ class VOCDetection(Dataset):
                     break
                     break
 
 
         self.batch_transforms = None
         self.batch_transforms = None
-        self.num_workers = get_num_workers(num_workers)
-        self.shuffle = shuffle
         self.allow_empty = allow_empty
         self.allow_empty = allow_empty
         self.empty_ratio = empty_ratio
         self.empty_ratio = empty_ratio
         self.file_list = list()
         self.file_list = list()