cd_dataset.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from enum import IntEnum
  16. import os.path as osp
  17. from paddle.io import Dataset
  18. from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
  19. class CDDataset(Dataset):
  20. """
  21. 读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
  22. Args:
  23. data_dir (str): 数据集所在的目录路径。
  24. file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
  25. label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
  26. transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
  27. num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
  28. shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
  29. with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
  30. """
  31. def __init__(self,
  32. data_dir,
  33. file_list,
  34. label_list=None,
  35. transforms=None,
  36. num_workers='auto',
  37. shuffle=False,
  38. with_seg_labels=False):
  39. super(CDDataset, self).__init__()
  40. DELIMETER = ' '
  41. self.transforms = copy.deepcopy(transforms)
  42. # TODO: batch padding
  43. self.batch_transforms = None
  44. self.num_workers = get_num_workers(num_workers)
  45. self.shuffle = shuffle
  46. self.file_list = list()
  47. self.labels = list()
  48. self.with_seg_labels = with_seg_labels
  49. if self.with_seg_labels:
  50. num_items = 5 # 3+2
  51. else:
  52. num_items = 3
  53. # TODO:非None时,让用户跳转数据集分析生成label_list
  54. # 不要在此处分析label file
  55. if label_list is not None:
  56. with open(label_list, encoding=get_encoding(label_list)) as f:
  57. for line in f:
  58. item = line.strip()
  59. self.labels.append(item)
  60. with open(file_list, encoding=get_encoding(file_list)) as f:
  61. for line in f:
  62. items = line.strip().split(DELIMETER)
  63. if len(items) != num_items:
  64. raise Exception("Line[{}] in file_list[{}] has an incorrect number of file paths.".format(
  65. line.strip(), file_list
  66. ))
  67. items = list(map(path_normalization, items))
  68. if not all(map(is_pic, items)):
  69. continue
  70. full_path_im_t1 = osp.join(data_dir, items[0])
  71. full_path_im_t2 = osp.join(data_dir, items[1])
  72. full_path_label = osp.join(data_dir, items[2])
  73. if not osp.exists(full_path_im_t1):
  74. raise IOError('Image file {} does not exist!'.format(
  75. full_path_im_t1))
  76. if not osp.exists(full_path_im_t2):
  77. raise IOError('Image file {} does not exist!'.format(
  78. full_path_im_t2))
  79. if not osp.exists(full_path_label):
  80. raise IOError('Label file {} does not exist!'.format(
  81. full_path_label))
  82. if with_seg_labels:
  83. full_path_seg_label_t1 = osp.join(data_dir, items[3])
  84. full_path_seg_label_t2 = osp.join(data_dir, items[4])
  85. if not osp.exists(full_path_seg_label_t1):
  86. raise IOError('Label file {} does not exist!'.format(
  87. full_path_seg_label_t1))
  88. if not osp.exists(full_path_seg_label_t2):
  89. raise IOError('Label file {} does not exist!'.format(
  90. full_path_seg_label_t2))
  91. item_dict = dict(
  92. image_t1=full_path_im_t1,
  93. image_t2=full_path_im_t2,
  94. mask=full_path_label
  95. )
  96. if with_seg_labels:
  97. item_dict['aux_masks'] = [full_path_seg_label_t1, full_path_seg_label_t2]
  98. self.file_list.append(item_dict)
  99. self.num_samples = len(self.file_list)
  100. logging.info("{} samples in file {}".format(
  101. len(self.file_list), file_list))
  102. def __getitem__(self, idx):
  103. sample = copy.deepcopy(self.file_list[idx])
  104. outputs = self.transforms(sample)
  105. return outputs
  106. def __len__(self):
  107. return len(self.file_list)
  108. class MaskType(IntEnum):
  109. """Enumeration of the mask types used in the change detection task."""
  110. CD = 0
  111. SEG_T1 = 1
  112. SEG_T2 = 2