data_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import re
  16. import imghdr
  17. import platform
  18. from collections import OrderedDict
  19. from functools import partial, wraps
  20. import numpy as np
  21. __all__ = ['build_input_from_file']
  22. def norm_path(path):
  23. win_sep = "\\"
  24. other_sep = "/"
  25. if platform.system() == "Windows":
  26. path = win_sep.join(path.split(other_sep))
  27. else:
  28. path = other_sep.join(path.split(win_sep))
  29. return path
  30. def is_pic(im_path):
  31. valid_suffix = [
  32. 'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy'
  33. ]
  34. suffix = im_path.split('.')[-1]
  35. if suffix in valid_suffix:
  36. return True
  37. im_format = imghdr.what(im_path)
  38. _, ext = osp.splitext(im_path)
  39. if im_format == 'tiff' or ext == '.img':
  40. return True
  41. return False
  42. def get_full_path(p, prefix=''):
  43. p = norm_path(p)
  44. return osp.join(prefix, p)
  45. def silent(func):
  46. def _do_nothing(*args, **kwargs):
  47. pass
  48. @wraps(func)
  49. def _wrapper(*args, **kwargs):
  50. import builtins
  51. print = builtins.print
  52. builtins.print = _do_nothing
  53. ret = func(*args, **kwargs)
  54. builtins.print = print
  55. return ret
  56. return _wrapper
  57. class ConstrSample(object):
  58. def __init__(self, prefix, label_list):
  59. super().__init__()
  60. self.prefix = prefix
  61. self.label_list_obj = self.read_label_list(label_list)
  62. self.get_full_path = partial(get_full_path, prefix=self.prefix)
  63. def read_label_list(self, label_list):
  64. if label_list is None:
  65. return None
  66. cname2cid = OrderedDict()
  67. label_id = 0
  68. with open(label_list, 'r') as f:
  69. for line in f:
  70. cname2cid[line.strip()] = label_id
  71. label_id += 1
  72. return cname2cid
  73. def __call__(self, *parts):
  74. raise NotImplementedError
  75. class ConstrSegSample(ConstrSample):
  76. def __call__(self, im_path, mask_path):
  77. return {
  78. 'image': self.get_full_path(im_path),
  79. 'mask': self.get_full_path(mask_path)
  80. }
  81. class ConstrCdSample(ConstrSample):
  82. def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
  83. sample = {
  84. 'image_t1': self.get_full_path(im1_path),
  85. 'image_t2': self.get_full_path(im2_path),
  86. 'mask': self.get_full_path(mask_path)
  87. }
  88. if len(aux_mask_paths) > 0:
  89. sample['aux_masks'] = [
  90. self.get_full_path(p) for p in aux_mask_paths
  91. ]
  92. return sample
  93. class ConstrClasSample(ConstrSample):
  94. def __call__(self, im_path, label):
  95. return {'image': self.get_full_path(im_path), 'label': int(label)}
  96. class ConstrDetSample(ConstrSample):
  97. def __init__(self, prefix, label_list):
  98. super().__init__(prefix, label_list)
  99. self.ct = 0
  100. def __call__(self, im_path, ann_path):
  101. im_path = self.get_full_path(im_path)
  102. ann_path = self.get_full_path(ann_path)
  103. # TODO: Precisely recognize the annotation format
  104. if ann_path.endswith('.json'):
  105. im_dir = im_path
  106. return self._parse_coco_files(im_dir, ann_path)
  107. elif ann_path.endswith('.xml'):
  108. return self._parse_voc_files(im_path, ann_path)
  109. else:
  110. raise ValueError("Cannot recognize the annotation format")
  111. def _parse_voc_files(self, im_path, ann_path):
  112. import xml.etree.ElementTree as ET
  113. cname2cid = self.label_list_obj
  114. tree = ET.parse(ann_path)
  115. # The xml file must contain id.
  116. if tree.find('id') is None:
  117. im_id = np.asarray([self.ct])
  118. else:
  119. self.ct = int(tree.find('id').text)
  120. im_id = np.asarray([int(tree.find('id').text)])
  121. pattern = re.compile('<size>', re.IGNORECASE)
  122. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))
  123. if len(size_tag) > 0:
  124. size_tag = size_tag[0][1:-1]
  125. size_element = tree.find(size_tag)
  126. pattern = re.compile('<width>', re.IGNORECASE)
  127. width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  128. 1:-1]
  129. im_w = float(size_element.find(width_tag).text)
  130. pattern = re.compile('<height>', re.IGNORECASE)
  131. height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  132. 1:-1]
  133. im_h = float(size_element.find(height_tag).text)
  134. else:
  135. im_w = 0
  136. im_h = 0
  137. pattern = re.compile('<object>', re.IGNORECASE)
  138. obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
  139. if len(obj_match) > 0:
  140. obj_tag = obj_match[0][1:-1]
  141. objs = tree.findall(obj_tag)
  142. else:
  143. objs = list()
  144. num_bbox, i = len(objs), 0
  145. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  146. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  147. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  148. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  149. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  150. for obj in objs:
  151. pattern = re.compile('<name>', re.IGNORECASE)
  152. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  153. cname = obj.find(name_tag).text.strip()
  154. pattern = re.compile('<difficult>', re.IGNORECASE)
  155. diff_tag = pattern.findall(str(ET.tostringlist(obj)))
  156. if len(diff_tag) == 0:
  157. _difficult = 0
  158. else:
  159. diff_tag = diff_tag[0][1:-1]
  160. try:
  161. _difficult = int(obj.find(diff_tag).text)
  162. except Exception:
  163. _difficult = 0
  164. pattern = re.compile('<bndbox>', re.IGNORECASE)
  165. box_tag = pattern.findall(str(ET.tostringlist(obj)))
  166. if len(box_tag) == 0:
  167. continue
  168. box_tag = box_tag[0][1:-1]
  169. box_element = obj.find(box_tag)
  170. pattern = re.compile('<xmin>', re.IGNORECASE)
  171. xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  172. -1]
  173. x1 = float(box_element.find(xmin_tag).text)
  174. pattern = re.compile('<ymin>', re.IGNORECASE)
  175. ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  176. -1]
  177. y1 = float(box_element.find(ymin_tag).text)
  178. pattern = re.compile('<xmax>', re.IGNORECASE)
  179. xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  180. -1]
  181. x2 = float(box_element.find(xmax_tag).text)
  182. pattern = re.compile('<ymax>', re.IGNORECASE)
  183. ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  184. -1]
  185. y2 = float(box_element.find(ymax_tag).text)
  186. x1 = max(0, x1)
  187. y1 = max(0, y1)
  188. if im_w > 0.5 and im_h > 0.5:
  189. x2 = min(im_w - 1, x2)
  190. y2 = min(im_h - 1, y2)
  191. if not (x2 >= x1 and y2 >= y1):
  192. continue
  193. gt_bbox[i, :] = [x1, y1, x2, y2]
  194. gt_class[i, 0] = cname2cid[cname]
  195. gt_score[i, 0] = 1.
  196. is_crowd[i, 0] = 0
  197. difficult[i, 0] = _difficult
  198. i += 1
  199. gt_bbox = gt_bbox[:i, :]
  200. gt_class = gt_class[:i, :]
  201. gt_score = gt_score[:i, :]
  202. is_crowd = is_crowd[:i, :]
  203. difficult = difficult[:i, :]
  204. im_info = {
  205. 'im_id': im_id,
  206. 'image_shape': np.array(
  207. [im_h, im_w], dtype=np.int32)
  208. }
  209. label_info = {
  210. 'is_crowd': is_crowd,
  211. 'gt_class': gt_class,
  212. 'gt_bbox': gt_bbox,
  213. 'gt_score': gt_score,
  214. 'difficult': difficult
  215. }
  216. self.ct += 1
  217. return {'image': im_path, ** im_info, ** label_info}
  218. @silent
  219. def _parse_coco_files(self, im_dir, ann_path):
  220. from pycocotools.coco import COCO
  221. coco = COCO(ann_path)
  222. img_ids = coco.getImgIds()
  223. img_ids.sort()
  224. samples = []
  225. for img_id in img_ids:
  226. img_anno = coco.loadImgs([img_id])[0]
  227. im_fname = img_anno['file_name']
  228. im_w = float(img_anno['width'])
  229. im_h = float(img_anno['height'])
  230. im_path = osp.join(im_dir, im_fname) if im_dir else im_fname
  231. im_info = {
  232. 'image': im_path,
  233. 'im_id': np.array([img_id]),
  234. 'image_shape': np.array(
  235. [im_h, im_w], dtype=np.int32)
  236. }
  237. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  238. instances = coco.loadAnns(ins_anno_ids)
  239. is_crowds = []
  240. gt_classes = []
  241. gt_bboxs = []
  242. gt_scores = []
  243. difficults = []
  244. for inst in instances:
  245. # Check gt bbox
  246. if inst.get('ignore', False):
  247. continue
  248. if 'bbox' not in inst.keys():
  249. continue
  250. else:
  251. if not any(np.array(inst['bbox'])):
  252. continue
  253. # Read box
  254. x1, y1, box_w, box_h = inst['bbox']
  255. x2 = x1 + box_w
  256. y2 = y1 + box_h
  257. eps = 1e-5
  258. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  259. inst['clean_bbox'] = [
  260. round(float(x), 3) for x in [x1, y1, x2, y2]
  261. ]
  262. is_crowds.append([inst['iscrowd']])
  263. gt_classes.append([inst['category_id']])
  264. gt_bboxs.append(inst['clean_bbox'])
  265. gt_scores.append([1.])
  266. difficults.append([0])
  267. label_info = {
  268. 'is_crowd': np.array(is_crowds),
  269. 'gt_class': np.array(gt_classes),
  270. 'gt_bbox': np.array(gt_bboxs).astype(np.float32),
  271. 'gt_score': np.array(gt_scores).astype(np.float32),
  272. 'difficult': np.array(difficults),
  273. }
  274. samples.append({ ** im_info, ** label_info})
  275. return samples
  276. def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
  277. """
  278. Construct a list of dictionaries from file. Each dict in the list can be used as the input to `paddlers.transforms.Transform` objects.
  279. Args:
  280. file_list (str): Path of file_list.
  281. prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
  282. task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input.
  283. Default: 'auto'.
  284. label_list (str | None, optional): Path of label_list. Default: None.
  285. Returns:
  286. list: List of samples.
  287. """
  288. def _determine_task(parts):
  289. if len(parts) in (3, 5):
  290. task = 'cd'
  291. elif len(parts) == 2:
  292. if parts[1].isdigit():
  293. task = 'clas'
  294. elif is_pic(osp.join(prefix, parts[1])):
  295. task = 'seg'
  296. else:
  297. task = 'det'
  298. else:
  299. raise RuntimeError(
  300. "Cannot automatically determine the task type. Please specify `task` manually."
  301. )
  302. return task
  303. if task not in ('seg', 'det', 'cd', 'clas', 'auto'):
  304. raise ValueError("Invalid value of `task`")
  305. samples = []
  306. ctor = None
  307. with open(file_list, 'r') as f:
  308. for line in f:
  309. line = line.strip()
  310. parts = line.split()
  311. if task == 'auto':
  312. task = _determine_task(parts)
  313. if ctor is None:
  314. # Select and build sample constructor
  315. ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
  316. ctor = ctor_class(prefix, label_list)
  317. sample = ctor(*parts)
  318. if isinstance(sample, list):
  319. samples.extend(sample)
  320. else:
  321. samples.append(sample)
  322. return samples