data_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import re
  16. import platform
  17. from collections import OrderedDict
  18. from functools import partial, wraps
  19. import numpy as np
  20. from paddlers.transforms import construct_sample
  21. __all__ = ['build_input_from_file']
  22. def norm_path(path):
  23. win_sep = "\\"
  24. other_sep = "/"
  25. if platform.system() == "Windows":
  26. path = win_sep.join(path.split(other_sep))
  27. else:
  28. path = other_sep.join(path.split(win_sep))
  29. return path
  30. def get_full_path(p, prefix=''):
  31. p = norm_path(p)
  32. return osp.join(prefix, p)
  33. def silent(func):
  34. def _do_nothing(*args, **kwargs):
  35. pass
  36. @wraps(func)
  37. def _wrapper(*args, **kwargs):
  38. import builtins
  39. print = builtins.print
  40. builtins.print = _do_nothing
  41. ret = func(*args, **kwargs)
  42. builtins.print = print
  43. return ret
  44. return _wrapper
  45. class ConstrSample(object):
  46. def __init__(self, prefix, label_list):
  47. super().__init__()
  48. self.prefix = prefix
  49. self.label_list_obj = self.read_label_list(label_list)
  50. self.get_full_path = partial(get_full_path, prefix=self.prefix)
  51. def read_label_list(self, label_list):
  52. if label_list is None:
  53. return None
  54. cname2cid = OrderedDict()
  55. label_id = 0
  56. with open(label_list, 'r') as f:
  57. for line in f:
  58. cname2cid[line.strip()] = label_id
  59. label_id += 1
  60. return cname2cid
  61. def __call__(self, *parts):
  62. raise NotImplementedError
  63. class ConstrSegSample(ConstrSample):
  64. def __call__(self, im_path, mask_path):
  65. return construct_sample(
  66. image=self.get_full_path(im_path),
  67. mask=self.get_full_path(mask_path))
  68. class ConstrCdSample(ConstrSample):
  69. def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
  70. sample = construct_sample(
  71. image_t1=self.get_full_path(im1_path),
  72. image_t2=self.get_full_path(im2_path),
  73. mask=self.get_full_path(mask_path))
  74. if len(aux_mask_paths) > 0:
  75. sample['aux_masks'] = [
  76. self.get_full_path(p) for p in aux_mask_paths
  77. ]
  78. return sample
  79. class ConstrClasSample(ConstrSample):
  80. def __call__(self, im_path, label):
  81. return construct_sample(
  82. image=self.get_full_path(im_path), label=int(label))
  83. class ConstrDetSample(ConstrSample):
  84. def __init__(self, prefix, label_list):
  85. super().__init__(prefix, label_list)
  86. self.ct = 0
  87. def __call__(self, im_path, ann_path):
  88. im_path = self.get_full_path(im_path)
  89. ann_path = self.get_full_path(ann_path)
  90. # TODO: Precisely recognize the annotation format
  91. if ann_path.endswith('.json'):
  92. im_dir = im_path
  93. return self._parse_coco_files(im_dir, ann_path)
  94. elif ann_path.endswith('.xml'):
  95. return self._parse_voc_files(im_path, ann_path)
  96. else:
  97. raise ValueError("Cannot recognize the annotation format")
  98. def _parse_voc_files(self, im_path, ann_path):
  99. import xml.etree.ElementTree as ET
  100. cname2cid = self.label_list_obj
  101. tree = ET.parse(ann_path)
  102. # The xml file must contain id.
  103. if tree.find('id') is None:
  104. im_id = np.asarray([self.ct])
  105. else:
  106. self.ct = int(tree.find('id').text)
  107. im_id = np.asarray([int(tree.find('id').text)])
  108. pattern = re.compile('<size>', re.IGNORECASE)
  109. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))
  110. if len(size_tag) > 0:
  111. size_tag = size_tag[0][1:-1]
  112. size_element = tree.find(size_tag)
  113. pattern = re.compile('<width>', re.IGNORECASE)
  114. width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  115. 1:-1]
  116. im_w = float(size_element.find(width_tag).text)
  117. pattern = re.compile('<height>', re.IGNORECASE)
  118. height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  119. 1:-1]
  120. im_h = float(size_element.find(height_tag).text)
  121. else:
  122. im_w = 0
  123. im_h = 0
  124. pattern = re.compile('<object>', re.IGNORECASE)
  125. obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
  126. if len(obj_match) > 0:
  127. obj_tag = obj_match[0][1:-1]
  128. objs = tree.findall(obj_tag)
  129. else:
  130. objs = list()
  131. num_bbox, i = len(objs), 0
  132. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  133. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  134. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  135. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  136. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  137. for obj in objs:
  138. pattern = re.compile('<name>', re.IGNORECASE)
  139. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  140. cname = obj.find(name_tag).text.strip()
  141. pattern = re.compile('<difficult>', re.IGNORECASE)
  142. diff_tag = pattern.findall(str(ET.tostringlist(obj)))
  143. if len(diff_tag) == 0:
  144. _difficult = 0
  145. else:
  146. diff_tag = diff_tag[0][1:-1]
  147. try:
  148. _difficult = int(obj.find(diff_tag).text)
  149. except Exception:
  150. _difficult = 0
  151. pattern = re.compile('<bndbox>', re.IGNORECASE)
  152. box_tag = pattern.findall(str(ET.tostringlist(obj)))
  153. if len(box_tag) == 0:
  154. continue
  155. box_tag = box_tag[0][1:-1]
  156. box_element = obj.find(box_tag)
  157. pattern = re.compile('<xmin>', re.IGNORECASE)
  158. xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  159. -1]
  160. x1 = float(box_element.find(xmin_tag).text)
  161. pattern = re.compile('<ymin>', re.IGNORECASE)
  162. ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  163. -1]
  164. y1 = float(box_element.find(ymin_tag).text)
  165. pattern = re.compile('<xmax>', re.IGNORECASE)
  166. xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  167. -1]
  168. x2 = float(box_element.find(xmax_tag).text)
  169. pattern = re.compile('<ymax>', re.IGNORECASE)
  170. ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  171. -1]
  172. y2 = float(box_element.find(ymax_tag).text)
  173. x1 = max(0, x1)
  174. y1 = max(0, y1)
  175. if im_w > 0.5 and im_h > 0.5:
  176. x2 = min(im_w - 1, x2)
  177. y2 = min(im_h - 1, y2)
  178. if not (x2 >= x1 and y2 >= y1):
  179. continue
  180. gt_bbox[i, :] = [x1, y1, x2, y2]
  181. gt_class[i, 0] = cname2cid[cname]
  182. gt_score[i, 0] = 1.
  183. is_crowd[i, 0] = 0
  184. difficult[i, 0] = _difficult
  185. i += 1
  186. gt_bbox = gt_bbox[:i, :]
  187. gt_class = gt_class[:i, :]
  188. gt_score = gt_score[:i, :]
  189. is_crowd = is_crowd[:i, :]
  190. difficult = difficult[:i, :]
  191. im_info = {
  192. 'im_id': im_id,
  193. 'image_shape': np.array(
  194. [im_h, im_w], dtype=np.int32)
  195. }
  196. label_info = {
  197. 'is_crowd': is_crowd,
  198. 'gt_class': gt_class,
  199. 'gt_bbox': gt_bbox,
  200. 'gt_score': gt_score,
  201. 'difficult': difficult
  202. }
  203. self.ct += 1
  204. return construct_sample(image=im_path, **im_info, **label_info)
  205. @silent
  206. def _parse_coco_files(self, im_dir, ann_path):
  207. from pycocotools.coco import COCO
  208. coco = COCO(ann_path)
  209. img_ids = coco.getImgIds()
  210. img_ids.sort()
  211. samples = []
  212. for img_id in img_ids:
  213. img_anno = coco.loadImgs([img_id])[0]
  214. im_fname = img_anno['file_name']
  215. im_w = float(img_anno['width'])
  216. im_h = float(img_anno['height'])
  217. im_path = osp.join(im_dir, im_fname) if im_dir else im_fname
  218. im_info = {
  219. 'image': im_path,
  220. 'im_id': np.array([img_id]),
  221. 'image_shape': np.array(
  222. [im_h, im_w], dtype=np.int32)
  223. }
  224. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  225. instances = coco.loadAnns(ins_anno_ids)
  226. is_crowds = []
  227. gt_classes = []
  228. gt_bboxs = []
  229. gt_scores = []
  230. difficults = []
  231. for inst in instances:
  232. # Check gt bbox
  233. if inst.get('ignore', False):
  234. continue
  235. if 'bbox' not in inst.keys():
  236. continue
  237. else:
  238. if not any(np.array(inst['bbox'])):
  239. continue
  240. # Read box
  241. x1, y1, box_w, box_h = inst['bbox']
  242. x2 = x1 + box_w
  243. y2 = y1 + box_h
  244. eps = 1e-5
  245. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  246. inst['clean_bbox'] = [
  247. round(float(x), 3) for x in [x1, y1, x2, y2]
  248. ]
  249. is_crowds.append([inst['iscrowd']])
  250. gt_classes.append([inst['category_id']])
  251. gt_bboxs.append(inst['clean_bbox'])
  252. gt_scores.append([1.])
  253. difficults.append([0])
  254. label_info = {
  255. 'is_crowd': np.array(is_crowds),
  256. 'gt_class': np.array(gt_classes),
  257. 'gt_bbox': np.array(gt_bboxs).astype(np.float32),
  258. 'gt_score': np.array(gt_scores).astype(np.float32),
  259. 'difficult': np.array(difficults),
  260. }
  261. samples.append(construct_sample(**im_info, **label_info))
  262. return samples
  263. class ConstrResSample(ConstrSample):
  264. def __init__(self, prefix, label_list, sr_factor=None):
  265. super().__init__(prefix, label_list)
  266. self.sr_factor = sr_factor
  267. def __call__(self, src_path, tar_path):
  268. sample = construct_sample(
  269. image=self.get_full_path(src_path),
  270. target=self.get_full_path(tar_path))
  271. if self.sr_factor is not None:
  272. sample['sr_factor'] = self.sr_factor
  273. return sample
  274. def build_input_from_file(file_list,
  275. prefix='',
  276. task='auto',
  277. label_list=None,
  278. **kwargs):
  279. """
  280. Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
  281. Args:
  282. file_list (str): Path of file list.
  283. prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
  284. task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', 'res', and 'auto'. When `task` is set to 'auto',
  285. automatically determine the task based on the input. Default: 'auto'.
  286. label_list (str|None, optional): Path of label_list. Default: None.
  287. Returns:
  288. list: List of samples.
  289. """
  290. def _determine_task(parts):
  291. task = 'unknown'
  292. if len(parts) in (3, 5):
  293. task = 'cd'
  294. elif len(parts) == 2:
  295. if parts[1].isdigit():
  296. task = 'clas'
  297. elif parts[1].endswith('.xml'):
  298. task = 'det'
  299. if task == 'unknown':
  300. raise RuntimeError(
  301. "Cannot automatically determine the task type. Please specify `task` manually."
  302. )
  303. return task
  304. if task not in ('seg', 'det', 'cd', 'clas', 'res', 'auto'):
  305. raise ValueError("Invalid value of `task`")
  306. samples = []
  307. ctor = None
  308. with open(file_list, 'r') as f:
  309. for line in f:
  310. line = line.strip()
  311. parts = line.split()
  312. if task == 'auto':
  313. task = _determine_task(parts)
  314. if ctor is None:
  315. ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
  316. ctor = ctor_class(prefix, label_list, **kwargs)
  317. sample = ctor(*parts)
  318. if isinstance(sample, list):
  319. samples.extend(sample)
  320. else:
  321. samples.append(sample)
  322. return samples