extract_ms_patches.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. #!/usr/bin/env python
  2. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import os.path as osp
  17. import argparse
  18. from collections import deque
  19. from functools import reduce
  20. import paddlers
  21. import numpy as np
  22. import cv2
  23. try:
  24. from osgeo import gdal
  25. except:
  26. import gdal
  27. from tqdm import tqdm
  28. from utils import time_it
  29. IGN_CLS = 255
  30. FMT = "im_{idx}{ext}"
  31. class QuadTreeNode(object):
  32. def __init__(self, i, j, h, w, level, cls_info=None):
  33. super().__init__()
  34. self.i = i
  35. self.j = j
  36. self.h = h
  37. self.w = w
  38. self.level = level
  39. self.cls_info = cls_info
  40. self.reset_children()
  41. @property
  42. def area(self):
  43. return self.h * self.w
  44. @property
  45. def is_bg_node(self):
  46. return self.cls_info is None
  47. @property
  48. def coords(self):
  49. return (self.i, self.j, self.h, self.w)
  50. def get_cls_cnt(self, cls):
  51. if self.cls_info is None or cls >= len(self.cls_info):
  52. return 0
  53. return self.cls_info[cls]
  54. def get_children(self):
  55. for child in self.children:
  56. if child is not None:
  57. yield child
  58. def reset_children(self):
  59. self.children = [None, None, None, None]
  60. def __repr__(self):
  61. return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
  62. class QuadTree(object):
  63. def __init__(self, min_blk_size=256):
  64. super().__init__()
  65. self.min_blk_size = min_blk_size
  66. self.h = None
  67. self.w = None
  68. self.root = None
  69. def build_tree(self, mask_band, bg_cls=0):
  70. cls_info_table = self.preprocess(mask_band, bg_cls)
  71. n_rows = len(cls_info_table)
  72. if n_rows == 0:
  73. return None
  74. n_cols = len(cls_info_table[0])
  75. self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
  76. n_cols - 1, 0)
  77. return self.root
  78. def preprocess(self, mask_ds, bg_cls):
  79. h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
  80. s = self.min_blk_size
  81. if s >= h or s >= w:
  82. raise ValueError("`min_blk_size` must be smaller than image size.")
  83. cls_info_table = []
  84. for i in range(0, h, s):
  85. cls_info_row = []
  86. for j in range(0, w, s):
  87. if i + s > h:
  88. ch = h - i
  89. else:
  90. ch = s
  91. if j + s > w:
  92. cw = w - j
  93. else:
  94. cw = s
  95. arr = mask_ds.ReadAsArray(j, i, cw, ch)
  96. bins = np.bincount(arr.ravel())
  97. if len(bins) > IGN_CLS:
  98. bins = np.delete(bins, IGN_CLS)
  99. if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
  100. cls_info_row.append(None)
  101. else:
  102. cls_info_row.append(bins)
  103. cls_info_table.append(cls_info_row)
  104. return cls_info_table
  105. def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
  106. if i_ed < i_st or j_ed < j_st:
  107. return None
  108. i = i_st * self.min_blk_size
  109. j = j_st * self.min_blk_size
  110. h = (i_ed - i_st + 1) * self.min_blk_size
  111. w = (j_ed - j_st + 1) * self.min_blk_size
  112. if i_ed == i_st and j_ed == j_st:
  113. return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
  114. i_mid = (i_ed + i_st) // 2
  115. j_mid = (j_ed + j_st) // 2
  116. root = QuadTreeNode(i, j, h, w, level)
  117. root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
  118. j_mid, level + 1)
  119. root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
  120. j_mid + 1, j_ed, level + 1)
  121. root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
  122. j_st, j_mid, level + 1)
  123. root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
  124. j_mid + 1, j_ed, level + 1)
  125. bins_list = [
  126. node.cls_info for node in root.get_children()
  127. if node.cls_info is not None
  128. ]
  129. if len(bins_list) > 0:
  130. merged_bins = reduce(merge_bins, bins_list)
  131. root.cls_info = merged_bins
  132. else:
  133. # Merge nodes
  134. root.reset_children()
  135. return root
  136. def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
  137. nodes = []
  138. q = deque()
  139. q.append(self.root)
  140. while q:
  141. node = q.popleft()
  142. if max_level is None or node.level < max_level:
  143. for child in node.get_children():
  144. if not include_bg and child.is_bg_node:
  145. continue
  146. if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
  147. continue
  148. nodes.append(child)
  149. q.append(child)
  150. return nodes
  151. def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
  152. im = paddlers.transforms.decode_image(im_path)
  153. if im.ndim == 2:
  154. im = np.stack([im] * 3, axis=2)
  155. elif im.ndim == 3:
  156. c = im.shape[2]
  157. if c < 3:
  158. raise ValueError(
  159. "For multi-spectral images, the number of bands should not be less than 3."
  160. )
  161. else:
  162. # Take first three bands as R, G, and B
  163. im = im[..., :3]
  164. else:
  165. raise ValueError("Unrecognized data format.")
  166. nodes = self.get_nodes(include_bg=True)
  167. vis = np.ascontiguousarray(im)
  168. for node in nodes:
  169. i, j, h, w = node.coords
  170. vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (255, 0, 0), 2)
  171. cv2.imwrite(save_path, vis[..., ::-1])
  172. return save_path
  173. def print_tree(self, node=None, level=0):
  174. if node is None:
  175. node = self.root
  176. print(' ' * level + '-', node)
  177. for child in node.get_children():
  178. self.print_tree(child, level + 1)
  179. def merge_bins(bins1, bins2):
  180. if len(bins1) < len(bins2):
  181. return merge_bins(bins2, bins1)
  182. elif len(bins1) == len(bins2):
  183. return bins1 + bins2
  184. else:
  185. return bins1 + np.concatenate(
  186. [bins2, np.zeros(len(bins1) - len(bins2))])
  187. @time_it
  188. def extract_ms_patches(image_paths,
  189. mask_path,
  190. save_dir,
  191. min_patch_size=256,
  192. bg_class=0,
  193. target_class=None,
  194. max_level=None,
  195. include_bg=False,
  196. nonzero_ratio=None,
  197. visualize=False):
  198. def _save_patch(src_path, i, j, h, w, subdir=None):
  199. src_path = osp.normpath(src_path)
  200. src_name, src_ext = osp.splitext(osp.basename(src_path))
  201. subdir = subdir if subdir is not None else src_name
  202. dst_dir = osp.join(save_dir, subdir)
  203. if not osp.exists(dst_dir):
  204. os.makedirs(dst_dir)
  205. dst_name = FMT.format(idx=idx, ext=src_ext)
  206. dst_path = osp.join(dst_dir, dst_name)
  207. gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
  208. return dst_path
  209. if nonzero_ratio is not None:
  210. print(
  211. "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
  212. )
  213. mask_ds = gdal.Open(mask_path)
  214. quad_tree = QuadTree(min_blk_size=min_patch_size)
  215. if mask_ds.RasterCount != 1:
  216. raise ValueError("The mask image has more than 1 band.")
  217. print("Start building quad tree...")
  218. quad_tree.build_tree(mask_ds, bg_class)
  219. if visualize:
  220. print("Start drawing rectangles...")
  221. save_path = quad_tree.visualize_regions(image_paths[0])
  222. print(f"The visualization result is saved in {save_path} .")
  223. print("Quad tree has been built. Now start collecting nodes...")
  224. nodes = quad_tree.get_nodes(
  225. tar_cls=target_class, max_level=max_level, include_bg=include_bg)
  226. print("Nodes collected. Saving patches...")
  227. for idx, node in enumerate(tqdm(nodes)):
  228. i, j, h, w = node.coords
  229. real_h = min(h, mask_ds.RasterYSize - i)
  230. real_w = min(w, mask_ds.RasterXSize - j)
  231. if real_h < h or real_w < w:
  232. # Skip incomplete patches
  233. continue
  234. is_valid = True
  235. if nonzero_ratio is not None:
  236. for src_path in image_paths:
  237. im_ds = gdal.Open(src_path)
  238. arr = im_ds.ReadAsArray(j, i, real_w, real_h)
  239. if np.count_nonzero(arr) / arr.size < nonzero_ratio:
  240. is_valid = False
  241. break
  242. if is_valid:
  243. for src_path in image_paths:
  244. _save_patch(src_path, i, j, real_h, real_w)
  245. _save_patch(mask_path, i, j, real_h, real_w, 'mask')
  246. if __name__ == '__main__':
  247. parser = argparse.ArgumentParser()
  248. parser.add_argument("--image_paths", type=str, required=True, nargs='+', \
  249. help="Path of images. Different images must have unique file names.")
  250. parser.add_argument("--mask_path", type=str, required=True, \
  251. help="Path of mask.")
  252. parser.add_argument("--save_dir", type=str, default='output', \
  253. help="Path to save the extracted patches.")
  254. parser.add_argument("--min_patch_size", type=int, default=256, \
  255. help="Minimum patch size (height and width).")
  256. parser.add_argument("--bg_class", type=int, default=0, \
  257. help="Index of the background category.")
  258. parser.add_argument("--target_class", type=int, default=None, \
  259. help="Index of the category of interest.")
  260. parser.add_argument("--max_level", type=int, default=None, \
  261. help="Maximum level of hierarchical patches.")
  262. parser.add_argument("--include_bg", action='store_true', \
  263. help="Include patches that contains only background pixels.")
  264. parser.add_argument("--nonzero_ratio", type=float, default=None, \
  265. help="Threshold for filtering out less informative patches.")
  266. parser.add_argument("--visualize", action='store_true', \
  267. help="Visualize the quadtree.")
  268. args = parser.parse_args()
  269. extract_ms_patches(args.image_paths, args.mask_path, args.save_dir,
  270. args.min_patch_size, args.bg_class, args.target_class,
  271. args.max_level, args.include_bg, args.nonzero_ratio,
  272. args.visualize)