Bobholamovic пре 2 година
родитељ
комит
055486f834
11 измењених фајлова са 160 додато и 144 уклоњено
  1. 20 20
      tools/coco2mask.py
  2. 12 12
      tools/geojson2mask.py
  3. 11 9
      tools/match.py
  4. 6 5
      tools/oif.py
  5. 14 11
      tools/pca.py
  6. 19 16
      tools/raster2vector.py
  7. 19 15
      tools/split.py
  8. 2 2
      tools/utils/__init__.py
  9. 31 24
      tools/utils/raster.py
  10. 5 4
      tools/utils/timer.py
  11. 21 26
      tools/utils/vector.py

+ 20 - 20
tools/coco2mask.py

@@ -25,7 +25,7 @@ import glob
 from tqdm import tqdm
 from tqdm import tqdm
 from PIL import Image
 from PIL import Image
 
 
-from utils import timer
+from utils import time_it
 
 
 
 
 def _mkdir_p(path):
 def _mkdir_p(path):
@@ -69,30 +69,30 @@ def _read_geojson(json_path):
         return annotations, sizes
         return annotations, sizes
 
 
 
 
-@timer
-def convert_data(raw_folder, end_folder):
+@time_it
+def convert_data(raw_dir, end_dir):
     print("-- Initializing --")
     print("-- Initializing --")
-    img_folder = osp.join(raw_folder, "images")
-    save_img_folder = osp.join(end_folder, "img")
-    save_lab_folder = osp.join(end_folder, "gt")
-    _mkdir_p(save_img_folder)
-    _mkdir_p(save_lab_folder)
-    names = os.listdir(img_folder)
+    img_dir = osp.join(raw_dir, "images")
+    save_img_dir = osp.join(end_dir, "img")
+    save_lab_dir = osp.join(end_dir, "gt")
+    _mkdir_p(save_img_dir)
+    _mkdir_p(save_lab_dir)
+    names = os.listdir(img_dir)
     print("-- Loading annotations --")
     print("-- Loading annotations --")
     anns = {}
     anns = {}
     sizes = {}
     sizes = {}
-    jsons = glob.glob(osp.join(raw_folder, "*.json"))
+    jsons = glob.glob(osp.join(raw_dir, "*.json"))
     for json in jsons:
     for json in jsons:
         j_ann, j_size = _read_geojson(json)
         j_ann, j_size = _read_geojson(json)
         anns.update(j_ann)
         anns.update(j_ann)
         sizes.update(j_size)
         sizes.update(j_size)
-    print("-- Converting datas --")
+    print("-- Converting data --")
     for k in tqdm(names):
     for k in tqdm(names):
         # for k in tqdm(anns.keys()):
         # for k in tqdm(anns.keys()):
-        img_path = osp.join(img_folder, k)
-        img_save_path = osp.join(save_img_folder, k)
+        img_path = osp.join(img_dir, k)
+        img_save_path = osp.join(save_img_dir, k)
         ext = "." + k.split(".")[-1]
         ext = "." + k.split(".")[-1]
-        lab_save_path = osp.join(save_lab_folder, k.replace(ext, ".png"))
+        lab_save_path = osp.join(save_lab_dir, k.replace(ext, ".png"))
         shutil.copy(img_path, img_save_path)
         shutil.copy(img_path, img_save_path)
         if k in anns.keys():
         if k in anns.keys():
             _save_mask(anns[k], sizes[k], lab_save_path)
             _save_mask(anns[k], sizes[k], lab_save_path)
@@ -101,12 +101,12 @@ def convert_data(raw_folder, end_folder):
                           lab_save_path)
                           lab_save_path)
 
 
 
 
-parser = argparse.ArgumentParser(description="input parameters")
-parser.add_argument("--raw_folder", type=str, required=True, \
-                    help="The folder path about original data, where `images` saves the original image, `annotation.json` saves the corresponding annotation information.")
-parser.add_argument("--save_folder", type=str, required=True, \
-                    help="The folder path to save the results, where `img` saves the image and `gt` saves the label.")
+parser = argparse.ArgumentParser()
+parser.add_argument("--raw_dir", type=str, required=True, \
+                    help="Directory that contains original data, where `images` stores the original image and `annotation.json` stores the corresponding annotation information.")
+parser.add_argument("--save_dir", type=str, required=True, \
+                    help="Directory to save the results, where `img` stores the image and `gt` stores the label.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()
-    convert_data(args.raw_folder, args.save_folder)
+    convert_data(args.raw_dir, args.save_dir)

+ 12 - 12
tools/geojson2mask.py

@@ -19,7 +19,7 @@ import numpy as np
 import argparse
 import argparse
 import geojson
 import geojson
 from tqdm import tqdm
 from tqdm import tqdm
-from utils import Raster, save_geotiff, vector_translate, timer
+from utils import Raster, save_geotiff, translate_vector, time_it
 
 
 
 
 def _gt_convert(x_geo, y_geo, geotf):
 def _gt_convert(x_geo, y_geo, geotf):
@@ -28,13 +28,13 @@ def _gt_convert(x_geo, y_geo, geotf):
     return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
     return np.round(np.linalg.solve(a, b)).tolist()  # 解一元二次方程
 
 
 
 
-@timer
+@time_it
 # TODO: update for vector2raster
 # TODO: update for vector2raster
 def convert_data(image_path, geojson_path):
 def convert_data(image_path, geojson_path):
     raster = Raster(image_path)
     raster = Raster(image_path)
     tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
     tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
     # vector to EPSG from raster
     # vector to EPSG from raster
-    temp_geojson_path = vector_translate(geojson_path, raster.proj)
+    temp_geojson_path = translate_vector(geojson_path, raster.proj)
     geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
     geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
     feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
     feats = geojson.loads(geo_reader.read())["features"]  # 所有图像块
     geo_reader.close()
     geo_reader.close()
@@ -45,25 +45,25 @@ def convert_data(image_path, geojson_path):
         elif geo["type"] == "MultiPolygon":  # 多面
         elif geo["type"] == "MultiPolygon":  # 多面
             geo_points = geo["coordinates"][0][0]
             geo_points = geo["coordinates"][0][0]
         else:
         else:
-            raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(
-                geo["type"]))
+            raise TypeError(
+                "Geometry type must be `Polygon` or `MultiPolygon`, not {}.".
+                format(geo["type"]))
         xy_points = np.array([
         xy_points = np.array([
-            _gt_convert(point[0], point[1], raster.geot)
-            for point in geo_points
+            _gt_convert(point[0], point[1], raster.geot) for point in geo_points
         ]).astype(np.int32)
         ]).astype(np.int32)
         # TODO: Label category
         # TODO: Label category
         cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
         cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
     ext = "." + geojson_path.split(".")[-1]
     ext = "." + geojson_path.split(".")[-1]
-    save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
+    save_geotiff(tmp_img,
+                 geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
     os.remove(temp_geojson_path)
     os.remove(temp_geojson_path)
 
 
 
 
-parser = argparse.ArgumentParser(description="input parameters")
+parser = argparse.ArgumentParser()
 parser.add_argument("--image_path", type=str, required=True, \
 parser.add_argument("--image_path", type=str, required=True, \
-                    help="The path of original image.")
+                    help="Path of original image.")
 parser.add_argument("--geojson_path", type=str, required=True, \
 parser.add_argument("--geojson_path", type=str, required=True, \
-                    help="The path of geojson. (coordinate of geojson is WGS84)")
-
+                    help="Path of the geojson file (the coordinate system should be WGS84).")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()

+ 11 - 9
tools/matcher.py → tools/match.py

@@ -17,7 +17,8 @@ import argparse
 import numpy as np
 import numpy as np
 import cv2
 import cv2
 
 
-from utils import Raster, raster2uint8, save_geotiff, timer
+from utils import Raster, raster2uint8, save_geotiff, time_it
+
 
 
 class MatchError(Exception):
 class MatchError(Exception):
     def __str__(self):
     def __str__(self):
@@ -60,8 +61,8 @@ def _get_match_img(raster, bands):
     return ima
     return ima
 
 
 
 
-@timer
-def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
+@time_it
+def match(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im1_ras = Raster(im1_path)
     im1_ras = Raster(im1_path)
     im2_ras = Raster(im2_path)
     im2_ras = Raster(im2_path)
     im1 = _get_match_img(im1_ras._src_data, im1_bands)
     im1 = _get_match_img(im1_ras._src_data, im1_bands)
@@ -73,19 +74,20 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
                                     (im1_ras.width, im1_ras.height))
                                     (im1_ras.width, im1_ras.height))
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
-    save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
+    save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot,
+                 im1_ras.datatype)
 
 
 
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--im1_path", type=str, required=True, \
 parser.add_argument("--im1_path", type=str, required=True, \
-                    help="The path of time1 image (with geoinfo).")
+                    help="Path of time1 image (with geoinfo).")
 parser.add_argument("--im2_path", type=str, required=True, \
 parser.add_argument("--im2_path", type=str, required=True, \
-                    help="The path of time2 image.")
+                    help="Path of time2 image.")
 parser.add_argument("--im1_bands", type=int, nargs="+", default=[1, 2, 3], \
 parser.add_argument("--im1_bands", type=int, nargs="+", default=[1, 2, 3], \
-                    help="The time1 image's band used for matching, RGB or monochrome, `[1, 2, 3]` is the default.")
+                    help="Bands of im1 to be used for matching, RGB or monochrome. `[1, 2, 3]` is the default value.")
 parser.add_argument("--im2_bands", type=int, nargs="+", default=[1, 2, 3], \
 parser.add_argument("--im2_bands", type=int, nargs="+", default=[1, 2, 3], \
-                    help="The time2 image's band used for matching, RGB or monochrome, `[1, 2, 3]` is the default.")
+                    help="Bands of im2 to be used for matching, RGB or monochrome. `[1, 2, 3]` is the default value.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()
-    matching(args.im1_path, args.im2_path, args.im1_bands, args.im2_bands)
+    match(args.im1_path, args.im2_path, args.im1_bands, args.im2_bands)

+ 6 - 5
tools/oif.py

@@ -19,7 +19,8 @@ from easydict import EasyDict as edict
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
 
 
-from utils import Raster, timer
+from utils import Raster, time_it
+
 
 
 def _calcOIF(rgb, stds, rho):
 def _calcOIF(rgb, stds, rho):
     r, g, b = rgb
     r, g, b = rgb
@@ -32,7 +33,7 @@ def _calcOIF(rgb, stds, rho):
     return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
     return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
 
 
 
 
-@timer
+@time_it
 def oif(img_path, topk=5):
 def oif(img_path, topk=5):
     raster = Raster(img_path)
     raster = Raster(img_path)
     img = raster.getArray()
     img = raster.getArray()
@@ -54,11 +55,11 @@ def oif(img_path, topk=5):
         print("Bands: {0}, OIF value: {1}.".format(k, v))
         print("Bands: {0}, OIF value: {1}.".format(k, v))
 
 
 
 
-parser = argparse.ArgumentParser(description="input parameters")
+parser = argparse.ArgumentParser()
 parser.add_argument("--im_path", type=str, required=True, \
 parser.add_argument("--im_path", type=str, required=True, \
-                    help="The path of HSIs image.")
+                    help="Path of HSIs image.")
 parser.add_argument("--topk", type=int, default=5, \
 parser.add_argument("--topk", type=int, default=5, \
-                    help="Number of top results, `5` is the default.")
+                    help="Number of top results. `5` is the default value.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()

+ 14 - 11
tools/pca.py

@@ -16,12 +16,14 @@ import os
 import os.path as osp
 import os.path as osp
 import numpy as np
 import numpy as np
 import argparse
 import argparse
+
 from sklearn.decomposition import PCA
 from sklearn.decomposition import PCA
 from joblib import dump
 from joblib import dump
-from utils import Raster, save_geotiff, timer
+
+from utils import Raster, save_geotiff, time_it
 
 
 
 
-@timer
+@time_it
 def pca_train(img_path, save_dir="output", dim=3):
 def pca_train(img_path, save_dir="output", dim=3):
     raster = Raster(img_path)
     raster = Raster(img_path)
     im = raster.getArray()
     im = raster.getArray()
@@ -33,20 +35,21 @@ def pca_train(img_path, save_dir="output", dim=3):
     name = osp.splitext(osp.normpath(img_path).split(os.sep)[-1])[0]
     name = osp.splitext(osp.normpath(img_path).split(os.sep)[-1])[0]
     model_save_path = osp.join(save_dir, (name + "_pca.joblib"))
     model_save_path = osp.join(save_dir, (name + "_pca.joblib"))
     image_save_path = osp.join(save_dir, (name + "_pca.tif"))
     image_save_path = osp.join(save_dir, (name + "_pca.tif"))
-    dump(pca_model, model_save_path)  # save model
-    output = pca_model.transform(n_im).reshape((raster.height, raster.width, -1))
-    save_geotiff(output, image_save_path, raster.proj, raster.geot)  # save tiff
-    print("The Image and model of PCA saved in {}.".format(save_dir))
+    dump(pca_model, model_save_path)  # Save model
+    output = pca_model.transform(n_im).reshape(
+        (raster.height, raster.width, -1))
+    save_geotiff(output, image_save_path, raster.proj, raster.geot)  # Save tiff
+    print("The output image and the PCA model are saved in {}.".format(
+        save_dir))
 
 
 
 
-parser = argparse.ArgumentParser(description="input parameters")
+parser = argparse.ArgumentParser()
 parser.add_argument("--im_path", type=str, required=True, \
 parser.add_argument("--im_path", type=str, required=True, \
-                    help="The path of HSIs image.")
+                    help="Path of HSIs image.")
 parser.add_argument("--save_dir", type=str, default="output", \
 parser.add_argument("--save_dir", type=str, default="output", \
-                    help="The params(*.joblib) saved folder, `output` is the default.")
+                    help="Directory to save PCA params(*.joblib). Default: `output`.")
 parser.add_argument("--dim", type=int, default=3, \
 parser.add_argument("--dim", type=int, default=3, \
-                    help="The dimension after reduced, `3` is the default.")
-
+                    help="Dimension to reduce to. Default: `3`.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()

+ 19 - 16
tools/raster2vector.py

@@ -25,13 +25,13 @@ except ImportError:
     import ogr
     import ogr
     import osr
     import osr
 
 
-from utils import Raster, save_geotiff, timer
+from utils import Raster, save_geotiff, time_it
 
 
 
 
 def _mask2tif(mask_path, tmp_path, proj, geot):
 def _mask2tif(mask_path, tmp_path, proj, geot):
     dst_ds = save_geotiff(
     dst_ds = save_geotiff(
-        np.asarray(Image.open(mask_path)),
-        tmp_path, proj, geot,  gdal.GDT_UInt16, False)
+        np.asarray(Image.open(mask_path)), tmp_path, proj, geot,
+        gdal.GDT_UInt16, False)
     return dst_ds
     return dst_ds
 
 
 
 
@@ -47,9 +47,7 @@ def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext):
     gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
     gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
     gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
     gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
     ogr.RegisterAll()
     ogr.RegisterAll()
-    drv = ogr.GetDriverByName(
-        "ESRI Shapefile" if ext == "shp" else "GeoJSON"
-    )
+    drv = ogr.GetDriverByName("ESRI Shapefile" if ext == "shp" else "GeoJSON")
     if osp.exists(vec_save_path):
     if osp.exists(vec_save_path):
         os.remove(vec_save_path)
         os.remove(vec_save_path)
     dst_ds = drv.CreateDataSource(vec_save_path)
     dst_ds = drv.CreateDataSource(vec_save_path)
@@ -73,30 +71,35 @@ def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext):
         os.remove(tmp_path)
         os.remove(tmp_path)
 
 
 
 
-@timer
+@time_it
 def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
 def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
     vec_ext = save_path.split(".")[-1].lower()
     vec_ext = save_path.split(".")[-1].lower()
     if vec_ext not in ["json", "geojson", "shp"]:
     if vec_ext not in ["json", "geojson", "shp"]:
-        raise ValueError("The ext of `save_path` must be `json/geojson` or `shp`, not {}.".format(vec_ext))
+        raise ValueError(
+            "The ext of `save_path` must be `json/geojson` or `shp`, not {}.".
+            format(vec_ext))
     ras_ext = srcimg_path.split(".")[-1].lower()
     ras_ext = srcimg_path.split(".")[-1].lower()
     if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
     if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
         src = Raster(srcimg_path)
         src = Raster(srcimg_path)
-        _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index, vec_ext)
+        _polygonize_raster(mask_path, save_path, src.proj, src.geot,
+                           ignore_index, vec_ext)
         src = None
         src = None
     else:
     else:
-        _polygonize_raster(mask_path, save_path, None, None, ignore_index, vec_ext)
+        _polygonize_raster(mask_path, save_path, None, None, ignore_index,
+                           vec_ext)
 
 
 
 
-parser = argparse.ArgumentParser(description="input parameters")
+parser = argparse.ArgumentParser()
 parser.add_argument("--mask_path", type=str, required=True, \
 parser.add_argument("--mask_path", type=str, required=True, \
-                    help="The path of mask data.")
+                    help="Path of mask data.")
 parser.add_argument("--save_path", type=str, required=True, \
 parser.add_argument("--save_path", type=str, required=True, \
-                    help="The path to save the results, file suffix is `*.json/geojson` or `*.shp`.")
+                    help="Path to save the shape file (the file suffix is `*.json/geojson` or `*.shp`).")
 parser.add_argument("--srcimg_path", type=str, default="", \
 parser.add_argument("--srcimg_path", type=str, default="", \
-                    help="The path of original data with geoinfos, `` is the default.")
+                    help="Path of original data with geoinfo. Default to empty.")
 parser.add_argument("--ignore_index", type=int, default=255, \
 parser.add_argument("--ignore_index", type=int, default=255, \
-                    help="It will not be converted to the value of SHP, `255` is the default.")
+                    help="The ignored index will not be converted to a value in the shape file. Default value is `255`.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()
-    raster2vector(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)
+    raster2vector(args.srcimg_path, args.mask_path, args.save_path,
+                  args.ignore_index)

+ 19 - 15
tools/spliter.py → tools/split.py

@@ -18,7 +18,7 @@ import argparse
 from math import ceil
 from math import ceil
 from tqdm import tqdm
 from tqdm import tqdm
 
 
-from utils import Raster, save_geotiff, timer
+from utils import Raster, save_geotiff, time_it
 
 
 
 
 def _calc_window_tf(geot, loc):
 def _calc_window_tf(geot, loc):
@@ -27,17 +27,19 @@ def _calc_window_tf(geot, loc):
     return (x + nx * hr, hr, r1, y + ny * vr, r2, vr)
     return (x + nx * hr, hr, r1, y + ny * vr, r2, vr)
 
 
 
 
-@timer
+@time_it
 def split_data(image_path, mask_path, block_size, save_folder):
 def split_data(image_path, mask_path, block_size, save_folder):
     if not osp.exists(save_folder):
     if not osp.exists(save_folder):
         os.makedirs(save_folder)
         os.makedirs(save_folder)
         os.makedirs(osp.join(save_folder, "images"))
         os.makedirs(osp.join(save_folder, "images"))
         if mask_path is not None:
         if mask_path is not None:
             os.makedirs(osp.join(save_folder, "masks"))
             os.makedirs(osp.join(save_folder, "masks"))
-    image_name, image_ext = image_path.replace("\\", "/").split("/")[-1].split(".")
+    image_name, image_ext = image_path.replace("\\",
+                                               "/").split("/")[-1].split(".")
     image = Raster(image_path)
     image = Raster(image_path)
     mask = Raster(mask_path) if mask_path is not None else None
     mask = Raster(mask_path) if mask_path is not None else None
-    if mask is not None and (image.width != mask.width or image.height != mask.height):
+    if mask is not None and (image.width != mask.width or
+                             image.height != mask.height):
         raise ValueError("image's shape must equal mask's shape.")
         raise ValueError("image's shape must equal mask's shape.")
     rows = ceil(image.height / block_size)
     rows = ceil(image.height / block_size)
     cols = ceil(image.width / block_size)
     cols = ceil(image.width / block_size)
@@ -47,20 +49,21 @@ def split_data(image_path, mask_path, block_size, save_folder):
         for r in range(rows):
         for r in range(rows):
             for c in range(cols):
             for c in range(cols):
                 loc_start = (c * block_size, r * block_size)
                 loc_start = (c * block_size, r * block_size)
-                image_title = image.getArray(loc_start, (block_size, block_size))
+                image_title = image.getArray(loc_start,
+                                             (block_size, block_size))
                 image_save_path = osp.join(save_folder, "images", (
                 image_save_path = osp.join(save_folder, "images", (
                     image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
                     image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
                 window_geotf = _calc_window_tf(image.geot, loc_start)
                 window_geotf = _calc_window_tf(image.geot, loc_start)
-                save_geotiff(
-                    image_title, image_save_path, image.proj, window_geotf
-                )
+                save_geotiff(image_title, image_save_path, image.proj,
+                             window_geotf)
                 if mask is not None:
                 if mask is not None:
-                    mask_title = mask.getArray(loc_start, (block_size, block_size))
-                    mask_save_path = osp.join(save_folder, "masks", (
-                        image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
-                    save_geotiff(
-                        mask_title, mask_save_path, image.proj, window_geotf
-                    )
+                    mask_title = mask.getArray(loc_start,
+                                               (block_size, block_size))
+                    mask_save_path = osp.join(save_folder, "masks",
+                                              (image_name + "_" + str(r) + "_" +
+                                               str(c) + "." + image_ext))
+                    save_geotiff(mask_title, mask_save_path, image.proj,
+                                 window_geotf)
                 pbar.update(1)
                 pbar.update(1)
 
 
 
 
@@ -76,4 +79,5 @@ parser.add_argument("--save_folder", type=str, default="dataset", \
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()
-    split_data(args.image_path, args.mask_path, args.block_size, args.save_folder)
+    split_data(args.image_path, args.mask_path, args.block_size,
+               args.save_folder)

+ 2 - 2
tools/utils/__init__.py

@@ -17,5 +17,5 @@ import os.path as osp
 sys.path.insert(0, osp.abspath(".."))  # add workspace
 sys.path.insert(0, osp.abspath(".."))  # add workspace
 
 
 from .raster import Raster, raster2uint8, save_geotiff
 from .raster import Raster, raster2uint8, save_geotiff
-from .vector import vector_translate
-from .timer import timer
+from .vector import translate_vector
+from .timer import time_it

+ 31 - 24
tools/utils/raster.py

@@ -16,14 +16,13 @@ import os.path as osp
 from typing import List, Tuple, Union, Optional
 from typing import List, Tuple, Union, Optional
 
 
 import numpy as np
 import numpy as np
-
-from paddlers.transforms.functions import to_uint8 as raster2uint8
-
 try:
 try:
     from osgeo import gdal
     from osgeo import gdal
 except:
 except:
     import gdal
     import gdal
 
 
+from paddlers.transforms.functions import to_uint8 as raster2uint8
+
 
 
 def _get_type(type_name: str) -> int:
 def _get_type(type_name: str) -> int:
     if type_name in ["bool", "uint8"]:
     if type_name in ["bool", "uint8"]:
@@ -53,7 +52,9 @@ class Raster:
                  gdal_obj: Optional[gdal.Dataset]=None,
                  gdal_obj: Optional[gdal.Dataset]=None,
                  band_list: Union[List[int], Tuple[int], None]=None,
                  band_list: Union[List[int], Tuple[int], None]=None,
                  to_uint8: bool=False) -> None:
                  to_uint8: bool=False) -> None:
-        """ Class of read raster.
+        """
+        Class of read raster.
+        
         Args:
         Args:
             path (Optional[str]): The path of raster.
             path (Optional[str]): The path of raster.
             gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
             gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
@@ -75,22 +76,25 @@ class Raster:
                         # https://www.osgeo.cn/gdal/drivers/raster/index.html
                         # https://www.osgeo.cn/gdal/drivers/raster/index.html
                         self._src_data = gdal.Open(path)
                         self._src_data = gdal.Open(path)
                     except:
                     except:
-                        raise TypeError(
-                            "Unsupported data format: `{}`".format(self.ext_type))
+                        raise TypeError("Unsupported data format: `{}`".format(
+                            self.ext_type))
             else:
             else:
                 raise ValueError("The path {0} not exists.".format(path))
                 raise ValueError("The path {0} not exists.".format(path))
         else:
         else:
             if gdal_obj is not None:
             if gdal_obj is not None:
                 self._src_data = gdal_obj
                 self._src_data = gdal_obj
             else:
             else:
-                raise ValueError("At least one of `path` and `gdal_obj` is not None.")
+                raise ValueError(
+                    "At least one of `path` and `gdal_obj` is not None.")
         self.to_uint8 = to_uint8
         self.to_uint8 = to_uint8
         self._getInfo()
         self._getInfo()
         self.setBands(band_list)
         self.setBands(band_list)
         self._getType()
         self._getType()
 
 
     def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
     def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
-        """ Set band of data.
+        """ 
+        Set band of data.
+        
         Args:
         Args:
             band_list (Union[List[int], Tuple[int], None]): 
             band_list (Union[List[int], Tuple[int], None]): 
                 band list (start with 1) or None (all of bands).
                 band list (start with 1) or None (all of bands).
@@ -105,16 +109,19 @@ class Raster:
                                  format(str(self.bands)))
                                  format(str(self.bands)))
         self.band_list = band_list
         self.band_list = band_list
 
 
-    def getArray(
-            self,
-            start_loc: Union[List[int], Tuple[int, int], None]=None,
-            block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
-        """ Get ndarray data 
+    def getArray(self,
+                 start_loc: Union[List[int], Tuple[int, int], None]=None,
+                 block_size: Union[List[int], Tuple[int, int]]=[512, 512]
+                 ) -> np.ndarray:
+        """ 
+        Get ndarray data 
+        
         Args:
         Args:
             start_loc (Union[List[int], Tuple[int], None], optional): 
             start_loc (Union[List[int], Tuple[int], None], optional): 
                 Coordinates of the upper left corner of the block, if None means return full image.
                 Coordinates of the upper left corner of the block, if None means return full image.
             block_size (Union[List[int], Tuple[int]], optional): 
             block_size (Union[List[int], Tuple[int]], optional): 
                 Block size. Defaults to [512, 512].
                 Block size. Defaults to [512, 512].
+
         Returns:
         Returns:
             np.ndarray: data's ndarray.
             np.ndarray: data's ndarray.
         """
         """
@@ -144,7 +151,7 @@ class Raster:
                 self.bands = 1
                 self.bands = 1
             self.geot = None
             self.geot = None
             self.proj = None
             self.proj = None
-        
+
     def _getType(self) -> None:
     def _getType(self) -> None:
         d_name = self.getArray([0, 0], [1, 1]).dtype.name
         d_name = self.getArray([0, 0], [1, 1]).dtype.name
         self.datatype = _get_type(d_name)
         self.datatype = _get_type(d_name)
@@ -159,9 +166,9 @@ class Raster:
             ima = np.stack(band_array, axis=0)
             ima = np.stack(band_array, axis=0)
         return ima
         return ima
 
 
-    def _getArray(
-            self,
-            window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray:
+    def _getArray(self,
+                  window: Union[None, List[int], Tuple[int, int, int, int]]=None
+                  ) -> np.ndarray:
         if self._src_data is None:
         if self._src_data is None:
             raise ValueError("The raster is None.")
             raise ValueError("The raster is None.")
         if window is not None:
         if window is not None:
@@ -193,10 +200,10 @@ class Raster:
             ima = raster2uint8(ima)
             ima = raster2uint8(ima)
         return ima
         return ima
 
 
-    def _getBlock(
-            self,
-            start_loc: Union[List[int], Tuple[int, int]],
-            block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
+    def _getBlock(self,
+                  start_loc: Union[List[int], Tuple[int, int]],
+                  block_size: Union[List[int], Tuple[int, int]]=[512, 512]
+                  ) -> np.ndarray:
         if len(start_loc) != 2 or len(block_size) != 2:
         if len(start_loc) != 2 or len(block_size) != 2:
             raise ValueError("The length start_loc/block_size must be 2.")
             raise ValueError("The length start_loc/block_size must be 2.")
         xoff, yoff = start_loc
         xoff, yoff = start_loc
@@ -220,9 +227,9 @@ class Raster:
         return tmp
         return tmp
 
 
 
 
-def save_geotiff(image: np.ndarray, 
-                 save_path: str, 
-                 proj: str, 
+def save_geotiff(image: np.ndarray,
+                 save_path: str,
+                 proj: str,
                  geotf: Tuple,
                  geotf: Tuple,
                  use_type: Optional[int]=None,
                  use_type: Optional[int]=None,
                  clear_ds: bool=True) -> None:
                  clear_ds: bool=True) -> None:

+ 5 - 4
tools/utils/timer.py

@@ -16,11 +16,12 @@ import time
 from functools import wraps
 from functools import wraps
 
 
 
 
-def timer(func):
+def time_it(func):
     @wraps(func)
     @wraps(func)
-    def wrapper(*args,**kwargs):
+    def wrapper(*args, **kwargs):
         start_time = time.time()
         start_time = time.time()
-        result = func(*args,**kwargs)
-        print("Total time: {0}.".format(time.time() - start_time))
+        result = func(*args, **kwargs)
+        print("Total time consumed: {0}.".format(time.time() - start_time))
         return result
         return result
+
     return wrapper
     return wrapper

+ 21 - 26
tools/utils/vector.py

@@ -22,32 +22,27 @@ except:
     import osr
     import osr
 
 
 
 
-def vector_translate(geojson_path: str,
+def translate_vector(geojson_path: str,
                      wo_wkt: str,
                      wo_wkt: str,
                      g_type: str="POLYGON",
                      g_type: str="POLYGON",
                      dim: str="XY") -> str:
                      dim: str="XY") -> str:
-	ogr.RegisterAll()
-	gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
-	data = ogr.Open(geojson_path)
-	layer = data.GetLayer()
-	spatial = layer.GetSpatialRef()
-	layerName = layer.GetName()
-	data.Destroy()
-	dstSRS = osr.SpatialReference()
-	dstSRS.ImportFromWkt(wo_wkt)
-	ext = "." + geojson_path.split(".")[-1]
-	save_path = geojson_path.replace(ext, ("_tmp" + ext))
-	options = gdal.VectorTranslateOptions(
-		srcSRS=spatial,
-		dstSRS=dstSRS,
-		reproject=True,
-		layerName=layerName,
-		geometryType=g_type,
-		dim=dim
-	)
-	gdal.VectorTranslate(
-		save_path,
-		srcDS=geojson_path,
-		options=options
-	)
-	return save_path
+    ogr.RegisterAll()
+    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
+    data = ogr.Open(geojson_path)
+    layer = data.GetLayer()
+    spatial = layer.GetSpatialRef()
+    layerName = layer.GetName()
+    data.Destroy()
+    dstSRS = osr.SpatialReference()
+    dstSRS.ImportFromWkt(wo_wkt)
+    ext = "." + geojson_path.split(".")[-1]
+    save_path = geojson_path.replace(ext, ("_tmp" + ext))
+    options = gdal.VectorTranslateOptions(
+        srcSRS=spatial,
+        dstSRS=dstSRS,
+        reproject=True,
+        layerName=layerName,
+        geometryType=g_type,
+        dim=dim)
+    gdal.VectorTranslate(save_path, srcDS=geojson_path, options=options)
+    return save_path