prepare_dota.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import copy
  16. import json
  17. import math
  18. import os
  19. from multiprocessing import Pool
  20. from numbers import Number
  21. import cv2
  22. import numpy as np
  23. import shapely.geometry as shgeo
  24. from tqdm import tqdm
  25. from common import add_crop_options
  26. wordname_15 = [
  27. 'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
  28. 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
  29. 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
  30. 'harbor', 'swimming-pool', 'helicopter'
  31. ]
  32. wordname_16 = wordname_15 + ['container-crane']
  33. wordname_18 = wordname_16 + ['airport', 'helipad']
  34. DATA_CLASSES = {
  35. 'dota10': wordname_15,
  36. 'dota15': wordname_16,
  37. 'dota20': wordname_18
  38. }
  39. def parse_args():
  40. parser = argparse.ArgumentParser()
  41. parser.add_argument(
  42. '--in_dataset_dir',
  43. type=str,
  44. nargs='+',
  45. required=True,
  46. help="Input dataset directories.")
  47. parser.add_argument(
  48. '--out_dataset_dir', type=str, help="Output dataset directory.")
  49. parser = add_crop_options(parser)
  50. parser.add_argument(
  51. '--coco_json_file',
  52. type=str,
  53. default='',
  54. help="COCO JSON annotation files.")
  55. parser.add_argument(
  56. '--rates',
  57. nargs='+',
  58. type=float,
  59. default=[1.],
  60. help="Scales for cropping multi-scale samples.")
  61. parser.add_argument(
  62. '--nproc', type=int, default=8, help="Number of processes to use.")
  63. parser.add_argument(
  64. '--iof_thr',
  65. type=float,
  66. default=0.5,
  67. help="Minimal IoF between an object and a window.")
  68. parser.add_argument(
  69. '--image_only',
  70. action='store_true',
  71. default=False,
  72. help="To process images only.")
  73. parser.add_argument(
  74. '--data_type', type=str, default='dota10', help="Type of dataset.")
  75. args = parser.parse_args()
  76. return args
  77. def load_dota_info(image_dir, anno_dir, file_name, ext=None):
  78. base_name, extension = os.path.splitext(file_name)
  79. if ext and (extension != ext and extension not in ext):
  80. return None
  81. info = {'image_file': os.path.join(image_dir, file_name), 'annotation': []}
  82. anno_file = os.path.join(anno_dir, base_name + '.txt')
  83. if not os.path.exists(anno_file):
  84. return info
  85. with open(anno_file, 'r') as f:
  86. for line in f:
  87. items = line.strip().split()
  88. if (len(items) < 9):
  89. continue
  90. anno = {
  91. 'poly': list(map(float, items[:8])),
  92. 'name': items[8],
  93. 'difficult': '0' if len(items) == 9 else items[9],
  94. }
  95. info['annotation'].append(anno)
  96. return info
  97. def load_dota_infos(root_dir, num_process=8, ext=None):
  98. image_dir = os.path.join(root_dir, 'images')
  99. anno_dir = os.path.join(root_dir, 'labelTxt')
  100. data_infos = []
  101. if num_process > 1:
  102. pool = Pool(num_process)
  103. results = []
  104. for file_name in os.listdir(image_dir):
  105. results.append(
  106. pool.apply_async(load_dota_info, (image_dir, anno_dir,
  107. file_name, ext)))
  108. pool.close()
  109. pool.join()
  110. for result in results:
  111. info = result.get()
  112. if info:
  113. data_infos.append(info)
  114. else:
  115. for file_name in os.listdir(image_dir):
  116. info = load_dota_info(image_dir, anno_dir, file_name, ext)
  117. if info:
  118. data_infos.append(info)
  119. return data_infos
  120. def process_single_sample(info, image_id, class_names):
  121. image_file = info['image_file']
  122. single_image = dict()
  123. single_image['file_name'] = os.path.split(image_file)[-1]
  124. single_image['id'] = image_id
  125. image = cv2.imread(image_file)
  126. height, width, _ = image.shape
  127. single_image['width'] = width
  128. single_image['height'] = height
  129. # process annotation field
  130. single_objs = []
  131. objects = info['annotation']
  132. for obj in objects:
  133. poly, name, difficult = obj['poly'], obj['name'], obj['difficult']
  134. if difficult == '2':
  135. continue
  136. single_obj = dict()
  137. single_obj['category_id'] = class_names.index(name) + 1
  138. single_obj['segmentation'] = [poly]
  139. single_obj['iscrowd'] = 0
  140. xmin, ymin, xmax, ymax = min(poly[0::2]), min(poly[1::2]), max(poly[
  141. 0::2]), max(poly[1::2])
  142. width, height = xmax - xmin, ymax - ymin
  143. single_obj['bbox'] = [xmin, ymin, width, height]
  144. single_obj['area'] = height * width
  145. single_obj['image_id'] = image_id
  146. single_objs.append(single_obj)
  147. return (single_image, single_objs)
  148. def data_to_coco(infos, output_path, class_names, num_process):
  149. data_dict = dict()
  150. data_dict['categories'] = []
  151. for i, name in enumerate(class_names):
  152. data_dict['categories'].append({
  153. 'id': i + 1,
  154. 'name': name,
  155. 'supercategory': name
  156. })
  157. pbar = tqdm(total=len(infos), desc='data to coco')
  158. images, annotations = [], []
  159. if num_process > 1:
  160. pool = Pool(num_process)
  161. results = []
  162. for i, info in enumerate(infos):
  163. image_id = i + 1
  164. results.append(
  165. pool.apply_async(
  166. process_single_sample, (info, image_id, class_names),
  167. callback=lambda x: pbar.update()))
  168. pool.close()
  169. pool.join()
  170. for result in results:
  171. single_image, single_anno = result.get()
  172. images.append(single_image)
  173. annotations += single_anno
  174. else:
  175. for i, info in enumerate(infos):
  176. image_id = i + 1
  177. single_image, single_anno = process_single_sample(info, image_id,
  178. class_names)
  179. images.append(single_image)
  180. annotations += single_anno
  181. pbar.update()
  182. pbar.close()
  183. for i, anno in enumerate(annotations):
  184. anno['id'] = i + 1
  185. data_dict['images'] = images
  186. data_dict['annotations'] = annotations
  187. with open(output_path, 'w') as f:
  188. json.dump(data_dict, f)
  189. def choose_best_pointorder_fit_another(poly1, poly2):
  190. """
  191. To make the two polygons best fit with each point
  192. """
  193. x1, y1, x2, y2, x3, y3, x4, y4 = poly1
  194. combinate = [
  195. np.array([x1, y1, x2, y2, x3, y3, x4, y4]),
  196. np.array([x2, y2, x3, y3, x4, y4, x1, y1]),
  197. np.array([x3, y3, x4, y4, x1, y1, x2, y2]),
  198. np.array([x4, y4, x1, y1, x2, y2, x3, y3])
  199. ]
  200. dst_coordinate = np.array(poly2)
  201. distances = np.array(
  202. [np.sum((coord - dst_coordinate)**2) for coord in combinate])
  203. sorted = distances.argsort()
  204. return combinate[sorted[0]]
  205. def cal_line_length(point1, point2):
  206. return math.sqrt(
  207. math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))
  208. class SliceBase(object):
  209. def __init__(self,
  210. gap=512,
  211. subsize=1024,
  212. thresh=0.7,
  213. choosebestpoint=True,
  214. ext='.png',
  215. padding=True,
  216. num_process=8,
  217. image_only=False):
  218. self.gap = gap
  219. self.subsize = subsize
  220. self.slide = subsize - gap
  221. self.thresh = thresh
  222. self.choosebestpoint = choosebestpoint
  223. self.ext = ext
  224. self.padding = padding
  225. self.num_process = num_process
  226. self.image_only = image_only
  227. def get_windows(self, height, width):
  228. windows = []
  229. left, up = 0, 0
  230. while (left < width):
  231. if (left + self.subsize >= width):
  232. left = max(width - self.subsize, 0)
  233. up = 0
  234. while (up < height):
  235. if (up + self.subsize >= height):
  236. up = max(height - self.subsize, 0)
  237. right = min(left + self.subsize, width - 1)
  238. down = min(up + self.subsize, height - 1)
  239. windows.append((left, up, right, down))
  240. if (up + self.subsize >= height):
  241. break
  242. else:
  243. up = up + self.slide
  244. if (left + self.subsize >= width):
  245. break
  246. else:
  247. left = left + self.slide
  248. return windows
  249. def slice_image_single(self, image, windows, output_dir, output_name):
  250. image_dir = os.path.join(output_dir, 'images')
  251. for (left, up, right, down) in windows:
  252. image_name = output_name + str(left) + '___' + str(up) + self.ext
  253. subimg = copy.deepcopy(image[up:up + self.subsize, left:left +
  254. self.subsize])
  255. h, w, c = subimg.shape
  256. if (self.padding):
  257. outimg = np.zeros((self.subsize, self.subsize, 3))
  258. outimg[0:h, 0:w, :] = subimg
  259. cv2.imwrite(os.path.join(image_dir, image_name), outimg)
  260. else:
  261. cv2.imwrite(os.path.join(image_dir, image_name), subimg)
  262. def iof(self, poly1, poly2):
  263. inter_poly = poly1.intersection(poly2)
  264. inter_area = inter_poly.area
  265. poly1_area = poly1.area
  266. half_iou = inter_area / poly1_area
  267. return inter_poly, half_iou
  268. def translate(self, poly, left, up):
  269. n = len(poly)
  270. out_poly = np.zeros(n)
  271. for i in range(n // 2):
  272. out_poly[i * 2] = int(poly[i * 2] - left)
  273. out_poly[i * 2 + 1] = int(poly[i * 2 + 1] - up)
  274. return out_poly
  275. def get_poly4_from_poly5(self, poly):
  276. distances = [
  277. cal_line_length((poly[i * 2], poly[i * 2 + 1]),
  278. (poly[(i + 1) * 2], poly[(i + 1) * 2 + 1]))
  279. for i in range(int(len(poly) / 2 - 1))
  280. ]
  281. distances.append(
  282. cal_line_length((poly[0], poly[1]), (poly[8], poly[9])))
  283. pos = np.array(distances).argsort()[0]
  284. count = 0
  285. out_poly = []
  286. while count < 5:
  287. if (count == pos):
  288. out_poly.append(
  289. (poly[count * 2] + poly[(count * 2 + 2) % 10]) / 2)
  290. out_poly.append(
  291. (poly[(count * 2 + 1) % 10] + poly[(count * 2 + 3) % 10]) /
  292. 2)
  293. count = count + 1
  294. elif (count == (pos + 1) % 5):
  295. count = count + 1
  296. continue
  297. else:
  298. out_poly.append(poly[count * 2])
  299. out_poly.append(poly[count * 2 + 1])
  300. count = count + 1
  301. return out_poly
  302. def slice_anno_single(self, annos, windows, output_dir, output_name):
  303. anno_dir = os.path.join(output_dir, 'labelTxt')
  304. for (left, up, right, down) in windows:
  305. image_poly = shgeo.Polygon(
  306. [(left, up), (right, up), (right, down), (left, down)])
  307. anno_file = output_name + str(left) + '___' + str(up) + '.txt'
  308. with open(os.path.join(anno_dir, anno_file), 'w') as f:
  309. for anno in annos:
  310. gt_poly = shgeo.Polygon(
  311. [(anno['poly'][0], anno['poly'][1]),
  312. (anno['poly'][2], anno['poly'][3]),
  313. (anno['poly'][4], anno['poly'][5]),
  314. (anno['poly'][6], anno['poly'][7])])
  315. if gt_poly.area <= 0:
  316. continue
  317. inter_poly, iof = self.iof(gt_poly, image_poly)
  318. if iof == 1:
  319. final_poly = self.translate(anno['poly'], left, up)
  320. elif iof > 0:
  321. inter_poly = shgeo.polygon.orient(inter_poly, sign=1)
  322. out_poly = list(inter_poly.exterior.coords)[0:-1]
  323. if len(out_poly) < 4 or len(out_poly) > 5:
  324. continue
  325. final_poly = []
  326. for p in out_poly:
  327. final_poly.append(p[0])
  328. final_poly.append(p[1])
  329. if len(out_poly) == 5:
  330. final_poly = self.get_poly4_from_poly5(final_poly)
  331. if self.choosebestpoint:
  332. final_poly = choose_best_pointorder_fit_another(
  333. final_poly, anno['poly'])
  334. final_poly = self.translate(final_poly, left, up)
  335. final_poly = np.clip(final_poly, 1, self.subsize)
  336. else:
  337. continue
  338. outline = ' '.join(list(map(str, final_poly)))
  339. if iof >= self.thresh:
  340. outline = outline + ' ' + anno['name'] + ' ' + str(anno[
  341. 'difficult'])
  342. else:
  343. outline = outline + ' ' + anno['name'] + ' ' + '2'
  344. f.write(outline + '\n')
  345. def slice_data_single(self, info, rate, output_dir):
  346. file_name = info['image_file']
  347. base_name = os.path.splitext(os.path.split(file_name)[-1])[0]
  348. base_name = base_name + '__' + str(rate) + '__'
  349. img = cv2.imread(file_name)
  350. if img.shape == ():
  351. return
  352. if (rate != 1):
  353. resize_img = cv2.resize(
  354. img, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC)
  355. else:
  356. resize_img = img
  357. height, width, _ = resize_img.shape
  358. windows = self.get_windows(height, width)
  359. self.slice_image_single(resize_img, windows, output_dir, base_name)
  360. if not self.image_only:
  361. annos = info['annotation']
  362. for anno in annos:
  363. anno['poly'] = list(map(lambda x: rate * x, anno['poly']))
  364. self.slice_anno_single(annos, windows, output_dir, base_name)
  365. def check_or_mkdirs(self, path):
  366. if not os.path.exists(path):
  367. os.makedirs(path, exist_ok=True)
  368. def slice_data(self, infos, rates, output_dir):
  369. """
  370. Args:
  371. infos (list[dict]): data_infos
  372. rates (float, list): scale rates
  373. output_dir (str): output directory
  374. """
  375. if isinstance(rates, Number):
  376. rates = [rates, ]
  377. self.check_or_mkdirs(output_dir)
  378. self.check_or_mkdirs(os.path.join(output_dir, 'images'))
  379. if not self.image_only:
  380. self.check_or_mkdirs(os.path.join(output_dir, 'labelTxt'))
  381. pbar = tqdm(total=len(rates) * len(infos), desc='slicing data')
  382. if self.num_process <= 1:
  383. for rate in rates:
  384. for info in infos:
  385. self.slice_data_single(info, rate, output_dir)
  386. pbar.update()
  387. else:
  388. pool = Pool(self.num_process)
  389. for rate in rates:
  390. for info in infos:
  391. pool.apply_async(
  392. self.slice_data_single, (info, rate, output_dir),
  393. callback=lambda x: pbar.update())
  394. pool.close()
  395. pool.join()
  396. pbar.close()
  397. def load_dataset(input_dir, nproc, data_type):
  398. if 'dota' in data_type.lower():
  399. infos = load_dota_infos(input_dir, nproc)
  400. else:
  401. raise ValueError('only dota dataset is supported now')
  402. return infos
  403. def main():
  404. args = parse_args()
  405. infos = []
  406. for input_dir in args.in_dataset_dir:
  407. infos += load_dataset(input_dir, args.nproc, args.data_type)
  408. slicer = SliceBase(
  409. args.crop_stride,
  410. args.crop_size,
  411. args.iof_thr,
  412. num_process=args.nproc,
  413. image_only=args.image_only)
  414. slicer.slice_data(infos, args.rates, args.out_dataset_dir)
  415. if args.coco_json_file:
  416. infos = load_dota_infos(args.out_dataset_dir, args.nproc)
  417. coco_json_file = os.path.join(args.out_dataset_dir, args.coco_json_file)
  418. class_names = DATA_CLASSES[args.data_type]
  419. data_to_coco(infos, coco_json_file, class_names, args.nproc)
  420. if __name__ == '__main__':
  421. main()