123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import os.path as osp
- import argparse
- from collections import deque
- from functools import reduce
- import paddlers
- import numpy as np
- import cv2
- try:
- from osgeo import gdal
- except:
- import gdal
- from tqdm import tqdm
- from utils import time_it
- IGN_CLS = 255
- FMT = "im_{idx}{ext}"
- class QuadTreeNode(object):
- def __init__(self, i, j, h, w, level, cls_info=None):
- super().__init__()
- self.i = i
- self.j = j
- self.h = h
- self.w = w
- self.level = level
- self.cls_info = cls_info
- self.reset_children()
- @property
- def area(self):
- return self.h * self.w
- @property
- def is_bg_node(self):
- return self.cls_info is None
- @property
- def coords(self):
- return (self.i, self.j, self.h, self.w)
- def get_cls_cnt(self, cls):
- if self.cls_info is None or cls >= len(self.cls_info):
- return 0
- return self.cls_info[cls]
- def get_children(self):
- for child in self.children:
- if child is not None:
- yield child
- def reset_children(self):
- self.children = [None, None, None, None]
- def __repr__(self):
- return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
- class QuadTree(object):
- def __init__(self, min_blk_size=256):
- super().__init__()
- self.min_blk_size = min_blk_size
- self.h = None
- self.w = None
- self.root = None
- def build_tree(self, mask_band, bg_cls=0):
- cls_info_table = self.preprocess(mask_band, bg_cls)
- n_rows = len(cls_info_table)
- if n_rows == 0:
- return None
- n_cols = len(cls_info_table[0])
- self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
- n_cols - 1, 0)
- return self.root
- def preprocess(self, mask_ds, bg_cls):
- h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
- s = self.min_blk_size
- if s >= h or s >= w:
- raise ValueError("`min_blk_size` must be smaller than image size.")
- cls_info_table = []
- for i in range(0, h, s):
- cls_info_row = []
- for j in range(0, w, s):
- if i + s > h:
- ch = h - i
- else:
- ch = s
- if j + s > w:
- cw = w - j
- else:
- cw = s
- arr = mask_ds.ReadAsArray(j, i, cw, ch)
- bins = np.bincount(arr.ravel())
- if len(bins) > IGN_CLS:
- bins = np.delete(bins, IGN_CLS)
- if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
- cls_info_row.append(None)
- else:
- cls_info_row.append(bins)
- cls_info_table.append(cls_info_row)
- return cls_info_table
- def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
- if i_ed < i_st or j_ed < j_st:
- return None
- i = i_st * self.min_blk_size
- j = j_st * self.min_blk_size
- h = (i_ed - i_st + 1) * self.min_blk_size
- w = (j_ed - j_st + 1) * self.min_blk_size
- if i_ed == i_st and j_ed == j_st:
- return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
- i_mid = (i_ed + i_st) // 2
- j_mid = (j_ed + j_st) // 2
- root = QuadTreeNode(i, j, h, w, level)
- root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
- j_mid, level + 1)
- root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
- j_mid + 1, j_ed, level + 1)
- root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
- j_st, j_mid, level + 1)
- root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
- j_mid + 1, j_ed, level + 1)
- bins_list = [
- node.cls_info for node in root.get_children()
- if node.cls_info is not None
- ]
- if len(bins_list) > 0:
- merged_bins = reduce(merge_bins, bins_list)
- root.cls_info = merged_bins
- else:
- # Merge nodes
- root.reset_children()
- return root
- def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
- nodes = []
- q = deque()
- q.append(self.root)
- while q:
- node = q.popleft()
- if max_level is None or node.level < max_level:
- for child in node.get_children():
- if not include_bg and child.is_bg_node:
- continue
- if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
- continue
- nodes.append(child)
- q.append(child)
- return nodes
- def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
- im = paddlers.transforms.decode_image(im_path)
- if im.ndim == 2:
- im = np.stack([im] * 3, axis=2)
- elif im.ndim == 3:
- c = im.shape[2]
- if c < 3:
- raise ValueError(
- "For multi-spectral images, the number of bands should not be less than 3."
- )
- else:
- # Take first three bands as R, G, and B
- im = im[..., :3]
- else:
- raise ValueError("Unrecognized data format.")
- nodes = self.get_nodes(include_bg=True)
- vis = np.ascontiguousarray(im)
- for node in nodes:
- i, j, h, w = node.coords
- vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (255, 0, 0), 2)
- cv2.imwrite(save_path, vis[..., ::-1])
- return save_path
- def print_tree(self, node=None, level=0):
- if node is None:
- node = self.root
- print(' ' * level + '-', node)
- for child in node.get_children():
- self.print_tree(child, level + 1)
- def merge_bins(bins1, bins2):
- if len(bins1) < len(bins2):
- return merge_bins(bins2, bins1)
- elif len(bins1) == len(bins2):
- return bins1 + bins2
- else:
- return bins1 + np.concatenate(
- [bins2, np.zeros(len(bins1) - len(bins2))])
- @time_it
- def extract_ms_patches(im_paths,
- mask_path,
- save_dir,
- min_patch_size=256,
- bg_class=0,
- target_class=None,
- max_level=None,
- include_bg=False,
- nonzero_ratio=None,
- visualize=False):
- def _save_patch(src_path, i, j, h, w, subdir=None):
- src_path = osp.normpath(src_path)
- src_name, src_ext = osp.splitext(osp.basename(src_path))
- subdir = subdir if subdir is not None else src_name
- dst_dir = osp.join(save_dir, subdir)
- if not osp.exists(dst_dir):
- os.makedirs(dst_dir)
- dst_name = FMT.format(idx=idx, ext=src_ext)
- dst_path = osp.join(dst_dir, dst_name)
- gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
- return dst_path
- if nonzero_ratio is not None:
- print(
- "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
- )
- mask_ds = gdal.Open(mask_path)
- quad_tree = QuadTree(min_blk_size=min_patch_size)
- if mask_ds.RasterCount != 1:
- raise ValueError("The mask image has more than 1 band.")
- print("Start building quad tree...")
- quad_tree.build_tree(mask_ds, bg_class)
- if visualize:
- print("Start drawing rectangles...")
- save_path = quad_tree.visualize_regions(im_paths[0])
- print(f"The visualization result is saved in {save_path} .")
- print("Quad tree has been built. Now start collecting nodes...")
- nodes = quad_tree.get_nodes(
- tar_cls=target_class, max_level=max_level, include_bg=include_bg)
- print("Nodes collected. Saving patches...")
- for idx, node in enumerate(tqdm(nodes)):
- i, j, h, w = node.coords
- real_h = min(h, mask_ds.RasterYSize - i)
- real_w = min(w, mask_ds.RasterXSize - j)
- if real_h < h or real_w < w:
- # Skip incomplete patches
- continue
- is_valid = True
- if nonzero_ratio is not None:
- for src_path in im_paths:
- im_ds = gdal.Open(src_path)
- arr = im_ds.ReadAsArray(j, i, real_w, real_h)
- if np.count_nonzero(arr) / arr.size < nonzero_ratio:
- is_valid = False
- break
- if is_valid:
- for src_path in im_paths:
- _save_patch(src_path, i, j, real_h, real_w)
- _save_patch(mask_path, i, j, real_h, real_w, 'mask')
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
- help="Path of images. Different images must have unique file names.")
- parser.add_argument("--mask_path", type=str, required=True, \
- help="Path of mask.")
- parser.add_argument("--save_dir", type=str, default='output', \
- help="Path to save the extracted patches.")
- parser.add_argument("--min_patch_size", type=int, default=256, \
- help="Minimum patch size (height and width).")
- parser.add_argument("--bg_class", type=int, default=0, \
- help="Index of the background category.")
- parser.add_argument("--target_class", type=int, default=None, \
- help="Index of the category of interest.")
- parser.add_argument("--max_level", type=int, default=None, \
- help="Maximum level of hierarchical patches.")
- parser.add_argument("--include_bg", action='store_true', \
- help="Include patches that contains only background pixels.")
- parser.add_argument("--nonzero_ratio", type=float, default=None, \
- help="Threshold for filtering out less informative patches.")
- parser.add_argument("--visualize", action='store_true', \
- help="Visualize the quadtree.")
- args = parser.parse_args()
- extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
- args.min_patch_size, args.bg_class, args.target_class,
- args.max_level, args.include_bg, args.nonzero_ratio,
- args.visualize)
|