coco.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import copy
  16. import os
  17. import os.path as osp
  18. import random
  19. from collections import OrderedDict
  20. import numpy as np
  21. from .base import BaseDataset
  22. from paddlers.utils import logging, get_encoding, norm_path, is_pic
  23. from paddlers.transforms import DecodeImg, MixupImage
  24. from paddlers.tools import YOLOAnchorCluster
  25. class COCODetDataset(BaseDataset):
  26. """
  27. Dataset with COCO annotations for detection tasks.
  28. Args:
  29. data_dir (str): Root directory of the dataset.
  30. image_dir (str): Directory that contains the images.
  31. ann_path (str): Path to COCO annotations.
  32. transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
  33. label_list (str, optional): Path of the file that contains the category names. Defaults to None.
  34. num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
  35. the number of workers will be automatically determined according to the number of CPU cores: If
  36. there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half
  37. the number of CPU cores. Defaults: 'auto'.
  38. shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
  39. allow_empty (bool, optional): Whether to add negative samples. Defaults to False.
  40. empty_ratio (float, optional): Ratio of negative samples. If `empty_ratio` is smaller than 0 or not less
  41. than 1, keep all generated negative samples. Defaults to 1.0.
  42. """
  43. def __init__(self,
  44. data_dir,
  45. image_dir,
  46. anno_path,
  47. transforms,
  48. label_list,
  49. num_workers='auto',
  50. shuffle=False,
  51. allow_empty=False,
  52. empty_ratio=1.):
  53. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  54. # or matplotlib.backends is imported for the first time.
  55. import matplotlib
  56. matplotlib.use('Agg')
  57. from pycocotools.coco import COCO
  58. super(COCODetDataset, self).__init__(data_dir, label_list, transforms,
  59. num_workers, shuffle)
  60. self.data_fields = None
  61. self.num_max_boxes = 50
  62. self.use_mix = False
  63. if self.transforms is not None:
  64. for op in self.transforms.transforms:
  65. if isinstance(op, MixupImage):
  66. self.mixup_op = copy.deepcopy(op)
  67. self.use_mix = True
  68. self.num_max_boxes *= 2
  69. break
  70. self.batch_transforms = None
  71. self.allow_empty = allow_empty
  72. self.empty_ratio = empty_ratio
  73. self.file_list = list()
  74. neg_file_list = list()
  75. self.labels = list()
  76. annotations = dict()
  77. annotations['images'] = list()
  78. annotations['categories'] = list()
  79. annotations['annotations'] = list()
  80. cname2cid = OrderedDict()
  81. label_id = 0
  82. with open(label_list, 'r', encoding=get_encoding(label_list)) as f:
  83. for line in f.readlines():
  84. cname2cid[line.strip()] = label_id
  85. label_id += 1
  86. self.labels.append(line.strip())
  87. for k, v in cname2cid.items():
  88. annotations['categories'].append({
  89. 'supercategory': 'component',
  90. 'id': v + 1,
  91. 'name': k
  92. })
  93. anno_path = norm_path(os.path.join(self.data_dir, anno_path))
  94. image_dir = norm_path(os.path.join(self.data_dir, image_dir))
  95. assert anno_path.endswith('.json'), \
  96. 'invalid coco annotation file: ' + anno_path
  97. from pycocotools.coco import COCO
  98. coco = COCO(anno_path)
  99. img_ids = coco.getImgIds()
  100. img_ids.sort()
  101. cat_ids = coco.getCatIds()
  102. ct = 0
  103. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  104. cname2cid = dict({
  105. coco.loadCats(catid)[0]['name']: clsid
  106. for catid, clsid in catid2clsid.items()
  107. })
  108. for img_id in img_ids:
  109. img_anno = coco.loadImgs([img_id])[0]
  110. im_fname = img_anno['file_name']
  111. im_w = float(img_anno['width'])
  112. im_h = float(img_anno['height'])
  113. im_path = os.path.join(image_dir,
  114. im_fname) if image_dir else im_fname
  115. if not os.path.exists(im_path):
  116. logging.warning('Illegal image file: {}, and it will be '
  117. 'ignored'.format(im_path))
  118. continue
  119. if im_w < 0 or im_h < 0:
  120. logging.warning(
  121. 'Illegal width: {} or height: {} in annotation, '
  122. 'and im_id: {} will be ignored'.format(im_w, im_h, img_id))
  123. continue
  124. im_info = {
  125. 'image': im_path,
  126. 'im_id': np.array([img_id]),
  127. 'image_shape': np.array(
  128. [im_h, im_w], dtype=np.int32)
  129. }
  130. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  131. instances = coco.loadAnns(ins_anno_ids)
  132. is_crowds = []
  133. gt_classes = []
  134. gt_bboxs = []
  135. gt_scores = []
  136. difficults = []
  137. for inst in instances:
  138. # Check gt bbox
  139. if inst.get('ignore', False):
  140. continue
  141. if 'bbox' not in inst.keys():
  142. continue
  143. else:
  144. if not any(np.array(inst['bbox'])):
  145. continue
  146. # Read the box
  147. x1, y1, box_w, box_h = inst['bbox']
  148. x2 = x1 + box_w
  149. y2 = y1 + box_h
  150. eps = 1e-5
  151. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  152. inst['clean_bbox'] = [
  153. round(float(x), 3) for x in [x1, y1, x2, y2]
  154. ]
  155. else:
  156. logging.warning(
  157. 'Found an invalid bbox in annotations: im_id: {}, '
  158. 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  159. img_id, float(inst['area']), x1, y1, x2, y2))
  160. is_crowds.append([inst['iscrowd']])
  161. gt_classes.append([inst['category_id']])
  162. gt_bboxs.append(inst['clean_bbox'])
  163. gt_scores.append([1.])
  164. difficults.append([0])
  165. annotations['annotations'].append({
  166. 'iscrowd': inst['iscrowd'],
  167. 'image_id': int(inst['image_id']),
  168. 'bbox': inst['clean_bbox'],
  169. 'area': inst['area'],
  170. 'category_id': inst['category_id'],
  171. 'id': inst['id'],
  172. 'difficult': 0
  173. })
  174. label_info = {
  175. 'is_crowd': np.array(is_crowds),
  176. 'gt_class': np.array(gt_classes),
  177. 'gt_bbox': np.array(gt_bboxs).astype(np.float32),
  178. 'gt_score': np.array(gt_scores).astype(np.float32),
  179. 'difficult': np.array(difficults),
  180. }
  181. if label_info['gt_bbox'].size > 0:
  182. self.file_list.append({ ** im_info, ** label_info})
  183. annotations['images'].append({
  184. 'height': im_h,
  185. 'width': im_w,
  186. 'id': int(im_info['im_id']),
  187. 'file_name': osp.split(im_info['image'])[1]
  188. })
  189. else:
  190. neg_file_list.append({ ** im_info, ** label_info})
  191. ct += 1
  192. if self.use_mix:
  193. self.num_max_boxes = max(self.num_max_boxes, 2 * len(instances))
  194. else:
  195. self.num_max_boxes = max(self.num_max_boxes, len(instances))
  196. if not ct:
  197. logging.error(
  198. "No coco record found in %s' % (file_list)", exit=True)
  199. self.pos_num = len(self.file_list)
  200. if self.allow_empty and neg_file_list:
  201. self.file_list += self._sample_empty(neg_file_list)
  202. logging.info(
  203. "{} samples in file {}, including {} positive samples and {} negative samples.".
  204. format(
  205. len(self.file_list), anno_path, self.pos_num,
  206. len(self.file_list) - self.pos_num))
  207. self.num_samples = len(self.file_list)
  208. self.coco_gt = COCO()
  209. self.coco_gt.dataset = annotations
  210. self.coco_gt.createIndex()
  211. self._epoch = 0
  212. def __getitem__(self, idx):
  213. sample = copy.deepcopy(self.file_list[idx])
  214. if self.data_fields is not None:
  215. sample = {k: sample[k] for k in self.data_fields}
  216. if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
  217. self._epoch < self.mixup_op.mixup_epoch):
  218. if self.num_samples > 1:
  219. mix_idx = random.randint(1, self.num_samples - 1)
  220. mix_pos = (mix_idx + idx) % self.num_samples
  221. else:
  222. mix_pos = 0
  223. sample_mix = copy.deepcopy(self.file_list[mix_pos])
  224. if self.data_fields is not None:
  225. sample_mix = {k: sample_mix[k] for k in self.data_fields}
  226. sample = self.mixup_op(sample=[
  227. DecodeImg(to_rgb=False)(sample),
  228. DecodeImg(to_rgb=False)(sample_mix)
  229. ])
  230. sample = self.transforms(sample)
  231. return sample
  232. def __len__(self):
  233. return self.num_samples
  234. def set_epoch(self, epoch_id):
  235. self._epoch = epoch_id
  236. def cluster_yolo_anchor(self,
  237. num_anchors,
  238. image_size,
  239. cache=True,
  240. cache_path=None,
  241. iters=300,
  242. gen_iters=1000,
  243. thresh=.25):
  244. """
  245. Cluster YOLO anchors.
  246. Reference:
  247. https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
  248. Args:
  249. num_anchors (int): Number of clusters.
  250. image_size (list[int]|int): [h, w] or an int value that corresponds to the shape [image_size, image_size].
  251. cache (bool, optional): Whether to use cache. Defaults to True.
  252. cache_path (str|None, optional): Path of cache directory. If None, use `dataset.data_dir`.
  253. Defaults to None.
  254. iters (int, optional): Iterations of k-means algorithm. Defaults to 300.
  255. gen_iters (int, optional): Iterations of genetic algorithm. Defaults to 1000.
  256. thresh (float, optional): Anchor scale threshold. Defaults to 0.25.
  257. """
  258. if cache_path is None:
  259. cache_path = self.data_dir
  260. cluster = YOLOAnchorCluster(
  261. num_anchors=num_anchors,
  262. dataset=self,
  263. image_size=image_size,
  264. cache=cache,
  265. cache_path=cache_path,
  266. iters=iters,
  267. gen_iters=gen_iters,
  268. thresh=thresh)
  269. anchors = cluster()
  270. return anchors
  271. def add_negative_samples(self, image_dir, empty_ratio=1):
  272. """
  273. Generate and add negative samples.
  274. Args:
  275. image_dir (str): Directory that contains images.
  276. empty_ratio (float|None, optional): Ratio of negative samples. If `empty_ratio` is smaller than
  277. 0 or not less than 1, keep all generated negative samples. Defaults to 1.0.
  278. """
  279. import cv2
  280. if not osp.isdir(image_dir):
  281. raise ValueError("{} is not a valid image directory.".format(
  282. image_dir))
  283. if empty_ratio is not None:
  284. self.empty_ratio = empty_ratio
  285. image_list = os.listdir(image_dir)
  286. max_img_id = max(len(self.file_list) - 1, max(self.coco_gt.getImgIds()))
  287. neg_file_list = list()
  288. for image in image_list:
  289. if not is_pic(image):
  290. continue
  291. gt_bbox = np.zeros((0, 4), dtype=np.float32)
  292. gt_class = np.zeros((0, 1), dtype=np.int32)
  293. gt_score = np.zeros((0, 1), dtype=np.float32)
  294. is_crowd = np.zeros((0, 1), dtype=np.int32)
  295. difficult = np.zeros((0, 1), dtype=np.int32)
  296. max_img_id += 1
  297. im_fname = osp.join(image_dir, image)
  298. img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
  299. im_h, im_w, im_c = img_data.shape
  300. im_info = {
  301. 'im_id': np.asarray([max_img_id]),
  302. 'image_shape': np.array(
  303. [im_h, im_w], dtype=np.int32)
  304. }
  305. label_info = {
  306. 'is_crowd': is_crowd,
  307. 'gt_class': gt_class,
  308. 'gt_bbox': gt_bbox,
  309. 'gt_score': gt_score,
  310. 'difficult': difficult
  311. }
  312. if 'gt_poly' in self.file_list[0]:
  313. label_info['gt_poly'] = []
  314. neg_file_list.append({'image': im_fname, ** im_info, ** label_info})
  315. if neg_file_list:
  316. self.allow_empty = True
  317. self.file_list += self._sample_empty(neg_file_list)
  318. logging.info(
  319. "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
  320. format(
  321. len(self.file_list) - self.num_samples, self.pos_num,
  322. len(self.file_list) - self.pos_num))
  323. self.num_samples = len(self.file_list)
  324. def _sample_empty(self, neg_file_list):
  325. if 0. <= self.empty_ratio < 1.:
  326. import random
  327. total_num = len(self.file_list)
  328. neg_num = total_num - self.pos_num
  329. sample_num = min((total_num * self.empty_ratio - neg_num) //
  330. (1 - self.empty_ratio), len(neg_file_list))
  331. return random.sample(neg_file_list, sample_num)
  332. else:
  333. return neg_file_list