voc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import copy
  16. import os
  17. import os.path as osp
  18. import random
  19. import re
  20. from collections import OrderedDict
  21. import xml.etree.ElementTree as ET
  22. import numpy as np
  23. from .base import BaseDataset
  24. from paddlers.utils import logging, get_encoding, norm_path, is_pic
  25. from paddlers.transforms import DecodeImg, MixupImage
  26. from paddlers.tools import YOLOAnchorCluster
  27. class VOCDetDataset(BaseDataset):
  28. """
  29. Dataset with PASCAL VOC annotations for detection tasks.
  30. Args:
  31. data_dir (str): Root directory of the dataset.
  32. file_list (str): Path of the file that contains relative paths of images and annotation files.
  33. transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
  34. label_list (str, optional): Path of the file that contains the category names. Defaults to None.
  35. num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
  36. the number of workers will be automatically determined according to the number of CPU cores: If
  37. there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half
  38. the number of CPU cores. Defaults: 'auto'.
  39. shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
  40. allow_empty (bool, optional): Whether to add negative samples. Defaults to False.
  41. empty_ratio (float, optional): Ratio of negative samples. If `empty_ratio` is smaller than 0 or not less
  42. than 1, keep all generated negative samples. Defaults to 1.0.
  43. """
  44. def __init__(self,
  45. data_dir,
  46. file_list,
  47. transforms,
  48. label_list,
  49. num_workers='auto',
  50. shuffle=False,
  51. allow_empty=False,
  52. empty_ratio=1.):
  53. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  54. # or matplotlib.backends is imported for the first time
  55. # pycocotools import matplotlib
  56. import matplotlib
  57. matplotlib.use('Agg')
  58. from pycocotools.coco import COCO
  59. super(VOCDetDataset, self).__init__(data_dir, label_list, transforms,
  60. num_workers, shuffle)
  61. self.data_fields = None
  62. self.num_max_boxes = 50
  63. self.use_mix = False
  64. if self.transforms is not None:
  65. for op in self.transforms.transforms:
  66. if isinstance(op, MixupImage):
  67. self.mixup_op = copy.deepcopy(op)
  68. self.use_mix = True
  69. self.num_max_boxes *= 2
  70. break
  71. self.batch_transforms = None
  72. self.allow_empty = allow_empty
  73. self.empty_ratio = empty_ratio
  74. self.file_list = list()
  75. neg_file_list = list()
  76. self.labels = list()
  77. annotations = dict()
  78. annotations['images'] = list()
  79. annotations['categories'] = list()
  80. annotations['annotations'] = list()
  81. cname2cid = OrderedDict()
  82. label_id = 0
  83. with open(label_list, 'r', encoding=get_encoding(label_list)) as f:
  84. for line in f.readlines():
  85. cname2cid[line.strip()] = label_id
  86. label_id += 1
  87. self.labels.append(line.strip())
  88. logging.info("Starting to read file list from dataset...")
  89. for k, v in cname2cid.items():
  90. annotations['categories'].append({
  91. 'supercategory': 'component',
  92. 'id': v + 1,
  93. 'name': k
  94. })
  95. ct = 0
  96. ann_ct = 0
  97. with open(file_list, 'r', encoding=get_encoding(file_list)) as f:
  98. while True:
  99. line = f.readline()
  100. if not line:
  101. break
  102. if len(line.strip().split()) > 2:
  103. raise ValueError("A space is defined as the separator, "
  104. "but it exists in image or label name {}."
  105. .format(line))
  106. img_file, xml_file = [
  107. osp.join(data_dir, x) for x in line.strip().split()[:2]
  108. ]
  109. img_file = norm_path(img_file)
  110. xml_file = norm_path(xml_file)
  111. if not is_pic(img_file):
  112. continue
  113. if not osp.isfile(xml_file):
  114. continue
  115. if not osp.exists(img_file):
  116. logging.warning('The image file {} does not exist!'.format(
  117. img_file))
  118. continue
  119. if not osp.exists(xml_file):
  120. logging.warning('The annotation file {} does not exist!'.
  121. format(xml_file))
  122. continue
  123. tree = ET.parse(xml_file)
  124. if tree.find('id') is None:
  125. im_id = np.asarray([ct])
  126. else:
  127. ct = int(tree.find('id').text)
  128. im_id = np.asarray([int(tree.find('id').text)])
  129. pattern = re.compile('<size>', re.IGNORECASE)
  130. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))
  131. if len(size_tag) > 0:
  132. size_tag = size_tag[0][1:-1]
  133. size_element = tree.find(size_tag)
  134. pattern = re.compile('<width>', re.IGNORECASE)
  135. width_tag = pattern.findall(
  136. str(ET.tostringlist(size_element)))[0][1:-1]
  137. im_w = float(size_element.find(width_tag).text)
  138. pattern = re.compile('<height>', re.IGNORECASE)
  139. height_tag = pattern.findall(
  140. str(ET.tostringlist(size_element)))[0][1:-1]
  141. im_h = float(size_element.find(height_tag).text)
  142. else:
  143. im_w = 0
  144. im_h = 0
  145. pattern = re.compile('<object>', re.IGNORECASE)
  146. obj_match = pattern.findall(
  147. str(ET.tostringlist(tree.getroot())))
  148. if len(obj_match) > 0:
  149. obj_tag = obj_match[0][1:-1]
  150. objs = tree.findall(obj_tag)
  151. else:
  152. objs = list()
  153. num_bbox, i = len(objs), 0
  154. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  155. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  156. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  157. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  158. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  159. for obj in objs:
  160. pattern = re.compile('<name>', re.IGNORECASE)
  161. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:
  162. -1]
  163. cname = obj.find(name_tag).text.strip()
  164. pattern = re.compile('<difficult>', re.IGNORECASE)
  165. diff_tag = pattern.findall(str(ET.tostringlist(obj)))
  166. if len(diff_tag) == 0:
  167. _difficult = 0
  168. else:
  169. diff_tag = diff_tag[0][1:-1]
  170. try:
  171. _difficult = int(obj.find(diff_tag).text)
  172. except Exception:
  173. _difficult = 0
  174. pattern = re.compile('<bndbox>', re.IGNORECASE)
  175. box_tag = pattern.findall(str(ET.tostringlist(obj)))
  176. if len(box_tag) == 0:
  177. logging.warning(
  178. "There is no field '<bndbox>' in the object, "
  179. "so this object will be ignored. xml file: {}".
  180. format(xml_file))
  181. continue
  182. box_tag = box_tag[0][1:-1]
  183. box_element = obj.find(box_tag)
  184. pattern = re.compile('<xmin>', re.IGNORECASE)
  185. xmin_tag = pattern.findall(
  186. str(ET.tostringlist(box_element)))[0][1:-1]
  187. x1 = float(box_element.find(xmin_tag).text)
  188. pattern = re.compile('<ymin>', re.IGNORECASE)
  189. ymin_tag = pattern.findall(
  190. str(ET.tostringlist(box_element)))[0][1:-1]
  191. y1 = float(box_element.find(ymin_tag).text)
  192. pattern = re.compile('<xmax>', re.IGNORECASE)
  193. xmax_tag = pattern.findall(
  194. str(ET.tostringlist(box_element)))[0][1:-1]
  195. x2 = float(box_element.find(xmax_tag).text)
  196. pattern = re.compile('<ymax>', re.IGNORECASE)
  197. ymax_tag = pattern.findall(
  198. str(ET.tostringlist(box_element)))[0][1:-1]
  199. y2 = float(box_element.find(ymax_tag).text)
  200. x1 = max(0, x1)
  201. y1 = max(0, y1)
  202. if im_w > 0.5 and im_h > 0.5:
  203. x2 = min(im_w - 1, x2)
  204. y2 = min(im_h - 1, y2)
  205. if not (x2 >= x1 and y2 >= y1):
  206. logging.warning(
  207. "Bounding box for object {} does not satisfy xmin {} <= xmax {} and ymin {} <= ymax {}, "
  208. "so this object is skipped. xml file: {}".format(
  209. i, x1, x2, y1, y2, xml_file))
  210. continue
  211. gt_bbox[i, :] = [x1, y1, x2, y2]
  212. gt_class[i, 0] = cname2cid[cname]
  213. gt_score[i, 0] = 1.
  214. is_crowd[i, 0] = 0
  215. difficult[i, 0] = _difficult
  216. i += 1
  217. annotations['annotations'].append({
  218. 'iscrowd': 0,
  219. 'image_id': int(im_id[0]),
  220. 'bbox': [x1, y1, x2 - x1, y2 - y1],
  221. 'area': float((x2 - x1) * (y2 - y1)),
  222. 'category_id': cname2cid[cname] + 1,
  223. 'id': ann_ct,
  224. 'difficult': _difficult
  225. })
  226. ann_ct += 1
  227. gt_bbox = gt_bbox[:i, :]
  228. gt_class = gt_class[:i, :]
  229. gt_score = gt_score[:i, :]
  230. is_crowd = is_crowd[:i, :]
  231. difficult = difficult[:i, :]
  232. im_info = {
  233. 'im_id': im_id,
  234. 'image_shape': np.array(
  235. [im_h, im_w], dtype=np.int32)
  236. }
  237. label_info = {
  238. 'is_crowd': is_crowd,
  239. 'gt_class': gt_class,
  240. 'gt_bbox': gt_bbox,
  241. 'gt_score': gt_score,
  242. 'difficult': difficult
  243. }
  244. if gt_bbox.size > 0:
  245. self.file_list.append({
  246. 'image': img_file,
  247. **
  248. im_info,
  249. **
  250. label_info
  251. })
  252. annotations['images'].append({
  253. 'height': im_h,
  254. 'width': im_w,
  255. 'id': int(im_id[0]),
  256. 'file_name': osp.split(img_file)[1]
  257. })
  258. else:
  259. neg_file_list.append({
  260. 'image': img_file,
  261. **
  262. im_info,
  263. **
  264. label_info
  265. })
  266. ct += 1
  267. if self.use_mix:
  268. self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
  269. else:
  270. self.num_max_boxes = max(self.num_max_boxes, len(objs))
  271. if not ct:
  272. logging.error("No voc record found in %s' % (file_list)", exit=True)
  273. self.pos_num = len(self.file_list)
  274. if self.allow_empty and neg_file_list:
  275. self.file_list += self._sample_empty(neg_file_list)
  276. logging.info(
  277. "{} samples in file {}, including {} positive samples and {} negative samples.".
  278. format(
  279. len(self.file_list), file_list, self.pos_num,
  280. len(self.file_list) - self.pos_num))
  281. self.num_samples = len(self.file_list)
  282. self.coco_gt = COCO()
  283. self.coco_gt.dataset = annotations
  284. self.coco_gt.createIndex()
  285. self._epoch = 0
  286. def __getitem__(self, idx):
  287. sample = copy.deepcopy(self.file_list[idx])
  288. if self.data_fields is not None:
  289. sample = {k: sample[k] for k in self.data_fields}
  290. if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
  291. self._epoch < self.mixup_op.mixup_epoch):
  292. if self.num_samples > 1:
  293. mix_idx = random.randint(1, self.num_samples - 1)
  294. mix_pos = (mix_idx + idx) % self.num_samples
  295. else:
  296. mix_pos = 0
  297. sample_mix = copy.deepcopy(self.file_list[mix_pos])
  298. if self.data_fields is not None:
  299. sample_mix = {k: sample_mix[k] for k in self.data_fields}
  300. sample = self.mixup_op(sample=[
  301. DecodeImg(to_rgb=False)(sample),
  302. DecodeImg(to_rgb=False)(sample_mix)
  303. ])
  304. sample = self.transforms(sample)
  305. return sample
  306. def __len__(self):
  307. return self.num_samples
  308. def set_epoch(self, epoch_id):
  309. self._epoch = epoch_id
  310. def cluster_yolo_anchor(self,
  311. num_anchors,
  312. image_size,
  313. cache=True,
  314. cache_path=None,
  315. iters=300,
  316. gen_iters=1000,
  317. thresh=.25):
  318. """
  319. Cluster YOLO anchors.
  320. Reference:
  321. https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
  322. Args:
  323. num_anchors (int): Number of clusters.
  324. image_size (list[int]|int): [h, w] or an int value that corresponds to the shape [image_size, image_size].
  325. cache (bool, optional): Whether to use cache. Defaults to True.
  326. cache_path (str|None, optional): Path of cache directory. If None, use `dataset.data_dir`.
  327. Defaults to None.
  328. iters (int, optional): Iterations of k-means algorithm. Defaults to 300.
  329. gen_iters (int, optional): Iterations of genetic algorithm. Defaults to 1000.
  330. thresh (float, optional): Anchor scale threshold. Defaults to 0.25.
  331. """
  332. if cache_path is None:
  333. cache_path = self.data_dir
  334. cluster = YOLOAnchorCluster(
  335. num_anchors=num_anchors,
  336. dataset=self,
  337. image_size=image_size,
  338. cache=cache,
  339. cache_path=cache_path,
  340. iters=iters,
  341. gen_iters=gen_iters,
  342. thresh=thresh)
  343. anchors = cluster()
  344. return anchors
  345. def add_negative_samples(self, image_dir, empty_ratio=1):
  346. """
  347. Generate and add negative samples.
  348. Args:
  349. image_dir (str): Directory that contains images.
  350. empty_ratio (float|None, optional): Ratio of negative samples. If `empty_ratio` is smaller than
  351. 0 or not less than 1, keep all generated negative samples. Defaults to 1.0.
  352. """
  353. import cv2
  354. if not osp.isdir(image_dir):
  355. raise ValueError("{} is not a valid image directory.".format(
  356. image_dir))
  357. if empty_ratio is not None:
  358. self.empty_ratio = empty_ratio
  359. image_list = os.listdir(image_dir)
  360. max_img_id = max(len(self.file_list) - 1, max(self.coco_gt.getImgIds()))
  361. neg_file_list = list()
  362. for image in image_list:
  363. if not is_pic(image):
  364. continue
  365. gt_bbox = np.zeros((0, 4), dtype=np.float32)
  366. gt_class = np.zeros((0, 1), dtype=np.int32)
  367. gt_score = np.zeros((0, 1), dtype=np.float32)
  368. is_crowd = np.zeros((0, 1), dtype=np.int32)
  369. difficult = np.zeros((0, 1), dtype=np.int32)
  370. max_img_id += 1
  371. im_fname = osp.join(image_dir, image)
  372. img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
  373. im_h, im_w, im_c = img_data.shape
  374. im_info = {
  375. 'im_id': np.asarray([max_img_id]),
  376. 'image_shape': np.array(
  377. [im_h, im_w], dtype=np.int32)
  378. }
  379. label_info = {
  380. 'is_crowd': is_crowd,
  381. 'gt_class': gt_class,
  382. 'gt_bbox': gt_bbox,
  383. 'gt_score': gt_score,
  384. 'difficult': difficult
  385. }
  386. if 'gt_poly' in self.file_list[0]:
  387. label_info['gt_poly'] = []
  388. neg_file_list.append({'image': im_fname, ** im_info, ** label_info})
  389. if neg_file_list:
  390. self.allow_empty = True
  391. self.file_list += self._sample_empty(neg_file_list)
  392. logging.info(
  393. "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
  394. format(
  395. len(self.file_list) - self.num_samples, self.pos_num,
  396. len(self.file_list) - self.pos_num))
  397. self.num_samples = len(self.file_list)
  398. def _sample_empty(self, neg_file_list):
  399. if 0. <= self.empty_ratio < 1.:
  400. import random
  401. total_num = len(self.file_list)
  402. neg_num = total_num - self.pos_num
  403. sample_num = min((total_num * self.empty_ratio - neg_num) //
  404. (1 - self.empty_ratio), len(neg_file_list))
  405. return random.sample(neg_file_list, sample_num)
  406. else:
  407. return neg_file_list