test_slider_predict.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import tempfile
  16. import paddlers as pdrs
  17. import paddlers.transforms as T
  18. from testing_utils import CommonTest
  19. class TestSegSliderPredict(CommonTest):
  20. def setUp(self):
  21. self.model = pdrs.tasks.seg.UNet(in_channels=10)
  22. self.transforms = T.Compose([
  23. T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
  24. T.ArrangeSegmenter('test')
  25. ])
  26. self.image_path = "data/ssst/multispectral.tif"
  27. self.basename = osp.basename(self.image_path)
  28. def test_blocksize_and_overlap_whole(self):
  29. # Original image size (256, 256)
  30. with tempfile.TemporaryDirectory() as td:
  31. # Whole-image inference using predict()
  32. pred_whole = self.model.predict(self.image_path, self.transforms)
  33. pred_whole = pred_whole['label_map']
  34. # Whole-image inference using slider_predict()
  35. save_dir = osp.join(td, 'pred1')
  36. self.model.slider_predict(self.image_path, save_dir, 256, 0,
  37. self.transforms)
  38. pred1 = T.decode_image(
  39. osp.join(save_dir, self.basename),
  40. to_uint8=False,
  41. decode_sar=False)
  42. self.check_output_equal(pred1.shape, pred_whole.shape)
  43. # `block_size` == `overlap`
  44. save_dir = osp.join(td, 'pred2')
  45. with self.assertRaises(ValueError):
  46. self.model.slider_predict(self.image_path, save_dir, 128, 128,
  47. self.transforms)
  48. # `block_size` is a tuple
  49. save_dir = osp.join(td, 'pred3')
  50. self.model.slider_predict(self.image_path, save_dir, (128, 32), 0,
  51. self.transforms)
  52. pred3 = T.decode_image(
  53. osp.join(save_dir, self.basename),
  54. to_uint8=False,
  55. decode_sar=False)
  56. self.check_output_equal(pred3.shape, pred_whole.shape)
  57. # `block_size` and `overlap` are both tuples
  58. save_dir = osp.join(td, 'pred4')
  59. self.model.slider_predict(self.image_path, save_dir, (128, 100),
  60. (10, 5), self.transforms)
  61. pred4 = T.decode_image(
  62. osp.join(save_dir, self.basename),
  63. to_uint8=False,
  64. decode_sar=False)
  65. self.check_output_equal(pred4.shape, pred_whole.shape)
  66. # `block_size` larger than image size
  67. save_dir = osp.join(td, 'pred5')
  68. with self.assertRaises(ValueError):
  69. self.model.slider_predict(self.image_path, save_dir, 512, 0,
  70. self.transforms)
  71. def test_merge_strategy(self):
  72. with tempfile.TemporaryDirectory() as td:
  73. # Whole-image inference using predict()
  74. pred_whole = self.model.predict(self.image_path, self.transforms)
  75. pred_whole = pred_whole['label_map']
  76. # 'keep_first'
  77. save_dir = osp.join(td, 'keep_first')
  78. self.model.slider_predict(
  79. self.image_path,
  80. save_dir,
  81. 128,
  82. 64,
  83. self.transforms,
  84. merge_strategy='keep_first')
  85. pred_keepfirst = T.decode_image(
  86. osp.join(save_dir, self.basename),
  87. to_uint8=False,
  88. decode_sar=False)
  89. self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
  90. # 'keep_last'
  91. save_dir = osp.join(td, 'keep_last')
  92. self.model.slider_predict(
  93. self.image_path,
  94. save_dir,
  95. 128,
  96. 64,
  97. self.transforms,
  98. merge_strategy='keep_last')
  99. pred_keeplast = T.decode_image(
  100. osp.join(save_dir, self.basename),
  101. to_uint8=False,
  102. decode_sar=False)
  103. self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
  104. # 'vote'
  105. save_dir = osp.join(td, 'vote')
  106. self.model.slider_predict(
  107. self.image_path,
  108. save_dir,
  109. 128,
  110. 64,
  111. self.transforms,
  112. merge_strategy='vote')
  113. pred_vote = T.decode_image(
  114. osp.join(save_dir, self.basename),
  115. to_uint8=False,
  116. decode_sar=False)
  117. self.check_output_equal(pred_vote.shape, pred_whole.shape)
  118. # 'accum'
  119. save_dir = osp.join(td, 'accum')
  120. self.model.slider_predict(
  121. self.image_path,
  122. save_dir,
  123. 128,
  124. 64,
  125. self.transforms,
  126. merge_strategy='vote')
  127. pred_accum = T.decode_image(
  128. osp.join(save_dir, self.basename),
  129. to_uint8=False,
  130. decode_sar=False)
  131. self.check_output_equal(pred_accum.shape, pred_whole.shape)
  132. def test_geo_info(self):
  133. with tempfile.TemporaryDirectory() as td:
  134. _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True)
  135. self.model.slider_predict(self.image_path, td, 128, 0,
  136. self.transforms)
  137. _, geo_info_out = T.decode_image(
  138. osp.join(td, self.basename), read_geo_info=True)
  139. self.assertEqual(geo_info_out['geo_trans'],
  140. geo_info_in['geo_trans'])
  141. self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])
  142. class TestCDSliderPredict(CommonTest):
  143. def setUp(self):
  144. self.model = pdrs.tasks.cd.BIT(in_channels=10)
  145. self.transforms = T.Compose([
  146. T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
  147. T.ArrangeChangeDetector('test')
  148. ])
  149. self.image_paths = ("data/ssmt/multispectral_t1.tif",
  150. "data/ssmt/multispectral_t2.tif")
  151. self.basename = osp.basename(self.image_paths[0])
  152. def test_blocksize_and_overlap_whole(self):
  153. # Original image size (256, 256)
  154. with tempfile.TemporaryDirectory() as td:
  155. # Whole-image inference using predict()
  156. pred_whole = self.model.predict(self.image_paths, self.transforms)
  157. pred_whole = pred_whole['label_map']
  158. # Whole-image inference using slider_predict()
  159. save_dir = osp.join(td, 'pred1')
  160. self.model.slider_predict(self.image_paths, save_dir, 256, 0,
  161. self.transforms)
  162. pred1 = T.decode_image(
  163. osp.join(save_dir, self.basename),
  164. to_uint8=False,
  165. decode_sar=False)
  166. self.check_output_equal(pred1.shape, pred_whole.shape)
  167. # `block_size` == `overlap`
  168. save_dir = osp.join(td, 'pred2')
  169. with self.assertRaises(ValueError):
  170. self.model.slider_predict(self.image_paths, save_dir, 128, 128,
  171. self.transforms)
  172. # `block_size` is a tuple
  173. save_dir = osp.join(td, 'pred3')
  174. self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0,
  175. self.transforms)
  176. pred3 = T.decode_image(
  177. osp.join(save_dir, self.basename),
  178. to_uint8=False,
  179. decode_sar=False)
  180. self.check_output_equal(pred3.shape, pred_whole.shape)
  181. # `block_size` and `overlap` are both tuples
  182. save_dir = osp.join(td, 'pred4')
  183. self.model.slider_predict(self.image_paths, save_dir, (128, 100),
  184. (10, 5), self.transforms)
  185. pred4 = T.decode_image(
  186. osp.join(save_dir, self.basename),
  187. to_uint8=False,
  188. decode_sar=False)
  189. self.check_output_equal(pred4.shape, pred_whole.shape)
  190. # `block_size` larger than image size
  191. save_dir = osp.join(td, 'pred5')
  192. with self.assertRaises(ValueError):
  193. self.model.slider_predict(self.image_paths, save_dir, 512, 0,
  194. self.transforms)
  195. def test_merge_strategy(self):
  196. with tempfile.TemporaryDirectory() as td:
  197. # Whole-image inference using predict()
  198. pred_whole = self.model.predict(self.image_paths, self.transforms)
  199. pred_whole = pred_whole['label_map']
  200. # 'keep_first'
  201. save_dir = osp.join(td, 'keep_first')
  202. self.model.slider_predict(
  203. self.image_paths,
  204. save_dir,
  205. 128,
  206. 64,
  207. self.transforms,
  208. merge_strategy='keep_first')
  209. pred_keepfirst = T.decode_image(
  210. osp.join(save_dir, self.basename),
  211. to_uint8=False,
  212. decode_sar=False)
  213. self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
  214. # 'keep_last'
  215. save_dir = osp.join(td, 'keep_last')
  216. self.model.slider_predict(
  217. self.image_paths,
  218. save_dir,
  219. 128,
  220. 64,
  221. self.transforms,
  222. merge_strategy='keep_last')
  223. pred_keeplast = T.decode_image(
  224. osp.join(save_dir, self.basename),
  225. to_uint8=False,
  226. decode_sar=False)
  227. self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
  228. # 'vote'
  229. save_dir = osp.join(td, 'vote')
  230. self.model.slider_predict(
  231. self.image_paths,
  232. save_dir,
  233. 128,
  234. 64,
  235. self.transforms,
  236. merge_strategy='vote')
  237. pred_vote = T.decode_image(
  238. osp.join(save_dir, self.basename),
  239. to_uint8=False,
  240. decode_sar=False)
  241. self.check_output_equal(pred_vote.shape, pred_whole.shape)
  242. # 'accum'
  243. save_dir = osp.join(td, 'accum')
  244. self.model.slider_predict(
  245. self.image_paths,
  246. save_dir,
  247. 128,
  248. 64,
  249. self.transforms,
  250. merge_strategy='vote')
  251. pred_accum = T.decode_image(
  252. osp.join(save_dir, self.basename),
  253. to_uint8=False,
  254. decode_sar=False)
  255. self.check_output_equal(pred_accum.shape, pred_whole.shape)
  256. def test_geo_info(self):
  257. with tempfile.TemporaryDirectory() as td:
  258. _, geo_info_in = T.decode_image(
  259. self.image_paths[0], read_geo_info=True)
  260. self.model.slider_predict(self.image_paths, td, 128, 0,
  261. self.transforms)
  262. _, geo_info_out = T.decode_image(
  263. osp.join(td, self.basename), read_geo_info=True)
  264. self.assertEqual(geo_info_out['geo_trans'],
  265. geo_info_in['geo_trans'])
  266. self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])