extract_ms_patches.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import argparse
  17. from collections import deque
  18. from functools import reduce
  19. import paddlers
  20. import numpy as np
  21. import cv2
  22. try:
  23. from osgeo import gdal
  24. except:
  25. import gdal
  26. from tqdm import tqdm
  27. from utils import time_it
  28. IGN_CLS = 255
  29. FMT = "im_{idx}{ext}"
  30. class QuadTreeNode(object):
  31. def __init__(self, i, j, h, w, level, cls_info=None):
  32. super().__init__()
  33. self.i = i
  34. self.j = j
  35. self.h = h
  36. self.w = w
  37. self.level = level
  38. self.cls_info = cls_info
  39. self.reset_children()
  40. @property
  41. def area(self):
  42. return self.h * self.w
  43. @property
  44. def is_bg_node(self):
  45. return self.cls_info is None
  46. @property
  47. def coords(self):
  48. return (self.i, self.j, self.h, self.w)
  49. def get_cls_cnt(self, cls):
  50. if self.cls_info is None or cls >= len(self.cls_info):
  51. return 0
  52. return self.cls_info[cls]
  53. def get_children(self):
  54. for child in self.children:
  55. if child is not None:
  56. yield child
  57. def reset_children(self):
  58. self.children = [None, None, None, None]
  59. def __repr__(self):
  60. return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
  61. class QuadTree(object):
  62. def __init__(self, min_blk_size=256):
  63. super().__init__()
  64. self.min_blk_size = min_blk_size
  65. self.h = None
  66. self.w = None
  67. self.root = None
  68. def build_tree(self, mask_band, bg_cls=0):
  69. cls_info_table = self.preprocess(mask_band, bg_cls)
  70. n_rows = len(cls_info_table)
  71. if n_rows == 0:
  72. return None
  73. n_cols = len(cls_info_table[0])
  74. self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
  75. n_cols - 1, 0)
  76. return self.root
  77. def preprocess(self, mask_ds, bg_cls):
  78. h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
  79. s = self.min_blk_size
  80. if s >= h or s >= w:
  81. raise ValueError("`min_blk_size` must be smaller than image size.")
  82. cls_info_table = []
  83. for i in range(0, h, s):
  84. cls_info_row = []
  85. for j in range(0, w, s):
  86. if i + s > h:
  87. ch = h - i
  88. else:
  89. ch = s
  90. if j + s > w:
  91. cw = w - j
  92. else:
  93. cw = s
  94. arr = mask_ds.ReadAsArray(j, i, cw, ch)
  95. bins = np.bincount(arr.ravel())
  96. if len(bins) > IGN_CLS:
  97. bins = np.delete(bins, IGN_CLS)
  98. if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
  99. cls_info_row.append(None)
  100. else:
  101. cls_info_row.append(bins)
  102. cls_info_table.append(cls_info_row)
  103. return cls_info_table
  104. def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
  105. if i_ed < i_st or j_ed < j_st:
  106. return None
  107. i = i_st * self.min_blk_size
  108. j = j_st * self.min_blk_size
  109. h = (i_ed - i_st + 1) * self.min_blk_size
  110. w = (j_ed - j_st + 1) * self.min_blk_size
  111. if i_ed == i_st and j_ed == j_st:
  112. return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
  113. i_mid = (i_ed + i_st) // 2
  114. j_mid = (j_ed + j_st) // 2
  115. root = QuadTreeNode(i, j, h, w, level)
  116. root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
  117. j_mid, level + 1)
  118. root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
  119. j_mid + 1, j_ed, level + 1)
  120. root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
  121. j_st, j_mid, level + 1)
  122. root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
  123. j_mid + 1, j_ed, level + 1)
  124. bins_list = [
  125. node.cls_info for node in root.get_children()
  126. if node.cls_info is not None
  127. ]
  128. if len(bins_list) > 0:
  129. merged_bins = reduce(merge_bins, bins_list)
  130. root.cls_info = merged_bins
  131. else:
  132. # Merge nodes
  133. root.reset_children()
  134. return root
  135. def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
  136. nodes = []
  137. q = deque()
  138. q.append(self.root)
  139. while q:
  140. node = q.popleft()
  141. if max_level is None or node.level < max_level:
  142. for child in node.get_children():
  143. if not include_bg and child.is_bg_node:
  144. continue
  145. if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
  146. continue
  147. nodes.append(child)
  148. q.append(child)
  149. return nodes
  150. def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
  151. im = paddlers.transforms.decode_image(im_path)
  152. if im.ndim == 2:
  153. im = np.stack([im] * 3, axis=2)
  154. elif im.ndim == 3:
  155. c = im.shape[2]
  156. if c < 3:
  157. raise ValueError(
  158. "For multi-spectral images, the number of bands should not be less than 3."
  159. )
  160. else:
  161. # Take first three bands as R, G, and B
  162. im = im[..., :3]
  163. else:
  164. raise ValueError("Unrecognized data format.")
  165. nodes = self.get_nodes(include_bg=True)
  166. vis = np.ascontiguousarray(im)
  167. for node in nodes:
  168. i, j, h, w = node.coords
  169. vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (255, 0, 0), 2)
  170. cv2.imwrite(save_path, vis[..., ::-1])
  171. return save_path
  172. def print_tree(self, node=None, level=0):
  173. if node is None:
  174. node = self.root
  175. print(' ' * level + '-', node)
  176. for child in node.get_children():
  177. self.print_tree(child, level + 1)
  178. def merge_bins(bins1, bins2):
  179. if len(bins1) < len(bins2):
  180. return merge_bins(bins2, bins1)
  181. elif len(bins1) == len(bins2):
  182. return bins1 + bins2
  183. else:
  184. return bins1 + np.concatenate(
  185. [bins2, np.zeros(len(bins1) - len(bins2))])
  186. @time_it
  187. def extract_ms_patches(im_paths,
  188. mask_path,
  189. save_dir,
  190. min_patch_size=256,
  191. bg_class=0,
  192. target_class=None,
  193. max_level=None,
  194. include_bg=False,
  195. nonzero_ratio=None,
  196. visualize=False):
  197. def _save_patch(src_path, i, j, h, w, subdir=None):
  198. src_path = osp.normpath(src_path)
  199. src_name, src_ext = osp.splitext(osp.basename(src_path))
  200. subdir = subdir if subdir is not None else src_name
  201. dst_dir = osp.join(save_dir, subdir)
  202. if not osp.exists(dst_dir):
  203. os.makedirs(dst_dir)
  204. dst_name = FMT.format(idx=idx, ext=src_ext)
  205. dst_path = osp.join(dst_dir, dst_name)
  206. gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
  207. return dst_path
  208. if nonzero_ratio is not None:
  209. print(
  210. "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
  211. )
  212. mask_ds = gdal.Open(mask_path)
  213. quad_tree = QuadTree(min_blk_size=min_patch_size)
  214. if mask_ds.RasterCount != 1:
  215. raise ValueError("The mask image has more than 1 band.")
  216. print("Start building quad tree...")
  217. quad_tree.build_tree(mask_ds, bg_class)
  218. if visualize:
  219. print("Start drawing rectangles...")
  220. save_path = quad_tree.visualize_regions(im_paths[0])
  221. print(f"The visualization result is saved in {save_path} .")
  222. print("Quad tree has been built. Now start collecting nodes...")
  223. nodes = quad_tree.get_nodes(
  224. tar_cls=target_class, max_level=max_level, include_bg=include_bg)
  225. print("Nodes collected. Saving patches...")
  226. for idx, node in enumerate(tqdm(nodes)):
  227. i, j, h, w = node.coords
  228. real_h = min(h, mask_ds.RasterYSize - i)
  229. real_w = min(w, mask_ds.RasterXSize - j)
  230. if real_h < h or real_w < w:
  231. # Skip incomplete patches
  232. continue
  233. is_valid = True
  234. if nonzero_ratio is not None:
  235. for src_path in im_paths:
  236. im_ds = gdal.Open(src_path)
  237. arr = im_ds.ReadAsArray(j, i, real_w, real_h)
  238. if np.count_nonzero(arr) / arr.size < nonzero_ratio:
  239. is_valid = False
  240. break
  241. if is_valid:
  242. for src_path in im_paths:
  243. _save_patch(src_path, i, j, real_h, real_w)
  244. _save_patch(mask_path, i, j, real_h, real_w, 'mask')
  245. if __name__ == '__main__':
  246. parser = argparse.ArgumentParser()
  247. parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
  248. help="Path of images. Different images must have unique file names.")
  249. parser.add_argument("--mask_path", type=str, required=True, \
  250. help="Path of mask.")
  251. parser.add_argument("--save_dir", type=str, default='output', \
  252. help="Path to save the extracted patches.")
  253. parser.add_argument("--min_patch_size", type=int, default=256, \
  254. help="Minimum patch size (height and width).")
  255. parser.add_argument("--bg_class", type=int, default=0, \
  256. help="Index of the background category.")
  257. parser.add_argument("--target_class", type=int, default=None, \
  258. help="Index of the category of interest.")
  259. parser.add_argument("--max_level", type=int, default=None, \
  260. help="Maximum level of hierarchical patches.")
  261. parser.add_argument("--include_bg", action='store_true', \
  262. help="Include patches that contains only background pixels.")
  263. parser.add_argument("--nonzero_ratio", type=float, default=None, \
  264. help="Threshold for filtering out less informative patches.")
  265. parser.add_argument("--visualize", action='store_true', \
  266. help="Visualize the quadtree.")
  267. args = parser.parse_args()
  268. extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
  269. args.min_patch_size, args.bg_class, args.target_class,
  270. args.max_level, args.include_bg, args.nonzero_ratio,
  271. args.visualize)