prepare_isaid_c2fnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Adapted from https://github.com/open-mmlab/mmsegmentation/blob/master/tools/convert_datasets/isaid.py
  15. #
  16. # Original copyright info:
  17. # Copyright (c) OpenMMLab. All rights reserved.
  18. #
  19. # See original LICENSE at https://github.com/open-mmlab/mmsegmentation/blob/master/LICENSE
  20. import argparse
  21. import glob
  22. import os
  23. import os.path as osp
  24. import shutil
  25. import tempfile
  26. import zipfile
  27. import cv2
  28. import numpy as np
  29. from tqdm import tqdm
  30. from PIL import Image
  31. iSAID_palette = \
  32. {
  33. 0: (0, 0, 0),
  34. 1: (0, 0, 63),
  35. 2: (0, 63, 63),
  36. 3: (0, 63, 0),
  37. 4: (0, 63, 127),
  38. 5: (0, 63, 191),
  39. 6: (0, 63, 255),
  40. 7: (0, 127, 63),
  41. 8: (0, 127, 127),
  42. 9: (0, 0, 127),
  43. 10: (0, 0, 191),
  44. 11: (0, 0, 255),
  45. 12: (0, 191, 127),
  46. 13: (0, 127, 191),
  47. 14: (0, 127, 255),
  48. 15: (0, 100, 155)
  49. }
  50. iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
  51. def mkdir_or_exist(dir_name, mode=0o777):
  52. if dir_name == '':
  53. return
  54. dir_name = osp.expanduser(dir_name)
  55. os.makedirs(dir_name, mode=mode, exist_ok=True)
  56. def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
  57. """RGB-color encoding to grayscale labels."""
  58. arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
  59. for c, i in palette.items():
  60. m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
  61. arr_2d[m] = i
  62. return arr_2d
  63. def pad(img, shape=None, padding=None, pad_val=0, padding_mode='constant'):
  64. assert (shape is not None) ^ (padding is not None)
  65. if shape is not None:
  66. width = max(shape[1] - img.shape[1], 0)
  67. height = max(shape[0] - img.shape[0], 0)
  68. padding = (0, 0, width, height)
  69. # Check pad_val
  70. if isinstance(pad_val, tuple):
  71. assert len(pad_val) == img.shape[-1]
  72. elif not isinstance(pad_val, numbers.Number):
  73. raise TypeError('pad_val must be a int or a tuple. '
  74. f'But received {type(pad_val)}')
  75. # Check padding
  76. if isinstance(padding, tuple) and len(padding) in [2, 4]:
  77. if len(padding) == 2:
  78. padding = (padding[0], padding[1], padding[0], padding[1])
  79. elif isinstance(padding, numbers.Number):
  80. padding = (padding, padding, padding, padding)
  81. else:
  82. raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
  83. f'But received {padding}')
  84. # Check padding mode
  85. assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
  86. border_type = {
  87. 'constant': cv2.BORDER_CONSTANT,
  88. 'edge': cv2.BORDER_REPLICATE,
  89. 'reflect': cv2.BORDER_REFLECT_101,
  90. 'symmetric': cv2.BORDER_REFLECT
  91. }
  92. img = cv2.copyMakeBorder(
  93. img,
  94. padding[1],
  95. padding[3],
  96. padding[0],
  97. padding[2],
  98. border_type[padding_mode],
  99. value=pad_val)
  100. return img
  101. def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
  102. img = np.asarray(Image.open(src_path).convert('RGB'))
  103. img_H, img_W, _ = img.shape
  104. if img_H < patch_H and img_W > patch_W:
  105. img = pad(img, shape=(patch_H, img_W), pad_val=0)
  106. img_H, img_W, _ = img.shape
  107. elif img_H > patch_H and img_W < patch_W:
  108. img = pad(img, shape=(img_H, patch_W), pad_val=0)
  109. img_H, img_W, _ = img.shape
  110. elif img_H < patch_H and img_W < patch_W:
  111. img = pad(img, shape=(patch_H, patch_W), pad_val=0)
  112. img_H, img_W, _ = img.shape
  113. for x in range(0, img_W, patch_W - overlap):
  114. for y in range(0, img_H, patch_H - overlap):
  115. x_str = x
  116. x_end = x + patch_W
  117. if x_end > img_W:
  118. diff_x = x_end - img_W
  119. x_str -= diff_x
  120. x_end = img_W
  121. y_str = y
  122. y_end = y + patch_H
  123. if y_end > img_H:
  124. diff_y = y_end - img_H
  125. y_str -= diff_y
  126. y_end = img_H
  127. img_patch = img[y_str:y_end, x_str:x_end, :]
  128. img_patch = Image.fromarray(img_patch.astype(np.uint8))
  129. image = osp.basename(src_path).split('.')[0] + '_' + str(
  130. y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
  131. x_end) + '.png'
  132. # print(image)
  133. save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
  134. img_patch.save(save_path_image)
  135. def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
  136. label = Image.open(src_path).convert('RGB')
  137. label = np.asarray(label)
  138. label = iSAID_convert_from_color(label)
  139. img_H, img_W = label.shape
  140. if img_H < patch_H and img_W > patch_W:
  141. label = pad(label, shape=(patch_H, img_W), pad_val=255)
  142. img_H = patch_H
  143. elif img_H > patch_H and img_W < patch_W:
  144. label = pad(label, shape=(img_H, patch_W), pad_val=255)
  145. img_W = patch_W
  146. elif img_H < patch_H and img_W < patch_W:
  147. label = pad(label, shape=(patch_H, patch_W), pad_val=255)
  148. img_H = patch_H
  149. img_W = patch_W
  150. for x in range(0, img_W, patch_W - overlap):
  151. for y in range(0, img_H, patch_H - overlap):
  152. x_str = x
  153. x_end = x + patch_W
  154. if x_end > img_W:
  155. diff_x = x_end - img_W
  156. x_str -= diff_x
  157. x_end = img_W
  158. y_str = y
  159. y_end = y + patch_H
  160. if y_end > img_H:
  161. diff_y = y_end - img_H
  162. y_str -= diff_y
  163. y_end = img_H
  164. lab_patch = label[y_str:y_end, x_str:x_end]
  165. lab_patch = Image.fromarray(lab_patch.astype(np.uint8))
  166. image = osp.basename(src_path).split('.')[0].split('_')[
  167. 0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
  168. x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
  169. lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
  170. def parse_args():
  171. parser = argparse.ArgumentParser()
  172. parser.add_argument('dataset_path', help='Path of raw iSAID dataset.')
  173. parser.add_argument('--tmp_dir', help='Path of the temporary directory.')
  174. parser.add_argument('-o', '--out_dir', help='Output path.')
  175. parser.add_argument(
  176. '--patch_width',
  177. default=896,
  178. type=int,
  179. help='Width of the cropped image patch.')
  180. parser.add_argument(
  181. '--patch_height',
  182. default=896,
  183. type=int,
  184. help='Height of the cropped image patch.')
  185. parser.add_argument(
  186. '--overlap_area', default=384, type=int, help='Overlap area.')
  187. args = parser.parse_args()
  188. return args
  189. if __name__ == '__main__':
  190. args = parse_args()
  191. dataset_path = args.dataset_path
  192. # image patch width and height
  193. patch_H, patch_W = args.patch_width, args.patch_height
  194. overlap = args.overlap_area # overlap area
  195. if args.out_dir is None:
  196. out_dir = osp.join('data', 'iSAID')
  197. else:
  198. out_dir = args.out_dir
  199. print('Creating directories...')
  200. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
  201. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
  202. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
  203. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
  204. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
  205. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
  206. assert os.path.exists(os.path.join(dataset_path, 'train')), \
  207. 'train is not in {}'.format(dataset_path)
  208. assert os.path.exists(os.path.join(dataset_path, 'val')), \
  209. 'val is not in {}'.format(dataset_path)
  210. assert os.path.exists(os.path.join(dataset_path, 'test')), \
  211. 'test is not in {}'.format(dataset_path)
  212. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  213. for dataset_mode in ['train', 'val', 'test']:
  214. # for dataset_mode in [ 'test']:
  215. print('Extracting {}ing.zip...'.format(dataset_mode))
  216. img_zipp_list = glob.glob(
  217. os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
  218. print('Find the data', img_zipp_list)
  219. for img_zipp in img_zipp_list:
  220. zip_file = zipfile.ZipFile(img_zipp)
  221. zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
  222. src_path_list = glob.glob(
  223. os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
  224. for i, img_path in enumerate(tqdm(src_path_list)):
  225. if dataset_mode != 'test':
  226. slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
  227. patch_W, overlap)
  228. else:
  229. shutil.move(img_path,
  230. os.path.join(out_dir, 'img_dir', dataset_mode))
  231. if dataset_mode != 'test':
  232. label_zipp_list = glob.glob(
  233. os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
  234. '*.zip'))
  235. for label_zipp in label_zipp_list:
  236. zip_file = zipfile.ZipFile(label_zipp)
  237. zip_file.extractall(
  238. os.path.join(tmp_dir, dataset_mode, 'lab'))
  239. lab_path_list = glob.glob(
  240. os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
  241. '*.png'))
  242. for i, lab_path in enumerate(tqdm(lab_path_list)):
  243. slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
  244. patch_W, overlap)
  245. print('Removing the temporary files...')
  246. print('Done!')