test_predictor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import tempfile
  16. import unittest.mock as mock
  17. import cv2
  18. import paddle
  19. import paddlers as pdrs
  20. from testing_utils import CommonTest, run_script
  21. class TestPredictor(CommonTest):
  22. MODULE = pdrs.tasks
  23. TRAINER_NAME_TO_EXPORT_OPTS = {}
  24. @staticmethod
  25. def add_tests(cls):
  26. """
  27. Automatically patch testing functions to cls.
  28. """
  29. def _test_predictor(trainer_name):
  30. def _test_predictor_impl(self):
  31. trainer_class = getattr(self.MODULE, trainer_name)
  32. # Construct trainer with default parameters
  33. trainer = trainer_class()
  34. with tempfile.TemporaryDirectory() as td:
  35. dynamic_model_dir = osp.join(td, "dynamic")
  36. static_model_dir = osp.join(td, "static")
  37. # HACK: BaseModel.save_model() requires BaseModel().optimizer to be set
  38. optimizer = mock.Mock()
  39. optimizer.state_dict.return_value = {'foo': 'bar'}
  40. trainer.optimizer = optimizer
  41. trainer.save_model(dynamic_model_dir)
  42. export_cmd = f"python export_model.py --model_dir {dynamic_model_dir} --save_dir {static_model_dir} "
  43. if trainer_name in self.TRAINER_NAME_TO_EXPORT_OPTS:
  44. export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
  45. trainer_name]
  46. elif '_default' in self.TRAINER_NAME_TO_EXPORT_OPTS:
  47. export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
  48. '_default']
  49. run_script(export_cmd, wd="../deploy/export")
  50. # Construct predictor
  51. # TODO: Test trt and mkl
  52. predictor = pdrs.deploy.Predictor(
  53. static_model_dir,
  54. use_gpu=paddle.device.get_device().startswith('gpu'))
  55. self.check_predictor(predictor, trainer)
  56. return _test_predictor_impl
  57. for trainer_name in cls.MODULE.__all__:
  58. setattr(cls, 'test_' + trainer_name, _test_predictor(trainer_name))
  59. return cls
  60. def check_predictor(self, predictor, trainer):
  61. raise NotImplementedError
  62. def check_dict_equal(self, dict_, expected_dict):
  63. if isinstance(dict_, list):
  64. self.assertIsInstance(expected_dict, list)
  65. self.assertEqual(len(dict_), len(expected_dict))
  66. for d1, d2 in zip(dict_, expected_dict):
  67. self.check_dict_equal(d1, d2)
  68. else:
  69. assert isinstance(dict_, dict)
  70. assert isinstance(expected_dict, dict)
  71. self.assertEqual(dict_.keys(), expected_dict.keys())
  72. for key in dict_.keys():
  73. self.check_output_equal(dict_[key], expected_dict[key])
  74. @TestPredictor.add_tests
  75. class TestCDPredictor(TestPredictor):
  76. MODULE = pdrs.tasks.change_detector
  77. TRAINER_NAME_TO_EXPORT_OPTS = {
  78. 'BIT': "--fixed_input_shape [1,3,256,256]",
  79. '_default': "--fixed_input_shape [-1,3,256,256]"
  80. }
  81. def check_predictor(self, predictor, trainer):
  82. t1_path = "data/ssmt/optical_t1.bmp"
  83. t2_path = "data/ssmt/optical_t2.bmp"
  84. single_input = (t1_path, t2_path)
  85. num_inputs = 2
  86. transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
  87. # Expected failure
  88. with self.assertRaises(ValueError):
  89. predictor.predict(t1_path, transforms=transforms)
  90. # Single input (file paths)
  91. input_ = single_input
  92. out_single_file_p = predictor.predict(input_, transforms=transforms)
  93. out_single_file_t = trainer.predict(input_, transforms=transforms)
  94. self.check_dict_equal(out_single_file_p, out_single_file_t)
  95. out_single_file_list_p = predictor.predict(
  96. [input_], transforms=transforms)
  97. self.assertEqual(len(out_single_file_list_p), 1)
  98. self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
  99. out_single_file_list_t = trainer.predict(
  100. [input_], transforms=transforms)
  101. self.check_dict_equal(out_single_file_list_p[0],
  102. out_single_file_list_t[0])
  103. # Single input (ndarrays)
  104. input_ = (
  105. cv2.imread(t1_path).astype('float32'),
  106. cv2.imread(t2_path).astype('float32')) # Reuse the name `input_`
  107. out_single_array_p = predictor.predict(input_, transforms=transforms)
  108. self.check_dict_equal(out_single_array_p, out_single_file_p)
  109. out_single_array_t = trainer.predict(input_, transforms=transforms)
  110. self.check_dict_equal(out_single_array_p, out_single_array_t)
  111. out_single_array_list_p = predictor.predict(
  112. [input_], transforms=transforms)
  113. self.assertEqual(len(out_single_array_list_p), 1)
  114. self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
  115. out_single_array_list_t = trainer.predict(
  116. [input_], transforms=transforms)
  117. self.check_dict_equal(out_single_array_list_p[0],
  118. out_single_array_list_t[0])
  119. if isinstance(trainer, pdrs.tasks.change_detector.BIT):
  120. return
  121. # Multiple inputs (file paths)
  122. input_ = [single_input] * num_inputs # Reuse the name `input_`
  123. out_multi_file_p = predictor.predict(input_, transforms=transforms)
  124. self.assertEqual(len(out_multi_file_p), num_inputs)
  125. out_multi_file_t = trainer.predict(input_, transforms=transforms)
  126. self.check_dict_equal(out_multi_file_p, out_multi_file_t)
  127. # Multiple inputs (ndarrays)
  128. input_ = [(cv2.imread(t1_path).astype('float32'), cv2.imread(t2_path)
  129. .astype('float32'))] * num_inputs # Reuse the name `input_`
  130. out_multi_array_p = predictor.predict(input_, transforms=transforms)
  131. self.assertEqual(len(out_multi_array_p), num_inputs)
  132. out_multi_array_t = trainer.predict(input_, transforms=transforms)
  133. self.check_dict_equal(out_multi_array_p, out_multi_array_t)
  134. @TestPredictor.add_tests
  135. class TestClasPredictor(TestPredictor):
  136. MODULE = pdrs.tasks.classifier
  137. TRAINER_NAME_TO_EXPORT_OPTS = {
  138. '_default': "--fixed_input_shape [-1,3,256,256]"
  139. }
  140. def check_predictor(self, predictor, trainer):
  141. single_input = "data/ssmt/optical_t1.bmp"
  142. num_inputs = 2
  143. transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
  144. labels = list(range(2))
  145. trainer.labels = labels
  146. predictor._model.labels = labels
  147. # Single input (file path)
  148. input_ = single_input
  149. out_single_file_p = predictor.predict(input_, transforms=transforms)
  150. out_single_file_t = trainer.predict(input_, transforms=transforms)
  151. self.check_dict_equal(out_single_file_p, out_single_file_t)
  152. out_single_file_list_p = predictor.predict(
  153. [input_], transforms=transforms)
  154. self.assertEqual(len(out_single_file_list_p), 1)
  155. self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
  156. out_single_file_list_t = trainer.predict(
  157. [input_], transforms=transforms)
  158. self.check_dict_equal(out_single_file_list_p[0],
  159. out_single_file_list_t[0])
  160. # Single input (ndarray)
  161. input_ = cv2.imread(single_input).astype(
  162. 'float32') # Reuse the name `input_`
  163. out_single_array_p = predictor.predict(input_, transforms=transforms)
  164. self.check_dict_equal(out_single_array_p, out_single_file_p)
  165. out_single_array_t = trainer.predict(input_, transforms=transforms)
  166. self.check_dict_equal(out_single_array_p, out_single_array_t)
  167. out_single_array_list_p = predictor.predict(
  168. [input_], transforms=transforms)
  169. self.assertEqual(len(out_single_array_list_p), 1)
  170. self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
  171. out_single_array_list_t = trainer.predict(
  172. [input_], transforms=transforms)
  173. self.check_dict_equal(out_single_array_list_p[0],
  174. out_single_array_list_t[0])
  175. # Multiple inputs (file paths)
  176. input_ = [single_input] * num_inputs # Reuse the name `input_`
  177. out_multi_file_p = predictor.predict(input_, transforms=transforms)
  178. self.assertEqual(len(out_multi_file_p), num_inputs)
  179. out_multi_file_t = trainer.predict(input_, transforms=transforms)
  180. self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
  181. self.check_dict_equal(out_multi_file_p, out_multi_file_t)
  182. # Multiple inputs (ndarrays)
  183. input_ = [cv2.imread(single_input).astype('float32')
  184. ] * num_inputs # Reuse the name `input_`
  185. out_multi_array_p = predictor.predict(input_, transforms=transforms)
  186. self.assertEqual(len(out_multi_array_p), num_inputs)
  187. out_multi_array_t = trainer.predict(input_, transforms=transforms)
  188. self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
  189. self.check_dict_equal(out_multi_array_p, out_multi_array_t)
  190. @TestPredictor.add_tests
  191. class TestDetPredictor(TestPredictor):
  192. MODULE = pdrs.tasks.object_detector
  193. TRAINER_NAME_TO_EXPORT_OPTS = {
  194. '_default': "--fixed_input_shape [-1,3,256,256]"
  195. }
  196. def check_predictor(self, predictor, trainer):
  197. single_input = "data/ssmt/optical_t1.bmp"
  198. num_inputs = 2
  199. transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
  200. labels = list(range(80))
  201. trainer.labels = labels
  202. predictor._model.labels = labels
  203. # Single input (file path)
  204. input_ = single_input
  205. out_single_file_p = predictor.predict(input_, transforms=transforms)
  206. out_single_file_t = trainer.predict(input_, transforms=transforms)
  207. self.check_dict_equal(out_single_file_p, out_single_file_t)
  208. out_single_file_list_p = predictor.predict(
  209. [input_], transforms=transforms)
  210. self.assertEqual(len(out_single_file_list_p), 1)
  211. self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
  212. out_single_file_list_t = trainer.predict(
  213. [input_], transforms=transforms)
  214. self.check_dict_equal(out_single_file_list_p[0],
  215. out_single_file_list_t[0])
  216. # Single input (ndarray)
  217. input_ = cv2.imread(single_input).astype(
  218. 'float32') # Reuse the name `input_`
  219. out_single_array_p = predictor.predict(input_, transforms=transforms)
  220. self.check_dict_equal(out_single_array_p, out_single_file_p)
  221. out_single_array_t = trainer.predict(input_, transforms=transforms)
  222. self.check_dict_equal(out_single_array_p, out_single_array_t)
  223. out_single_array_list_p = predictor.predict(
  224. [input_], transforms=transforms)
  225. self.assertEqual(len(out_single_array_list_p), 1)
  226. self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
  227. out_single_array_list_t = trainer.predict(
  228. [input_], transforms=transforms)
  229. self.check_dict_equal(out_single_array_list_p[0],
  230. out_single_array_list_t[0])
  231. # Multiple inputs (file paths)
  232. input_ = [single_input] * num_inputs # Reuse the name `input_`
  233. out_multi_file_p = predictor.predict(input_, transforms=transforms)
  234. self.assertEqual(len(out_multi_file_p), num_inputs)
  235. out_multi_file_t = trainer.predict(input_, transforms=transforms)
  236. self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
  237. self.check_dict_equal(out_multi_file_p, out_multi_file_t)
  238. # Multiple inputs (ndarrays)
  239. input_ = [cv2.imread(single_input).astype('float32')
  240. ] * num_inputs # Reuse the name `input_`
  241. out_multi_array_p = predictor.predict(input_, transforms=transforms)
  242. self.assertEqual(len(out_multi_array_p), num_inputs)
  243. out_multi_array_t = trainer.predict(input_, transforms=transforms)
  244. self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
  245. self.check_dict_equal(out_multi_array_p, out_multi_array_t)
  246. @TestPredictor.add_tests
  247. class TestSegPredictor(TestPredictor):
  248. MODULE = pdrs.tasks.segmenter
  249. TRAINER_NAME_TO_EXPORT_OPTS = {
  250. '_default': "--fixed_input_shape [-1,3,256,256]"
  251. }
  252. def check_predictor(self, predictor, trainer):
  253. single_input = "data/ssmt/optical_t1.bmp"
  254. num_inputs = 2
  255. transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
  256. # Single input (file path)
  257. input_ = single_input
  258. out_single_file_p = predictor.predict(input_, transforms=transforms)
  259. out_single_file_t = trainer.predict(input_, transforms=transforms)
  260. self.check_dict_equal(out_single_file_p, out_single_file_t)
  261. out_single_file_list_p = predictor.predict(
  262. [input_], transforms=transforms)
  263. self.assertEqual(len(out_single_file_list_p), 1)
  264. self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
  265. out_single_file_list_t = trainer.predict(
  266. [input_], transforms=transforms)
  267. self.check_dict_equal(out_single_file_list_p[0],
  268. out_single_file_list_t[0])
  269. # Single input (ndarray)
  270. input_ = cv2.imread(single_input).astype(
  271. 'float32') # Reuse the name `input_`
  272. out_single_array_p = predictor.predict(input_, transforms=transforms)
  273. self.check_dict_equal(out_single_array_p, out_single_file_p)
  274. out_single_array_t = trainer.predict(input_, transforms=transforms)
  275. self.check_dict_equal(out_single_array_p, out_single_array_t)
  276. out_single_array_list_p = predictor.predict(
  277. [input_], transforms=transforms)
  278. self.assertEqual(len(out_single_array_list_p), 1)
  279. self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
  280. out_single_array_list_t = trainer.predict(
  281. [input_], transforms=transforms)
  282. self.check_dict_equal(out_single_array_list_p[0],
  283. out_single_array_list_t[0])
  284. # Multiple inputs (file paths)
  285. input_ = [single_input] * num_inputs # Reuse the name `input_`
  286. out_multi_file_p = predictor.predict(input_, transforms=transforms)
  287. self.assertEqual(len(out_multi_file_p), num_inputs)
  288. out_multi_file_t = trainer.predict(input_, transforms=transforms)
  289. self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
  290. self.check_dict_equal(out_multi_file_p, out_multi_file_t)
  291. # Multiple inputs (ndarrays)
  292. input_ = [cv2.imread(single_input).astype('float32')
  293. ] * num_inputs # Reuse the name `input_`
  294. out_multi_array_p = predictor.predict(input_, transforms=transforms)
  295. self.assertEqual(len(out_multi_array_p), num_inputs)
  296. out_multi_array_t = trainer.predict(input_, transforms=transforms)
  297. self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
  298. self.check_dict_equal(out_multi_array_p, out_multi_array_t)