Browse Source

Add visualization

Bobholamovic 2 years ago
parent
commit
79e6621fd1
1 changed files with 43 additions and 8 deletions
  1. 43 8
      tools/extract_ms_patches.py

+ 43 - 8
tools/extract_ms_patches.py

@@ -20,6 +20,7 @@ from functools import reduce
 
 import paddlers
 import numpy as np
+import cv2
 try:
     from osgeo import gdal
 except:
@@ -111,7 +112,7 @@ class QuadTree(object):
                 bins = np.bincount(arr.ravel())
                 if len(bins) > IGN_CLS:
                     bins = np.delete(bins, IGN_CLS)
-                if bins.sum() == bins[bg_cls]:
+                if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
                     cls_info_row.append(None)
                 else:
                     cls_info_row.append(bins)
@@ -173,6 +174,29 @@ class QuadTree(object):
                     q.append(child)
         return nodes
 
+    def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
+        im = paddlers.transforms.decode_image(im_path)
+        if im.ndim == 2:
+            im = np.stack([im] * 3, axis=2)
+        elif im.ndim == 3:
+            c = im.shape[2]
+            if c < 3:
+                raise ValueError(
+                    "For multi-spectral images, the number of bands should not be less than 3."
+                )
+            else:
+                # Take first three bands as R, G, and B
+                im = im[..., :3]
+        else:
+            raise ValueError("Unrecognized data format.")
+        nodes = self.get_nodes(include_bg=True)
+        vis = np.ascontiguousarray(im)
+        for node in nodes:
+            i, j, h, w = node.coords
+            vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2)
+        cv2.imwrite(save_path, vis)
+        return save_path
+
     def print_tree(self, node=None, level=0):
         if node is None:
             node = self.root
@@ -200,7 +224,8 @@ def extract_ms_patches(im_paths,
                        target_class=None,
                        max_level=None,
                        include_bg=False,
-                       nonzero_ratio=None):
+                       nonzero_ratio=None,
+                       visualize=False):
     def _save_patch(src_path, i, j, h, w, subdir=None):
         src_path = osp.normpath(src_path)
         src_name, src_ext = osp.splitext(osp.basename(src_path))
@@ -224,26 +249,33 @@ def extract_ms_patches(im_paths,
         raise ValueError("The mask image has more than 1 band.")
     print("Start building quad tree...")
     quad_tree.build_tree(mask_ds, bg_class)
+    if visualize:
+        print("Start drawing rectangles...")
+        save_path = quad_tree.visualize_regions(im_paths[0])
+        print(f"The visualization result is saved in {save_path} .")
     print("Quad tree has been built. Now start collecting nodes...")
     nodes = quad_tree.get_nodes(
         tar_cls=target_class, max_level=max_level, include_bg=include_bg)
     print("Nodes collected. Saving patches...")
     for idx, node in enumerate(tqdm(nodes)):
         i, j, h, w = node.coords
-        h = min(h, mask_ds.RasterYSize - i)
-        w = min(w, mask_ds.RasterXSize - j)
+        real_h = min(h, mask_ds.RasterYSize - i)
+        real_w = min(w, mask_ds.RasterXSize - j)
+        if real_h < h or real_w < w:
+            # Skip incomplete patches
+            continue
         is_valid = True
         if nonzero_ratio is not None:
             for src_path in im_paths:
                 im_ds = gdal.Open(src_path)
-                arr = im_ds.ReadAsArray(j, i, w, h)
+                arr = im_ds.ReadAsArray(j, i, real_w, real_h)
                 if np.count_nonzero(arr) / arr.size < nonzero_ratio:
                     is_valid = False
                     break
         if is_valid:
             for src_path in im_paths:
-                _save_patch(src_path, i, j, h, w)
-            _save_patch(mask_path, i, j, h, w, 'mask')
+                _save_patch(src_path, i, j, real_h, real_w)
+            _save_patch(mask_path, i, j, real_h, real_w, 'mask')
 
 
 if __name__ == '__main__':
@@ -266,8 +298,11 @@ if __name__ == '__main__':
                         help="Include patches that contains only background pixels.")
     parser.add_argument("--nonzero_ratio", type=float, default=None, \
                         help="Threshold for filtering out less informative patches.")
+    parser.add_argument("--visualize", action='store_true', \
+                        help="Visualize the quadtree.")
     args = parser.parse_args()
 
     extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
                        args.min_patch_size, args.bg_class, args.target_class,
-                       args.max_level, args.include_bg, args.nonzero_ratio)
+                       args.max_level, args.include_bg, args.nonzero_ratio,
+                       args.visualize)