浏览代码

Add tools unittests

Bobholamovic 3 年之前
父节点
当前提交
dc1a407581

+ 16 - 3
tests/testing_utils.py

@@ -16,11 +16,22 @@
 
 import unittest
 import warnings
+import subprocess
 
 import numpy as np
 import paddle
 
-__all__ = ['CommonTest', 'CpuCommonTest']
+__all__ = ['CommonTest', 'CpuCommonTest', 'run_script']
+
+
+def run_script(cmd, silent=True, wd=None):
+    # XXX: This function is not safe!!!
+    cfg = dict(check=True, shell=True)
+    if silent:
+        cfg['stdout'] = subprocess.DEVNULL
+    if wd is not None:
+        cmd = f"cd {wd} && {cmd}"
+    return subprocess.run(cmd, **cfg)
 
 
 # Assume all elements has same data type
@@ -140,7 +151,8 @@ class _CommonTestNamespace:
                                rtol=1.e-5,
                                atol=1.e-8):
             '''
-                Check whether result and expected result are equal, including shape. 
+            Check whether result and expected result are equal, including shape. 
+            
             Args:
                 result: str, int, bool, set, np.ndarray.
                     The result needs to be checked.
@@ -159,7 +171,8 @@ class _CommonTestNamespace:
                                    rtol=1.e-5,
                                    atol=1.e-8):
             '''
-                Check whether result and expected result are not equal, including shape. 
+            Check whether result and expected result are not equal, including shape. 
+
             Args:
                 result: str, int, bool, set, np.ndarray.
                     The result needs to be checked.

+ 0 - 0
tests/tools/test_coco2mask.py


+ 0 - 0
tests/tools/test_geojson2mask.py


+ 25 - 0
tests/tools/test_match.py

@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+from testing_utils import CpuCommonTest, run_script
+
+
+class TestMatch(CpuCommonTest):
+    def test_script(self):
+        with tempfile.TemporaryDirectory() as td:
+            run_script(
+                f"python match.py --im1_path ../tests/data/ssmt/multispectral_t1.tif --im2_path ../tests/data/ssmt/multispectral_t1.tif --save_path {td}/out.tiff",
+                wd='../tools')

+ 24 - 0
tests/tools/test_oif.py

@@ -0,0 +1,24 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+from testing_utils import CpuCommonTest, run_script
+
+
+class TestOIF(CpuCommonTest):
+    def test_script(self):
+        run_script(
+            f"python oif.py --im_path ../tests/data/ssst/multispectral.tif",
+            wd='../tools')

+ 25 - 0
tests/tools/test_pca.py

@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+from testing_utils import CpuCommonTest, run_script
+
+
+class TestPCA(CpuCommonTest):
+    def test_script(self):
+        with tempfile.TemporaryDirectory() as td:
+            run_script(
+                f"python pca.py --im_path ../tests/data/ssst/multispectral.tif --save_dir {td} --dim 5",
+                wd='../tools')

+ 0 - 0
tests/tools/test_raster2vector.py


+ 25 - 0
tests/tools/test_split.py

@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+from testing_utils import CpuCommonTest, run_script
+
+
+class TestSplit(CpuCommonTest):
+    def test_script(self):
+        with tempfile.TemporaryDirectory() as td:
+            run_script(
+                f"python split.py --image_path ../tests/data/ssst/multispectral.tif --mask_path ../tests/data/ssst/multiclass_gt2.png --block_size 128 --save_dir {td}",
+                wd='../tools')

+ 15 - 12
tools/match.py

@@ -62,32 +62,35 @@ def _get_match_img(raster, bands):
 
 
 @time_it
-def match(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
+def match(im1_path,
+          im2_path,
+          save_path,
+          im1_bands=[1, 2, 3],
+          im2_bands=[1, 2, 3]):
     im1_ras = Raster(im1_path)
     im2_ras = Raster(im2_path)
     im1 = _get_match_img(im1_ras._src_data, im1_bands)
     im2 = _get_match_img(im2_ras._src_data, im2_bands)
     H = _calcu_tf(im1, im2)
-    # test
-    # im2_t = cv2.warpPerspective(im2, H, (im1.shape[1], im1.shape[0]))
-    # cv2.imwrite("B_M.png", cv2.cvtColor(im2_t, cv2.COLOR_RGB2BGR))
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
                                     (im1_ras.width, im1_ras.height))
-    save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
     save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot,
                  im1_ras.datatype)
 
 
 parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--im1_path", type=str, required=True, \
+parser.add_argument('--im1_path', type=str, required=True, \
                     help="Path of time1 image (with geoinfo).")
-parser.add_argument("--im2_path", type=str, required=True, \
+parser.add_argument('--im2_path', type=str, required=True, \
                     help="Path of time2 image.")
-parser.add_argument("--im1_bands", type=int, nargs="+", default=[1, 2, 3], \
-                    help="Bands of im1 to be used for matching, RGB or monochrome. `[1, 2, 3]` is the default value.")
-parser.add_argument("--im2_bands", type=int, nargs="+", default=[1, 2, 3], \
-                    help="Bands of im2 to be used for matching, RGB or monochrome. `[1, 2, 3]` is the default value.")
+parser.add_argument('--save_path', type=str, required=True, \
+                    help="Path to save matching result.")
+parser.add_argument('--im1_bands', type=int, nargs="+", default=[1, 2, 3], \
+                    help="Bands of im1 to be used for matching, RGB or monochrome. The default value is [1, 2, 3].")
+parser.add_argument('--im2_bands', type=int, nargs="+", default=[1, 2, 3], \
+                    help="Bands of im2 to be used for matching, RGB or monochrome. The default value is [1, 2, 3].")
 
 if __name__ == "__main__":
     args = parser.parse_args()
-    match(args.im1_path, args.im2_path, args.im1_bands, args.im2_bands)
+    match(args.im1_path, args.im2_path, args.save_path, args.im1_bands,
+          args.im2_bands)

+ 1 - 1
tools/oif.py

@@ -59,7 +59,7 @@ parser = argparse.ArgumentParser()
 parser.add_argument("--im_path", type=str, required=True, \
                     help="Path of HSIs image.")
 parser.add_argument("--topk", type=int, default=5, \
-                    help="Number of top results. `5` is the default value.")
+                    help="Number of top results. The default value is 5.")
 
 if __name__ == "__main__":
     args = parser.parse_args()

+ 2 - 2
tools/pca.py

@@ -47,9 +47,9 @@ parser = argparse.ArgumentParser()
 parser.add_argument("--im_path", type=str, required=True, \
                     help="Path of HSIs image.")
 parser.add_argument("--save_dir", type=str, default="output", \
-                    help="Directory to save PCA params(*.joblib). Default: `output`.")
+                    help="Directory to save PCA params(*.joblib). Default: output.")
 parser.add_argument("--dim", type=int, default=3, \
-                    help="Dimension to reduce to. Default: `3`.")
+                    help="Dimension to reduce to. Default: 3.")
 
 if __name__ == "__main__":
     args = parser.parse_args()

+ 70 - 70
tools/geojson2mask.py → tools/raster2geotiff.py

@@ -1,70 +1,70 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import codecs
-import cv2
-import numpy as np
-import argparse
-import geojson
-from tqdm import tqdm
-from utils import Raster, save_geotiff, translate_vector, time_it
-
-
-def _gt_convert(x_geo, y_geo, geotf):
-    a = np.array([[geotf[1], geotf[2]], [geotf[4], geotf[5]]])
-    b = np.array([x_geo - geotf[0], y_geo - geotf[3]])
-    return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
-
-
-@time_it
-# TODO: update for vector2raster
-def convert_data(image_path, geojson_path):
-    raster = Raster(image_path)
-    tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
-    # vector to EPSG from raster
-    temp_geojson_path = translate_vector(geojson_path, raster.proj)
-    geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
-    feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
-    geo_reader.close()
-    for feat in tqdm(feats):
-        geo = feat["geometry"]
-        if geo["type"] == "Polygon":  # 多边形
-            geo_points = geo["coordinates"][0]
-        elif geo["type"] == "MultiPolygon":  # 多面
-            geo_points = geo["coordinates"][0][0]
-        else:
-            raise TypeError(
-                "Geometry type must be `Polygon` or `MultiPolygon`, not {}.".
-                format(geo["type"]))
-        xy_points = np.array([
-            _gt_convert(point[0], point[1], raster.geot) for point in geo_points
-        ]).astype(np.int32)
-        # TODO: Label category
-        cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
-    ext = "." + geojson_path.split(".")[-1]
-    save_geotiff(tmp_img,
-                 geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
-    os.remove(temp_geojson_path)
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--image_path", type=str, required=True, \
-                    help="Path of original image.")
-parser.add_argument("--geojson_path", type=str, required=True, \
-                    help="Path of the geojson file (the coordinate system should be WGS84).")
-
-if __name__ == "__main__":
-    args = parser.parse_args()
-    convert_data(args.image_path, args.geojson_path)
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import codecs
+import cv2
+import numpy as np
+import argparse
+import geojson
+from tqdm import tqdm
+from utils import Raster, save_geotiff, translate_vector, time_it
+
+
+def _gt_convert(x_geo, y_geo, geotf):
+    a = np.array([[geotf[1], geotf[2]], [geotf[4], geotf[5]]])
+    b = np.array([x_geo - geotf[0], y_geo - geotf[3]])
+    return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
+
+
+@time_it
+# TODO: update for vector2raster
+def convert_data(image_path, geojson_path):
+    raster = Raster(image_path)
+    tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
+    # vector to EPSG from raster
+    temp_geojson_path = translate_vector(geojson_path, raster.proj)
+    geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
+    feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
+    geo_reader.close()
+    for feat in tqdm(feats):
+        geo = feat["geometry"]
+        if geo["type"] == "Polygon":  # 多边形
+            geo_points = geo["coordinates"][0]
+        elif geo["type"] == "MultiPolygon":  # 多面
+            geo_points = geo["coordinates"][0][0]
+        else:
+            raise TypeError(
+                "Geometry type must be `Polygon` or `MultiPolygon`, not {}.".
+                format(geo["type"]))
+        xy_points = np.array([
+            _gt_convert(point[0], point[1], raster.geot) for point in geo_points
+        ]).astype(np.int32)
+        # TODO: Label category
+        cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
+    ext = "." + geojson_path.split(".")[-1]
+    save_geotiff(tmp_img,
+                 geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
+    os.remove(temp_geojson_path)
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--raster_path", type=str, required=True, \
+                    help="Path of original raster image.")
+parser.add_argument("--geotiff_path", type=str, required=True, \
+                    help="Path to store the geotiff file (the coordinate system is WGS84).")
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    convert_data(args.raster_path, args.geotiff_path)

+ 2 - 2
tools/raster2vector.py

@@ -93,11 +93,11 @@ parser = argparse.ArgumentParser()
 parser.add_argument("--mask_path", type=str, required=True, \
                     help="Path of mask data.")
 parser.add_argument("--save_path", type=str, required=True, \
-                    help="Path to save the shape file (the file suffix is `*.json/geojson` or `*.shp`).")
+                    help="Path to save the shape file (the file suffix is `.json/geojson` or `.shp`).")
 parser.add_argument("--srcimg_path", type=str, default="", \
                     help="Path of original data with geoinfo. Default to empty.")
 parser.add_argument("--ignore_index", type=int, default=255, \
-                    help="The ignored index will not be converted to a value in the shape file. Default value is `255`.")
+                    help="The ignored index will not be converted to a value in the shape file. Default value is 255.")
 
 if __name__ == "__main__":
     args = parser.parse_args()

+ 14 - 15
tools/split.py

@@ -28,12 +28,12 @@ def _calc_window_tf(geot, loc):
 
 
 @time_it
-def split_data(image_path, mask_path, block_size, save_folder):
-    if not osp.exists(save_folder):
-        os.makedirs(save_folder)
-        os.makedirs(osp.join(save_folder, "images"))
-        if mask_path is not None:
-            os.makedirs(osp.join(save_folder, "masks"))
+def split_data(image_path, mask_path, block_size, save_dir):
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
+    os.makedirs(osp.join(save_dir, "images"))
+    if mask_path is not None:
+        os.makedirs(osp.join(save_dir, "masks"))
     image_name, image_ext = image_path.replace("\\",
                                                "/").split("/")[-1].split(".")
     image = Raster(image_path)
@@ -51,7 +51,7 @@ def split_data(image_path, mask_path, block_size, save_folder):
                 loc_start = (c * block_size, r * block_size)
                 image_title = image.getArray(loc_start,
                                              (block_size, block_size))
-                image_save_path = osp.join(save_folder, "images", (
+                image_save_path = osp.join(save_dir, "images", (
                     image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
                 window_geotf = _calc_window_tf(image.geot, loc_start)
                 save_geotiff(image_title, image_save_path, image.proj,
@@ -59,7 +59,7 @@ def split_data(image_path, mask_path, block_size, save_folder):
                 if mask is not None:
                     mask_title = mask.getArray(loc_start,
                                                (block_size, block_size))
-                    mask_save_path = osp.join(save_folder, "masks",
+                    mask_save_path = osp.join(save_dir, "masks",
                                               (image_name + "_" + str(r) + "_" +
                                                str(c) + "." + image_ext))
                     save_geotiff(mask_title, mask_save_path, image.proj,
@@ -69,15 +69,14 @@ def split_data(image_path, mask_path, block_size, save_folder):
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--image_path", type=str, required=True, \
-                    help="The path of big image data.")
+                    help="Path of input image.")
 parser.add_argument("--mask_path", type=str, default=None, \
-                    help="The path of big image label data.")
+                    help="Path of input labels.")
 parser.add_argument("--block_size", type=int, default=512, \
-                    help="The size of image block, `512` is the default.")
-parser.add_argument("--save_folder", type=str, default="dataset", \
-                    help="The folder path to save the results, `dataset` is the default.")
+                    help="Size of image block. Default value is 512.")
+parser.add_argument("--save_dir", type=str, default="dataset", \
+                    help="Directory to save the results. Default value is `dataset`.")
 
 if __name__ == "__main__":
     args = parser.parse_args()
-    split_data(args.image_path, args.mask_path, args.block_size,
-               args.save_folder)
+    split_data(args.image_path, args.mask_path, args.block_size, args.save_dir)

+ 9 - 9
tools/utils/raster.py

@@ -53,15 +53,15 @@ class Raster:
                  band_list: Union[List[int], Tuple[int], None]=None,
                  to_uint8: bool=False) -> None:
         """
-        Class of read raster.
+        Class of raster reader.
         
         Args:
-            path (Optional[str]): The path of raster.
-            gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
+            path (Optional[str]): Path of raster file.
+            gdal_obj (Optional[Any], optional): GDAL dataset. Defaults to None.
             band_list (Union[List[int], Tuple[int], None], optional): 
-                band list (start with 1) or None (all of bands). Defaults to None.
+                Select a set of bands (the band index starts from 1) or None (read all bands). Defaults to None.
             to_uint8 (bool, optional): 
-                Convert uint8 or return raw data. Defaults to False.
+                Whether to convert data type to uint8. Defaults to False.
         """
         super(Raster, self).__init__()
         if path is not None:
@@ -93,11 +93,11 @@ class Raster:
 
     def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
         """ 
-        Set band of data.
+        Set bands of data.
         
         Args:
             band_list (Union[List[int], Tuple[int], None]): 
-                band list (start with 1) or None (all of bands).
+                Select a set of bands (the band index starts from 1) or None (read all bands). Defaults to None.
         """
         if band_list is not None:
             if len(band_list) > self.bands:
@@ -114,11 +114,11 @@ class Raster:
                  block_size: Union[List[int], Tuple[int, int]]=[512, 512]
                  ) -> np.ndarray:
         """ 
-        Get ndarray data 
+        Fetch data in a ndarray.
         
         Args:
             start_loc (Union[List[int], Tuple[int], None], optional): 
-                Coordinates of the upper left corner of the block, if None means return full image.
+                Coordinates of the upper left corner of the block. None value means returning full image.
             block_size (Union[List[int], Tuple[int]], optional): 
                 Block size. Defaults to [512, 512].