Эх сурвалжийг харах

fix(tools): fix and update some (#110)

* refactor(tools): update savegeotiff and timer

* fix(geojson2mask): add convert tiff to EPSG:4326

* fix(geojson2mask): fix use tiff's EPSG

* fix(mask2geojson): update hole and fix srs

* fix(tools): update merge raster2vector

* feat(tools): update spliter to support HSI

* fix(raster2vector): fix save geojson without ignore index

* fix(tools): name fixed
Yizhou Chen 3 жил өмнө
parent
commit
ec9d58bb0a

+ 3 - 0
.gitignore

@@ -126,6 +126,9 @@ venv.bak/
 .dmypy.json
 .dmypy.json
 dmypy.json
 dmypy.json
 
 
+# myvscode
+.vscode
+
 # Pyre type checker
 # Pyre type checker
 .pyre/
 .pyre/
 
 

+ 2 - 2
tools/coco2mask.py

@@ -25,7 +25,7 @@ import glob
 from tqdm import tqdm
 from tqdm import tqdm
 from PIL import Image
 from PIL import Image
 
 
-from utils import Timer
+from utils import timer
 
 
 
 
 def _mkdir_p(path):
 def _mkdir_p(path):
@@ -69,7 +69,7 @@ def _read_geojson(json_path):
         return annotations, sizes
         return annotations, sizes
 
 
 
 
-@Timer
+@timer
 def convert_data(raw_folder, end_folder):
 def convert_data(raw_folder, end_folder):
     print("-- Initializing --")
     print("-- Initializing --")
     img_folder = osp.join(raw_folder, "images")
     img_folder = osp.join(raw_folder, "images")

+ 12 - 5
tools/geojson2mask.py

@@ -12,13 +12,14 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+import os
 import codecs
 import codecs
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import argparse
 import argparse
 import geojson
 import geojson
 from tqdm import tqdm
 from tqdm import tqdm
-from utils import Raster, save_geotiff, Timer
+from utils import Raster, save_geotiff, vector_translate, timer
 
 
 
 
 def _gt_convert(x_geo, y_geo, geotf):
 def _gt_convert(x_geo, y_geo, geotf):
@@ -27,12 +28,16 @@ def _gt_convert(x_geo, y_geo, geotf):
     return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
     return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
 
 
 
 
-@Timer
+@timer
+# TODO: update for vector2raster
 def convert_data(image_path, geojson_path):
 def convert_data(image_path, geojson_path):
     raster = Raster(image_path)
     raster = Raster(image_path)
     tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
     tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
-    geo_reader = codecs.open(geojson_path, "r", encoding="utf-8")
+    # vector to EPSG from raster
+    temp_geojson_path = vector_translate(geojson_path, raster.proj)
+    geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
     feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
     feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
+    geo_reader.close()
     for feat in tqdm(feats):
     for feat in tqdm(feats):
         geo = feat["geometry"]
         geo = feat["geometry"]
         if geo["type"] == "Polygon":  # 多边形
         if geo["type"] == "Polygon":  # 多边形
@@ -40,7 +45,8 @@ def convert_data(image_path, geojson_path):
         elif geo["type"] == "MultiPolygon":  # 多面
         elif geo["type"] == "MultiPolygon":  # 多面
             geo_points = geo["coordinates"][0][0]
             geo_points = geo["coordinates"][0][0]
         else:
         else:
-            raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(geo["type"]))
+            raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(
+                geo["type"]))
         xy_points = np.array([
         xy_points = np.array([
             _gt_convert(point[0], point[1], raster.geot)
             _gt_convert(point[0], point[1], raster.geot)
             for point in geo_points
             for point in geo_points
@@ -49,13 +55,14 @@ def convert_data(image_path, geojson_path):
         cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
         cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
     ext = "." + geojson_path.split(".")[-1]
     ext = "." + geojson_path.split(".")[-1]
     save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
     save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
+    os.remove(temp_geojson_path)
 
 
 
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--image_path", type=str, required=True, \
 parser.add_argument("--image_path", type=str, required=True, \
                     help="The path of original image.")
                     help="The path of original image.")
 parser.add_argument("--geojson_path", type=str, required=True, \
 parser.add_argument("--geojson_path", type=str, required=True, \
-                    help="The path of geojson.")
+                    help="The path of geojson. (coordinate of geojson is WGS84)")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 0 - 81
tools/mask2geojson.py

@@ -1,81 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import codecs
-import argparse
-
-import cv2
-import numpy as np
-import geojson
-from geojson import Polygon, Feature, FeatureCollection
-
-from utils import Raster, Timer
-
-
-def _gt_convert(x, y, geotf):
-    x_geo = geotf[0] + x * geotf[1] + y * geotf[2]
-    y_geo = geotf[3] + x * geotf[4] + y * geotf[5]
-    return x_geo, y_geo
-
-
-@Timer
-def convert_data(mask_path, save_path, epsilon=0):
-    raster = Raster(mask_path)
-    img = raster.getArray()
-    ext = save_path.split(".")[-1]
-    if ext != "json" and ext != "geojson":
-        raise ValueError("The ext of `save_path` must be `json` or `geojson`, not {}.".format(ext))
-    geo_writer = codecs.open(save_path, "w", encoding="utf-8")
-    clas = np.unique(img)
-    cv2_v = (cv2.__version__.split(".")[0] == "3")
-    feats = []
-    if not isinstance(epsilon, (int, float)):
-        epsilon = 0
-    for iclas in range(1, len(clas)):
-        tmp = np.zeros_like(img).astype("uint8")
-        tmp[img == iclas] = 1
-        # TODO: Detect internal and external contour
-        results = cv2.findContours(tmp, cv2.RETR_EXTERNAL,
-                                   cv2.CHAIN_APPROX_TC89_KCOS)
-        contours = results[1] if cv2_v else results[0]
-        # hierarchys = results[2] if cv2_v else results[1]
-        if len(contours) == 0:
-            continue
-        for contour in contours:
-            contour = cv2.approxPolyDP(contour, epsilon, True)
-            polys = []
-            for point in contour:
-                x, y = point[0]
-                xg, yg = _gt_convert(x, y, raster.geot)
-                polys.append((xg, yg))
-            polys.append(polys[0])
-            feat = Feature(
-                geometry=Polygon([polys]), properties={"class": int(iclas)})
-            feats.append(feat)
-    gjs = FeatureCollection(feats)
-    geo_writer.write(geojson.dumps(gjs))
-    geo_writer.close()
-
-
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--mask_path", type=str, required=True, \
-                    help="The path of mask tif.")
-parser.add_argument("--save_path", type=str, required=True, \
-                    help="The path to save the results, file suffix is `*.json/geojson`.")
-parser.add_argument("--epsilon", type=float, default=0, \
-                    help="The CV2 simplified parameters, `0` is the default.")
-
-if __name__ == "__main__":
-    args = parser.parse_args()
-    convert_data(args.mask_path, args.save_path, args.epsilon)

+ 0 - 93
tools/mask2shp.py

@@ -1,93 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import os.path as osp
-import argparse
-
-import numpy as np
-from PIL import Image
-try:
-    from osgeo import gdal, ogr, osr
-except ImportError:
-    import gdal
-    import ogr
-    import osr
-
-from utils import Raster, Timer
-
-
-def _mask2tif(mask_path, tmp_path, proj, geot):
-    mask = np.asarray(Image.open(mask_path))
-    if len(mask.shape) == 3:
-        mask = mask[:, :, 0]
-    row, columns = mask.shape[:2]
-    driver = gdal.GetDriverByName("GTiff")
-    dst_ds = driver.Create(tmp_path, columns, row, 1, gdal.GDT_UInt16)
-    dst_ds.SetGeoTransform(geot)
-    dst_ds.SetProjection(proj)
-    dst_ds.GetRasterBand(1).WriteArray(mask)
-    dst_ds.FlushCache()
-    return dst_ds
-
-
-def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
-    tmp_path = shp_save_path.replace(".shp", ".tif")
-    ds = _mask2tif(mask_path, tmp_path, proj, geot)
-    srcband = ds.GetRasterBand(1)
-    maskband = srcband.GetMaskBand()
-    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
-    gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
-    ogr.RegisterAll()
-    drv = ogr.GetDriverByName("ESRI Shapefile")
-    if osp.exists(shp_save_path):
-        os.remove(shp_save_path)
-    dst_ds = drv.CreateDataSource(shp_save_path)
-    prosrs = osr.SpatialReference(wkt=ds.GetProjection())
-    dst_layer = dst_ds.CreateLayer(
-        "Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
-    dst_fieldname = "DN"
-    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
-    dst_layer.CreateField(fd)
-    gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
-    lyr = dst_ds.GetLayer()
-    lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
-    for holes in lyr:
-        lyr.DeleteFeature(holes.GetFID())
-    dst_ds.Destroy()
-    ds = None
-    os.remove(tmp_path)
-
-
-@Timer
-def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255):
-    src = Raster(srcimg_path)
-    _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index)
-    src = None
-
-
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--srcimg_path", type=str, required=True, \
-                    help="The path of original data with geoinfos.")
-parser.add_argument("--mask_path", type=str, required=True, \
-                    help="The path of mask data.")
-parser.add_argument("--save_path", type=str, default="output", \
-                    help="The path to save the results shapefile, `output` is the default.")
-parser.add_argument("--ignore_index", type=int, default=255, \
-                    help="It will not be converted to the value of SHP, `255` is the default.")
-
-if __name__ == "__main__":
-    args = parser.parse_args()
-    raster2shp(args.srcimg_path, args.mask_path, args.save_path,
-               args.ignore_index)

+ 3 - 26
tools/matcher.py

@@ -16,12 +16,8 @@ import argparse
 
 
 import numpy as np
 import numpy as np
 import cv2
 import cv2
-try:
-    from osgeo import gdal
-except ImportError:
-    import gdal
 
 
-from utils import Raster, raster2uint8, Timer
+from utils import Raster, raster2uint8, save_geotiff, timer
 
 
 class MatchError(Exception):
 class MatchError(Exception):
     def __str__(self):
     def __str__(self):
@@ -64,26 +60,7 @@ def _get_match_img(raster, bands):
     return ima
     return ima
 
 
 
 
-def _img2tif(ima, save_path, proj, geot, dtype):
-    if len(ima.shape) == 3:
-        row, columns, bands = ima.shape
-    else:
-        row, columns = ima.shape
-        bands = 1
-    driver = gdal.GetDriverByName("GTiff")
-    dst_ds = driver.Create(save_path, columns, row, bands, dtype)
-    dst_ds.SetGeoTransform(geot)
-    dst_ds.SetProjection(proj)
-    if bands != 1:
-        for b in range(bands):
-            dst_ds.GetRasterBand(b + 1).WriteArray(ima[:, :, b])
-    else:
-        dst_ds.GetRasterBand(1).WriteArray(ima)
-    dst_ds.FlushCache()
-    return dst_ds
-
-
-@Timer
+@timer
 def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
 def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im1_ras = Raster(im1_path)
     im1_ras = Raster(im1_path)
     im2_ras = Raster(im2_path)
     im2_ras = Raster(im2_path)
@@ -96,7 +73,7 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
                                     (im1_ras.width, im1_ras.height))
                                     (im1_ras.width, im1_ras.height))
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
-    _img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
+    save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
 
 
 
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser = argparse.ArgumentParser(description="input parameters")

+ 2 - 2
tools/oif.py

@@ -19,7 +19,7 @@ from easydict import EasyDict as edict
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
 
 
-from utils import Raster, Timer
+from utils import Raster, timer
 
 
 def _calcOIF(rgb, stds, rho):
 def _calcOIF(rgb, stds, rho):
     r, g, b = rgb
     r, g, b = rgb
@@ -32,7 +32,7 @@ def _calcOIF(rgb, stds, rho):
     return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
     return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
 
 
 
 
-@Timer
+@timer
 def oif(img_path, topk=5):
 def oif(img_path, topk=5):
     raster = Raster(img_path)
     raster = Raster(img_path)
     img = raster.getArray()
     img = raster.getArray()

+ 2 - 2
tools/pca.py

@@ -18,10 +18,10 @@ import numpy as np
 import argparse
 import argparse
 from sklearn.decomposition import PCA
 from sklearn.decomposition import PCA
 from joblib import dump
 from joblib import dump
-from utils import Raster, Timer, save_geotiff
+from utils import Raster, save_geotiff, timer
 
 
 
 
-@Timer
+@timer
 def pca_train(img_path, save_dir="output", dim=3):
 def pca_train(img_path, save_dir="output", dim=3):
     raster = Raster(img_path)
     raster = Raster(img_path)
     im = raster.getArray()
     im = raster.getArray()

+ 102 - 0
tools/raster2vector.py

@@ -0,0 +1,102 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import argparse
+
+import numpy as np
+from PIL import Image
+try:
+    from osgeo import gdal, ogr, osr
+except ImportError:
+    import gdal
+    import ogr
+    import osr
+
+from utils import Raster, save_geotiff, timer
+
+
+def _mask2tif(mask_path, tmp_path, proj, geot):
+    dst_ds = save_geotiff(
+        np.asarray(Image.open(mask_path)),
+        tmp_path, proj, geot,  gdal.GDT_UInt16, False)
+    return dst_ds
+
+
+def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext):
+    if proj is None or geot is None:
+        tmp_path = None
+        ds = gdal.Open(mask_path)
+    else:
+        tmp_path = vec_save_path.replace("." + ext, ".tif")
+        ds = _mask2tif(mask_path, tmp_path, proj, geot)
+    srcband = ds.GetRasterBand(1)
+    maskband = srcband.GetMaskBand()
+    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
+    gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
+    ogr.RegisterAll()
+    drv = ogr.GetDriverByName(
+        "ESRI Shapefile" if ext == "shp" else "GeoJSON"
+    )
+    if osp.exists(vec_save_path):
+        os.remove(vec_save_path)
+    dst_ds = drv.CreateDataSource(vec_save_path)
+    prosrs = osr.SpatialReference(wkt=ds.GetProjection())
+    dst_layer = dst_ds.CreateLayer(
+        "POLYGON", geom_type=ogr.wkbPolygon, srs=prosrs)
+    dst_fieldname = "CLAS"
+    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
+    dst_layer.CreateField(fd)
+    gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
+    # TODO: temporary: delete ignored values
+    dst_ds.Destroy()
+    ds = None
+    vec_ds = drv.Open(vec_save_path, 1)
+    lyr = vec_ds.GetLayer()
+    lyr.SetAttributeFilter("{} = '{}'".format(dst_fieldname, str(ignore_index)))
+    for holes in lyr:
+        lyr.DeleteFeature(holes.GetFID())
+    vec_ds.Destroy()
+    if tmp_path is not None:
+        os.remove(tmp_path)
+
+
+@timer
+def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
+    vec_ext = save_path.split(".")[-1].lower()
+    if vec_ext not in ["json", "geojson", "shp"]:
+        raise ValueError("The ext of `save_path` must be `json/geojson` or `shp`, not {}.".format(vec_ext))
+    ras_ext = srcimg_path.split(".")[-1].lower()
+    if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
+        src = Raster(srcimg_path)
+        _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index, vec_ext)
+        src = None
+    else:
+        _polygonize_raster(mask_path, save_path, None, None, ignore_index, vec_ext)
+
+
+parser = argparse.ArgumentParser(description="input parameters")
+parser.add_argument("--mask_path", type=str, required=True, \
+                    help="The path of mask data.")
+parser.add_argument("--save_path", type=str, required=True, \
+                    help="The path to save the results, file suffix is `*.json/geojson` or `*.shp`.")
+parser.add_argument("--srcimg_path", type=str, default="", \
+                    help="The path of original data with geoinfos, `` is the default.")
+parser.add_argument("--ignore_index", type=int, default=255, \
+                    help="It will not be converted to the value of SHP, `255` is the default.")
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    raster2vector(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)

+ 31 - 21
tools/spliter.py

@@ -16,42 +16,52 @@ import os
 import os.path as osp
 import os.path as osp
 import argparse
 import argparse
 from math import ceil
 from math import ceil
+from tqdm import tqdm
 
 
-from PIL import Image
+from utils import Raster, save_geotiff, timer
 
 
-from utils import Raster, Timer
 
 
+def _calc_window_tf(geot, loc):
+    x, hr, r1, y, r2, vr = geot
+    nx, ny = loc
+    return (x + nx * hr, hr, r1, y + ny * vr, r2, vr)
 
 
-@Timer
+
+@timer
 def split_data(image_path, mask_path, block_size, save_folder):
 def split_data(image_path, mask_path, block_size, save_folder):
     if not osp.exists(save_folder):
     if not osp.exists(save_folder):
         os.makedirs(save_folder)
         os.makedirs(save_folder)
         os.makedirs(osp.join(save_folder, "images"))
         os.makedirs(osp.join(save_folder, "images"))
         if mask_path is not None:
         if mask_path is not None:
             os.makedirs(osp.join(save_folder, "masks"))
             os.makedirs(osp.join(save_folder, "masks"))
-    image_name = image_path.replace("\\", "/").split("/")[-1].split(".")[0]
-    image = Raster(image_path, to_uint8=True)
+    image_name, image_ext = image_path.replace("\\", "/").split("/")[-1].split(".")
+    image = Raster(image_path)
     mask = Raster(mask_path) if mask_path is not None else None
     mask = Raster(mask_path) if mask_path is not None else None
-    if image.width != mask.width or image.height != mask.height:
+    if mask is not None and (image.width != mask.width or image.height != mask.height):
         raise ValueError("image's shape must equal mask's shape.")
         raise ValueError("image's shape must equal mask's shape.")
     rows = ceil(image.height / block_size)
     rows = ceil(image.height / block_size)
     cols = ceil(image.width / block_size)
     cols = ceil(image.width / block_size)
     total_number = int(rows * cols)
     total_number = int(rows * cols)
-    for r in range(rows):
-        for c in range(cols):
-            loc_start = (c * block_size, r * block_size)
-            image_title = Image.fromarray(image.getArray(
-                loc_start, (block_size, block_size))).convert("RGB")
-            image_save_path = osp.join(save_folder, "images", (
-                image_name + "_" + str(r) + "_" + str(c) + ".jpg"))
-            image_title.save(image_save_path, "JPEG")
-            if mask is not None:
-                mask_title = Image.fromarray(mask.getArray(
-                    loc_start, (block_size, block_size))).convert("L")
-                mask_save_path = osp.join(save_folder, "masks", (
-                    image_name + "_" + str(r) + "_" + str(c) + ".png"))
-                mask_title.save(mask_save_path, "PNG")
-            print("-- {:d}/{:d} --".format(int(r * cols + c + 1), total_number))
+
+    with tqdm(total=total_number) as pbar:
+        for r in range(rows):
+            for c in range(cols):
+                loc_start = (c * block_size, r * block_size)
+                image_title = image.getArray(loc_start, (block_size, block_size))
+                image_save_path = osp.join(save_folder, "images", (
+                    image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
+                window_geotf = _calc_window_tf(image.geot, loc_start)
+                save_geotiff(
+                    image_title, image_save_path, image.proj, window_geotf
+                )
+                if mask is not None:
+                    mask_title = mask.getArray(loc_start, (block_size, block_size))
+                    mask_save_path = osp.join(save_folder, "masks", (
+                        image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
+                    save_geotiff(
+                        mask_title, mask_save_path, image.proj, window_geotf
+                    )
+                pbar.update(1)
 
 
 
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser = argparse.ArgumentParser(description="input parameters")

+ 2 - 1
tools/utils/__init__.py

@@ -17,4 +17,5 @@ import os.path as osp
 sys.path.insert(0, osp.abspath(".."))  # add workspace
 sys.path.insert(0, osp.abspath(".."))  # add workspace
 
 
 from .raster import Raster, raster2uint8, save_geotiff
 from .raster import Raster, raster2uint8, save_geotiff
-from .timer import Timer
+from .vector import vector_translate
+from .timer import timer

+ 57 - 31
tools/utils/raster.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import os.path as osp
 import os.path as osp
-from typing import List, Tuple, Union
+from typing import List, Tuple, Union, Optional
 
 
 import numpy as np
 import numpy as np
 
 
@@ -49,36 +49,45 @@ def _get_type(type_name: str) -> int:
 
 
 class Raster:
 class Raster:
     def __init__(self,
     def __init__(self,
-                 path: str,
+                 path: Optional[str],
+                 gdal_obj: Optional[gdal.Dataset]=None,
                  band_list: Union[List[int], Tuple[int], None]=None,
                  band_list: Union[List[int], Tuple[int], None]=None,
                  to_uint8: bool=False) -> None:
                  to_uint8: bool=False) -> None:
         """ Class of read raster.
         """ Class of read raster.
         Args:
         Args:
-            path (str): The path of raster.
+            path (Optional[str]): The path of raster.
+            gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
             band_list (Union[List[int], Tuple[int], None], optional): 
             band_list (Union[List[int], Tuple[int], None], optional): 
                 band list (start with 1) or None (all of bands). Defaults to None.
                 band list (start with 1) or None (all of bands). Defaults to None.
             to_uint8 (bool, optional): 
             to_uint8 (bool, optional): 
                 Convert uint8 or return raw data. Defaults to False.
                 Convert uint8 or return raw data. Defaults to False.
         """
         """
         super(Raster, self).__init__()
         super(Raster, self).__init__()
-        if osp.exists(path):
-            self.path = path
-            self.ext_type = path.split(".")[-1]
-            if self.ext_type.lower() in ["npy", "npz"]:
-                self._src_data = None
+        if path is not None:
+            if osp.exists(path):
+                self.path = path
+                self.ext_type = path.split(".")[-1]
+                if self.ext_type.lower() in ["npy", "npz"]:
+                    self._src_data = None
+                else:
+                    try:
+                        # raster format support in GDAL: 
+                        # https://www.osgeo.cn/gdal/drivers/raster/index.html
+                        self._src_data = gdal.Open(path)
+                    except:
+                        raise TypeError(
+                            "Unsupported data format: `{}`".format(self.ext_type))
             else:
             else:
-                try:
-                    # raster format support in GDAL: 
-                    # https://www.osgeo.cn/gdal/drivers/raster/index.html
-                    self._src_data = gdal.Open(path)
-                except:
-                    raise TypeError("Unsupported data format: `{}`".format(
-                        self.ext_type))
-            self.to_uint8 = to_uint8
-            self.setBands(band_list)
-            self._getInfo()
+                raise ValueError("The path {0} not exists.".format(path))
         else:
         else:
-            raise ValueError("The path {0} not exists.".format(path))
+            if gdal_obj is not None:
+                self._src_data = gdal_obj
+            else:
+                raise ValueError("At least one of `path` and `gdal_obj` is not None.")
+        self.to_uint8 = to_uint8
+        self._getInfo()
+        self.setBands(band_list)
+        self._getType()
 
 
     def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
     def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
         """ Set band of data.
         """ Set band of data.
@@ -86,7 +95,6 @@ class Raster:
             band_list (Union[List[int], Tuple[int], None]): 
             band_list (Union[List[int], Tuple[int], None]): 
                 band list (start with 1) or None (all of bands).
                 band list (start with 1) or None (all of bands).
         """
         """
-        self.bands = self._src_data.RasterCount
         if band_list is not None:
         if band_list is not None:
             if len(band_list) > self.bands:
             if len(band_list) > self.bands:
                 raise ValueError(
                 raise ValueError(
@@ -99,8 +107,8 @@ class Raster:
 
 
     def getArray(
     def getArray(
             self,
             self,
-            start_loc: Union[List[int], Tuple[int], None]=None,
-            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+            start_loc: Union[List[int], Tuple[int, int], None]=None,
+            block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
         """ Get ndarray data 
         """ Get ndarray data 
         Args:
         Args:
             start_loc (Union[List[int], Tuple[int], None], optional): 
             start_loc (Union[List[int], Tuple[int], None], optional): 
@@ -123,13 +131,12 @@ class Raster:
         if self._src_data is not None:
         if self._src_data is not None:
             self.width = self._src_data.RasterXSize
             self.width = self._src_data.RasterXSize
             self.height = self._src_data.RasterYSize
             self.height = self._src_data.RasterYSize
+            self.bands = self._src_data.RasterCount
             self.geot = self._src_data.GetGeoTransform()
             self.geot = self._src_data.GetGeoTransform()
             self.proj = self._src_data.GetProjection()
             self.proj = self._src_data.GetProjection()
-            d_name = self._getBlock([0, 0], [1, 1]).dtype.name
         else:
         else:
             d_img = self._getNumpy()
             d_img = self._getNumpy()
             d_shape = d_img.shape
             d_shape = d_img.shape
-            d_name = d_img.dtype.name
             if len(d_shape) == 3:
             if len(d_shape) == 3:
                 self.height, self.width, self.bands = d_shape
                 self.height, self.width, self.bands = d_shape
             else:
             else:
@@ -137,6 +144,9 @@ class Raster:
                 self.bands = 1
                 self.bands = 1
             self.geot = None
             self.geot = None
             self.proj = None
             self.proj = None
+        
+    def _getType(self) -> None:
+        d_name = self.getArray([0, 0], [1, 1]).dtype.name
         self.datatype = _get_type(d_name)
         self.datatype = _get_type(d_name)
 
 
     def _getNumpy(self):
     def _getNumpy(self):
@@ -151,7 +161,9 @@ class Raster:
 
 
     def _getArray(
     def _getArray(
             self,
             self,
-            window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
+            window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray:
+        if self._src_data is None:
+            raise ValueError("The raster is None.")
         if window is not None:
         if window is not None:
             xoff, yoff, xsize, ysize = window
             xoff, yoff, xsize, ysize = window
         if self.band_list is None:
         if self.band_list is None:
@@ -183,8 +195,8 @@ class Raster:
 
 
     def _getBlock(
     def _getBlock(
             self,
             self,
-            start_loc: Union[List[int], Tuple[int]],
-            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+            start_loc: Union[List[int], Tuple[int, int]],
+            block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
         if len(start_loc) != 2 or len(block_size) != 2:
         if len(start_loc) != 2 or len(block_size) != 2:
             raise ValueError("The length start_loc/block_size must be 2.")
             raise ValueError("The length start_loc/block_size must be 2.")
         xoff, yoff = start_loc
         xoff, yoff = start_loc
@@ -208,9 +220,21 @@ class Raster:
         return tmp
         return tmp
 
 
 
 
-def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None:
-    height, width, channel = image.shape
-    data_type = _get_type(image.dtype.name)
+def save_geotiff(image: np.ndarray, 
+                 save_path: str, 
+                 proj: str, 
+                 geotf: Tuple,
+                 use_type: Optional[int]=None,
+                 clear_ds: bool=True) -> None:
+    if len(image.shape) == 2:
+        height, width = image.shape
+        channel = 1
+    else:
+        height, width, channel = image.shape
+    if use_type is not None:
+        data_type = use_type
+    else:
+        data_type = _get_type(image.dtype.name)
     driver = gdal.GetDriverByName("GTiff")
     driver = gdal.GetDriverByName("GTiff")
     dst_ds = driver.Create(save_path, width, height, channel, data_type)
     dst_ds = driver.Create(save_path, width, height, channel, data_type)
     dst_ds.SetGeoTransform(geotf)
     dst_ds.SetGeoTransform(geotf)
@@ -224,4 +248,6 @@ def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) ->
         band = dst_ds.GetRasterBand(1)
         band = dst_ds.GetRasterBand(1)
         band.WriteArray(image)
         band.WriteArray(image)
         dst_ds.FlushCache()
         dst_ds.FlushCache()
-    dst_ds = None
+    if clear_ds:
+        dst_ds = None
+    return dst_ds

+ 7 - 7
tools/utils/timer.py

@@ -13,14 +13,14 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import time
 import time
+from functools import wraps
 
 
 
 
-class Timer(object):
-    def __init__(self, func):
-        self.func = func
-
-    def __call__(self, *args, **kwds):
+def timer(func):
+    @wraps(func)
+    def wrapper(*args,**kwargs):
         start_time = time.time()
         start_time = time.time()
-        func_t = self.func(*args, **kwds)
+        result = func(*args,**kwargs)
         print("Total time: {0}.".format(time.time() - start_time))
         print("Total time: {0}.".format(time.time() - start_time))
-        return func_t
+        return result
+    return wrapper

+ 53 - 0
tools/utils/vector.py

@@ -0,0 +1,53 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# reference: https://zhuanlan.zhihu.com/p/378918221
+
+try:
+    from osgeo import gdal, ogr, osr
+except:
+    import gdal
+    import ogr
+    import osr
+
+
+def vector_translate(geojson_path: str,
+                     wo_wkt: str,
+                     g_type: str="POLYGON",
+                     dim: str="XY") -> str:
+	ogr.RegisterAll()
+	gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
+	data = ogr.Open(geojson_path)
+	layer = data.GetLayer()
+	spatial = layer.GetSpatialRef()
+	layerName = layer.GetName()
+	data.Destroy()
+	dstSRS = osr.SpatialReference()
+	dstSRS.ImportFromWkt(wo_wkt)
+	ext = "." + geojson_path.split(".")[-1]
+	save_path = geojson_path.replace(ext, ("_tmp" + ext))
+	options = gdal.VectorTranslateOptions(
+		srcSRS=spatial,
+		dstSRS=dstSRS,
+		reproject=True,
+		layerName=layerName,
+		geometryType=g_type,
+		dim=dim
+	)
+	gdal.VectorTranslate(
+		save_path,
+		srcDS=geojson_path,
+		options=options
+	)
+	return save_path