data_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import re
  16. import imghdr
  17. import platform
  18. from collections import OrderedDict
  19. from functools import partial
  20. import numpy as np
  21. __all__ = ['build_input_from_file']
  22. def norm_path(path):
  23. win_sep = "\\"
  24. other_sep = "/"
  25. if platform.system() == "Windows":
  26. path = win_sep.join(path.split(other_sep))
  27. else:
  28. path = other_sep.join(path.split(win_sep))
  29. return path
  30. def is_pic(im_path):
  31. valid_suffix = [
  32. 'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy'
  33. ]
  34. suffix = im_path.split('.')[-1]
  35. if suffix in valid_suffix:
  36. return True
  37. im_format = imghdr.what(im_path)
  38. _, ext = osp.splitext(im_path)
  39. if im_format == 'tiff' or ext == '.img':
  40. return True
  41. return False
  42. def get_full_path(p, prefix=''):
  43. p = norm_path(p)
  44. return osp.join(prefix, p)
  45. class ConstrSample(object):
  46. def __init__(self, prefix, label_list):
  47. super().__init__()
  48. self.prefix = prefix
  49. self.label_list_obj = self.read_label_list(label_list)
  50. self.get_full_path = partial(get_full_path, prefix=self.prefix)
  51. def read_label_list(self, label_list):
  52. if label_list is None:
  53. return None
  54. cname2cid = OrderedDict()
  55. label_id = 0
  56. with open(label_list, 'r') as f:
  57. for line in f:
  58. cname2cid[line.strip()] = label_id
  59. label_id += 1
  60. return cname2cid
  61. def __call__(self, *parts):
  62. raise NotImplementedError
  63. class ConstrSegSample(ConstrSample):
  64. def __call__(self, im_path, mask_path):
  65. return {
  66. 'image': self.get_full_path(im_path),
  67. 'mask': self.get_full_path(mask_path)
  68. }
  69. class ConstrCdSample(ConstrSample):
  70. def __call__(self, im1_path, im2_path, mask_path, *aux_mask_paths):
  71. sample = {
  72. 'image_t1': self.get_full_path(im1_path),
  73. 'image_t2': self.get_full_path(im2_path),
  74. 'mask': self.get_full_path(mask_path)
  75. }
  76. if len(aux_mask_paths) > 0:
  77. sample['aux_masks'] = [
  78. self.get_full_path(p) for p in aux_mask_paths
  79. ]
  80. return sample
  81. class ConstrClasSample(ConstrSample):
  82. def __call__(self, im_path, label):
  83. return {'image': self.get_full_path(im_path), 'label': int(label)}
  84. class ConstrDetSample(ConstrSample):
  85. def __init__(self, prefix, label_list):
  86. super().__init__(prefix, label_list)
  87. self.ct = 0
  88. def __call__(self, im_path, ann_path):
  89. im_path = self.get_full_path(im_path)
  90. ann_path = self.get_full_path(ann_path)
  91. # TODO: Precisely recognize the annotation format
  92. if ann_path.endswith('.json'):
  93. im_dir = im_path
  94. return self._parse_coco_files(im_dir, ann_path)
  95. elif ann_path.endswith('.xml'):
  96. return self._parse_voc_files(im_path, ann_path)
  97. else:
  98. raise ValueError("Cannot recognize the annotation format")
  99. def _parse_voc_files(self, im_path, ann_path):
  100. import xml.etree.ElementTree as ET
  101. cname2cid = self.label_list_obj
  102. tree = ET.parse(ann_path)
  103. # The xml file must contain id.
  104. if tree.find('id') is None:
  105. im_id = np.asarray([self.ct])
  106. else:
  107. self.ct = int(tree.find('id').text)
  108. im_id = np.asarray([int(tree.find('id').text)])
  109. pattern = re.compile('<size>', re.IGNORECASE)
  110. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))
  111. if len(size_tag) > 0:
  112. size_tag = size_tag[0][1:-1]
  113. size_element = tree.find(size_tag)
  114. pattern = re.compile('<width>', re.IGNORECASE)
  115. width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  116. 1:-1]
  117. im_w = float(size_element.find(width_tag).text)
  118. pattern = re.compile('<height>', re.IGNORECASE)
  119. height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][
  120. 1:-1]
  121. im_h = float(size_element.find(height_tag).text)
  122. else:
  123. im_w = 0
  124. im_h = 0
  125. pattern = re.compile('<object>', re.IGNORECASE)
  126. obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
  127. if len(obj_match) > 0:
  128. obj_tag = obj_match[0][1:-1]
  129. objs = tree.findall(obj_tag)
  130. else:
  131. objs = list()
  132. num_bbox, i = len(objs), 0
  133. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  134. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  135. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  136. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  137. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  138. for obj in objs:
  139. pattern = re.compile('<name>', re.IGNORECASE)
  140. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  141. cname = obj.find(name_tag).text.strip()
  142. pattern = re.compile('<difficult>', re.IGNORECASE)
  143. diff_tag = pattern.findall(str(ET.tostringlist(obj)))
  144. if len(diff_tag) == 0:
  145. _difficult = 0
  146. else:
  147. diff_tag = diff_tag[0][1:-1]
  148. try:
  149. _difficult = int(obj.find(diff_tag).text)
  150. except Exception:
  151. _difficult = 0
  152. pattern = re.compile('<bndbox>', re.IGNORECASE)
  153. box_tag = pattern.findall(str(ET.tostringlist(obj)))
  154. if len(box_tag) == 0:
  155. continue
  156. box_tag = box_tag[0][1:-1]
  157. box_element = obj.find(box_tag)
  158. pattern = re.compile('<xmin>', re.IGNORECASE)
  159. xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  160. -1]
  161. x1 = float(box_element.find(xmin_tag).text)
  162. pattern = re.compile('<ymin>', re.IGNORECASE)
  163. ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  164. -1]
  165. y1 = float(box_element.find(ymin_tag).text)
  166. pattern = re.compile('<xmax>', re.IGNORECASE)
  167. xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  168. -1]
  169. x2 = float(box_element.find(xmax_tag).text)
  170. pattern = re.compile('<ymax>', re.IGNORECASE)
  171. ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:
  172. -1]
  173. y2 = float(box_element.find(ymax_tag).text)
  174. x1 = max(0, x1)
  175. y1 = max(0, y1)
  176. if im_w > 0.5 and im_h > 0.5:
  177. x2 = min(im_w - 1, x2)
  178. y2 = min(im_h - 1, y2)
  179. if not (x2 >= x1 and y2 >= y1):
  180. continue
  181. gt_bbox[i, :] = [x1, y1, x2, y2]
  182. gt_class[i, 0] = cname2cid[cname]
  183. gt_score[i, 0] = 1.
  184. is_crowd[i, 0] = 0
  185. difficult[i, 0] = _difficult
  186. i += 1
  187. gt_bbox = gt_bbox[:i, :]
  188. gt_class = gt_class[:i, :]
  189. gt_score = gt_score[:i, :]
  190. is_crowd = is_crowd[:i, :]
  191. difficult = difficult[:i, :]
  192. im_info = {
  193. 'im_id': im_id,
  194. 'image_shape': np.array(
  195. [im_h, im_w], dtype=np.int32)
  196. }
  197. label_info = {
  198. 'is_crowd': is_crowd,
  199. 'gt_class': gt_class,
  200. 'gt_bbox': gt_bbox,
  201. 'gt_score': gt_score,
  202. 'difficult': difficult
  203. }
  204. self.ct += 1
  205. return {'image': im_path, ** im_info, ** label_info}
  206. def _parse_coco_files(self, im_dir, ann_path):
  207. from pycocotools.coco import COCO
  208. coco = COCO(ann_path)
  209. img_ids = coco.getImgIds()
  210. img_ids.sort()
  211. samples = []
  212. for img_id in img_ids:
  213. img_anno = coco.loadImgs([img_id])[0]
  214. im_fname = img_anno['file_name']
  215. im_w = float(img_anno['width'])
  216. im_h = float(img_anno['height'])
  217. im_path = osp.join(im_dir, im_fname) if im_dir else im_fname
  218. im_info = {
  219. 'image': im_path,
  220. 'im_id': np.array([img_id]),
  221. 'image_shape': np.array(
  222. [im_h, im_w], dtype=np.int32)
  223. }
  224. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  225. instances = coco.loadAnns(ins_anno_ids)
  226. is_crowds = []
  227. gt_classes = []
  228. gt_bboxs = []
  229. gt_scores = []
  230. difficults = []
  231. for inst in instances:
  232. # Check gt bbox
  233. if inst.get('ignore', False):
  234. continue
  235. if 'bbox' not in inst.keys():
  236. continue
  237. else:
  238. if not any(np.array(inst['bbox'])):
  239. continue
  240. # Read box
  241. x1, y1, box_w, box_h = inst['bbox']
  242. x2 = x1 + box_w
  243. y2 = y1 + box_h
  244. eps = 1e-5
  245. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  246. inst['clean_bbox'] = [
  247. round(float(x), 3) for x in [x1, y1, x2, y2]
  248. ]
  249. is_crowds.append([inst['iscrowd']])
  250. gt_classes.append([inst['category_id']])
  251. gt_bboxs.append(inst['clean_bbox'])
  252. gt_scores.append([1.])
  253. difficults.append([0])
  254. label_info = {
  255. 'is_crowd': np.array(is_crowds),
  256. 'gt_class': np.array(gt_classes),
  257. 'gt_bbox': np.array(gt_bboxs).astype(np.float32),
  258. 'gt_score': np.array(gt_scores).astype(np.float32),
  259. 'difficult': np.array(difficults),
  260. }
  261. samples.append({ ** im_info, ** label_info})
  262. return samples
  263. def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
  264. """
  265. Construct a list of dictionaries from file. Each dict in the list can be used as the input to `paddlers.transforms.Transform` objects.
  266. Args:
  267. file_list (str): Path of file_list.
  268. prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
  269. task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input.
  270. Default: 'auto'.
  271. label_list (str | None, optional): Path of label_list. Default: None.
  272. Returns:
  273. list: List of samples.
  274. """
  275. def _determine_task(parts):
  276. if len(parts) in (3, 5):
  277. task = 'cd'
  278. elif len(parts) == 2:
  279. if parts[1].isdigit():
  280. task = 'clas'
  281. elif is_pic(osp.join(prefix, parts[1])):
  282. task = 'seg'
  283. else:
  284. task = 'det'
  285. else:
  286. raise RuntimeError(
  287. "Cannot automatically determine the task type. Please specify `task` manually."
  288. )
  289. return task
  290. if task not in ('seg', 'det', 'cd', 'clas', 'auto'):
  291. raise ValueError("Invalid value of `task`")
  292. samples = []
  293. ctor = None
  294. with open(file_list, 'r') as f:
  295. for line in f:
  296. line = line.strip()
  297. parts = line.split()
  298. if task == 'auto':
  299. task = _determine_task(parts)
  300. if ctor is None:
  301. # Select and build sample constructor
  302. ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
  303. ctor = ctor_class(prefix, label_list)
  304. sample = ctor(*parts)
  305. if isinstance(sample, list):
  306. samples.extend(sample)
  307. else:
  308. samples.append(sample)
  309. return samples