123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import copy
- import json
- import math
- import os
- from multiprocessing import Pool
- from numbers import Number
- import cv2
- import numpy as np
- import shapely.geometry as shgeo
- from tqdm import tqdm
- from common import add_crop_options
- wordname_15 = [
- 'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
- 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
- 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
- 'harbor', 'swimming-pool', 'helicopter'
- ]
- wordname_16 = wordname_15 + ['container-crane']
- wordname_18 = wordname_16 + ['airport', 'helipad']
- DATA_CLASSES = {
- 'dota10': wordname_15,
- 'dota15': wordname_16,
- 'dota20': wordname_18
- }
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--in_dataset_dir',
- type=str,
- nargs='+',
- required=True,
- help="Input dataset directories.")
- parser.add_argument(
- '--out_dataset_dir', type=str, help="Output dataset directory.")
- parser = add_crop_options(parser)
- parser.add_argument(
- '--coco_json_file',
- type=str,
- default='',
- help="COCO JSON annotation files.")
- parser.add_argument(
- '--rates',
- nargs='+',
- type=float,
- default=[1.],
- help="Scales for cropping multi-scale samples.")
- parser.add_argument(
- '--nproc', type=int, default=8, help="Number of processes to use.")
- parser.add_argument(
- '--iof_thr',
- type=float,
- default=0.5,
- help="Minimal IoF between an object and a window.")
- parser.add_argument(
- '--image_only',
- action='store_true',
- default=False,
- help="To process images only.")
- parser.add_argument(
- '--data_type', type=str, default='dota10', help="Type of dataset.")
- args = parser.parse_args()
- return args
- def load_dota_info(image_dir, anno_dir, file_name, ext=None):
- base_name, extension = os.path.splitext(file_name)
- if ext and (extension != ext and extension not in ext):
- return None
- info = {'image_file': os.path.join(image_dir, file_name), 'annotation': []}
- anno_file = os.path.join(anno_dir, base_name + '.txt')
- if not os.path.exists(anno_file):
- return info
- with open(anno_file, 'r') as f:
- for line in f:
- items = line.strip().split()
- if (len(items) < 9):
- continue
- anno = {
- 'poly': list(map(float, items[:8])),
- 'name': items[8],
- 'difficult': '0' if len(items) == 9 else items[9],
- }
- info['annotation'].append(anno)
- return info
- def load_dota_infos(root_dir, num_process=8, ext=None):
- image_dir = os.path.join(root_dir, 'images')
- anno_dir = os.path.join(root_dir, 'labelTxt')
- data_infos = []
- if num_process > 1:
- pool = Pool(num_process)
- results = []
- for file_name in os.listdir(image_dir):
- results.append(
- pool.apply_async(load_dota_info, (image_dir, anno_dir,
- file_name, ext)))
- pool.close()
- pool.join()
- for result in results:
- info = result.get()
- if info:
- data_infos.append(info)
- else:
- for file_name in os.listdir(image_dir):
- info = load_dota_info(image_dir, anno_dir, file_name, ext)
- if info:
- data_infos.append(info)
- return data_infos
- def process_single_sample(info, image_id, class_names):
- image_file = info['image_file']
- single_image = dict()
- single_image['file_name'] = os.path.split(image_file)[-1]
- single_image['id'] = image_id
- image = cv2.imread(image_file)
- height, width, _ = image.shape
- single_image['width'] = width
- single_image['height'] = height
- # process annotation field
- single_objs = []
- objects = info['annotation']
- for obj in objects:
- poly, name, difficult = obj['poly'], obj['name'], obj['difficult']
- if difficult == '2':
- continue
- single_obj = dict()
- single_obj['category_id'] = class_names.index(name) + 1
- single_obj['segmentation'] = [poly]
- single_obj['iscrowd'] = 0
- xmin, ymin, xmax, ymax = min(poly[0::2]), min(poly[1::2]), max(poly[
- 0::2]), max(poly[1::2])
- width, height = xmax - xmin, ymax - ymin
- single_obj['bbox'] = [xmin, ymin, width, height]
- single_obj['area'] = height * width
- single_obj['image_id'] = image_id
- single_objs.append(single_obj)
- return (single_image, single_objs)
- def data_to_coco(infos, output_path, class_names, num_process):
- data_dict = dict()
- data_dict['categories'] = []
- for i, name in enumerate(class_names):
- data_dict['categories'].append({
- 'id': i + 1,
- 'name': name,
- 'supercategory': name
- })
- pbar = tqdm(total=len(infos), desc='data to coco')
- images, annotations = [], []
- if num_process > 1:
- pool = Pool(num_process)
- results = []
- for i, info in enumerate(infos):
- image_id = i + 1
- results.append(
- pool.apply_async(
- process_single_sample, (info, image_id, class_names),
- callback=lambda x: pbar.update()))
- pool.close()
- pool.join()
- for result in results:
- single_image, single_anno = result.get()
- images.append(single_image)
- annotations += single_anno
- else:
- for i, info in enumerate(infos):
- image_id = i + 1
- single_image, single_anno = process_single_sample(info, image_id,
- class_names)
- images.append(single_image)
- annotations += single_anno
- pbar.update()
- pbar.close()
- for i, anno in enumerate(annotations):
- anno['id'] = i + 1
- data_dict['images'] = images
- data_dict['annotations'] = annotations
- with open(output_path, 'w') as f:
- json.dump(data_dict, f)
- def choose_best_pointorder_fit_another(poly1, poly2):
- """
- To make the two polygons best fit with each point
- """
- x1, y1, x2, y2, x3, y3, x4, y4 = poly1
- combinate = [
- np.array([x1, y1, x2, y2, x3, y3, x4, y4]),
- np.array([x2, y2, x3, y3, x4, y4, x1, y1]),
- np.array([x3, y3, x4, y4, x1, y1, x2, y2]),
- np.array([x4, y4, x1, y1, x2, y2, x3, y3])
- ]
- dst_coordinate = np.array(poly2)
- distances = np.array(
- [np.sum((coord - dst_coordinate)**2) for coord in combinate])
- sorted = distances.argsort()
- return combinate[sorted[0]]
- def cal_line_length(point1, point2):
- return math.sqrt(
- math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))
- class SliceBase(object):
- def __init__(self,
- gap=512,
- subsize=1024,
- thresh=0.7,
- choosebestpoint=True,
- ext='.png',
- padding=True,
- num_process=8,
- image_only=False):
- self.gap = gap
- self.subsize = subsize
- self.slide = subsize - gap
- self.thresh = thresh
- self.choosebestpoint = choosebestpoint
- self.ext = ext
- self.padding = padding
- self.num_process = num_process
- self.image_only = image_only
- def get_windows(self, height, width):
- windows = []
- left, up = 0, 0
- while (left < width):
- if (left + self.subsize >= width):
- left = max(width - self.subsize, 0)
- up = 0
- while (up < height):
- if (up + self.subsize >= height):
- up = max(height - self.subsize, 0)
- right = min(left + self.subsize, width - 1)
- down = min(up + self.subsize, height - 1)
- windows.append((left, up, right, down))
- if (up + self.subsize >= height):
- break
- else:
- up = up + self.slide
- if (left + self.subsize >= width):
- break
- else:
- left = left + self.slide
- return windows
- def slice_image_single(self, image, windows, output_dir, output_name):
- image_dir = os.path.join(output_dir, 'images')
- for (left, up, right, down) in windows:
- image_name = output_name + str(left) + '___' + str(up) + self.ext
- subimg = copy.deepcopy(image[up:up + self.subsize, left:left +
- self.subsize])
- h, w, c = subimg.shape
- if (self.padding):
- outimg = np.zeros((self.subsize, self.subsize, 3))
- outimg[0:h, 0:w, :] = subimg
- cv2.imwrite(os.path.join(image_dir, image_name), outimg)
- else:
- cv2.imwrite(os.path.join(image_dir, image_name), subimg)
- def iof(self, poly1, poly2):
- inter_poly = poly1.intersection(poly2)
- inter_area = inter_poly.area
- poly1_area = poly1.area
- half_iou = inter_area / poly1_area
- return inter_poly, half_iou
- def translate(self, poly, left, up):
- n = len(poly)
- out_poly = np.zeros(n)
- for i in range(n // 2):
- out_poly[i * 2] = int(poly[i * 2] - left)
- out_poly[i * 2 + 1] = int(poly[i * 2 + 1] - up)
- return out_poly
- def get_poly4_from_poly5(self, poly):
- distances = [
- cal_line_length((poly[i * 2], poly[i * 2 + 1]),
- (poly[(i + 1) * 2], poly[(i + 1) * 2 + 1]))
- for i in range(int(len(poly) / 2 - 1))
- ]
- distances.append(
- cal_line_length((poly[0], poly[1]), (poly[8], poly[9])))
- pos = np.array(distances).argsort()[0]
- count = 0
- out_poly = []
- while count < 5:
- if (count == pos):
- out_poly.append(
- (poly[count * 2] + poly[(count * 2 + 2) % 10]) / 2)
- out_poly.append(
- (poly[(count * 2 + 1) % 10] + poly[(count * 2 + 3) % 10]) /
- 2)
- count = count + 1
- elif (count == (pos + 1) % 5):
- count = count + 1
- continue
- else:
- out_poly.append(poly[count * 2])
- out_poly.append(poly[count * 2 + 1])
- count = count + 1
- return out_poly
- def slice_anno_single(self, annos, windows, output_dir, output_name):
- anno_dir = os.path.join(output_dir, 'labelTxt')
- for (left, up, right, down) in windows:
- image_poly = shgeo.Polygon(
- [(left, up), (right, up), (right, down), (left, down)])
- anno_file = output_name + str(left) + '___' + str(up) + '.txt'
- with open(os.path.join(anno_dir, anno_file), 'w') as f:
- for anno in annos:
- gt_poly = shgeo.Polygon(
- [(anno['poly'][0], anno['poly'][1]),
- (anno['poly'][2], anno['poly'][3]),
- (anno['poly'][4], anno['poly'][5]),
- (anno['poly'][6], anno['poly'][7])])
- if gt_poly.area <= 0:
- continue
- inter_poly, iof = self.iof(gt_poly, image_poly)
- if iof == 1:
- final_poly = self.translate(anno['poly'], left, up)
- elif iof > 0:
- inter_poly = shgeo.polygon.orient(inter_poly, sign=1)
- out_poly = list(inter_poly.exterior.coords)[0:-1]
- if len(out_poly) < 4 or len(out_poly) > 5:
- continue
- final_poly = []
- for p in out_poly:
- final_poly.append(p[0])
- final_poly.append(p[1])
- if len(out_poly) == 5:
- final_poly = self.get_poly4_from_poly5(final_poly)
- if self.choosebestpoint:
- final_poly = choose_best_pointorder_fit_another(
- final_poly, anno['poly'])
- final_poly = self.translate(final_poly, left, up)
- final_poly = np.clip(final_poly, 1, self.subsize)
- else:
- continue
- outline = ' '.join(list(map(str, final_poly)))
- if iof >= self.thresh:
- outline = outline + ' ' + anno['name'] + ' ' + str(anno[
- 'difficult'])
- else:
- outline = outline + ' ' + anno['name'] + ' ' + '2'
- f.write(outline + '\n')
- def slice_data_single(self, info, rate, output_dir):
- file_name = info['image_file']
- base_name = os.path.splitext(os.path.split(file_name)[-1])[0]
- base_name = base_name + '__' + str(rate) + '__'
- img = cv2.imread(file_name)
- if img.shape == ():
- return
- if (rate != 1):
- resize_img = cv2.resize(
- img, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC)
- else:
- resize_img = img
- height, width, _ = resize_img.shape
- windows = self.get_windows(height, width)
- self.slice_image_single(resize_img, windows, output_dir, base_name)
- if not self.image_only:
- annos = info['annotation']
- for anno in annos:
- anno['poly'] = list(map(lambda x: rate * x, anno['poly']))
- self.slice_anno_single(annos, windows, output_dir, base_name)
- def check_or_mkdirs(self, path):
- if not os.path.exists(path):
- os.makedirs(path, exist_ok=True)
- def slice_data(self, infos, rates, output_dir):
- """
- Args:
- infos (list[dict]): data_infos
- rates (float, list): scale rates
- output_dir (str): output directory
- """
- if isinstance(rates, Number):
- rates = [rates, ]
- self.check_or_mkdirs(output_dir)
- self.check_or_mkdirs(os.path.join(output_dir, 'images'))
- if not self.image_only:
- self.check_or_mkdirs(os.path.join(output_dir, 'labelTxt'))
- pbar = tqdm(total=len(rates) * len(infos), desc='slicing data')
- if self.num_process <= 1:
- for rate in rates:
- for info in infos:
- self.slice_data_single(info, rate, output_dir)
- pbar.update()
- else:
- pool = Pool(self.num_process)
- for rate in rates:
- for info in infos:
- pool.apply_async(
- self.slice_data_single, (info, rate, output_dir),
- callback=lambda x: pbar.update())
- pool.close()
- pool.join()
- pbar.close()
- def load_dataset(input_dir, nproc, data_type):
- if 'dota' in data_type.lower():
- infos = load_dota_infos(input_dir, nproc)
- else:
- raise ValueError('only dota dataset is supported now')
- return infos
- def main():
- args = parse_args()
- infos = []
- for input_dir in args.in_dataset_dir:
- infos += load_dataset(input_dir, args.nproc, args.data_type)
- slicer = SliceBase(
- args.crop_stride,
- args.crop_size,
- args.iof_thr,
- num_process=args.nproc,
- image_only=args.image_only)
- slicer.slice_data(infos, args.rates, args.out_dataset_dir)
- if args.coco_json_file:
- infos = load_dota_infos(args.out_dataset_dir, args.nproc)
- coco_json_file = os.path.join(args.out_dataset_dir, args.coco_json_file)
- class_names = DATA_CLASSES[args.data_type]
- data_to_coco(infos, coco_json_file, class_names, args.nproc)
- if __name__ == '__main__':
- main()
|