prepare_isaid.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #!/usr/bin/env python
  2. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os.path as osp
  16. from glob import glob
  17. from PIL import Image
  18. from tqdm import tqdm
  19. from common import (get_default_parser, add_crop_options, crop_patches,
  20. create_file_list, copy_dataset, create_label_list,
  21. get_path_tuples)
  22. # According to the official doc(https://github.com/CAPTAIN-WHU/iSAID_Devkit),
  23. # the files should be organized as follows:
  24. #
  25. # iSAID
  26. # ├── test
  27. # │   └── images
  28. # │   ├── P0006.png
  29. # │   ├── ...
  30. # │   └── P0009.png
  31. # ├── train
  32. # │   └── images
  33. # │   ├── P0002_instance_color_RGB.png
  34. # │   ├── P0002_instance_id_RGB.png
  35. # │   ├── P0002.png
  36. # │   ├── ...
  37. # │   ├── P0010_instance_color_RGB.png
  38. # │   ├── P0010_instance_id_RGB.png
  39. # │   └── P0010.png
  40. # └── val
  41. # └── images
  42. # ├── P0003_instance_color_RGB.png
  43. # ├── P0003_instance_id_RGB.png
  44. # ├── P0003.png
  45. # ├── ...
  46. # ├── P0004_instance_color_RGB.png
  47. # ├── P0004_instance_id_RGB.png
  48. # └── P0004.png
  49. CLASSES = ('background', 'ship', 'storage_tank', 'baseball_diamond',
  50. 'tennis_court', 'basketball_court', 'ground_track_field', 'bridge',
  51. 'large_vehicle', 'small_vehicle', 'helicopter', 'swimming_pool',
  52. 'roundabout', 'soccer_ball_field', 'plane', 'harbor')
  53. # Refer to https://github.com/Z-Zheng/FarSeg/blob/master/data/isaid.py
  54. COLOR_MAP = [[0, 0, 0], [0, 0, 63], [0, 191, 127], [0, 63, 0], [0, 63, 127],
  55. [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
  56. [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 63, 63], [0, 127, 191],
  57. [0, 127, 255], [0, 100, 155]]
  58. SUBSETS = ('train', 'val')
  59. SUBDIR = 'images'
  60. FILE_LIST_PATTERN = "{subset}.txt"
  61. LABEL_LIST_NAME = "labels.txt"
  62. URL = ""
  63. def flatten(nested_list):
  64. flattened_list = []
  65. for ele in nested_list:
  66. if isinstance(ele, list):
  67. flattened_list.extend(flatten(ele))
  68. else:
  69. flattened_list.append(ele)
  70. return flattened_list
  71. def rgb2mask(rgb):
  72. palette = flatten(COLOR_MAP)
  73. # Pad with zero
  74. palette = palette + [0] * (256 * 3 - len(palette))
  75. ref = Image.new(mode='P', size=(1, 1))
  76. ref.putpalette(palette)
  77. mask = rgb.quantize(palette=ref, dither=0)
  78. return mask
  79. if __name__ == '__main__':
  80. parser = get_default_parser()
  81. parser.add_argument(
  82. '--crop_size', type=int, help="Size of cropped patches.", default=800)
  83. parser.add_argument(
  84. '--crop_stride',
  85. type=int,
  86. help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
  87. default=600)
  88. args = parser.parse_args()
  89. out_dir = osp.join(args.out_dataset_dir,
  90. osp.basename(osp.normpath(args.in_dataset_dir)))
  91. assert args.crop_size is not None
  92. # According to https://github.com/CAPTAIN-WHU/iSAID_Devkit/blob/master/preprocess/split.py
  93. # Set keep_last=True
  94. crop_patches(
  95. args.crop_size,
  96. args.crop_stride,
  97. data_dir=args.in_dataset_dir,
  98. out_dir=out_dir,
  99. subsets=SUBSETS,
  100. subdirs=(SUBDIR, ),
  101. glob_pattern='*.png',
  102. max_workers=8,
  103. keep_last=True)
  104. for subset in SUBSETS:
  105. path_tuples = []
  106. print(f"Processing {subset} labels...")
  107. for im_subdir in tqdm(glob(osp.join(out_dir, subset, SUBDIR, "*/"))):
  108. im_name = osp.basename(im_subdir[:-1]) # Strip trailing '/'
  109. if '_' in im_name:
  110. # Do not process labels
  111. continue
  112. mask_subdir = osp.join(out_dir, subset, SUBDIR,
  113. im_name + '_instance_color_RGB')
  114. for mask_path in glob(osp.join(mask_subdir, '*.png')):
  115. # Convert RGB files to mask files (pseudo color)
  116. rgb = Image.open(mask_path).convert('RGB')
  117. mask = rgb2mask(rgb)
  118. # Write to the original location
  119. mask.save(mask_path)
  120. path_tuples.extend(
  121. get_path_tuples(
  122. im_subdir,
  123. mask_subdir,
  124. glob_pattern='*.png',
  125. data_dir=args.out_dataset_dir))
  126. path_tuples.sort()
  127. file_list = osp.join(
  128. args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
  129. create_file_list(file_list, path_tuples)
  130. print(f"Write file list to {file_list}.")
  131. label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME)
  132. create_label_list(label_list, CLASSES)
  133. print(f"Write label list to {label_list}.")