Browse Source

fix(tools): fix and update some (#110)

* refactor(tools): update savegeotiff and timer

* fix(geojson2mask): add convert tiff to EPSG:4326

* fix(geojson2mask): fix use tiff's EPSG

* fix(mask2geojson): update hole and fix srs

* fix(tools): update merge raster2vector

* feat(tools): update spliter to support HSI

* fix(raster2vector): fix save geojson without ignore index

* fix(tools): name fixed
Yizhou Chen 2 years ago
parent
commit
ec9d58bb0a

+ 3 - 0
.gitignore

@@ -126,6 +126,9 @@ venv.bak/
 .dmypy.json
 dmypy.json
 
+# myvscode
+.vscode
+
 # Pyre type checker
 .pyre/
 

+ 2 - 2
tools/coco2mask.py

@@ -25,7 +25,7 @@ import glob
 from tqdm import tqdm
 from PIL import Image
 
-from utils import Timer
+from utils import timer
 
 
 def _mkdir_p(path):
@@ -69,7 +69,7 @@ def _read_geojson(json_path):
         return annotations, sizes
 
 
-@Timer
+@timer
 def convert_data(raw_folder, end_folder):
     print("-- Initializing --")
     img_folder = osp.join(raw_folder, "images")

+ 12 - 5
tools/geojson2mask.py

@@ -12,13 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
 import codecs
 import cv2
 import numpy as np
 import argparse
 import geojson
 from tqdm import tqdm
-from utils import Raster, save_geotiff, Timer
+from utils import Raster, save_geotiff, vector_translate, timer
 
 
 def _gt_convert(x_geo, y_geo, geotf):
@@ -27,12 +28,16 @@ def _gt_convert(x_geo, y_geo, geotf):
     return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
 
 
-@Timer
+@timer
+# TODO: update for vector2raster
 def convert_data(image_path, geojson_path):
     raster = Raster(image_path)
     tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
-    geo_reader = codecs.open(geojson_path, "r", encoding="utf-8")
+    # vector to EPSG from raster
+    temp_geojson_path = vector_translate(geojson_path, raster.proj)
+    geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
     feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
+    geo_reader.close()
     for feat in tqdm(feats):
         geo = feat["geometry"]
         if geo["type"] == "Polygon":  # 多边形
@@ -40,7 +45,8 @@ def convert_data(image_path, geojson_path):
         elif geo["type"] == "MultiPolygon":  # 多面
             geo_points = geo["coordinates"][0][0]
         else:
-            raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(geo["type"]))
+            raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(
+                geo["type"]))
         xy_points = np.array([
             _gt_convert(point[0], point[1], raster.geot)
             for point in geo_points
@@ -49,13 +55,14 @@ def convert_data(image_path, geojson_path):
         cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
     ext = "." + geojson_path.split(".")[-1]
     save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
+    os.remove(temp_geojson_path)
 
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--image_path", type=str, required=True, \
                     help="The path of original image.")
 parser.add_argument("--geojson_path", type=str, required=True, \
-                    help="The path of geojson.")
+                    help="The path of geojson. (coordinate of geojson is WGS84)")
 
 
 if __name__ == "__main__":

+ 0 - 81
tools/mask2geojson.py

@@ -1,81 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import codecs
-import argparse
-
-import cv2
-import numpy as np
-import geojson
-from geojson import Polygon, Feature, FeatureCollection
-
-from utils import Raster, Timer
-
-
-def _gt_convert(x, y, geotf):
-    x_geo = geotf[0] + x * geotf[1] + y * geotf[2]
-    y_geo = geotf[3] + x * geotf[4] + y * geotf[5]
-    return x_geo, y_geo
-
-
-@Timer
-def convert_data(mask_path, save_path, epsilon=0):
-    raster = Raster(mask_path)
-    img = raster.getArray()
-    ext = save_path.split(".")[-1]
-    if ext != "json" and ext != "geojson":
-        raise ValueError("The ext of `save_path` must be `json` or `geojson`, not {}.".format(ext))
-    geo_writer = codecs.open(save_path, "w", encoding="utf-8")
-    clas = np.unique(img)
-    cv2_v = (cv2.__version__.split(".")[0] == "3")
-    feats = []
-    if not isinstance(epsilon, (int, float)):
-        epsilon = 0
-    for iclas in range(1, len(clas)):
-        tmp = np.zeros_like(img).astype("uint8")
-        tmp[img == iclas] = 1
-        # TODO: Detect internal and external contour
-        results = cv2.findContours(tmp, cv2.RETR_EXTERNAL,
-                                   cv2.CHAIN_APPROX_TC89_KCOS)
-        contours = results[1] if cv2_v else results[0]
-        # hierarchys = results[2] if cv2_v else results[1]
-        if len(contours) == 0:
-            continue
-        for contour in contours:
-            contour = cv2.approxPolyDP(contour, epsilon, True)
-            polys = []
-            for point in contour:
-                x, y = point[0]
-                xg, yg = _gt_convert(x, y, raster.geot)
-                polys.append((xg, yg))
-            polys.append(polys[0])
-            feat = Feature(
-                geometry=Polygon([polys]), properties={"class": int(iclas)})
-            feats.append(feat)
-    gjs = FeatureCollection(feats)
-    geo_writer.write(geojson.dumps(gjs))
-    geo_writer.close()
-
-
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--mask_path", type=str, required=True, \
-                    help="The path of mask tif.")
-parser.add_argument("--save_path", type=str, required=True, \
-                    help="The path to save the results, file suffix is `*.json/geojson`.")
-parser.add_argument("--epsilon", type=float, default=0, \
-                    help="The CV2 simplified parameters, `0` is the default.")
-
-if __name__ == "__main__":
-    args = parser.parse_args()
-    convert_data(args.mask_path, args.save_path, args.epsilon)

+ 0 - 93
tools/mask2shp.py

@@ -1,93 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import os.path as osp
-import argparse
-
-import numpy as np
-from PIL import Image
-try:
-    from osgeo import gdal, ogr, osr
-except ImportError:
-    import gdal
-    import ogr
-    import osr
-
-from utils import Raster, Timer
-
-
-def _mask2tif(mask_path, tmp_path, proj, geot):
-    mask = np.asarray(Image.open(mask_path))
-    if len(mask.shape) == 3:
-        mask = mask[:, :, 0]
-    row, columns = mask.shape[:2]
-    driver = gdal.GetDriverByName("GTiff")
-    dst_ds = driver.Create(tmp_path, columns, row, 1, gdal.GDT_UInt16)
-    dst_ds.SetGeoTransform(geot)
-    dst_ds.SetProjection(proj)
-    dst_ds.GetRasterBand(1).WriteArray(mask)
-    dst_ds.FlushCache()
-    return dst_ds
-
-
-def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
-    tmp_path = shp_save_path.replace(".shp", ".tif")
-    ds = _mask2tif(mask_path, tmp_path, proj, geot)
-    srcband = ds.GetRasterBand(1)
-    maskband = srcband.GetMaskBand()
-    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
-    gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
-    ogr.RegisterAll()
-    drv = ogr.GetDriverByName("ESRI Shapefile")
-    if osp.exists(shp_save_path):
-        os.remove(shp_save_path)
-    dst_ds = drv.CreateDataSource(shp_save_path)
-    prosrs = osr.SpatialReference(wkt=ds.GetProjection())
-    dst_layer = dst_ds.CreateLayer(
-        "Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
-    dst_fieldname = "DN"
-    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
-    dst_layer.CreateField(fd)
-    gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
-    lyr = dst_ds.GetLayer()
-    lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
-    for holes in lyr:
-        lyr.DeleteFeature(holes.GetFID())
-    dst_ds.Destroy()
-    ds = None
-    os.remove(tmp_path)
-
-
-@Timer
-def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255):
-    src = Raster(srcimg_path)
-    _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index)
-    src = None
-
-
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--srcimg_path", type=str, required=True, \
-                    help="The path of original data with geoinfos.")
-parser.add_argument("--mask_path", type=str, required=True, \
-                    help="The path of mask data.")
-parser.add_argument("--save_path", type=str, default="output", \
-                    help="The path to save the results shapefile, `output` is the default.")
-parser.add_argument("--ignore_index", type=int, default=255, \
-                    help="It will not be converted to the value of SHP, `255` is the default.")
-
-if __name__ == "__main__":
-    args = parser.parse_args()
-    raster2shp(args.srcimg_path, args.mask_path, args.save_path,
-               args.ignore_index)

+ 3 - 26
tools/matcher.py

@@ -16,12 +16,8 @@ import argparse
 
 import numpy as np
 import cv2
-try:
-    from osgeo import gdal
-except ImportError:
-    import gdal
 
-from utils import Raster, raster2uint8, Timer
+from utils import Raster, raster2uint8, save_geotiff, timer
 
 class MatchError(Exception):
     def __str__(self):
@@ -64,26 +60,7 @@ def _get_match_img(raster, bands):
     return ima
 
 
-def _img2tif(ima, save_path, proj, geot, dtype):
-    if len(ima.shape) == 3:
-        row, columns, bands = ima.shape
-    else:
-        row, columns = ima.shape
-        bands = 1
-    driver = gdal.GetDriverByName("GTiff")
-    dst_ds = driver.Create(save_path, columns, row, bands, dtype)
-    dst_ds.SetGeoTransform(geot)
-    dst_ds.SetProjection(proj)
-    if bands != 1:
-        for b in range(bands):
-            dst_ds.GetRasterBand(b + 1).WriteArray(ima[:, :, b])
-    else:
-        dst_ds.GetRasterBand(1).WriteArray(ima)
-    dst_ds.FlushCache()
-    return dst_ds
-
-
-@Timer
+@timer
 def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im1_ras = Raster(im1_path)
     im2_ras = Raster(im2_path)
@@ -96,7 +73,7 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
                                     (im1_ras.width, im1_ras.height))
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
-    _img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
+    save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
 
 
 parser = argparse.ArgumentParser(description="input parameters")

+ 2 - 2
tools/oif.py

@@ -19,7 +19,7 @@ from easydict import EasyDict as edict
 import numpy as np
 import pandas as pd
 
-from utils import Raster, Timer
+from utils import Raster, timer
 
 def _calcOIF(rgb, stds, rho):
     r, g, b = rgb
@@ -32,7 +32,7 @@ def _calcOIF(rgb, stds, rho):
     return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
 
 
-@Timer
+@timer
 def oif(img_path, topk=5):
     raster = Raster(img_path)
     img = raster.getArray()

+ 2 - 2
tools/pca.py

@@ -18,10 +18,10 @@ import numpy as np
 import argparse
 from sklearn.decomposition import PCA
 from joblib import dump
-from utils import Raster, Timer, save_geotiff
+from utils import Raster, save_geotiff, timer
 
 
-@Timer
+@timer
 def pca_train(img_path, save_dir="output", dim=3):
     raster = Raster(img_path)
     im = raster.getArray()

+ 102 - 0
tools/raster2vector.py

@@ -0,0 +1,102 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import argparse
+
+import numpy as np
+from PIL import Image
+try:
+    from osgeo import gdal, ogr, osr
+except ImportError:
+    import gdal
+    import ogr
+    import osr
+
+from utils import Raster, save_geotiff, timer
+
+
+def _mask2tif(mask_path, tmp_path, proj, geot):
+    dst_ds = save_geotiff(
+        np.asarray(Image.open(mask_path)),
+        tmp_path, proj, geot,  gdal.GDT_UInt16, False)
+    return dst_ds
+
+
+def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext):
+    if proj is None or geot is None:
+        tmp_path = None
+        ds = gdal.Open(mask_path)
+    else:
+        tmp_path = vec_save_path.replace("." + ext, ".tif")
+        ds = _mask2tif(mask_path, tmp_path, proj, geot)
+    srcband = ds.GetRasterBand(1)
+    maskband = srcband.GetMaskBand()
+    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
+    gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
+    ogr.RegisterAll()
+    drv = ogr.GetDriverByName(
+        "ESRI Shapefile" if ext == "shp" else "GeoJSON"
+    )
+    if osp.exists(vec_save_path):
+        os.remove(vec_save_path)
+    dst_ds = drv.CreateDataSource(vec_save_path)
+    prosrs = osr.SpatialReference(wkt=ds.GetProjection())
+    dst_layer = dst_ds.CreateLayer(
+        "POLYGON", geom_type=ogr.wkbPolygon, srs=prosrs)
+    dst_fieldname = "CLAS"
+    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
+    dst_layer.CreateField(fd)
+    gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
+    # TODO: temporary: delete ignored values
+    dst_ds.Destroy()
+    ds = None
+    vec_ds = drv.Open(vec_save_path, 1)
+    lyr = vec_ds.GetLayer()
+    lyr.SetAttributeFilter("{} = '{}'".format(dst_fieldname, str(ignore_index)))
+    for holes in lyr:
+        lyr.DeleteFeature(holes.GetFID())
+    vec_ds.Destroy()
+    if tmp_path is not None:
+        os.remove(tmp_path)
+
+
+@timer
+def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
+    vec_ext = save_path.split(".")[-1].lower()
+    if vec_ext not in ["json", "geojson", "shp"]:
+        raise ValueError("The ext of `save_path` must be `json/geojson` or `shp`, not {}.".format(vec_ext))
+    ras_ext = srcimg_path.split(".")[-1].lower()
+    if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
+        src = Raster(srcimg_path)
+        _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index, vec_ext)
+        src = None
+    else:
+        _polygonize_raster(mask_path, save_path, None, None, ignore_index, vec_ext)
+
+
+parser = argparse.ArgumentParser(description="input parameters")
+parser.add_argument("--mask_path", type=str, required=True, \
+                    help="The path of mask data.")
+parser.add_argument("--save_path", type=str, required=True, \
+                    help="The path to save the results, file suffix is `*.json/geojson` or `*.shp`.")
+parser.add_argument("--srcimg_path", type=str, default="", \
+                    help="The path of original data with geoinfos, `` is the default.")
+parser.add_argument("--ignore_index", type=int, default=255, \
+                    help="It will not be converted to the value of SHP, `255` is the default.")
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    raster2vector(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)

+ 31 - 21
tools/spliter.py

@@ -16,42 +16,52 @@ import os
 import os.path as osp
 import argparse
 from math import ceil
+from tqdm import tqdm
 
-from PIL import Image
+from utils import Raster, save_geotiff, timer
 
-from utils import Raster, Timer
 
+def _calc_window_tf(geot, loc):
+    x, hr, r1, y, r2, vr = geot
+    nx, ny = loc
+    return (x + nx * hr, hr, r1, y + ny * vr, r2, vr)
 
-@Timer
+
+@timer
 def split_data(image_path, mask_path, block_size, save_folder):
     if not osp.exists(save_folder):
         os.makedirs(save_folder)
         os.makedirs(osp.join(save_folder, "images"))
         if mask_path is not None:
             os.makedirs(osp.join(save_folder, "masks"))
-    image_name = image_path.replace("\\", "/").split("/")[-1].split(".")[0]
-    image = Raster(image_path, to_uint8=True)
+    image_name, image_ext = image_path.replace("\\", "/").split("/")[-1].split(".")
+    image = Raster(image_path)
     mask = Raster(mask_path) if mask_path is not None else None
-    if image.width != mask.width or image.height != mask.height:
+    if mask is not None and (image.width != mask.width or image.height != mask.height):
         raise ValueError("image's shape must equal mask's shape.")
     rows = ceil(image.height / block_size)
     cols = ceil(image.width / block_size)
     total_number = int(rows * cols)
-    for r in range(rows):
-        for c in range(cols):
-            loc_start = (c * block_size, r * block_size)
-            image_title = Image.fromarray(image.getArray(
-                loc_start, (block_size, block_size))).convert("RGB")
-            image_save_path = osp.join(save_folder, "images", (
-                image_name + "_" + str(r) + "_" + str(c) + ".jpg"))
-            image_title.save(image_save_path, "JPEG")
-            if mask is not None:
-                mask_title = Image.fromarray(mask.getArray(
-                    loc_start, (block_size, block_size))).convert("L")
-                mask_save_path = osp.join(save_folder, "masks", (
-                    image_name + "_" + str(r) + "_" + str(c) + ".png"))
-                mask_title.save(mask_save_path, "PNG")
-            print("-- {:d}/{:d} --".format(int(r * cols + c + 1), total_number))
+
+    with tqdm(total=total_number) as pbar:
+        for r in range(rows):
+            for c in range(cols):
+                loc_start = (c * block_size, r * block_size)
+                image_title = image.getArray(loc_start, (block_size, block_size))
+                image_save_path = osp.join(save_folder, "images", (
+                    image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
+                window_geotf = _calc_window_tf(image.geot, loc_start)
+                save_geotiff(
+                    image_title, image_save_path, image.proj, window_geotf
+                )
+                if mask is not None:
+                    mask_title = mask.getArray(loc_start, (block_size, block_size))
+                    mask_save_path = osp.join(save_folder, "masks", (
+                        image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
+                    save_geotiff(
+                        mask_title, mask_save_path, image.proj, window_geotf
+                    )
+                pbar.update(1)
 
 
 parser = argparse.ArgumentParser(description="input parameters")

+ 2 - 1
tools/utils/__init__.py

@@ -17,4 +17,5 @@ import os.path as osp
 sys.path.insert(0, osp.abspath(".."))  # add workspace
 
 from .raster import Raster, raster2uint8, save_geotiff
-from .timer import Timer
+from .vector import vector_translate
+from .timer import timer

+ 57 - 31
tools/utils/raster.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import os.path as osp
-from typing import List, Tuple, Union
+from typing import List, Tuple, Union, Optional
 
 import numpy as np
 
@@ -49,36 +49,45 @@ def _get_type(type_name: str) -> int:
 
 class Raster:
     def __init__(self,
-                 path: str,
+                 path: Optional[str],
+                 gdal_obj: Optional[gdal.Dataset]=None,
                  band_list: Union[List[int], Tuple[int], None]=None,
                  to_uint8: bool=False) -> None:
         """ Class of read raster.
         Args:
-            path (str): The path of raster.
+            path (Optional[str]): The path of raster.
+            gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
             band_list (Union[List[int], Tuple[int], None], optional): 
                 band list (start with 1) or None (all of bands). Defaults to None.
             to_uint8 (bool, optional): 
                 Convert uint8 or return raw data. Defaults to False.
         """
         super(Raster, self).__init__()
-        if osp.exists(path):
-            self.path = path
-            self.ext_type = path.split(".")[-1]
-            if self.ext_type.lower() in ["npy", "npz"]:
-                self._src_data = None
+        if path is not None:
+            if osp.exists(path):
+                self.path = path
+                self.ext_type = path.split(".")[-1]
+                if self.ext_type.lower() in ["npy", "npz"]:
+                    self._src_data = None
+                else:
+                    try:
+                        # raster format support in GDAL: 
+                        # https://www.osgeo.cn/gdal/drivers/raster/index.html
+                        self._src_data = gdal.Open(path)
+                    except:
+                        raise TypeError(
+                            "Unsupported data format: `{}`".format(self.ext_type))
             else:
-                try:
-                    # raster format support in GDAL: 
-                    # https://www.osgeo.cn/gdal/drivers/raster/index.html
-                    self._src_data = gdal.Open(path)
-                except:
-                    raise TypeError("Unsupported data format: `{}`".format(
-                        self.ext_type))
-            self.to_uint8 = to_uint8
-            self.setBands(band_list)
-            self._getInfo()
+                raise ValueError("The path {0} not exists.".format(path))
         else:
-            raise ValueError("The path {0} not exists.".format(path))
+            if gdal_obj is not None:
+                self._src_data = gdal_obj
+            else:
+                raise ValueError("At least one of `path` and `gdal_obj` is not None.")
+        self.to_uint8 = to_uint8
+        self._getInfo()
+        self.setBands(band_list)
+        self._getType()
 
     def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
         """ Set band of data.
@@ -86,7 +95,6 @@ class Raster:
             band_list (Union[List[int], Tuple[int], None]): 
                 band list (start with 1) or None (all of bands).
         """
-        self.bands = self._src_data.RasterCount
         if band_list is not None:
             if len(band_list) > self.bands:
                 raise ValueError(
@@ -99,8 +107,8 @@ class Raster:
 
     def getArray(
             self,
-            start_loc: Union[List[int], Tuple[int], None]=None,
-            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+            start_loc: Union[List[int], Tuple[int, int], None]=None,
+            block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
         """ Get ndarray data 
         Args:
             start_loc (Union[List[int], Tuple[int], None], optional): 
@@ -123,13 +131,12 @@ class Raster:
         if self._src_data is not None:
             self.width = self._src_data.RasterXSize
             self.height = self._src_data.RasterYSize
+            self.bands = self._src_data.RasterCount
             self.geot = self._src_data.GetGeoTransform()
             self.proj = self._src_data.GetProjection()
-            d_name = self._getBlock([0, 0], [1, 1]).dtype.name
         else:
             d_img = self._getNumpy()
             d_shape = d_img.shape
-            d_name = d_img.dtype.name
             if len(d_shape) == 3:
                 self.height, self.width, self.bands = d_shape
             else:
@@ -137,6 +144,9 @@ class Raster:
                 self.bands = 1
             self.geot = None
             self.proj = None
+        
+    def _getType(self) -> None:
+        d_name = self.getArray([0, 0], [1, 1]).dtype.name
         self.datatype = _get_type(d_name)
 
     def _getNumpy(self):
@@ -151,7 +161,9 @@ class Raster:
 
     def _getArray(
             self,
-            window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
+            window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray:
+        if self._src_data is None:
+            raise ValueError("The raster is None.")
         if window is not None:
             xoff, yoff, xsize, ysize = window
         if self.band_list is None:
@@ -183,8 +195,8 @@ class Raster:
 
     def _getBlock(
             self,
-            start_loc: Union[List[int], Tuple[int]],
-            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+            start_loc: Union[List[int], Tuple[int, int]],
+            block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
         if len(start_loc) != 2 or len(block_size) != 2:
             raise ValueError("The length start_loc/block_size must be 2.")
         xoff, yoff = start_loc
@@ -208,9 +220,21 @@ class Raster:
         return tmp
 
 
-def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None:
-    height, width, channel = image.shape
-    data_type = _get_type(image.dtype.name)
+def save_geotiff(image: np.ndarray, 
+                 save_path: str, 
+                 proj: str, 
+                 geotf: Tuple,
+                 use_type: Optional[int]=None,
+                 clear_ds: bool=True) -> None:
+    if len(image.shape) == 2:
+        height, width = image.shape
+        channel = 1
+    else:
+        height, width, channel = image.shape
+    if use_type is not None:
+        data_type = use_type
+    else:
+        data_type = _get_type(image.dtype.name)
     driver = gdal.GetDriverByName("GTiff")
     dst_ds = driver.Create(save_path, width, height, channel, data_type)
     dst_ds.SetGeoTransform(geotf)
@@ -224,4 +248,6 @@ def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) ->
         band = dst_ds.GetRasterBand(1)
         band.WriteArray(image)
         dst_ds.FlushCache()
-    dst_ds = None
+    if clear_ds:
+        dst_ds = None
+    return dst_ds

+ 7 - 7
tools/utils/timer.py

@@ -13,14 +13,14 @@
 # limitations under the License.
 
 import time
+from functools import wraps
 
 
-class Timer(object):
-    def __init__(self, func):
-        self.func = func
-
-    def __call__(self, *args, **kwds):
+def timer(func):
+    @wraps(func)
+    def wrapper(*args,**kwargs):
         start_time = time.time()
-        func_t = self.func(*args, **kwds)
+        result = func(*args,**kwargs)
         print("Total time: {0}.".format(time.time() - start_time))
-        return func_t
+        return result
+    return wrapper

+ 53 - 0
tools/utils/vector.py

@@ -0,0 +1,53 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# reference: https://zhuanlan.zhihu.com/p/378918221
+
+try:
+    from osgeo import gdal, ogr, osr
+except:
+    import gdal
+    import ogr
+    import osr
+
+
+def vector_translate(geojson_path: str,
+                     wo_wkt: str,
+                     g_type: str="POLYGON",
+                     dim: str="XY") -> str:
+	ogr.RegisterAll()
+	gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
+	data = ogr.Open(geojson_path)
+	layer = data.GetLayer()
+	spatial = layer.GetSpatialRef()
+	layerName = layer.GetName()
+	data.Destroy()
+	dstSRS = osr.SpatialReference()
+	dstSRS.ImportFromWkt(wo_wkt)
+	ext = "." + geojson_path.split(".")[-1]
+	save_path = geojson_path.replace(ext, ("_tmp" + ext))
+	options = gdal.VectorTranslateOptions(
+		srcSRS=spatial,
+		dstSRS=dstSRS,
+		reproject=True,
+		layerName=layerName,
+		geometryType=g_type,
+		dim=dim
+	)
+	gdal.VectorTranslate(
+		save_path,
+		srcDS=geojson_path,
+		options=options
+	)
+	return save_path