predictor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. from operator import itemgetter
  16. from functools import partial
  17. import numpy as np
  18. import paddle
  19. from paddle.inference import Config
  20. from paddle.inference import create_predictor
  21. from paddle.inference import PrecisionType
  22. from paddlers.tasks import load_model
  23. from paddlers.utils import logging, Timer
  24. from paddlers.tasks.utils.slider_predict import slider_predict
  25. class Predictor(object):
  26. def __init__(self,
  27. model_dir,
  28. use_gpu=False,
  29. gpu_id=0,
  30. cpu_thread_num=1,
  31. use_mkl=True,
  32. mkl_thread_num=4,
  33. use_trt=False,
  34. use_glog=False,
  35. memory_optimize=True,
  36. max_trt_batch_size=1,
  37. trt_precision_mode='float32'):
  38. """
  39. Args:
  40. model_dir (str): Path of the exported model.
  41. use_gpu (bool, optional): Whether to use a GPU. Defaults to False.
  42. gpu_id (int, optional): GPU ID. Defaults to 0.
  43. cpu_thread_num (int, optional): Number of threads to use when making predictions using CPUs.
  44. Defaults to 1.
  45. use_mkl (bool, optional): Whether to use MKL-DNN. Defaults to False.
  46. mkl_thread_num (int, optional): Number of MKL threads. Defaults to 4.
  47. use_trt (bool, optional): Whether to use TensorRT. Defaults to False.
  48. use_glog (bool, optional): Whether to enable glog logs. Defaults to False.
  49. memory_optimize (bool, optional): Whether to enable memory optimization. Defaults to True.
  50. max_trt_batch_size (int, optional): Maximum batch size when configured with TensorRT. Defaults to 1.
  51. trt_precision_mode (str, optional):Precision to use when configured with TensorRT. Possible values
  52. are {'float32', 'float16'}. Defaults to 'float32'.
  53. """
  54. self.model_dir = model_dir
  55. self._model = load_model(model_dir, with_net=False)
  56. if trt_precision_mode.lower() == 'float32':
  57. trt_precision_mode = PrecisionType.Float32
  58. elif trt_precision_mode.lower() == 'float16':
  59. trt_precision_mode = PrecisionType.Float16
  60. else:
  61. logging.error(
  62. "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
  63. .format(trt_precision_mode),
  64. exit=True)
  65. self.predictor = self.create_predictor(
  66. use_gpu=use_gpu,
  67. gpu_id=gpu_id,
  68. cpu_thread_num=cpu_thread_num,
  69. use_mkl=use_mkl,
  70. mkl_thread_num=mkl_thread_num,
  71. use_trt=use_trt,
  72. use_glog=use_glog,
  73. memory_optimize=memory_optimize,
  74. max_trt_batch_size=max_trt_batch_size,
  75. trt_precision_mode=trt_precision_mode)
  76. self.timer = Timer()
  77. def create_predictor(self,
  78. use_gpu=True,
  79. gpu_id=0,
  80. cpu_thread_num=1,
  81. use_mkl=True,
  82. mkl_thread_num=4,
  83. use_trt=False,
  84. use_glog=False,
  85. memory_optimize=True,
  86. max_trt_batch_size=1,
  87. trt_precision_mode=PrecisionType.Float32):
  88. config = Config(
  89. osp.join(self.model_dir, 'model.pdmodel'),
  90. osp.join(self.model_dir, 'model.pdiparams'))
  91. if use_gpu:
  92. # Set memory on GPUs (in MB) and device ID
  93. config.enable_use_gpu(200, gpu_id)
  94. config.switch_ir_optim(True)
  95. if use_trt:
  96. if self._model.model_type == 'segmenter':
  97. logging.warning(
  98. "Semantic segmentation models do not support TensorRT acceleration, "
  99. "TensorRT is forcibly disabled.")
  100. elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
  101. logging.warning(
  102. "RCNN models do not support TensorRT acceleration, "
  103. "TensorRT is forcibly disabled.")
  104. else:
  105. config.enable_tensorrt_engine(
  106. workspace_size=1 << 10,
  107. max_batch_size=max_trt_batch_size,
  108. min_subgraph_size=3,
  109. precision_mode=trt_precision_mode,
  110. use_static=False,
  111. use_calib_mode=False)
  112. else:
  113. config.disable_gpu()
  114. config.set_cpu_math_library_num_threads(cpu_thread_num)
  115. if use_mkl:
  116. if self._model.__class__.__name__ == 'MaskRCNN':
  117. logging.warning(
  118. "MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
  119. )
  120. else:
  121. try:
  122. # Cache 10 different shapes for mkldnn to avoid memory leak.
  123. config.set_mkldnn_cache_capacity(10)
  124. config.enable_mkldnn()
  125. config.set_cpu_math_library_num_threads(mkl_thread_num)
  126. except Exception as e:
  127. logging.warning(
  128. "The current environment does not support MKL-DNN, MKL-DNN is disabled."
  129. )
  130. pass
  131. if not use_glog:
  132. config.disable_glog_info()
  133. if memory_optimize:
  134. config.enable_memory_optim()
  135. config.switch_use_feed_fetch_ops(False)
  136. predictor = create_predictor(config)
  137. return predictor
  138. def preprocess(self, images, transforms):
  139. preprocessed_samples = self._model.preprocess(
  140. images, transforms, to_tensor=False)
  141. if self._model.model_type == 'classifier':
  142. preprocessed_samples = {'image': preprocessed_samples[0]}
  143. elif self._model.model_type == 'segmenter':
  144. preprocessed_samples = {
  145. 'image': preprocessed_samples[0],
  146. 'ori_shape': preprocessed_samples[1]
  147. }
  148. elif self._model.model_type == 'detector':
  149. pass
  150. elif self._model.model_type == 'change_detector':
  151. preprocessed_samples = {
  152. 'image': preprocessed_samples[0],
  153. 'image2': preprocessed_samples[1],
  154. 'ori_shape': preprocessed_samples[2]
  155. }
  156. elif self._model.model_type == 'restorer':
  157. preprocessed_samples = {
  158. 'image': preprocessed_samples[0],
  159. 'tar_shape': preprocessed_samples[1]
  160. }
  161. else:
  162. logging.error(
  163. "Invalid model type {}".format(self._model.model_type),
  164. exit=True)
  165. return preprocessed_samples
  166. def postprocess(self,
  167. net_outputs,
  168. topk=1,
  169. ori_shape=None,
  170. tar_shape=None,
  171. transforms=None):
  172. if self._model.model_type == 'classifier':
  173. true_topk = min(self._model.num_classes, topk)
  174. if self._model.postprocess is None:
  175. self._model.build_postprocess_from_labels(topk)
  176. # XXX: Convert ndarray to tensor as self._model.postprocess requires
  177. assert len(net_outputs) == 1
  178. net_outputs = paddle.to_tensor(net_outputs[0])
  179. outputs = self._model.postprocess(net_outputs)
  180. class_ids = map(itemgetter('class_ids'), outputs)
  181. scores = map(itemgetter('scores'), outputs)
  182. label_names = map(itemgetter('label_names'), outputs)
  183. preds = [{
  184. 'class_ids_map': l,
  185. 'scores_map': s,
  186. 'label_names_map': n,
  187. } for l, s, n in zip(class_ids, scores, label_names)]
  188. elif self._model.model_type in ('segmenter', 'change_detector'):
  189. label_map, score_map = self._model.postprocess(
  190. net_outputs,
  191. batch_origin_shape=ori_shape,
  192. transforms=transforms.transforms)
  193. preds = [{
  194. 'label_map': l,
  195. 'score_map': s
  196. } for l, s in zip(label_map, score_map)]
  197. elif self._model.model_type == 'detector':
  198. net_outputs = {
  199. k: v
  200. for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
  201. }
  202. preds = self._model.postprocess(net_outputs)
  203. elif self._model.model_type == 'restorer':
  204. res_maps = self._model.postprocess(
  205. net_outputs[0],
  206. batch_tar_shape=tar_shape,
  207. transforms=transforms.transforms)
  208. preds = [{'res_map': res_map} for res_map in res_maps]
  209. else:
  210. logging.error(
  211. "Invalid model type {}.".format(self._model.model_type),
  212. exit=True)
  213. return preds
  214. def raw_predict(self, inputs):
  215. """
  216. Predict according to preprocessed inputs.
  217. Args:
  218. inputs (dict): Preprocessed inputs.
  219. """
  220. input_names = self.predictor.get_input_names()
  221. for name in input_names:
  222. input_tensor = self.predictor.get_input_handle(name)
  223. input_tensor.copy_from_cpu(inputs[name])
  224. self.predictor.run()
  225. output_names = self.predictor.get_output_names()
  226. net_outputs = list()
  227. for name in output_names:
  228. output_tensor = self.predictor.get_output_handle(name)
  229. net_outputs.append(output_tensor.copy_to_cpu())
  230. return net_outputs
  231. def _run(self, images, topk=1, transforms=None):
  232. self.timer.preprocess_time_s.start()
  233. preprocessed_input = self.preprocess(images, transforms)
  234. self.timer.preprocess_time_s.end(iter_num=len(images))
  235. self.timer.inference_time_s.start()
  236. net_outputs = self.raw_predict(preprocessed_input)
  237. self.timer.inference_time_s.end(iter_num=1)
  238. self.timer.postprocess_time_s.start()
  239. results = self.postprocess(
  240. net_outputs,
  241. topk,
  242. ori_shape=preprocessed_input.get('ori_shape', None),
  243. tar_shape=preprocessed_input.get('tar_shape', None),
  244. transforms=transforms)
  245. self.timer.postprocess_time_s.end(iter_num=len(images))
  246. return results
  247. def predict(self,
  248. img_file,
  249. topk=1,
  250. transforms=None,
  251. warmup_iters=0,
  252. repeats=1,
  253. quiet=False):
  254. """
  255. Do inference.
  256. Args:
  257. img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
  258. object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
  259. a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
  260. paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change
  261. detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
  262. topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
  263. transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
  264. from `model.yml`. Defaults to None.
  265. warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
  266. repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
  267. 1, the reported time consumption is the average of all repeats. Defaults to 1.
  268. quiet (bool, optional): If True, do not display the timing information. Defaults to False.
  269. """
  270. if repeats < 1:
  271. logging.error("`repeats` must be greater than 1.", exit=True)
  272. if transforms is None and not hasattr(self._model, 'test_transforms'):
  273. raise ValueError("Transforms need to be defined, now is None.")
  274. if transforms is None:
  275. transforms = self._model.test_transforms
  276. if isinstance(img_file, tuple) and len(img_file) != 2:
  277. raise ValueError(
  278. f"A change detection model accepts exactly two input images, but there are {len(img_file)}."
  279. )
  280. if isinstance(img_file, (str, np.ndarray, tuple)):
  281. images = [img_file]
  282. else:
  283. images = img_file
  284. for _ in range(warmup_iters):
  285. self._run(images=images, topk=topk, transforms=transforms)
  286. self.timer.reset()
  287. for _ in range(repeats):
  288. results = self._run(images=images, topk=topk, transforms=transforms)
  289. self.timer.repeats = repeats
  290. self.timer.img_num = len(images)
  291. if not quiet:
  292. self.timer.info(average=True)
  293. if isinstance(img_file, (str, np.ndarray, tuple)):
  294. results = results[0]
  295. return results
  296. def slider_predict(self,
  297. img_file,
  298. save_dir,
  299. block_size,
  300. overlap=36,
  301. transforms=None,
  302. invalid_value=255,
  303. merge_strategy='keep_last',
  304. batch_size=1,
  305. quiet=False):
  306. """
  307. Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the
  308. sliding-predicting mode.
  309. Args:
  310. img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file`
  311. should be either the path of the image to predict, a decoded image (a np.ndarray, which should be
  312. consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)),
  313. or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of
  314. image paths, a tuple of decoded images, or a list of tuples.
  315. save_dir (str): Directory that contains saved geotiff file.
  316. block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in
  317. (W, H) format.
  318. overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple,
  319. it should be in (W, H) format. Defaults to 36.
  320. transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
  321. from `model.yml`. Defaults to None.
  322. invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255.
  323. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are
  324. {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and
  325. the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel
  326. according to accumulated probabilities. Defaults to 'keep_last'.
  327. batch_size (int, optional): Batch size used in inference. Defaults to 1.
  328. quiet (bool, optional): If True, disable the progress bar. Defaults to False.
  329. """
  330. slider_predict(
  331. partial(
  332. self.predict, quiet=True),
  333. img_file,
  334. save_dir,
  335. block_size,
  336. overlap,
  337. transforms,
  338. invalid_value,
  339. merge_strategy,
  340. batch_size,
  341. not quiet)
  342. def batch_predict(self, image_list, **params):
  343. return self.predict(img_file=image_list, **params)