Преглед на файлове

Add extract_ms_patches code

Bobholamovic преди 2 години
родител
ревизия
ba60035e59
променени са 2 файла, в които са добавени 296 реда и са изтрити 0 реда
  1. 23 0
      tests/tools/run_extract_ms_patches.py
  2. 273 0
      tools/extract_ms_patches.py

+ 23 - 0
tests/tools/run_extract_ms_patches.py

@@ -0,0 +1,23 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+from testing_utils import run_script
+
+if __name__ == '__main__':
+    with tempfile.TemporaryDirectory() as td:
+        run_script(
+            f"python extract_ms_patches.py --im_paths ../tests/data/ssst/multispectral.tif --mask_path ../tests/data/ssst/multiclass_gt2.png --min_patch_size 32 --save_dir {td}",
+            wd="../tools")

+ 273 - 0
tools/extract_ms_patches.py

@@ -0,0 +1,273 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import argparse
+from collections import deque
+from functools import reduce
+
+import paddlers
+import numpy as np
+try:
+    from osgeo import gdal
+except:
+    import gdal
+from tqdm import tqdm
+
+from utils import time_it
+
+IGN_CLS = 255
+FMT = "im_{idx}{ext}"
+
+
+class QuadTreeNode(object):
+    def __init__(self, i, j, h, w, level, cls_info=None):
+        super().__init__()
+        self.i = i
+        self.j = j
+        self.h = h
+        self.w = w
+        self.level = level
+        self.cls_info = cls_info
+        self.reset_children()
+
+    @property
+    def area(self):
+        return self.h * self.w
+
+    @property
+    def is_bg_node(self):
+        return self.cls_info is None
+
+    @property
+    def coords(self):
+        return (self.i, self.j, self.h, self.w)
+
+    def get_cls_cnt(self, cls):
+        if self.cls_info is None or cls >= len(self.cls_info):
+            return 0
+        return self.cls_info[cls]
+
+    def get_children(self):
+        for child in self.children:
+            if child is not None:
+                yield child
+
+    def reset_children(self):
+        self.children = [None, None, None, None]
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
+
+
+class QuadTree(object):
+    def __init__(self, min_blk_size=256):
+        super().__init__()
+        self.min_blk_size = min_blk_size
+        self.h = None
+        self.w = None
+        self.root = None
+
+    def build_tree(self, mask_band, bg_cls=0):
+        cls_info_table = self.preprocess(mask_band, bg_cls)
+        n_rows = len(cls_info_table)
+        if n_rows == 0:
+            return None
+        n_cols = len(cls_info_table[0])
+        self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
+                                     n_cols - 1, 0)
+        return self.root
+
+    def preprocess(self, mask_ds, bg_cls):
+        h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
+        s = self.min_blk_size
+        if s >= h or s >= w:
+            raise ValueError("`min_blk_size` must be smaller than image size.")
+        cls_info_table = []
+        for i in range(0, h, s):
+            cls_info_row = []
+            for j in range(0, w, s):
+                if i + s > h:
+                    ch = h - i
+                else:
+                    ch = s
+                if j + s > w:
+                    cw = w - j
+                else:
+                    cw = s
+                arr = mask_ds.ReadAsArray(j, i, cw, ch)
+                bins = np.bincount(arr.ravel())
+                if len(bins) > IGN_CLS:
+                    bins = np.delete(bins, IGN_CLS)
+                if bins.sum() == bins[bg_cls]:
+                    cls_info_row.append(None)
+                else:
+                    cls_info_row.append(bins)
+            cls_info_table.append(cls_info_row)
+        return cls_info_table
+
+    def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
+        if i_ed < i_st or j_ed < j_st:
+            return None
+
+        i = i_st * self.min_blk_size
+        j = j_st * self.min_blk_size
+        h = (i_ed - i_st + 1) * self.min_blk_size
+        w = (j_ed - j_st + 1) * self.min_blk_size
+
+        if i_ed == i_st and j_ed == j_st:
+            return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
+
+        i_mid = (i_ed + i_st) // 2
+        j_mid = (j_ed + j_st) // 2
+
+        root = QuadTreeNode(i, j, h, w, level)
+
+        root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
+                                            j_mid, level + 1)
+        root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
+                                            j_mid + 1, j_ed, level + 1)
+        root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
+                                            j_st, j_mid, level + 1)
+        root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
+                                            j_mid + 1, j_ed, level + 1)
+
+        bins_list = [
+            node.cls_info for node in root.get_children()
+            if node.cls_info is not None
+        ]
+        if len(bins_list) > 0:
+            merged_bins = reduce(merge_bins, bins_list)
+            root.cls_info = merged_bins
+        else:
+            # Merge nodes
+            root.reset_children()
+
+        return root
+
+    def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
+        nodes = []
+        q = deque()
+        q.append(self.root)
+        while q:
+            node = q.popleft()
+            if max_level is None or node.level < max_level:
+                for child in node.get_children():
+                    if not include_bg and child.is_bg_node:
+                        continue
+                    if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
+                        continue
+                    nodes.append(child)
+                    q.append(child)
+        return nodes
+
+    def print_tree(self, node=None, level=0):
+        if node is None:
+            node = self.root
+        print(' ' * level + '-', node)
+        for child in node.get_children():
+            self.print_tree(child, level + 1)
+
+
+def merge_bins(bins1, bins2):
+    if len(bins1) < len(bins2):
+        return merge_bins(bins2, bins1)
+    elif len(bins1) == len(bins2):
+        return bins1 + bins2
+    else:
+        return bins1 + np.concatenate(
+            [bins2, np.zeros(len(bins1) - len(bins2))])
+
+
+@time_it
+def extract_ms_patches(im_paths,
+                       mask_path,
+                       save_dir,
+                       min_patch_size=256,
+                       bg_class=0,
+                       target_class=None,
+                       max_level=None,
+                       include_bg=False,
+                       nonzero_ratio=None):
+    def _save_patch(src_path, i, j, h, w, subdir=None):
+        src_path = osp.normpath(src_path)
+        src_name, src_ext = osp.splitext(osp.basename(src_path))
+        subdir = subdir if subdir is not None else src_name
+        dst_dir = osp.join(save_dir, subdir)
+        if not osp.exists(dst_dir):
+            os.makedirs(dst_dir)
+        dst_name = FMT.format(idx=idx, ext=src_ext)
+        dst_path = osp.join(dst_dir, dst_name)
+        gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
+        return dst_path
+
+    if nonzero_ratio is not None:
+        print(
+            "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
+        )
+
+    mask_ds = gdal.Open(mask_path)
+    quad_tree = QuadTree(min_blk_size=min_patch_size)
+    if mask_ds.RasterCount != 1:
+        raise ValueError("The mask image has more than 1 band.")
+    print("Start building quad tree...")
+    quad_tree.build_tree(mask_ds, bg_class)
+    print("Quad tree has been built. Now start collecting nodes...")
+    nodes = quad_tree.get_nodes(
+        tar_cls=target_class, max_level=max_level, include_bg=include_bg)
+    print("Nodes collected. Saving patches...")
+    for idx, node in enumerate(tqdm(nodes)):
+        i, j, h, w = node.coords
+        h = min(h, mask_ds.RasterYSize - i)
+        w = min(w, mask_ds.RasterXSize - j)
+        is_valid = True
+        if nonzero_ratio is not None:
+            for src_path in im_paths:
+                im_ds = gdal.Open(src_path)
+                arr = im_ds.ReadAsArray(j, i, w, h)
+                if np.count_nonzero(arr) / arr.size < nonzero_ratio:
+                    is_valid = False
+                    break
+        if is_valid:
+            for src_path in im_paths:
+                _save_patch(src_path, i, j, h, w)
+            _save_patch(mask_path, i, j, h, w, 'mask')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
+                        help="Path of images. Different images must have unique file names.")
+    parser.add_argument("--mask_path", type=str, required=True, \
+                        help="Path of mask.")
+    parser.add_argument("--save_dir", type=str, required=True, \
+                        help="Path to save the extracted patches.")
+    parser.add_argument("--min_patch_size", type=int, default=256, \
+                        help="Minimum patch size (height and width).")
+    parser.add_argument("--bg_class", type=int, default=0, \
+                        help="Index of the background category.")
+    parser.add_argument("--target_class", type=int, default=None, \
+                        help="Index of the category of interest.")
+    parser.add_argument("--max_level", type=int, default=None, \
+                        help="Maximum level of hierarchical patches.")
+    parser.add_argument("--include_bg", action='store_true', \
+                        help="Include patches that contains only background pixels.")
+    parser.add_argument("--nonzero_ratio", type=float, default=None, \
+                        help="Threshold for filtering out less informative patches.")
+    args = parser.parse_args()
+
+    extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
+                       args.min_patch_size, args.bg_class, args.target_class,
+                       args.max_level, args.include_bg, args.nonzero_ratio)