infer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #!/usr/bin/env python
  2. import os
  3. import os.path as osp
  4. import argparse
  5. from operator import itemgetter
  6. import numpy as np
  7. import paddle
  8. from paddle.inference import Config
  9. from paddle.inference import create_predictor
  10. from paddle.inference import PrecisionType
  11. from paddlers.tasks import load_model
  12. from paddlers.utils import logging
  13. from config_utils import parse_configs
  14. class _bool(object):
  15. def __new__(cls, x):
  16. if isinstance(x, str):
  17. if x.lower() == 'false':
  18. return False
  19. elif x.lower() == 'true':
  20. return True
  21. return bool.__new__(x)
  22. class TIPCPredictor(object):
  23. def __init__(self,
  24. model_dir,
  25. device='cpu',
  26. gpu_id=0,
  27. cpu_thread_num=1,
  28. use_mkl=True,
  29. mkl_thread_num=4,
  30. use_trt=False,
  31. memory_optimize=True,
  32. trt_precision_mode='fp32',
  33. benchmark=False,
  34. model_name='',
  35. batch_size=1):
  36. self.model_dir = model_dir
  37. self._model = load_model(model_dir, with_net=False)
  38. if trt_precision_mode.lower() == 'fp32':
  39. trt_precision_mode = PrecisionType.Float32
  40. elif trt_precision_mode.lower() == 'fp16':
  41. trt_precision_mode = PrecisionType.Float16
  42. else:
  43. logging.error(
  44. "TensorRT precision mode {} is invalid. Supported modes are fp32 and fp16."
  45. .format(trt_precision_mode),
  46. exit=True)
  47. self.config = self.get_config(
  48. device=device,
  49. gpu_id=gpu_id,
  50. cpu_thread_num=cpu_thread_num,
  51. use_mkl=use_mkl,
  52. mkl_thread_num=mkl_thread_num,
  53. use_trt=use_trt,
  54. use_glog=False,
  55. memory_optimize=memory_optimize,
  56. max_trt_batch_size=1,
  57. trt_precision_mode=trt_precision_mode)
  58. self.predictor = create_predictor(self.config)
  59. self.batch_size = batch_size
  60. if benchmark:
  61. import auto_log
  62. pid = os.getpid()
  63. self.autolog = auto_log.AutoLogger(
  64. model_name=model_name,
  65. model_precision=trt_precision_mode,
  66. batch_size=batch_size,
  67. data_shape='dynamic',
  68. save_path=None,
  69. inference_config=self.config,
  70. pids=pid,
  71. process_name=None,
  72. gpu_ids=0,
  73. time_keys=[
  74. 'preprocess_time', 'inference_time', 'postprocess_time'
  75. ],
  76. warmup=0,
  77. logger=logging)
  78. self.benchmark = benchmark
  79. def get_config(self, device, gpu_id, cpu_thread_num, use_mkl,
  80. mkl_thread_num, use_trt, use_glog, memory_optimize,
  81. max_trt_batch_size, trt_precision_mode):
  82. config = Config(
  83. osp.join(self.model_dir, 'model.pdmodel'),
  84. osp.join(self.model_dir, 'model.pdiparams'))
  85. if device == 'gpu':
  86. config.enable_use_gpu(200, gpu_id)
  87. config.switch_ir_optim(True)
  88. if use_trt:
  89. if self._model.model_type == 'segmenter':
  90. logging.warning(
  91. "Semantic segmentation models do not support TensorRT acceleration, "
  92. "TensorRT is forcibly disabled.")
  93. elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
  94. logging.warning(
  95. "RCNN models do not support TensorRT acceleration, "
  96. "TensorRT is forcibly disabled.")
  97. else:
  98. config.enable_tensorrt_engine(
  99. workspace_size=1 << 10,
  100. max_batch_size=max_trt_batch_size,
  101. min_subgraph_size=3,
  102. precision_mode=trt_precision_mode,
  103. use_static=False,
  104. use_calib_mode=False)
  105. else:
  106. config.disable_gpu()
  107. config.set_cpu_math_library_num_threads(cpu_thread_num)
  108. if use_mkl:
  109. if self._model.__class__.__name__ == 'MaskRCNN':
  110. logging.warning(
  111. "MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
  112. )
  113. else:
  114. try:
  115. # Cache 10 different shapes for mkldnn to avoid memory leak.
  116. config.set_mkldnn_cache_capacity(10)
  117. config.enable_mkldnn()
  118. config.set_cpu_math_library_num_threads(mkl_thread_num)
  119. except Exception as e:
  120. logging.warning(
  121. "The current environment does not support MKL-DNN, MKL-DNN is disabled."
  122. )
  123. pass
  124. if not use_glog:
  125. config.disable_glog_info()
  126. if memory_optimize:
  127. config.enable_memory_optim()
  128. config.switch_use_feed_fetch_ops(False)
  129. return config
  130. def preprocess(self, images, transforms):
  131. preprocessed_samples, batch_trans_info = self._model.preprocess(
  132. images, transforms, to_tensor=False)
  133. if self._model.model_type == 'classifier':
  134. preprocessed_samples = {'image': preprocessed_samples}
  135. elif self._model.model_type == 'segmenter':
  136. preprocessed_samples = {'image': preprocessed_samples[0]}
  137. elif self._model.model_type == 'detector':
  138. pass
  139. elif self._model.model_type == 'change_detector':
  140. preprocessed_samples = {
  141. 'image': preprocessed_samples[0],
  142. 'image2': preprocessed_samples[1]
  143. }
  144. elif self._model.model_type == 'restorer':
  145. preprocessed_samples = {'image': preprocessed_samples[0]}
  146. else:
  147. logging.error(
  148. "Invalid model type {}".format(self.model_type), exit=True)
  149. return preprocessed_samples, batch_trans_info
  150. def postprocess(self, net_outputs, batch_restore_list, topk=1):
  151. if self._model.model_type == 'classifier':
  152. true_topk = min(self._model.num_classes, topk)
  153. if self._model.postprocess is None:
  154. self._model.build_postprocess_from_labels(topk)
  155. # XXX: Convert ndarray to tensor as `self._model.postprocess` requires
  156. assert len(net_outputs) == 1
  157. net_outputs = paddle.to_tensor(net_outputs[0])
  158. outputs = self._model.postprocess(net_outputs)
  159. class_ids = map(itemgetter('class_ids'), outputs)
  160. scores = map(itemgetter('scores'), outputs)
  161. label_names = map(itemgetter('label_names'), outputs)
  162. preds = [{
  163. 'class_ids_map': l,
  164. 'scores_map': s,
  165. 'label_names_map': n,
  166. } for l, s, n in zip(class_ids, scores, label_names)]
  167. elif self._model.model_type in ('segmenter', 'change_detector'):
  168. label_map, score_map = self._model.postprocess(
  169. net_outputs, batch_restore_list=batch_restore_list)
  170. preds = [{
  171. 'label_map': l,
  172. 'score_map': s
  173. } for l, s in zip(label_map, score_map)]
  174. elif self._model.model_type == 'detector':
  175. net_outputs = {
  176. k: v
  177. for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
  178. }
  179. preds = self._model.postprocess(net_outputs)
  180. elif self._model.model_type == 'restorer':
  181. res_maps = self._model.postprocess(
  182. net_outputs[0], batch_restore_list=batch_restore_list)
  183. preds = [{'res_map': res_map} for res_map in res_maps]
  184. else:
  185. logging.error(
  186. "Invalid model type {}.".format(self.model_type), exit=True)
  187. return preds
  188. def _run(self, images, topk=1, transforms=None, time_it=False):
  189. if self.benchmark and time_it:
  190. self.autolog.times.start()
  191. preprocessed_input, batch_trans_info = self.preprocess(images,
  192. transforms)
  193. input_names = self.predictor.get_input_names()
  194. for name in input_names:
  195. input_tensor = self.predictor.get_input_handle(name)
  196. input_tensor.copy_from_cpu(preprocessed_input[name])
  197. if self.benchmark and time_it:
  198. self.autolog.times.stamp()
  199. self.predictor.run()
  200. output_names = self.predictor.get_output_names()
  201. net_outputs = []
  202. for name in output_names:
  203. output_tensor = self.predictor.get_output_handle(name)
  204. net_outputs.append(output_tensor.copy_to_cpu())
  205. if self.benchmark and time_it:
  206. self.autolog.times.stamp()
  207. res = self.postprocess(
  208. net_outputs, batch_restore_list=batch_trans_info, topk=topk)
  209. if self.benchmark and time_it:
  210. self.autolog.times.end(stamp=True)
  211. return res
  212. def predict(self, data_dir, file_list, topk=1, warmup_iters=5):
  213. transforms = self._model.test_transforms
  214. # Warm up
  215. iters = 0
  216. while True:
  217. for images in self._parse_lines(data_dir, file_list):
  218. if iters >= warmup_iters:
  219. break
  220. self._run(
  221. images=images,
  222. topk=topk,
  223. transforms=transforms,
  224. time_it=False)
  225. iters += 1
  226. else:
  227. continue
  228. break
  229. results = []
  230. for images in self._parse_lines(data_dir, file_list):
  231. res = self._run(
  232. images=images, topk=topk, transforms=transforms, time_it=True)
  233. results.append(res)
  234. return results
  235. def _parse_lines(self, data_dir, file_list):
  236. with open(file_list, 'r') as f:
  237. batch = []
  238. for line in f:
  239. items = line.strip().split()
  240. items = [osp.join(data_dir, item) for item in items]
  241. if self._model.model_type == 'change_detector':
  242. batch.append((items[0], items[1]))
  243. else:
  244. batch.append(items[0])
  245. if len(batch) == self.batch_size:
  246. yield batch
  247. batch.clear()
  248. if 0 < len(batch) < self.batch_size:
  249. yield batch
  250. if __name__ == '__main__':
  251. parser = argparse.ArgumentParser()
  252. parser.add_argument('--config', type=str)
  253. parser.add_argument('--inherit_off', action='store_true')
  254. parser.add_argument('--model_dir', type=str, default='./')
  255. parser.add_argument(
  256. '--device', type=str, choices=['cpu', 'gpu'], default='cpu')
  257. parser.add_argument('--enable_mkldnn', type=_bool, default=False)
  258. parser.add_argument('--cpu_threads', type=int, default=10)
  259. parser.add_argument('--use_trt', type=_bool, default=False)
  260. parser.add_argument(
  261. '--precision', type=str, choices=['fp32', 'fp16'], default='fp16')
  262. parser.add_argument('--batch_size', type=int, default=1)
  263. parser.add_argument('--benchmark', type=_bool, default=False)
  264. parser.add_argument('--model_name', type=str, default='')
  265. args = parser.parse_args()
  266. cfg = parse_configs(args.config, not args.inherit_off)
  267. eval_dataset = cfg['datasets']['eval']
  268. data_dir = eval_dataset.args['data_dir']
  269. file_list = eval_dataset.args['file_list']
  270. predictor = TIPCPredictor(
  271. args.model_dir,
  272. device=args.device,
  273. cpu_thread_num=args.cpu_threads,
  274. use_mkl=args.enable_mkldnn,
  275. mkl_thread_num=args.cpu_threads,
  276. use_trt=args.use_trt,
  277. trt_precision_mode=args.precision,
  278. benchmark=args.benchmark)
  279. predictor.predict(data_dir, file_list)
  280. if args.benchmark:
  281. predictor.autolog.report()