|
@@ -20,6 +20,7 @@ from functools import reduce
|
|
|
|
|
|
import paddlers
|
|
|
import numpy as np
|
|
|
+import cv2
|
|
|
try:
|
|
|
from osgeo import gdal
|
|
|
except:
|
|
@@ -111,7 +112,7 @@ class QuadTree(object):
|
|
|
bins = np.bincount(arr.ravel())
|
|
|
if len(bins) > IGN_CLS:
|
|
|
bins = np.delete(bins, IGN_CLS)
|
|
|
- if bins.sum() == bins[bg_cls]:
|
|
|
+ if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
|
|
|
cls_info_row.append(None)
|
|
|
else:
|
|
|
cls_info_row.append(bins)
|
|
@@ -173,6 +174,29 @@ class QuadTree(object):
|
|
|
q.append(child)
|
|
|
return nodes
|
|
|
|
|
|
+ def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
|
|
|
+ im = paddlers.transforms.decode_image(im_path)
|
|
|
+ if im.ndim == 2:
|
|
|
+ im = np.stack([im] * 3, axis=2)
|
|
|
+ elif im.ndim == 3:
|
|
|
+ c = im.shape[2]
|
|
|
+ if c < 3:
|
|
|
+ raise ValueError(
|
|
|
+ "For multi-spectral images, the number of bands should not be less than 3."
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # Take first three bands as R, G, and B
|
|
|
+ im = im[..., :3]
|
|
|
+ else:
|
|
|
+ raise ValueError("Unrecognized data format.")
|
|
|
+ nodes = self.get_nodes(include_bg=True)
|
|
|
+ vis = np.ascontiguousarray(im)
|
|
|
+ for node in nodes:
|
|
|
+ i, j, h, w = node.coords
|
|
|
+ vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2)
|
|
|
+ cv2.imwrite(save_path, vis)
|
|
|
+ return save_path
|
|
|
+
|
|
|
def print_tree(self, node=None, level=0):
|
|
|
if node is None:
|
|
|
node = self.root
|
|
@@ -200,7 +224,8 @@ def extract_ms_patches(im_paths,
|
|
|
target_class=None,
|
|
|
max_level=None,
|
|
|
include_bg=False,
|
|
|
- nonzero_ratio=None):
|
|
|
+ nonzero_ratio=None,
|
|
|
+ visualize=False):
|
|
|
def _save_patch(src_path, i, j, h, w, subdir=None):
|
|
|
src_path = osp.normpath(src_path)
|
|
|
src_name, src_ext = osp.splitext(osp.basename(src_path))
|
|
@@ -224,26 +249,33 @@ def extract_ms_patches(im_paths,
|
|
|
raise ValueError("The mask image has more than 1 band.")
|
|
|
print("Start building quad tree...")
|
|
|
quad_tree.build_tree(mask_ds, bg_class)
|
|
|
+ if visualize:
|
|
|
+ print("Start drawing rectangles...")
|
|
|
+ save_path = quad_tree.visualize_regions(im_paths[0])
|
|
|
+ print(f"The visualization result is saved in {save_path} .")
|
|
|
print("Quad tree has been built. Now start collecting nodes...")
|
|
|
nodes = quad_tree.get_nodes(
|
|
|
tar_cls=target_class, max_level=max_level, include_bg=include_bg)
|
|
|
print("Nodes collected. Saving patches...")
|
|
|
for idx, node in enumerate(tqdm(nodes)):
|
|
|
i, j, h, w = node.coords
|
|
|
- h = min(h, mask_ds.RasterYSize - i)
|
|
|
- w = min(w, mask_ds.RasterXSize - j)
|
|
|
+ real_h = min(h, mask_ds.RasterYSize - i)
|
|
|
+ real_w = min(w, mask_ds.RasterXSize - j)
|
|
|
+ if real_h < h or real_w < w:
|
|
|
+ # Skip incomplete patches
|
|
|
+ continue
|
|
|
is_valid = True
|
|
|
if nonzero_ratio is not None:
|
|
|
for src_path in im_paths:
|
|
|
im_ds = gdal.Open(src_path)
|
|
|
- arr = im_ds.ReadAsArray(j, i, w, h)
|
|
|
+ arr = im_ds.ReadAsArray(j, i, real_w, real_h)
|
|
|
if np.count_nonzero(arr) / arr.size < nonzero_ratio:
|
|
|
is_valid = False
|
|
|
break
|
|
|
if is_valid:
|
|
|
for src_path in im_paths:
|
|
|
- _save_patch(src_path, i, j, h, w)
|
|
|
- _save_patch(mask_path, i, j, h, w, 'mask')
|
|
|
+ _save_patch(src_path, i, j, real_h, real_w)
|
|
|
+ _save_patch(mask_path, i, j, real_h, real_w, 'mask')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
@@ -266,8 +298,11 @@ if __name__ == '__main__':
|
|
|
help="Include patches that contains only background pixels.")
|
|
|
parser.add_argument("--nonzero_ratio", type=float, default=None, \
|
|
|
help="Threshold for filtering out less informative patches.")
|
|
|
+ parser.add_argument("--visualize", action='store_true', \
|
|
|
+ help="Visualize the quadtree.")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
|
|
|
args.min_patch_size, args.bg_class, args.target_class,
|
|
|
- args.max_level, args.include_bg, args.nonzero_ratio)
|
|
|
+ args.max_level, args.include_bg, args.nonzero_ratio,
|
|
|
+ args.visualize)
|