data_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import re
  16. import platform
  17. from collections import OrderedDict
  18. from functools import partial, wraps
  19. import numpy as np
  20. __all__ = ['build_input_from_file']
  21. def norm_path(path):
  22. win_sep = "\\"
  23. other_sep = "/"
  24. if platform.system() == "Windows":
  25. path = win_sep.join(path.split(other_sep))
  26. else:
  27. path = other_sep.join(path.split(win_sep))
  28. return path
  29. def get_full_path(p, prefix=''):
  30. p = norm_path(p)
  31. return osp.join(prefix, p)
  32. def silent(func):
  33. def _do_nothing(*args, **kwargs):
  34. pass
  35. @wraps(func)
  36. def _wrapper(*args, **kwargs):
  37. import builtins
  38. print = builtins.print
  39. builtins.print = _do_nothing
  40. ret = func(*args, **kwargs)
  41. builtins.print = print
  42. return ret
  43. return _wrapper
  44. class ConstrSample(object):
  45. def __init__(self, prefix, label_list):
  46. super().__init__()
  47. self.prefix = prefix
  48. self.label_list_obj = self.read_label_list(label_list)
  49. self.get_full_path = partial(get_full_path, prefix=self.prefix)
  50. def read_label_list(self, label_list):
  51. if label_list is None:
  52. return None
  53. cname2cid = OrderedDict()
  54. label_id = 0
  55. with open(label_list, 'r') as f:
  56. for line in f:
  57. cname2cid[line.strip()] = label_id
  58. label_id += 1
  59. return cname2cid
  60. def __call__(self, *parts):
  61. raise NotImplementedError
  62. class ConstrSegSample(ConstrSample):
  63. def __call__(self, im_path, mask_path):
  64. return {
  65. 'image': self.get_full_path(im_path),
  66. 'mask': self.get_full_path(mask_path)
  67. }
  68. class ConstrCdSample(ConstrSample):
  69. def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
  70. sample = {
  71. 'image_t1': self.get_full_path(im1_path),
  72. 'image_t2': self.get_full_path(im2_path),
  73. 'mask': self.get_full_path(mask_path)
  74. }
  75. if len(aux_mask_paths) > 0:
  76. sample['aux_masks'] = [
  77. self.get_full_path(p) for p in aux_mask_paths
  78. ]
  79. return sample
  80. class ConstrClasSample(ConstrSample):
  81. def __call__(self, im_path, label):
  82. return {'image': self.get_full_path(im_path), 'label': int(label)}
  83. class ConstrDetSample(ConstrSample):
  84. def __init__(self, prefix, label_list):
  85. super().__init__(prefix, label_list)
  86. self.ct = 0
  87. def __call__(self, im_path, ann_path):
  88. im_path = self.get_full_path(im_path)
  89. ann_path = self.get_full_path(ann_path)
  90. # TODO: Precisely recognize the annotation format
  91. if ann_path.endswith('.json'):
  92. im_dir = im_path
  93. return self._parse_coco_files(im_dir, ann_path)
  94. elif ann_path.endswith('.xml'):
  95. return self._parse_voc_files(im_path, ann_path)
  96. else:
  97. raise ValueError("Cannot recognize the annotation format")
  98. def _parse_voc_files(self, im_path, ann_path):
  99. import xml.etree.ElementTree as ET
  100. cname2cid = self.label_list_obj
  101. tree = ET.parse(ann_path)
  102. # The xml file must contain id.
  103. if tree.find('id') is None:
  104. im_id = np.asarray([self.ct])
  105. else:
  106. self.ct = int(tree.find('id').text)
  107. im_id = np.asarray([int(tree.find('id').text)])
  108. pattern = re.compile('<size>', re.IGNORECASE)
  109. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))
  110. if len(size_tag) > 0:
  111. size_tag = size_tag[0][1:-1]
  112. size_element = tree.find(size_tag)
  113. pattern = re.compile('<width>', re.IGNORECASE)
  114. width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  115. 1:-1]
  116. im_w = float(size_element.find(width_tag).text)
  117. pattern = re.compile('<height>', re.IGNORECASE)
  118. height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  119. 1:-1]
  120. im_h = float(size_element.find(height_tag).text)
  121. else:
  122. im_w = 0
  123. im_h = 0
  124. pattern = re.compile('<object>', re.IGNORECASE)
  125. obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
  126. if len(obj_match) > 0:
  127. obj_tag = obj_match[0][1:-1]
  128. objs = tree.findall(obj_tag)
  129. else:
  130. objs = list()
  131. num_bbox, i = len(objs), 0
  132. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  133. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  134. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  135. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  136. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  137. for obj in objs:
  138. pattern = re.compile('<name>', re.IGNORECASE)
  139. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  140. cname = obj.find(name_tag).text.strip()
  141. pattern = re.compile('<difficult>', re.IGNORECASE)
  142. diff_tag = pattern.findall(str(ET.tostringlist(obj)))
  143. if len(diff_tag) == 0:
  144. _difficult = 0
  145. else:
  146. diff_tag = diff_tag[0][1:-1]
  147. try:
  148. _difficult = int(obj.find(diff_tag).text)
  149. except Exception:
  150. _difficult = 0
  151. pattern = re.compile('<bndbox>', re.IGNORECASE)
  152. box_tag = pattern.findall(str(ET.tostringlist(obj)))
  153. if len(box_tag) == 0:
  154. continue
  155. box_tag = box_tag[0][1:-1]
  156. box_element = obj.find(box_tag)
  157. pattern = re.compile('<xmin>', re.IGNORECASE)
  158. xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  159. -1]
  160. x1 = float(box_element.find(xmin_tag).text)
  161. pattern = re.compile('<ymin>', re.IGNORECASE)
  162. ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  163. -1]
  164. y1 = float(box_element.find(ymin_tag).text)
  165. pattern = re.compile('<xmax>', re.IGNORECASE)
  166. xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  167. -1]
  168. x2 = float(box_element.find(xmax_tag).text)
  169. pattern = re.compile('<ymax>', re.IGNORECASE)
  170. ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  171. -1]
  172. y2 = float(box_element.find(ymax_tag).text)
  173. x1 = max(0, x1)
  174. y1 = max(0, y1)
  175. if im_w > 0.5 and im_h > 0.5:
  176. x2 = min(im_w - 1, x2)
  177. y2 = min(im_h - 1, y2)
  178. if not (x2 >= x1 and y2 >= y1):
  179. continue
  180. gt_bbox[i, :] = [x1, y1, x2, y2]
  181. gt_class[i, 0] = cname2cid[cname]
  182. gt_score[i, 0] = 1.
  183. is_crowd[i, 0] = 0
  184. difficult[i, 0] = _difficult
  185. i += 1
  186. gt_bbox = gt_bbox[:i, :]
  187. gt_class = gt_class[:i, :]
  188. gt_score = gt_score[:i, :]
  189. is_crowd = is_crowd[:i, :]
  190. difficult = difficult[:i, :]
  191. im_info = {
  192. 'im_id': im_id,
  193. 'image_shape': np.array(
  194. [im_h, im_w], dtype=np.int32)
  195. }
  196. label_info = {
  197. 'is_crowd': is_crowd,
  198. 'gt_class': gt_class,
  199. 'gt_bbox': gt_bbox,
  200. 'gt_score': gt_score,
  201. 'difficult': difficult
  202. }
  203. self.ct += 1
  204. return {'image': im_path, ** im_info, ** label_info}
  205. @silent
  206. def _parse_coco_files(self, im_dir, ann_path):
  207. from pycocotools.coco import COCO
  208. coco = COCO(ann_path)
  209. img_ids = coco.getImgIds()
  210. img_ids.sort()
  211. samples = []
  212. for img_id in img_ids:
  213. img_anno = coco.loadImgs([img_id])[0]
  214. im_fname = img_anno['file_name']
  215. im_w = float(img_anno['width'])
  216. im_h = float(img_anno['height'])
  217. im_path = osp.join(im_dir, im_fname) if im_dir else im_fname
  218. im_info = {
  219. 'image': im_path,
  220. 'im_id': np.array([img_id]),
  221. 'image_shape': np.array(
  222. [im_h, im_w], dtype=np.int32)
  223. }
  224. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  225. instances = coco.loadAnns(ins_anno_ids)
  226. is_crowds = []
  227. gt_classes = []
  228. gt_bboxs = []
  229. gt_scores = []
  230. difficults = []
  231. for inst in instances:
  232. # Check gt bbox
  233. if inst.get('ignore', False):
  234. continue
  235. if 'bbox' not in inst.keys():
  236. continue
  237. else:
  238. if not any(np.array(inst['bbox'])):
  239. continue
  240. # Read box
  241. x1, y1, box_w, box_h = inst['bbox']
  242. x2 = x1 + box_w
  243. y2 = y1 + box_h
  244. eps = 1e-5
  245. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  246. inst['clean_bbox'] = [
  247. round(float(x), 3) for x in [x1, y1, x2, y2]
  248. ]
  249. is_crowds.append([inst['iscrowd']])
  250. gt_classes.append([inst['category_id']])
  251. gt_bboxs.append(inst['clean_bbox'])
  252. gt_scores.append([1.])
  253. difficults.append([0])
  254. label_info = {
  255. 'is_crowd': np.array(is_crowds),
  256. 'gt_class': np.array(gt_classes),
  257. 'gt_bbox': np.array(gt_bboxs).astype(np.float32),
  258. 'gt_score': np.array(gt_scores).astype(np.float32),
  259. 'difficult': np.array(difficults),
  260. }
  261. samples.append({ ** im_info, ** label_info})
  262. return samples
  263. class ConstrResSample(ConstrSample):
  264. def __init__(self, prefix, label_list, sr_factor=None):
  265. super().__init__(prefix, label_list)
  266. self.sr_factor = sr_factor
  267. def __call__(self, src_path, tar_path):
  268. sample = {
  269. 'image': self.get_full_path(src_path),
  270. 'target': self.get_full_path(tar_path)
  271. }
  272. if self.sr_factor is not None:
  273. sample['sr_factor'] = self.sr_factor
  274. return sample
  275. def build_input_from_file(file_list,
  276. prefix='',
  277. task='auto',
  278. label_list=None,
  279. **kwargs):
  280. """
  281. Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
  282. Args:
  283. file_list (str): Path of file list.
  284. prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
  285. task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', 'res', and 'auto'. When `task` is set to 'auto',
  286. automatically determine the task based on the input. Default: 'auto'.
  287. label_list (str|None, optional): Path of label_list. Default: None.
  288. Returns:
  289. list: List of samples.
  290. """
  291. def _determine_task(parts):
  292. task = 'unknown'
  293. if len(parts) in (3, 5):
  294. task = 'cd'
  295. elif len(parts) == 2:
  296. if parts[1].isdigit():
  297. task = 'clas'
  298. elif parts[1].endswith('.xml'):
  299. task = 'det'
  300. if task == 'unknown':
  301. raise RuntimeError(
  302. "Cannot automatically determine the task type. Please specify `task` manually."
  303. )
  304. return task
  305. if task not in ('seg', 'det', 'cd', 'clas', 'res', 'auto'):
  306. raise ValueError("Invalid value of `task`")
  307. samples = []
  308. ctor = None
  309. with open(file_list, 'r') as f:
  310. for line in f:
  311. line = line.strip()
  312. parts = line.split()
  313. if task == 'auto':
  314. task = _determine_task(parts)
  315. if ctor is None:
  316. ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
  317. ctor = ctor_class(prefix, label_list, **kwargs)
  318. sample = ctor(*parts)
  319. if isinstance(sample, list):
  320. samples.extend(sample)
  321. else:
  322. samples.append(sample)
  323. return samples