infer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. #!/usr/bin/env python
  2. import os
  3. import os.path as osp
  4. import argparse
  5. from operator import itemgetter
  6. import numpy as np
  7. import paddle
  8. from paddle.inference import Config
  9. from paddle.inference import create_predictor
  10. from paddle.inference import PrecisionType
  11. from paddlers.tasks import load_model
  12. from paddlers.utils import logging
  13. from config_utils import parse_configs
  14. class _bool(object):
  15. def __new__(cls, x):
  16. if isinstance(x, str):
  17. if x.lower() == 'false':
  18. return False
  19. elif x.lower() == 'true':
  20. return True
  21. return bool.__new__(x)
  22. class TIPCPredictor(object):
  23. def __init__(self,
  24. model_dir,
  25. device='cpu',
  26. gpu_id=0,
  27. cpu_thread_num=1,
  28. use_mkl=True,
  29. mkl_thread_num=4,
  30. use_trt=False,
  31. memory_optimize=True,
  32. trt_precision_mode='fp32',
  33. benchmark=False,
  34. model_name='',
  35. batch_size=1):
  36. self.model_dir = model_dir
  37. self._model = load_model(model_dir, with_net=False)
  38. if trt_precision_mode.lower() == 'fp32':
  39. trt_precision_mode = PrecisionType.Float32
  40. elif trt_precision_mode.lower() == 'fp16':
  41. trt_precision_mode = PrecisionType.Float16
  42. else:
  43. logging.error(
  44. "TensorRT precision mode {} is invalid. Supported modes are fp32 and fp16."
  45. .format(trt_precision_mode),
  46. exit=True)
  47. self.config = self.get_config(
  48. device=device,
  49. gpu_id=gpu_id,
  50. cpu_thread_num=cpu_thread_num,
  51. use_mkl=use_mkl,
  52. mkl_thread_num=mkl_thread_num,
  53. use_trt=use_trt,
  54. use_glog=False,
  55. memory_optimize=memory_optimize,
  56. max_trt_batch_size=1,
  57. trt_precision_mode=trt_precision_mode)
  58. self.predictor = create_predictor(self.config)
  59. self.batch_size = batch_size
  60. if benchmark:
  61. import auto_log
  62. pid = os.getpid()
  63. self.autolog = auto_log.AutoLogger(
  64. model_name=model_name,
  65. model_precision=trt_precision_mode,
  66. batch_size=batch_size,
  67. data_shape='dynamic',
  68. save_path=None,
  69. inference_config=self.config,
  70. pids=pid,
  71. process_name=None,
  72. gpu_ids=0,
  73. time_keys=[
  74. 'preprocess_time', 'inference_time', 'postprocess_time'
  75. ],
  76. warmup=0,
  77. logger=logging)
  78. self.benchmark = benchmark
  79. def get_config(self, device, gpu_id, cpu_thread_num, use_mkl,
  80. mkl_thread_num, use_trt, use_glog, memory_optimize,
  81. max_trt_batch_size, trt_precision_mode):
  82. config = Config(
  83. osp.join(self.model_dir, 'model.pdmodel'),
  84. osp.join(self.model_dir, 'model.pdiparams'))
  85. if device == 'gpu':
  86. config.enable_use_gpu(200, gpu_id)
  87. config.switch_ir_optim(True)
  88. if use_trt:
  89. if self._model.model_type == 'segmenter':
  90. logging.warning(
  91. "Semantic segmentation models do not support TensorRT acceleration, "
  92. "TensorRT is forcibly disabled.")
  93. elif 'RCNN' in self._model.__class__.__name__:
  94. logging.warning(
  95. "RCNN models do not support TensorRT acceleration, "
  96. "TensorRT is forcibly disabled.")
  97. else:
  98. config.enable_tensorrt_engine(
  99. workspace_size=1 << 10,
  100. max_batch_size=max_trt_batch_size,
  101. min_subgraph_size=3,
  102. precision_mode=trt_precision_mode,
  103. use_static=False,
  104. use_calib_mode=False)
  105. else:
  106. config.disable_gpu()
  107. config.set_cpu_math_library_num_threads(cpu_thread_num)
  108. if use_mkl:
  109. if self._model.__class__.__name__ == 'MaskRCNN':
  110. logging.warning(
  111. "MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
  112. )
  113. else:
  114. try:
  115. # Cache 10 different shapes for mkldnn to avoid memory leak
  116. config.set_mkldnn_cache_capacity(10)
  117. config.enable_mkldnn()
  118. config.set_cpu_math_library_num_threads(mkl_thread_num)
  119. except Exception as e:
  120. logging.warning(
  121. "The current environment does not support MKL-DNN, MKL-DNN is disabled."
  122. )
  123. pass
  124. if not use_glog:
  125. config.disable_glog_info()
  126. if memory_optimize:
  127. config.enable_memory_optim()
  128. config.switch_use_feed_fetch_ops(False)
  129. return config
  130. def preprocess(self, images, transforms):
  131. preprocessed_samples = self._model.preprocess(
  132. images, transforms, to_tensor=False)
  133. if self._model.model_type == 'classifier':
  134. preprocessed_samples = {'image': preprocessed_samples[0]}
  135. elif self._model.model_type == 'segmenter':
  136. preprocessed_samples = {
  137. 'image': preprocessed_samples[0],
  138. 'ori_shape': preprocessed_samples[1]
  139. }
  140. elif self._model.model_type == 'detector':
  141. pass
  142. elif self._model.model_type == 'change_detector':
  143. preprocessed_samples = {
  144. 'image': preprocessed_samples[0],
  145. 'image2': preprocessed_samples[1],
  146. 'ori_shape': preprocessed_samples[2]
  147. }
  148. else:
  149. logging.error(
  150. "Invalid model type {}".format(self._model.model_type),
  151. exit=True)
  152. return preprocessed_samples
  153. def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None):
  154. if self._model.model_type == 'classifier':
  155. true_topk = min(self._model.num_classes, topk)
  156. if self._model.postprocess is None:
  157. self._model.build_postprocess_from_labels(topk)
  158. # XXX: Convert ndarray to tensor as self._model.postprocess requires
  159. assert len(net_outputs) == 1
  160. net_outputs = paddle.to_tensor(net_outputs[0])
  161. outputs = self._model.postprocess(net_outputs)
  162. class_ids = map(itemgetter('class_ids'), outputs)
  163. scores = map(itemgetter('scores'), outputs)
  164. label_names = map(itemgetter('label_names'), outputs)
  165. preds = [{
  166. 'class_ids_map': l,
  167. 'scores_map': s,
  168. 'label_names_map': n,
  169. } for l, s, n in zip(class_ids, scores, label_names)]
  170. elif self._model.model_type in ('segmenter', 'change_detector'):
  171. label_map, score_map = self._model.postprocess(
  172. net_outputs,
  173. batch_origin_shape=ori_shape,
  174. transforms=transforms.transforms)
  175. preds = [{
  176. 'label_map': l,
  177. 'score_map': s
  178. } for l, s in zip(label_map, score_map)]
  179. elif self._model.model_type == 'detector':
  180. net_outputs = {
  181. k: v
  182. for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
  183. }
  184. preds = self._model.postprocess(net_outputs)
  185. else:
  186. logging.error(
  187. "Invalid model type {}.".format(self._model.model_type),
  188. exit=True)
  189. return preds
  190. def _run(self, images, topk=1, transforms=None, time_it=False):
  191. if self.benchmark and time_it:
  192. self.autolog.times.start()
  193. preprocessed_input = self.preprocess(images, transforms)
  194. input_names = self.predictor.get_input_names()
  195. for name in input_names:
  196. input_tensor = self.predictor.get_input_handle(name)
  197. input_tensor.copy_from_cpu(preprocessed_input[name])
  198. if self.benchmark and time_it:
  199. self.autolog.times.stamp()
  200. self.predictor.run()
  201. output_names = self.predictor.get_output_names()
  202. net_outputs = []
  203. for name in output_names:
  204. output_tensor = self.predictor.get_output_handle(name)
  205. net_outputs.append(output_tensor.copy_to_cpu())
  206. if self.benchmark and time_it:
  207. self.autolog.times.stamp()
  208. res = self.postprocess(
  209. net_outputs,
  210. topk,
  211. ori_shape=preprocessed_input.get('ori_shape', None),
  212. transforms=transforms)
  213. if self.benchmark and time_it:
  214. self.autolog.times.end(stamp=True)
  215. return res
  216. def predict(self, data_dir, file_list, topk=1, warmup_iters=5):
  217. transforms = self._model.test_transforms
  218. # Warm up
  219. iters = 0
  220. while True:
  221. for images in self._parse_lines(data_dir, file_list):
  222. if iters >= warmup_iters:
  223. break
  224. self._run(
  225. images=images,
  226. topk=topk,
  227. transforms=transforms,
  228. time_it=False)
  229. iters += 1
  230. else:
  231. continue
  232. break
  233. results = []
  234. for images in self._parse_lines(data_dir, file_list):
  235. res = self._run(
  236. images=images, topk=topk, transforms=transforms, time_it=True)
  237. results.append(res)
  238. return results
  239. def _parse_lines(self, data_dir, file_list):
  240. with open(file_list, 'r') as f:
  241. batch = []
  242. for line in f:
  243. items = line.strip().split()
  244. items = [osp.join(data_dir, item) for item in items]
  245. if self._model.model_type == 'change_detector':
  246. batch.append((items[0], items[1]))
  247. else:
  248. batch.append(items[0])
  249. if len(batch) == self.batch_size:
  250. yield batch
  251. batch.clear()
  252. if 0 < len(batch) < self.batch_size:
  253. yield batch
  254. if __name__ == '__main__':
  255. parser = argparse.ArgumentParser()
  256. parser.add_argument('--config', type=str)
  257. parser.add_argument('--inherit_off', action='store_true')
  258. parser.add_argument('--model_dir', type=str, default='./')
  259. parser.add_argument(
  260. '--device', type=str, choices=['cpu', 'gpu'], default='cpu')
  261. parser.add_argument('--enable_mkldnn', type=_bool, default=False)
  262. parser.add_argument('--cpu_threads', type=int, default=10)
  263. parser.add_argument('--use_trt', type=_bool, default=False)
  264. parser.add_argument(
  265. '--precision', type=str, choices=['fp32', 'fp16'], default='fp16')
  266. parser.add_argument('--batch_size', type=int, default=1)
  267. parser.add_argument('--benchmark', type=_bool, default=False)
  268. parser.add_argument('--model_name', type=str, default='')
  269. args = parser.parse_args()
  270. cfg = parse_configs(args.config, not args.inherit_off)
  271. eval_dataset = cfg['datasets']['eval']
  272. data_dir = eval_dataset.args['data_dir']
  273. file_list = eval_dataset.args['file_list']
  274. predictor = TIPCPredictor(
  275. args.model_dir,
  276. device=args.device,
  277. cpu_thread_num=args.cpu_threads,
  278. use_mkl=args.enable_mkldnn,
  279. mkl_thread_num=args.cpu_threads,
  280. use_trt=args.use_trt,
  281. trt_precision_mode=args.precision,
  282. benchmark=args.benchmark)
  283. predictor.predict(data_dir, file_list)
  284. if args.benchmark:
  285. predictor.autolog.report()