coco.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import copy
  16. import os
  17. import os.path as osp
  18. import random
  19. from collections import OrderedDict
  20. import numpy as np
  21. from .base import BaseDataset
  22. from paddlers.utils import logging, get_encoding, norm_path, is_pic
  23. from paddlers.transforms import DecodeImg, MixupImage
  24. from paddlers.tools import YOLOAnchorCluster
  25. class COCODetection(BaseDataset):
  26. """读取COCO格式的检测数据集,并对样本进行相应的处理。
  27. Args:
  28. data_dir (str): 数据集所在的目录路径。
  29. image_dir (str): 描述数据集图片文件路径。
  30. anno_path (str): COCO标注文件路径。
  31. label_list (str): 描述数据集包含的类别信息文件路径。
  32. transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
  33. num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
  34. 系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
  35. 一半。
  36. shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
  37. allow_empty (bool): 是否加载负样本。默认为False。
  38. empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。
  39. """
  40. def __init__(self,
  41. data_dir,
  42. image_dir,
  43. anno_path,
  44. label_list,
  45. transforms=None,
  46. num_workers='auto',
  47. shuffle=False,
  48. allow_empty=False,
  49. empty_ratio=1.):
  50. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  51. # or matplotlib.backends is imported for the first time
  52. # pycocotools import matplotlib
  53. import matplotlib
  54. matplotlib.use('Agg')
  55. from pycocotools.coco import COCO
  56. super(COCODetection, self).__init__(data_dir, label_list, transforms,
  57. num_workers, shuffle)
  58. self.data_fields = None
  59. self.num_max_boxes = 50
  60. self.use_mix = False
  61. if self.transforms is not None:
  62. for op in self.transforms.transforms:
  63. if isinstance(op, MixupImage):
  64. self.mixup_op = copy.deepcopy(op)
  65. self.use_mix = True
  66. self.num_max_boxes *= 2
  67. break
  68. self.batch_transforms = None
  69. self.allow_empty = allow_empty
  70. self.empty_ratio = empty_ratio
  71. self.file_list = list()
  72. neg_file_list = list()
  73. self.labels = list()
  74. annotations = dict()
  75. annotations['images'] = list()
  76. annotations['categories'] = list()
  77. annotations['annotations'] = list()
  78. cname2cid = OrderedDict()
  79. label_id = 0
  80. with open(label_list, 'r', encoding=get_encoding(label_list)) as f:
  81. for line in f.readlines():
  82. cname2cid[line.strip()] = label_id
  83. label_id += 1
  84. self.labels.append(line.strip())
  85. for k, v in cname2cid.items():
  86. annotations['categories'].append({
  87. 'supercategory': 'component',
  88. 'id': v + 1,
  89. 'name': k
  90. })
  91. anno_path = norm_path(os.path.join(self.data_dir, anno_path))
  92. image_dir = norm_path(os.path.join(self.data_dir, image_dir))
  93. assert anno_path.endswith('.json'), \
  94. 'invalid coco annotation file: ' + anno_path
  95. from pycocotools.coco import COCO
  96. coco = COCO(anno_path)
  97. img_ids = coco.getImgIds()
  98. img_ids.sort()
  99. cat_ids = coco.getCatIds()
  100. ct = 0
  101. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  102. cname2cid = dict({
  103. coco.loadCats(catid)[0]['name']: clsid
  104. for catid, clsid in catid2clsid.items()
  105. })
  106. for img_id in img_ids:
  107. img_anno = coco.loadImgs([img_id])[0]
  108. im_fname = img_anno['file_name']
  109. im_w = float(img_anno['width'])
  110. im_h = float(img_anno['height'])
  111. im_path = os.path.join(image_dir,
  112. im_fname) if image_dir else im_fname
  113. if not os.path.exists(im_path):
  114. logging.warning('Illegal image file: {}, and it will be '
  115. 'ignored'.format(im_path))
  116. continue
  117. if im_w < 0 or im_h < 0:
  118. logging.warning(
  119. 'Illegal width: {} or height: {} in annotation, '
  120. 'and im_id: {} will be ignored'.format(im_w, im_h, img_id))
  121. continue
  122. im_info = {
  123. 'image': im_path,
  124. 'im_id': np.array([img_id]),
  125. 'image_shape': np.array(
  126. [im_h, im_w], dtype=np.int32)
  127. }
  128. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  129. instances = coco.loadAnns(ins_anno_ids)
  130. is_crowds = []
  131. gt_classes = []
  132. gt_bboxs = []
  133. gt_scores = []
  134. difficults = []
  135. for inst in instances:
  136. # check gt bbox
  137. if inst.get('ignore', False):
  138. continue
  139. if 'bbox' not in inst.keys():
  140. continue
  141. else:
  142. if not any(np.array(inst['bbox'])):
  143. continue
  144. # read box
  145. x1, y1, box_w, box_h = inst['bbox']
  146. x2 = x1 + box_w
  147. y2 = y1 + box_h
  148. eps = 1e-5
  149. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  150. inst['clean_bbox'] = [
  151. round(float(x), 3) for x in [x1, y1, x2, y2]
  152. ]
  153. else:
  154. logging.warning(
  155. 'Found an invalid bbox in annotations: im_id: {}, '
  156. 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  157. img_id, float(inst['area']), x1, y1, x2, y2))
  158. is_crowds.append([inst['iscrowd']])
  159. gt_classes.append([inst['category_id']])
  160. gt_bboxs.append(inst['clean_bbox'])
  161. gt_scores.append([1.])
  162. difficults.append([0])
  163. annotations['annotations'].append({
  164. 'iscrowd': inst['iscrowd'],
  165. 'image_id': int(inst['image_id']),
  166. 'bbox': inst['clean_bbox'],
  167. 'area': inst['area'],
  168. 'category_id': inst['category_id'],
  169. 'id': inst['id'],
  170. 'difficult': 0
  171. })
  172. label_info = {
  173. 'is_crowd': np.array(is_crowds),
  174. 'gt_class': np.array(gt_classes),
  175. 'gt_bbox': np.array(gt_bboxs).astype(np.float32),
  176. 'gt_score': np.array(gt_scores).astype(np.float32),
  177. 'difficult': np.array(difficults),
  178. }
  179. if label_info['gt_bbox'].size > 0:
  180. self.file_list.append({ ** im_info, ** label_info})
  181. annotations['images'].append({
  182. 'height': im_h,
  183. 'width': im_w,
  184. 'id': int(im_info['im_id']),
  185. 'file_name': osp.split(im_info['image'])[1]
  186. })
  187. else:
  188. neg_file_list.append({ ** im_info, ** label_info})
  189. ct += 1
  190. if self.use_mix:
  191. self.num_max_boxes = max(self.num_max_boxes, 2 * len(instances))
  192. else:
  193. self.num_max_boxes = max(self.num_max_boxes, len(instances))
  194. if not ct:
  195. logging.error(
  196. "No coco record found in %s' % (file_list)", exit=True)
  197. self.pos_num = len(self.file_list)
  198. if self.allow_empty and neg_file_list:
  199. self.file_list += self._sample_empty(neg_file_list)
  200. logging.info(
  201. "{} samples in file {}, including {} positive samples and {} negative samples.".
  202. format(
  203. len(self.file_list), anno_path, self.pos_num,
  204. len(self.file_list) - self.pos_num))
  205. self.num_samples = len(self.file_list)
  206. self.coco_gt = COCO()
  207. self.coco_gt.dataset = annotations
  208. self.coco_gt.createIndex()
  209. self._epoch = 0
  210. def __getitem__(self, idx):
  211. sample = copy.deepcopy(self.file_list[idx])
  212. if self.data_fields is not None:
  213. sample = {k: sample[k] for k in self.data_fields}
  214. if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
  215. self._epoch < self.mixup_op.mixup_epoch):
  216. if self.num_samples > 1:
  217. mix_idx = random.randint(1, self.num_samples - 1)
  218. mix_pos = (mix_idx + idx) % self.num_samples
  219. else:
  220. mix_pos = 0
  221. sample_mix = copy.deepcopy(self.file_list[mix_pos])
  222. if self.data_fields is not None:
  223. sample_mix = {k: sample_mix[k] for k in self.data_fields}
  224. sample = self.mixup_op(sample=[
  225. DecodeImg(to_rgb=False)(sample),
  226. DecodeImg(to_rgb=False)(sample_mix)
  227. ])
  228. sample = self.transforms(sample)
  229. return sample
  230. def __len__(self):
  231. return self.num_samples
  232. def set_epoch(self, epoch_id):
  233. self._epoch = epoch_id
  234. def cluster_yolo_anchor(self,
  235. num_anchors,
  236. image_size,
  237. cache=True,
  238. cache_path=None,
  239. iters=300,
  240. gen_iters=1000,
  241. thresh=.25):
  242. """
  243. Cluster YOLO anchors.
  244. Reference:
  245. https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
  246. Args:
  247. num_anchors (int): number of clusters
  248. image_size (list or int): [h, w], being an int means image height and image width are the same.
  249. cache (bool): whether using cache
  250. cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset.
  251. iters (int, optional): iters of kmeans algorithm
  252. gen_iters (int, optional): iters of genetic algorithm
  253. threshold (float, optional): anchor scale threshold
  254. verbose (bool, optional): whether print results
  255. """
  256. if cache_path is None:
  257. cache_path = self.data_dir
  258. cluster = YOLOAnchorCluster(
  259. num_anchors=num_anchors,
  260. dataset=self,
  261. image_size=image_size,
  262. cache=cache,
  263. cache_path=cache_path,
  264. iters=iters,
  265. gen_iters=gen_iters,
  266. thresh=thresh)
  267. anchors = cluster()
  268. return anchors
  269. def add_negative_samples(self, image_dir, empty_ratio=1):
  270. """将背景图片加入训练
  271. Args:
  272. image_dir (str):背景图片所在的文件夹目录。
  273. empty_ratio (float or None): 用于指定负样本占总样本数的比例。如果为None,保留数据集初始化是设置的`empty_ratio`值,
  274. 否则更新原有`empty_ratio`值。如果小于0或大于等于1,则保留全部的负样本。默认为1。
  275. """
  276. import cv2
  277. if not osp.isdir(image_dir):
  278. raise Exception("{} is not a valid image directory.".format(
  279. image_dir))
  280. if empty_ratio is not None:
  281. self.empty_ratio = empty_ratio
  282. image_list = os.listdir(image_dir)
  283. max_img_id = max(len(self.file_list) - 1, max(self.coco_gt.getImgIds()))
  284. neg_file_list = list()
  285. for image in image_list:
  286. if not is_pic(image):
  287. continue
  288. gt_bbox = np.zeros((0, 4), dtype=np.float32)
  289. gt_class = np.zeros((0, 1), dtype=np.int32)
  290. gt_score = np.zeros((0, 1), dtype=np.float32)
  291. is_crowd = np.zeros((0, 1), dtype=np.int32)
  292. difficult = np.zeros((0, 1), dtype=np.int32)
  293. max_img_id += 1
  294. im_fname = osp.join(image_dir, image)
  295. img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
  296. im_h, im_w, im_c = img_data.shape
  297. im_info = {
  298. 'im_id': np.asarray([max_img_id]),
  299. 'image_shape': np.array(
  300. [im_h, im_w], dtype=np.int32)
  301. }
  302. label_info = {
  303. 'is_crowd': is_crowd,
  304. 'gt_class': gt_class,
  305. 'gt_bbox': gt_bbox,
  306. 'gt_score': gt_score,
  307. 'difficult': difficult
  308. }
  309. if 'gt_poly' in self.file_list[0]:
  310. label_info['gt_poly'] = []
  311. neg_file_list.append({'image': im_fname, ** im_info, ** label_info})
  312. if neg_file_list:
  313. self.allow_empty = True
  314. self.file_list += self._sample_empty(neg_file_list)
  315. logging.info(
  316. "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
  317. format(
  318. len(self.file_list) - self.num_samples, self.pos_num,
  319. len(self.file_list) - self.pos_num))
  320. self.num_samples = len(self.file_list)
  321. def _sample_empty(self, neg_file_list):
  322. if 0. <= self.empty_ratio < 1.:
  323. import random
  324. total_num = len(self.file_list)
  325. neg_num = total_num - self.pos_num
  326. sample_num = min((total_num * self.empty_ratio - neg_num) //
  327. (1 - self.empty_ratio), len(neg_file_list))
  328. return random.sample(neg_file_list, sample_num)
  329. else:
  330. return neg_file_list