Bladeren bron

Merge pull request #37 from Bobholamovic/quadtree

[Feat] Add Multi-scale Patch Extraction Tool
cc 2 jaren geleden
bovenliggende
commit
66894e3adb
11 gewijzigde bestanden met toevoegingen van 442 en 84 verwijderingen
  1. 44 20
      docs/data/tools.md
  2. 2 2
      tests/run_tests.sh
  3. 23 0
      tests/tools/run_extract_ms_patches.py
  4. 7 7
      tools/coco2mask.py
  5. 308 0
      tools/extract_ms_patches.py
  6. 10 8
      tools/mask2geojson.py
  7. 10 10
      tools/mask2shape.py
  8. 12 12
      tools/match.py
  9. 7 7
      tools/oif.py
  10. 8 8
      tools/pca.py
  11. 11 10
      tools/split.py

+ 44 - 20
docs/data/tools.md

@@ -2,13 +2,14 @@
 
 PaddleRS在`tools`目录中提供了丰富的遥感影像处理工具,包括:
 
-- `coco2mask.py`:用于将COCO格式的标注文件转换为png格式。
+- `coco2mask.py`:用于将COCO格式的标注文件转换为.png格式。
 - `mask2shape.py`:用于将模型推理输出的.png格式栅格标签转换为矢量格式。
 - `mask2geojson.py`:用于将模型推理输出的.png格式栅格标签转换为GeoJSON格式。
 - `match.py`:用于实现两幅影像的配准。
 - `split.py`:用于对大幅面影像数据进行切片。
 - `coco_tools/`:COCO工具合集,用于统计处理COCO格式标注文件。
 - `prepare_dataset/`:数据集预处理脚本合集。
+- `extract_ms_patches.py`:从整幅遥感影像中提取多尺度影像块。
 
 ## 使用说明
 
@@ -20,7 +21,7 @@ cd tools
 
 ### coco2mask
 
-`coco2mask.py`的主要功能是将图像以及对应的COCO格式的分割标签转换为图像与.png格式的标签,结果会分别存放在`img`和`gt`两个目录中。相关的数据样例可以参考[中国典型城市建筑物实例数据集](https://www.scidb.cn/detail?dataSetId=806674532768153600&dataSetType=journal)。对于mask,保存结果为单通道的伪彩色像。使用方式如下:
+`coco2mask.py`的主要功能是将影像以及对应的COCO格式的分割标签转换为影像与.png格式的标签,结果会分别存放在`img`和`gt`两个目录中。相关的数据样例可以参考[中国典型城市建筑物实例数据集](https://www.scidb.cn/detail?dataSetId=806674532768153600&dataSetType=journal)。对于mask,保存结果为单通道的伪彩色像。使用方式如下:
 
 ```shell
 python coco2mask.py --raw_dir {输入目录路径} --save_dir {输出目录路径}
@@ -28,8 +29,8 @@ python coco2mask.py --raw_dir {输入目录路径} --save_dir {输出目录路
 
 其中:
 
-- `raw_dir`:存放原始数据的目录,其中像存放在`images`子目录中,标签以`xxx.json`格式保存。
-- `save_dir`:保存输出结果的目录,其中像保存在`img`子目录中,.png格式的标签保存在`gt`子目录中。
+- `raw_dir`:存放原始数据的目录,其中像存放在`images`子目录中,标签以`xxx.json`格式保存。
+- `save_dir`:保存输出结果的目录,其中像保存在`img`子目录中,.png格式的标签保存在`gt`子目录中。
 
 ### mask2shape
 
@@ -41,10 +42,10 @@ python mask2shape.py --srcimg_path {带有地理信息的原始影像路径} --m
 
 其中:
 
-- `srcimg_path`:原始影像路径,需要带有地理坐标信息,以便为生成的shapefile提供crs等信息。
+- `srcimg_path`:原始影像路径,需要带有地理元信息,以便为生成的shapefile提供地理投影坐标系等信息。
 - `mask_path`:模型推理得到的.png格式的分割结果。
 - `save_path`:保存shapefile的路径,默认为`output`。
-- `ignore_index`:需要在shapefile中忽略的索引值(例如分割任务中的背景类),默认为255。
+- `ignore_index`:需要在shapefile中忽略的索引值(例如分割任务中的背景类),默认为`255`
 
 ### mask2geojson
 
@@ -69,37 +70,37 @@ python match.py --im1_path [时相1影像路径] --im2_path [时相2影像路径
 
 其中:
 
-- `im1_path`:时相1影像路径。该影像必须包含地理信息,且配准过程中以该影像为基准像。
+- `im1_path`:时相1影像路径。该影像必须包含地理信息,且配准过程中以该影像为基准像。
 - `im2_path`:时相2影像路径。该影像的地理信息将不被用到。配准过程中将该影像配准到时相1影像。
 - `im1_bands`:时相1影像用于配准的波段,指定为三通道(分别代表R、G、B)或单通道,默认为[1, 2, 3]。
 - `im2_bands`:时相2影像用于配准的波段,指定为三通道(分别代表R、G、B)或单通道,默认为[1, 2, 3]。
-- `save_path`: 配准后时相2影像输出路径。
+- `save_path` 配准后时相2影像输出路径。
 
 ### split
 
-`split.py`的主要功能是将大幅面遥感图像划分为图像块,这些图像块可以作为训练时的输入。使用方式如下:
+`split.py`的主要功能是将大幅面遥感影像划分为影像块,这些影像块可以作为训练时的输入。使用方式如下:
 
 ```shell
-python split.py --image_path {输入影像路径} [--mask_path {真值标签路径}] [--block_size {像块尺寸}] [--save_dir {输出目录}]
+python split.py --image_path {输入影像路径} [--mask_path {真值标签路径}] [--block_size {像块尺寸}] [--save_dir {输出目录}]
 ```
 
 其中:
 
-- `image_path`:需要切分的像的路径。
-- `mask_path`:一同切分的标签图像路径,默认没有
-- `block_size`:切分像块大小,默认为512。
-- `save_folder`:保存切分后结果的文件夹路径,默认为`output`。
+- `image_path`:需要切分的像的路径。
+- `mask_path`:一同切分的标签影像路径,默认为`None`
+- `block_size`:切分像块大小,默认为512。
+- `save_dir`:保存切分后结果的文件夹路径,默认为`output`。
 
 ### coco_tools
 
 目前`coco_tools`目录中共包含6个工具,各工具功能如下:
 
-- `json_InfoShow.py`:    打印json文件中各个字典的基本信息;
-- `json_ImgSta.py`:      统计json文件中的图像信息,生成统计表、统计图;
-- `json_AnnoSta.py`:     统计json文件中的标注信息,生成统计表、统计图;
-- `json_Img2Json.py`:    统计test集图像,生成json文件;
-- `json_Split.py`:       将json文件中的内容划分为train set和val set;
-- `json_Merge.py`:       将多个json文件合并为一个。
+- `json_InfoShow.py`    打印json文件中各个字典的基本信息;
+- `json_ImgSta.py`:      统计json文件中的影像信息,生成统计表、统计图;
+- `json_AnnoSta.py`     统计json文件中的标注信息,生成统计表、统计图;
+- `json_Img2Json.py`:    统计test集影像,生成json文件;
+- `json_Split.py`       将json文件中的内容划分为train set和val set;
+- `json_Merge.py`       将多个json文件合并为一个。
 
 详细使用方法请参见[coco_tools使用说明](coco_tools.md)。
 
@@ -123,3 +124,26 @@ python prepare_dataset/prepare_levircd.py --help
 - `--ratios`:对于支持子集随机划分的数据集,指定需要划分的各个子集的样本比例。示例:`--ratios 0.7 0.2 0.1`。
 
 您可以在[此文档](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/data_prep.md)中查看PaddleRS提供哪些数据集的预处理脚本。
+
+### extract_ms_patches
+
+`extract_ms_patches.py`的主要功能是利用四叉树从整幅遥感影像中提取不同尺度的包含感兴趣目标的影像块,提取的影像块可用作模型训练样本。使用方式如下:
+
+```shell
+python extract_ms_patches.py --image_paths {一个或多个输入影像路径} --mask_path {真值标签路径} [--save_dir {输出目录}] [--min_patch_size {最小的影像块尺寸}] [--bg_class {背景类类别编号}] [--target_class {目标类类别编号}] [--max_level {检索的最大尺度层级}] [--include_bg] [--nonzero_ratio {影像块中非零像素占比阈值}] [--visualize]
+```
+
+其中:
+
+- `image_paths`:源影像路径,可以指定多个路径。
+- `mask_path`:真值标签路径。
+- `save_dir`:保存切分后结果的文件夹路径,默认为`output`。
+- `min_patch_size`:提取的影像块的最小尺寸(以影像块长/宽的像素个数计),即四叉树的叶子结点在图中覆盖的最小范围,默认为`256`。
+- `bg_class`:背景类别的类别编号,默认为`0`。
+- `target_class`:目标类别的类别编号,若为`None`,则表示所有背景类别以外的类别均为目标类别,默认为`None`。
+- `max_level`:检索的最大尺度层级,若为`None`,则表示不限制层级,默认为`None`。
+- `include_bg`:若指定此选项,则也保存那些仅包含背景类别、不包含目标类别的影像块。
+- `--nonzero_ratio`:指定一个阈值,对于任意一幅源影像,若影像块中非零像素占比小于此阈值,则该影像块将被舍弃。若为`None`,则表示不进行过滤。默认为`None`。
+- `--visualize`:若指定此选项,则程序执行完毕后将生成图像`./vis_quadtree.png`,其中保存有四叉树中节点情况的可视化结果,一个例子如下图所示:
+
+[vis_quadtree_example](https://user-images.githubusercontent.com/21275753/189264850-f94b3d7b-c631-47b1-9833-0800de2ccf54.png)

+ 2 - 2
tests/run_tests.sh

@@ -9,8 +9,8 @@ bash download_test_data.sh
 python -m unittest discover -v
 
 # Test tools
-for script in $(ls run*.py); do
-    python ${script}
+for script in $(ls tools/run*.py); do
+    PYTHONPATH="$(pwd)" python ${script}
 done
 
 # Test tutorials

+ 23 - 0
tests/tools/run_extract_ms_patches.py

@@ -0,0 +1,23 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+from testing_utils import run_script
+
+if __name__ == '__main__':
+    with tempfile.TemporaryDirectory() as td:
+        run_script(
+            f"python extract_ms_patches.py --im_paths ../tests/data/ssst/multispectral.tif --mask_path ../tests/data/ssst/multiclass_gt2.png --min_patch_size 32 --save_dir {td}",
+            wd="../tools")

+ 7 - 7
tools/coco2mask.py

@@ -19,8 +19,9 @@ import json
 import argparse
 from collections import defaultdict
 
-import cv2
+import paddlers
 import numpy as np
+import cv2
 import glob
 from tqdm import tqdm
 from PIL import Image
@@ -101,12 +102,11 @@ def convert_data(raw_dir, end_dir):
                           lab_save_path)
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument("--raw_dir", type=str, required=True, \
-                    help="Directory that contains original data, where `images` stores the original image and `annotation.json` stores the corresponding annotation information.")
-parser.add_argument("--save_dir", type=str, required=True, \
-                    help="Directory to save the results, where `img` stores the image and `gt` stores the label.")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--raw_dir", type=str, required=True, \
+                        help="Directory that contains original data, where `images` stores the original image and `annotation.json` stores the corresponding annotation information.")
+    parser.add_argument("--save_dir", type=str, required=True, \
+                        help="Directory to save the results, where `img` stores the image and `gt` stores the label.")
     args = parser.parse_args()
     convert_data(args.raw_dir, args.save_dir)

+ 308 - 0
tools/extract_ms_patches.py

@@ -0,0 +1,308 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import argparse
+from collections import deque
+from functools import reduce
+
+import paddlers
+import numpy as np
+import cv2
+try:
+    from osgeo import gdal
+except:
+    import gdal
+from tqdm import tqdm
+
+from utils import time_it
+
+IGN_CLS = 255
+FMT = "im_{idx}{ext}"
+
+
+class QuadTreeNode(object):
+    def __init__(self, i, j, h, w, level, cls_info=None):
+        super().__init__()
+        self.i = i
+        self.j = j
+        self.h = h
+        self.w = w
+        self.level = level
+        self.cls_info = cls_info
+        self.reset_children()
+
+    @property
+    def area(self):
+        return self.h * self.w
+
+    @property
+    def is_bg_node(self):
+        return self.cls_info is None
+
+    @property
+    def coords(self):
+        return (self.i, self.j, self.h, self.w)
+
+    def get_cls_cnt(self, cls):
+        if self.cls_info is None or cls >= len(self.cls_info):
+            return 0
+        return self.cls_info[cls]
+
+    def get_children(self):
+        for child in self.children:
+            if child is not None:
+                yield child
+
+    def reset_children(self):
+        self.children = [None, None, None, None]
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
+
+
+class QuadTree(object):
+    def __init__(self, min_blk_size=256):
+        super().__init__()
+        self.min_blk_size = min_blk_size
+        self.h = None
+        self.w = None
+        self.root = None
+
+    def build_tree(self, mask_band, bg_cls=0):
+        cls_info_table = self.preprocess(mask_band, bg_cls)
+        n_rows = len(cls_info_table)
+        if n_rows == 0:
+            return None
+        n_cols = len(cls_info_table[0])
+        self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
+                                     n_cols - 1, 0)
+        return self.root
+
+    def preprocess(self, mask_ds, bg_cls):
+        h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
+        s = self.min_blk_size
+        if s >= h or s >= w:
+            raise ValueError("`min_blk_size` must be smaller than image size.")
+        cls_info_table = []
+        for i in range(0, h, s):
+            cls_info_row = []
+            for j in range(0, w, s):
+                if i + s > h:
+                    ch = h - i
+                else:
+                    ch = s
+                if j + s > w:
+                    cw = w - j
+                else:
+                    cw = s
+                arr = mask_ds.ReadAsArray(j, i, cw, ch)
+                bins = np.bincount(arr.ravel())
+                if len(bins) > IGN_CLS:
+                    bins = np.delete(bins, IGN_CLS)
+                if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
+                    cls_info_row.append(None)
+                else:
+                    cls_info_row.append(bins)
+            cls_info_table.append(cls_info_row)
+        return cls_info_table
+
+    def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
+        if i_ed < i_st or j_ed < j_st:
+            return None
+
+        i = i_st * self.min_blk_size
+        j = j_st * self.min_blk_size
+        h = (i_ed - i_st + 1) * self.min_blk_size
+        w = (j_ed - j_st + 1) * self.min_blk_size
+
+        if i_ed == i_st and j_ed == j_st:
+            return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
+
+        i_mid = (i_ed + i_st) // 2
+        j_mid = (j_ed + j_st) // 2
+
+        root = QuadTreeNode(i, j, h, w, level)
+
+        root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
+                                            j_mid, level + 1)
+        root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
+                                            j_mid + 1, j_ed, level + 1)
+        root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
+                                            j_st, j_mid, level + 1)
+        root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
+                                            j_mid + 1, j_ed, level + 1)
+
+        bins_list = [
+            node.cls_info for node in root.get_children()
+            if node.cls_info is not None
+        ]
+        if len(bins_list) > 0:
+            merged_bins = reduce(merge_bins, bins_list)
+            root.cls_info = merged_bins
+        else:
+            # Merge nodes
+            root.reset_children()
+
+        return root
+
+    def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
+        nodes = []
+        q = deque()
+        q.append(self.root)
+        while q:
+            node = q.popleft()
+            if max_level is None or node.level < max_level:
+                for child in node.get_children():
+                    if not include_bg and child.is_bg_node:
+                        continue
+                    if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
+                        continue
+                    nodes.append(child)
+                    q.append(child)
+        return nodes
+
+    def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
+        im = paddlers.transforms.decode_image(im_path)
+        if im.ndim == 2:
+            im = np.stack([im] * 3, axis=2)
+        elif im.ndim == 3:
+            c = im.shape[2]
+            if c < 3:
+                raise ValueError(
+                    "For multi-spectral images, the number of bands should not be less than 3."
+                )
+            else:
+                # Take first three bands as R, G, and B
+                im = im[..., :3]
+        else:
+            raise ValueError("Unrecognized data format.")
+        nodes = self.get_nodes(include_bg=True)
+        vis = np.ascontiguousarray(im)
+        for node in nodes:
+            i, j, h, w = node.coords
+            vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2)
+        cv2.imwrite(save_path, vis)
+        return save_path
+
+    def print_tree(self, node=None, level=0):
+        if node is None:
+            node = self.root
+        print(' ' * level + '-', node)
+        for child in node.get_children():
+            self.print_tree(child, level + 1)
+
+
+def merge_bins(bins1, bins2):
+    if len(bins1) < len(bins2):
+        return merge_bins(bins2, bins1)
+    elif len(bins1) == len(bins2):
+        return bins1 + bins2
+    else:
+        return bins1 + np.concatenate(
+            [bins2, np.zeros(len(bins1) - len(bins2))])
+
+
+@time_it
+def extract_ms_patches(im_paths,
+                       mask_path,
+                       save_dir,
+                       min_patch_size=256,
+                       bg_class=0,
+                       target_class=None,
+                       max_level=None,
+                       include_bg=False,
+                       nonzero_ratio=None,
+                       visualize=False):
+    def _save_patch(src_path, i, j, h, w, subdir=None):
+        src_path = osp.normpath(src_path)
+        src_name, src_ext = osp.splitext(osp.basename(src_path))
+        subdir = subdir if subdir is not None else src_name
+        dst_dir = osp.join(save_dir, subdir)
+        if not osp.exists(dst_dir):
+            os.makedirs(dst_dir)
+        dst_name = FMT.format(idx=idx, ext=src_ext)
+        dst_path = osp.join(dst_dir, dst_name)
+        gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
+        return dst_path
+
+    if nonzero_ratio is not None:
+        print(
+            "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
+        )
+
+    mask_ds = gdal.Open(mask_path)
+    quad_tree = QuadTree(min_blk_size=min_patch_size)
+    if mask_ds.RasterCount != 1:
+        raise ValueError("The mask image has more than 1 band.")
+    print("Start building quad tree...")
+    quad_tree.build_tree(mask_ds, bg_class)
+    if visualize:
+        print("Start drawing rectangles...")
+        save_path = quad_tree.visualize_regions(im_paths[0])
+        print(f"The visualization result is saved in {save_path} .")
+    print("Quad tree has been built. Now start collecting nodes...")
+    nodes = quad_tree.get_nodes(
+        tar_cls=target_class, max_level=max_level, include_bg=include_bg)
+    print("Nodes collected. Saving patches...")
+    for idx, node in enumerate(tqdm(nodes)):
+        i, j, h, w = node.coords
+        real_h = min(h, mask_ds.RasterYSize - i)
+        real_w = min(w, mask_ds.RasterXSize - j)
+        if real_h < h or real_w < w:
+            # Skip incomplete patches
+            continue
+        is_valid = True
+        if nonzero_ratio is not None:
+            for src_path in im_paths:
+                im_ds = gdal.Open(src_path)
+                arr = im_ds.ReadAsArray(j, i, real_w, real_h)
+                if np.count_nonzero(arr) / arr.size < nonzero_ratio:
+                    is_valid = False
+                    break
+        if is_valid:
+            for src_path in im_paths:
+                _save_patch(src_path, i, j, real_h, real_w)
+            _save_patch(mask_path, i, j, real_h, real_w, 'mask')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
+                        help="Path of images. Different images must have unique file names.")
+    parser.add_argument("--mask_path", type=str, required=True, \
+                        help="Path of mask.")
+    parser.add_argument("--save_dir", type=str, default='output', \
+                        help="Path to save the extracted patches.")
+    parser.add_argument("--min_patch_size", type=int, default=256, \
+                        help="Minimum patch size (height and width).")
+    parser.add_argument("--bg_class", type=int, default=0, \
+                        help="Index of the background category.")
+    parser.add_argument("--target_class", type=int, default=None, \
+                        help="Index of the category of interest.")
+    parser.add_argument("--max_level", type=int, default=None, \
+                        help="Maximum level of hierarchical patches.")
+    parser.add_argument("--include_bg", action='store_true', \
+                        help="Include patches that contains only background pixels.")
+    parser.add_argument("--nonzero_ratio", type=float, default=None, \
+                        help="Threshold for filtering out less informative patches.")
+    parser.add_argument("--visualize", action='store_true', \
+                        help="Visualize the quadtree.")
+    args = parser.parse_args()
+
+    extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
+                       args.min_patch_size, args.bg_class, args.target_class,
+                       args.max_level, args.include_bg, args.nonzero_ratio,
+                       args.visualize)

+ 10 - 8
tools/mask2geojson.py

@@ -14,11 +14,14 @@
 
 import os
 import codecs
-import cv2
-import numpy as np
 import argparse
+
+import paddlers
+import numpy as np
+import cv2
 import geojson
 from tqdm import tqdm
+
 from utils import Raster, save_geotiff, translate_vector, time_it
 
 
@@ -60,12 +63,11 @@ def convert_data(image_path, geojson_path):
     os.remove(temp_geojson_path)
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument("--mask_path", type=str, required=True, \
-                    help="Path of mask data.")
-parser.add_argument("--save_path", type=str, required=True, \
-                    help="Path to store the GeoJSON file (the coordinate system is WGS84).")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--mask_path", type=str, required=True, \
+                        help="Path of mask data.")
+    parser.add_argument("--save_path", type=str, required=True, \
+                        help="Path to store the GeoJSON file (the coordinate system is WGS84).")
     args = parser.parse_args()
     convert_data(args.raster_path, args.geojson_path)

+ 10 - 10
tools/mask2shape.py

@@ -16,6 +16,7 @@ import os
 import os.path as osp
 import argparse
 
+import paddlers
 import numpy as np
 from PIL import Image
 try:
@@ -89,17 +90,16 @@ def mask2shape(srcimg_path, mask_path, save_path, ignore_index=255):
                            vec_ext)
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument("--mask_path", type=str, required=True, \
-                    help="Path of mask data.")
-parser.add_argument("--save_path", type=str, required=True, \
-                    help="Path to save the shape file (the extension is .json/geojson or .shp).")
-parser.add_argument("--srcimg_path", type=str, default="", \
-                    help="Path of original data with geoinfo. Default to empty.")
-parser.add_argument("--ignore_index", type=int, default=255, \
-                    help="The ignored index will not be converted to a value in the shape file. Default value is 255.")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--mask_path", type=str, required=True, \
+                        help="Path of mask data.")
+    parser.add_argument("--save_path", type=str, required=True, \
+                        help="Path to save the shape file (the extension is .json/geojson or .shp).")
+    parser.add_argument("--srcimg_path", type=str, default="", \
+                        help="Path of original data with geoinfo. Default to empty.")
+    parser.add_argument("--ignore_index", type=int, default=255, \
+                        help="The ignored index will not be converted to a value in the shape file. Default value is 255.")
     args = parser.parse_args()
     mask2shape(args.srcimg_path, args.mask_path, args.save_path,
                args.ignore_index)

+ 12 - 12
tools/match.py

@@ -14,6 +14,7 @@
 
 import argparse
 
+import paddlers
 import numpy as np
 import cv2
 
@@ -78,19 +79,18 @@ def match(im1_path,
                  im1_ras.datatype)
 
 
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument('--im1_path', type=str, required=True, \
-                    help="Path of time1 image (with geoinfo).")
-parser.add_argument('--im2_path', type=str, required=True, \
-                    help="Path of time2 image.")
-parser.add_argument('--save_path', type=str, required=True, \
-                    help="Path to save matching result.")
-parser.add_argument('--im1_bands', type=int, nargs="+", default=[1, 2, 3], \
-                    help="Bands of im1 to be used for matching, RGB or monochrome. The default value is [1, 2, 3].")
-parser.add_argument('--im2_bands', type=int, nargs="+", default=[1, 2, 3], \
-                    help="Bands of im2 to be used for matching, RGB or monochrome. The default value is [1, 2, 3].")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="input parameters")
+    parser.add_argument('--im1_path', type=str, required=True, \
+                        help="Path of time1 image (with geoinfo).")
+    parser.add_argument('--im2_path', type=str, required=True, \
+                        help="Path of time2 image.")
+    parser.add_argument('--save_path', type=str, required=True, \
+                        help="Path to save matching result.")
+    parser.add_argument('--im1_bands', type=int, nargs="+", default=[1, 2, 3], \
+                        help="Bands of im1 to be used for matching, RGB or monochrome. The default value is [1, 2, 3].")
+    parser.add_argument('--im2_bands', type=int, nargs="+", default=[1, 2, 3], \
+                        help="Bands of im2 to be used for matching, RGB or monochrome. The default value is [1, 2, 3].")
     args = parser.parse_args()
     match(args.im1_path, args.im2_path, args.save_path, args.im1_bands,
           args.im2_bands)

+ 7 - 7
tools/oif.py

@@ -14,10 +14,11 @@
 
 import itertools
 import argparse
-from easydict import EasyDict as edict
 
+import paddlers
 import numpy as np
 import pandas as pd
+from easydict import EasyDict as edict
 
 from utils import Raster, time_it
 
@@ -55,12 +56,11 @@ def oif(img_path, topk=5):
         print("Bands: {0}, OIF value: {1}.".format(k, v))
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument("--im_path", type=str, required=True, \
-                    help="Path of HSIs image.")
-parser.add_argument("--topk", type=int, default=5, \
-                    help="Number of top results. The default value is 5.")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--im_path", type=str, required=True, \
+                        help="Path of HSIs image.")
+    parser.add_argument("--topk", type=int, default=5, \
+                        help="Number of top results. The default value is 5.")
     args = parser.parse_args()
     oif(args.im_path, args.topk)

+ 8 - 8
tools/pca.py

@@ -17,6 +17,7 @@ import os.path as osp
 import numpy as np
 import argparse
 
+import paddlers
 from sklearn.decomposition import PCA
 from joblib import dump
 
@@ -43,14 +44,13 @@ def pca_train(img_path, save_dir="output", dim=3):
         save_dir))
 
 
-parser = argparse.ArgumentParser()
-parser.add_argument("--im_path", type=str, required=True, \
-                    help="Path of HSIs image.")
-parser.add_argument("--save_dir", type=str, default="output", \
-                    help="Directory to save PCA params(*.joblib). Default: output.")
-parser.add_argument("--dim", type=int, default=3, \
-                    help="Dimension to reduce to. Default: 3.")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--im_path", type=str, required=True, \
+                        help="Path of HSIs image.")
+    parser.add_argument("--save_dir", type=str, default="output", \
+                        help="Directory to save PCA params(*.joblib). Default: output.")
+    parser.add_argument("--dim", type=int, default=3, \
+                        help="Dimension to reduce to. Default: 3.")
     args = parser.parse_args()
     pca_train(args.im_path, args.save_dir, args.dim)

+ 11 - 10
tools/split.py

@@ -16,6 +16,8 @@ import os
 import os.path as osp
 import argparse
 from math import ceil
+
+import paddlers
 from tqdm import tqdm
 
 from utils import Raster, save_geotiff, time_it
@@ -67,16 +69,15 @@ def split_data(image_path, mask_path, block_size, save_dir):
                 pbar.update(1)
 
 
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--image_path", type=str, required=True, \
-                    help="Path of input image.")
-parser.add_argument("--mask_path", type=str, default=None, \
-                    help="Path of input labels.")
-parser.add_argument("--block_size", type=int, default=512, \
-                    help="Size of image block. Default value is 512.")
-parser.add_argument("--save_dir", type=str, default="dataset", \
-                    help="Directory to save the results. Default value is 'dataset'.")
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="input parameters")
+    parser.add_argument("--image_path", type=str, required=True, \
+                        help="Path of input image.")
+    parser.add_argument("--mask_path", type=str, default=None, \
+                        help="Path of input labels.")
+    parser.add_argument("--block_size", type=int, default=512, \
+                        help="Size of image block. Default value is 512.")
+    parser.add_argument("--save_dir", type=str, default="dataset", \
+                        help="Directory to save the results. Default value is 'dataset'.")
     args = parser.parse_args()
     split_data(args.image_path, args.mask_path, args.block_size, args.save_dir)