extract_ms_patches.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import argparse
  17. from collections import deque
  18. from functools import reduce
  19. import paddlers
  20. import numpy as np
  21. try:
  22. from osgeo import gdal
  23. except:
  24. import gdal
  25. from tqdm import tqdm
  26. from utils import time_it
  27. IGN_CLS = 255
  28. FMT = "im_{idx}{ext}"
  29. class QuadTreeNode(object):
  30. def __init__(self, i, j, h, w, level, cls_info=None):
  31. super().__init__()
  32. self.i = i
  33. self.j = j
  34. self.h = h
  35. self.w = w
  36. self.level = level
  37. self.cls_info = cls_info
  38. self.reset_children()
  39. @property
  40. def area(self):
  41. return self.h * self.w
  42. @property
  43. def is_bg_node(self):
  44. return self.cls_info is None
  45. @property
  46. def coords(self):
  47. return (self.i, self.j, self.h, self.w)
  48. def get_cls_cnt(self, cls):
  49. if self.cls_info is None or cls >= len(self.cls_info):
  50. return 0
  51. return self.cls_info[cls]
  52. def get_children(self):
  53. for child in self.children:
  54. if child is not None:
  55. yield child
  56. def reset_children(self):
  57. self.children = [None, None, None, None]
  58. def __repr__(self):
  59. return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
  60. class QuadTree(object):
  61. def __init__(self, min_blk_size=256):
  62. super().__init__()
  63. self.min_blk_size = min_blk_size
  64. self.h = None
  65. self.w = None
  66. self.root = None
  67. def build_tree(self, mask_band, bg_cls=0):
  68. cls_info_table = self.preprocess(mask_band, bg_cls)
  69. n_rows = len(cls_info_table)
  70. if n_rows == 0:
  71. return None
  72. n_cols = len(cls_info_table[0])
  73. self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
  74. n_cols - 1, 0)
  75. return self.root
  76. def preprocess(self, mask_ds, bg_cls):
  77. h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
  78. s = self.min_blk_size
  79. if s >= h or s >= w:
  80. raise ValueError("`min_blk_size` must be smaller than image size.")
  81. cls_info_table = []
  82. for i in range(0, h, s):
  83. cls_info_row = []
  84. for j in range(0, w, s):
  85. if i + s > h:
  86. ch = h - i
  87. else:
  88. ch = s
  89. if j + s > w:
  90. cw = w - j
  91. else:
  92. cw = s
  93. arr = mask_ds.ReadAsArray(j, i, cw, ch)
  94. bins = np.bincount(arr.ravel())
  95. if len(bins) > IGN_CLS:
  96. bins = np.delete(bins, IGN_CLS)
  97. if bins.sum() == bins[bg_cls]:
  98. cls_info_row.append(None)
  99. else:
  100. cls_info_row.append(bins)
  101. cls_info_table.append(cls_info_row)
  102. return cls_info_table
  103. def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
  104. if i_ed < i_st or j_ed < j_st:
  105. return None
  106. i = i_st * self.min_blk_size
  107. j = j_st * self.min_blk_size
  108. h = (i_ed - i_st + 1) * self.min_blk_size
  109. w = (j_ed - j_st + 1) * self.min_blk_size
  110. if i_ed == i_st and j_ed == j_st:
  111. return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
  112. i_mid = (i_ed + i_st) // 2
  113. j_mid = (j_ed + j_st) // 2
  114. root = QuadTreeNode(i, j, h, w, level)
  115. root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
  116. j_mid, level + 1)
  117. root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
  118. j_mid + 1, j_ed, level + 1)
  119. root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
  120. j_st, j_mid, level + 1)
  121. root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
  122. j_mid + 1, j_ed, level + 1)
  123. bins_list = [
  124. node.cls_info for node in root.get_children()
  125. if node.cls_info is not None
  126. ]
  127. if len(bins_list) > 0:
  128. merged_bins = reduce(merge_bins, bins_list)
  129. root.cls_info = merged_bins
  130. else:
  131. # Merge nodes
  132. root.reset_children()
  133. return root
  134. def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
  135. nodes = []
  136. q = deque()
  137. q.append(self.root)
  138. while q:
  139. node = q.popleft()
  140. if max_level is None or node.level < max_level:
  141. for child in node.get_children():
  142. if not include_bg and child.is_bg_node:
  143. continue
  144. if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
  145. continue
  146. nodes.append(child)
  147. q.append(child)
  148. return nodes
  149. def print_tree(self, node=None, level=0):
  150. if node is None:
  151. node = self.root
  152. print(' ' * level + '-', node)
  153. for child in node.get_children():
  154. self.print_tree(child, level + 1)
  155. def merge_bins(bins1, bins2):
  156. if len(bins1) < len(bins2):
  157. return merge_bins(bins2, bins1)
  158. elif len(bins1) == len(bins2):
  159. return bins1 + bins2
  160. else:
  161. return bins1 + np.concatenate(
  162. [bins2, np.zeros(len(bins1) - len(bins2))])
  163. @time_it
  164. def extract_ms_patches(im_paths,
  165. mask_path,
  166. save_dir,
  167. min_patch_size=256,
  168. bg_class=0,
  169. target_class=None,
  170. max_level=None,
  171. include_bg=False,
  172. nonzero_ratio=None):
  173. def _save_patch(src_path, i, j, h, w, subdir=None):
  174. src_path = osp.normpath(src_path)
  175. src_name, src_ext = osp.splitext(osp.basename(src_path))
  176. subdir = subdir if subdir is not None else src_name
  177. dst_dir = osp.join(save_dir, subdir)
  178. if not osp.exists(dst_dir):
  179. os.makedirs(dst_dir)
  180. dst_name = FMT.format(idx=idx, ext=src_ext)
  181. dst_path = osp.join(dst_dir, dst_name)
  182. gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
  183. return dst_path
  184. if nonzero_ratio is not None:
  185. print(
  186. "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
  187. )
  188. mask_ds = gdal.Open(mask_path)
  189. quad_tree = QuadTree(min_blk_size=min_patch_size)
  190. if mask_ds.RasterCount != 1:
  191. raise ValueError("The mask image has more than 1 band.")
  192. print("Start building quad tree...")
  193. quad_tree.build_tree(mask_ds, bg_class)
  194. print("Quad tree has been built. Now start collecting nodes...")
  195. nodes = quad_tree.get_nodes(
  196. tar_cls=target_class, max_level=max_level, include_bg=include_bg)
  197. print("Nodes collected. Saving patches...")
  198. for idx, node in enumerate(tqdm(nodes)):
  199. i, j, h, w = node.coords
  200. h = min(h, mask_ds.RasterYSize - i)
  201. w = min(w, mask_ds.RasterXSize - j)
  202. is_valid = True
  203. if nonzero_ratio is not None:
  204. for src_path in im_paths:
  205. im_ds = gdal.Open(src_path)
  206. arr = im_ds.ReadAsArray(j, i, w, h)
  207. if np.count_nonzero(arr) / arr.size < nonzero_ratio:
  208. is_valid = False
  209. break
  210. if is_valid:
  211. for src_path in im_paths:
  212. _save_patch(src_path, i, j, h, w)
  213. _save_patch(mask_path, i, j, h, w, 'mask')
  214. if __name__ == '__main__':
  215. parser = argparse.ArgumentParser()
  216. parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
  217. help="Path of images. Different images must have unique file names.")
  218. parser.add_argument("--mask_path", type=str, required=True, \
  219. help="Path of mask.")
  220. parser.add_argument("--save_dir", type=str, required=True, \
  221. help="Path to save the extracted patches.")
  222. parser.add_argument("--min_patch_size", type=int, default=256, \
  223. help="Minimum patch size (height and width).")
  224. parser.add_argument("--bg_class", type=int, default=0, \
  225. help="Index of the background category.")
  226. parser.add_argument("--target_class", type=int, default=None, \
  227. help="Index of the category of interest.")
  228. parser.add_argument("--max_level", type=int, default=None, \
  229. help="Maximum level of hierarchical patches.")
  230. parser.add_argument("--include_bg", action='store_true', \
  231. help="Include patches that contains only background pixels.")
  232. parser.add_argument("--nonzero_ratio", type=float, default=None, \
  233. help="Threshold for filtering out less informative patches.")
  234. args = parser.parse_args()
  235. extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
  236. args.min_patch_size, args.bg_class, args.target_class,
  237. args.max_level, args.include_bg, args.nonzero_ratio)