prepare_isaid.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python
  2. import os.path as osp
  3. from glob import glob
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from common import (get_default_parser, add_crop_options, crop_patches,
  7. create_file_list, copy_dataset, create_label_list,
  8. get_path_tuples)
  9. # According to the official doc(https://github.com/CAPTAIN-WHU/iSAID_Devkit),
  10. # the files should be organized as follows:
  11. #
  12. # iSAID
  13. # ├── test
  14. # │   └── images
  15. # │   ├── P0006.png
  16. # │   └── ...
  17. # │   └── P0009.png
  18. # ├── train
  19. # │   └── images
  20. # │   ├── P0002_instance_color_RGB.png
  21. # │   ├── P0002_instance_id_RGB.png
  22. # │   ├── P0002.png
  23. # │   ├── ...
  24. # │   ├── P0010_instance_color_RGB.png
  25. # │   ├── P0010_instance_id_RGB.png
  26. # │   └── P0010.png
  27. # └── val
  28. # └── images
  29. # ├── P0003_instance_color_RGB.png
  30. # ├── P0003_instance_id_RGB.png
  31. # ├── P0003.png
  32. # ├── ...
  33. # ├── P0004_instance_color_RGB.png
  34. # ├── P0004_instance_id_RGB.png
  35. # └── P0004.png
  36. CLASSES = ('background', 'ship', 'storage_tank', 'baseball_diamond',
  37. 'tennis_court', 'basketball_court', 'ground_track_field', 'bridge',
  38. 'large_vehicle', 'small_vehicle', 'helicopter', 'swimming_pool',
  39. 'roundabout', 'soccer_ball_field', 'plane', 'harbor')
  40. # Refer to https://github.com/Z-Zheng/FarSeg/blob/master/data/isaid.py
  41. COLOR_MAP = [[0, 0, 0], [0, 0, 63], [0, 191, 127], [0, 63, 0], [0, 63, 127],
  42. [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
  43. [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 63, 63], [0, 127, 191],
  44. [0, 127, 255], [0, 100, 155]]
  45. SUBSETS = ('train', 'val')
  46. SUBDIR = 'images'
  47. FILE_LIST_PATTERN = "{subset}.txt"
  48. LABEL_LIST_NAME = "labels.txt"
  49. URL = ""
  50. def flatten(nested_list):
  51. flattened_list = []
  52. for ele in nested_list:
  53. if isinstance(ele, list):
  54. flattened_list.extend(flatten(ele))
  55. else:
  56. flattened_list.append(ele)
  57. return flattened_list
  58. def rgb2mask(rgb):
  59. palette = flatten(COLOR_MAP)
  60. # Pad with zero
  61. palette = palette + [0] * (256 * 3 - len(palette))
  62. ref = Image.new(mode='P', size=(1, 1))
  63. ref.putpalette(palette)
  64. mask = rgb.quantize(palette=ref, dither=0)
  65. return mask
  66. if __name__ == '__main__':
  67. parser = get_default_parser()
  68. parser.add_argument(
  69. '--crop_size', type=int, help="Size of cropped patches.", default=800)
  70. parser.add_argument(
  71. '--crop_stride',
  72. type=int,
  73. help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
  74. default=600)
  75. args = parser.parse_args()
  76. out_dir = osp.join(args.out_dataset_dir,
  77. osp.basename(osp.normpath(args.in_dataset_dir)))
  78. assert args.crop_size is not None
  79. # According to https://github.com/CAPTAIN-WHU/iSAID_Devkit/blob/master/preprocess/split.py
  80. # Set keep_last=True
  81. crop_patches(
  82. args.crop_size,
  83. args.crop_stride,
  84. data_dir=args.in_dataset_dir,
  85. out_dir=out_dir,
  86. subsets=SUBSETS,
  87. subdirs=(SUBDIR, ),
  88. glob_pattern='*.png',
  89. max_workers=8,
  90. keep_last=True)
  91. for subset in SUBSETS:
  92. path_tuples = []
  93. print(f"Processing {subset} labels...")
  94. for im_subdir in tqdm(glob(osp.join(out_dir, subset, SUBDIR, "*/"))):
  95. im_name = osp.basename(im_subdir[:-1]) # Strip trailing '/'
  96. if '_' in im_name:
  97. # Do not process labels
  98. continue
  99. mask_subdir = osp.join(out_dir, subset, SUBDIR,
  100. im_name + '_instance_color_RGB')
  101. for mask_path in glob(osp.join(mask_subdir, '*.png')):
  102. # Convert RGB files to mask files (pseudo color)
  103. rgb = Image.open(mask_path).convert('RGB')
  104. mask = rgb2mask(rgb)
  105. # Write to the original location
  106. mask.save(mask_path)
  107. path_tuples.extend(
  108. get_path_tuples(
  109. im_subdir,
  110. mask_subdir,
  111. glob_pattern='*.png',
  112. data_dir=args.out_dataset_dir))
  113. path_tuples.sort()
  114. file_list = osp.join(
  115. args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
  116. create_file_list(file_list, path_tuples)
  117. print(f"Write file list to {file_list}.")
  118. label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME)
  119. create_label_list(label_list, CLASSES)
  120. print(f"Write label list to {label_list}.")