slider_predict.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import math
  17. from abc import ABCMeta, abstractmethod
  18. from collections import Counter, defaultdict
  19. import numpy as np
  20. from tqdm import tqdm
  21. import paddlers.utils.logging as logging
  22. class Cache(metaclass=ABCMeta):
  23. @abstractmethod
  24. def get_block(self, i_st, j_st, h, w):
  25. pass
  26. class SlowCache(Cache):
  27. def __init__(self):
  28. super(SlowCache, self).__init__()
  29. self.cache = defaultdict(Counter)
  30. def push_pixel(self, i, j, l):
  31. self.cache[(i, j)][l] += 1
  32. def push_block(self, i_st, j_st, h, w, data):
  33. for i in range(0, h):
  34. for j in range(0, w):
  35. self.push_pixel(i_st + i, j_st + j, data[i, j])
  36. def pop_pixel(self, i, j):
  37. self.cache.pop((i, j))
  38. def pop_block(self, i_st, j_st, h, w):
  39. for i in range(0, h):
  40. for j in range(0, w):
  41. self.pop_pixel(i_st + i, j_st + j)
  42. def get_pixel(self, i, j):
  43. winners = self.cache[(i, j)].most_common(1)
  44. winner = winners[0]
  45. return winner[0]
  46. def get_block(self, i_st, j_st, h, w):
  47. block = []
  48. for i in range(i_st, i_st + h):
  49. row = []
  50. for j in range(j_st, j_st + w):
  51. row.append(self.get_pixel(i, j))
  52. block.append(row)
  53. return np.asarray(block)
  54. class ProbCache(Cache):
  55. def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
  56. super(ProbCache, self).__init__()
  57. self.cache = None
  58. self.h = h
  59. self.w = w
  60. self.ch = ch
  61. self.cw = cw
  62. self.sh = sh
  63. self.sw = sw
  64. if not issubclass(dtype, np.floating):
  65. raise TypeError("`dtype` must be one of the floating types.")
  66. self.dtype = dtype
  67. order = order.lower()
  68. if order not in ('c', 'f'):
  69. raise ValueError("`order` other than 'c' and 'f' is not supported.")
  70. self.order = order
  71. def _alloc_memory(self, nc):
  72. if self.order == 'c':
  73. # Colomn-first order (C-style)
  74. #
  75. # <-- cw -->
  76. # |--------|---------------------|^ ^
  77. # | || | sh
  78. # |--------|---------------------|| ch v
  79. # | ||
  80. # |--------|---------------------|v
  81. # <------------ w --------------->
  82. self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
  83. elif self.order == 'f':
  84. # Row-first order (Fortran-style)
  85. #
  86. # <-- sw -->
  87. # <---- cw ---->
  88. # |--------|---|^ ^
  89. # | | || |
  90. # | | || ch
  91. # | | || |
  92. # |--------|---|| h v
  93. # | | ||
  94. # | | ||
  95. # | | ||
  96. # |--------|---|v
  97. self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
  98. def update_block(self, i_st, j_st, h, w, prob_map):
  99. if self.cache is None:
  100. nc = prob_map.shape[2]
  101. # Lazy allocation of memory
  102. self._alloc_memory(nc)
  103. self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
  104. def roll_cache(self, shift):
  105. if self.order == 'c':
  106. self.cache[:-shift] = self.cache[shift:]
  107. self.cache[-shift:, :] = 0
  108. elif self.order == 'f':
  109. self.cache[:, :-shift] = self.cache[:, shift:]
  110. self.cache[:, -shift:] = 0
  111. def get_block(self, i_st, j_st, h, w):
  112. return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
  113. class OverlapProcessor(metaclass=ABCMeta):
  114. def __init__(self, h, w, ch, cw, sh, sw):
  115. super(OverlapProcessor, self).__init__()
  116. self.h = h
  117. self.w = w
  118. self.ch = ch
  119. self.cw = cw
  120. self.sh = sh
  121. self.sw = sw
  122. @abstractmethod
  123. def process_pred(self, out, xoff, yoff):
  124. pass
  125. class KeepFirstProcessor(OverlapProcessor):
  126. def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
  127. super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
  128. self.ds = ds
  129. self.inval = inval
  130. def process_pred(self, out, xoff, yoff):
  131. pred = out['label_map']
  132. pred = pred[:self.ch, :self.cw]
  133. rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
  134. mask = rd_block != self.inval
  135. pred = np.where(mask, rd_block, pred)
  136. return pred
  137. class KeepLastProcessor(OverlapProcessor):
  138. def process_pred(self, out, xoff, yoff):
  139. pred = out['label_map']
  140. pred = pred[:self.ch, :self.cw]
  141. return pred
  142. class AccumProcessor(OverlapProcessor):
  143. def __init__(self,
  144. h,
  145. w,
  146. ch,
  147. cw,
  148. sh,
  149. sw,
  150. dtype=np.float16,
  151. assign_weight=True):
  152. super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
  153. self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
  154. self.prev_yoff = None
  155. self.assign_weight = assign_weight
  156. def process_pred(self, out, xoff, yoff):
  157. if self.prev_yoff is not None and yoff != self.prev_yoff:
  158. if yoff < self.prev_yoff:
  159. raise RuntimeError
  160. self.cache.roll_cache(yoff - self.prev_yoff)
  161. pred = out['label_map']
  162. pred = pred[:self.ch, :self.cw]
  163. prob = out['score_map']
  164. prob = prob[:self.ch, :self.cw]
  165. if self.assign_weight:
  166. prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
  167. self.cache.update_block(0, xoff, self.ch, self.cw, prob)
  168. pred = self.cache.get_block(0, xoff, self.ch, self.cw)
  169. self.prev_yoff = yoff
  170. return pred
  171. def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
  172. if not inplace:
  173. array = array.copy()
  174. h, w = array.shape[:2]
  175. hm, wm = int(h * border_ratio), int(w * border_ratio)
  176. array[:hm] *= weight
  177. array[-hm:] *= weight
  178. array[:, :wm] *= weight
  179. array[:, -wm:] *= weight
  180. return array
  181. def read_block(ds,
  182. xoff,
  183. yoff,
  184. xsize,
  185. ysize,
  186. tar_xsize=None,
  187. tar_ysize=None,
  188. pad_val=0):
  189. if tar_xsize is None:
  190. tar_xsize = xsize
  191. if tar_ysize is None:
  192. tar_ysize = ysize
  193. # Read data from dataset
  194. block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
  195. c, real_ysize, real_xsize = block.shape
  196. assert real_ysize == ysize and real_xsize == xsize
  197. # [c, h, w] -> [h, w, c]
  198. block = block.transpose((1, 2, 0))
  199. if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
  200. if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
  201. raise ValueError
  202. padded_block = np.full(
  203. (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
  204. # Fill
  205. padded_block[:real_ysize, :real_xsize] = block
  206. return padded_block
  207. else:
  208. return block
  209. def slider_predict(predict_func,
  210. img_file,
  211. save_dir,
  212. block_size,
  213. overlap,
  214. transforms,
  215. invalid_value,
  216. merge_strategy,
  217. batch_size,
  218. show_progress=False):
  219. """
  220. Do inference using sliding windows.
  221. Args:
  222. predict_func (callable): A callable object that makes the prediction.
  223. img_file (str|tuple[str]): Image path(s).
  224. save_dir (str): Directory that contains saved geotiff file.
  225. block_size (list[int] | tuple[int] | int):
  226. Size of block. If `block_size` is list or tuple, it should be in
  227. (W, H) format.
  228. overlap (list[int] | tuple[int] | int):
  229. Overlap between two blocks. If `overlap` is list or tuple, it should
  230. be in (W, H) format.
  231. transforms (paddlers.transforms.Compose|None): Transforms for inputs. If
  232. None, the transforms for evaluation process will be used.
  233. invalid_value (int): Value that marks invalid pixels in output image.
  234. Defaults to 255.
  235. merge_strategy (str): Strategy to merge overlapping blocks. Choices are
  236. {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
  237. means keeping the values of the first and the last block in
  238. traversal order, respectively. 'accum' means determining the class
  239. of an overlapping pixel according to accumulated probabilities.
  240. batch_size (int): Batch size used in inference.
  241. show_progress (bool, optional): Whether to show prediction progress with a
  242. progress bar. Defaults to True.
  243. """
  244. try:
  245. from osgeo import gdal
  246. except:
  247. import gdal
  248. if isinstance(block_size, int):
  249. block_size = (block_size, block_size)
  250. elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
  251. block_size = tuple(block_size)
  252. else:
  253. raise ValueError(
  254. "`block_size` must be a tuple/list of length 2 or an integer.")
  255. if isinstance(overlap, int):
  256. overlap = (overlap, overlap)
  257. elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
  258. overlap = tuple(overlap)
  259. else:
  260. raise ValueError(
  261. "`overlap` must be a tuple/list of length 2 or an integer.")
  262. step = np.array(
  263. block_size, dtype=np.int32) - np.array(
  264. overlap, dtype=np.int32)
  265. if step[0] == 0 or step[1] == 0:
  266. raise ValueError("`block_size` and `overlap` should not be equal.")
  267. if isinstance(img_file, tuple):
  268. if len(img_file) != 2:
  269. raise ValueError("Tuple `img_file` must have the length of two.")
  270. # Assume that two input images have the same size
  271. src_data = gdal.Open(img_file[0])
  272. src2_data = gdal.Open(img_file[1])
  273. # Output name is the same as the name of the first image
  274. file_name = osp.basename(osp.normpath(img_file[0]))
  275. else:
  276. src_data = gdal.Open(img_file)
  277. file_name = osp.basename(osp.normpath(img_file))
  278. # Get size of original raster
  279. width = src_data.RasterXSize
  280. height = src_data.RasterYSize
  281. bands = src_data.RasterCount
  282. # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
  283. # except for SAR images.
  284. if bands == 1:
  285. logging.warning(
  286. f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
  287. )
  288. if block_size[0] > width or block_size[1] > height:
  289. raise ValueError("`block_size` should not be larger than image size.")
  290. driver = gdal.GetDriverByName("GTiff")
  291. if not osp.exists(save_dir):
  292. os.makedirs(save_dir)
  293. # Replace extension name with '.tif'
  294. file_name = osp.splitext(file_name)[0] + ".tif"
  295. save_file = osp.join(save_dir, file_name)
  296. dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
  297. # Set meta-information
  298. dst_data.SetGeoTransform(src_data.GetGeoTransform())
  299. dst_data.SetProjection(src_data.GetProjection())
  300. # Initialize raster with `invalid_value`
  301. band = dst_data.GetRasterBand(1)
  302. band.WriteArray(
  303. np.full(
  304. (height, width), fill_value=invalid_value, dtype="uint8"))
  305. if overlap == (0, 0) or block_size == (width, height):
  306. # When there is no overlap or the whole image is used as input,
  307. # use 'keep_last' strategy as it introduces least overheads
  308. merge_strategy = 'keep_last'
  309. if merge_strategy == 'keep_first':
  310. overlap_processor = KeepFirstProcessor(
  311. height,
  312. width,
  313. *block_size[::-1],
  314. *step[::-1],
  315. band,
  316. inval=invalid_value)
  317. elif merge_strategy == 'keep_last':
  318. overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
  319. *step[::-1])
  320. elif merge_strategy == 'accum':
  321. overlap_processor = AccumProcessor(height, width, *block_size[::-1],
  322. *step[::-1])
  323. else:
  324. raise ValueError("{} is not a supported stragegy for block merging.".
  325. format(merge_strategy))
  326. xsize, ysize = block_size
  327. num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
  328. cnt = 0
  329. if show_progress:
  330. pb = tqdm(total=num_blocks)
  331. batch_data = []
  332. batch_offsets = []
  333. for yoff in range(0, height, step[1]):
  334. for xoff in range(0, width, step[0]):
  335. if xoff + xsize > width:
  336. xoff = width - xsize
  337. is_end_of_row = True
  338. else:
  339. is_end_of_row = False
  340. if yoff + ysize > height:
  341. yoff = height - ysize
  342. is_end_of_col = True
  343. else:
  344. is_end_of_col = False
  345. # Read
  346. im = read_block(src_data, xoff, yoff, xsize, ysize)
  347. if isinstance(img_file, tuple):
  348. im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
  349. batch_data.append((im, im2))
  350. else:
  351. batch_data.append(im)
  352. batch_offsets.append((xoff, yoff))
  353. len_batch = len(batch_data)
  354. if is_end_of_row and is_end_of_col and len_batch < batch_size:
  355. # Pad `batch_data` by repeating the last element
  356. batch_data = batch_data + [batch_data[-1]] * (batch_size -
  357. len_batch)
  358. # While keeping `len(batch_offsets)` the number of valid elements in the batch
  359. if len(batch_data) == batch_size:
  360. # Predict
  361. batch_out = predict_func(batch_data, transforms=transforms)
  362. for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
  363. # Get processed result
  364. pred = overlap_processor.process_pred(out, xoff_, yoff_)
  365. # Write to file
  366. band.WriteArray(pred, xoff_, yoff_)
  367. dst_data.FlushCache()
  368. batch_data.clear()
  369. batch_offsets.clear()
  370. cnt += 1
  371. if show_progress:
  372. pb.update(1)
  373. pb.set_description("{} out of {} blocks processed.".format(
  374. cnt, num_blocks))
  375. dst_data = None
  376. logging.info("GeoTiff file saved in {}.".format(save_file))