Bläddra i källkod

[Fix] Fix PCA used (#78)

Yizhou Chen 3 år sedan
förälder
incheckning
de61f6007f
6 ändrade filer med 337 tillägg och 264 borttagningar
  1. 20 25
      paddlers/transforms/functions.py
  2. 15 10
      paddlers/transforms/operators.py
  3. 2 2
      tools/geojson2mask.py
  4. 53 0
      tools/pca.py
  5. 20 20
      tools/utils/__init__.py
  6. 227 207
      tools/utils/raster.py

+ 20 - 25
paddlers/transforms/functions.py

@@ -18,9 +18,9 @@ import copy
 import numpy as np
 import shapely.ops
 from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
-from sklearn.decomposition import PCA
 from sklearn.linear_model import LinearRegression
 from skimage import exposure
+from joblib import load
 
 
 def normalize(im, mean, std, min_value=[0, 0, 0], max_value=[255, 255, 255]):
@@ -427,10 +427,6 @@ def to_uint8(im, is_linear=False):
         return np.uint8(stretched_img * 255)
 
     dtype = im.dtype.name
-    dtypes = ["uint8", "uint16", "uint32", "float32"]
-    if dtype not in dtypes:
-        raise ValueError(
-            f"'dtype' must be uint8/uint16/uint32/float32, not {dtype}.")
     if dtype != "uint8":
         im = _sample_norm(im)
     if is_linear:
@@ -533,26 +529,6 @@ def de_haze(im, gamma=False):
     return (result * 255).astype("uint8")
 
 
-def pca(im, dim=3, whiten=True):
-    """ Dimensionality reduction of PCA. 
-
-    Args:
-        im (np.ndarray): The image.
-        dim (int, optional): Reserved dimensions. Defaults to 3.
-        whiten (bool, optional): PCA whiten or not. Defaults to True.
-
-    Returns:
-        np.ndarray: The image after PCA.
-    """
-    H, W, C = im.shape
-    n_im = np.reshape(im, (-1, C))
-    pca = PCA(n_components=dim, whiten=whiten)
-    im_pca = pca.fit_transform(n_im)
-    result = np.reshape(im_pca, (H, W, dim))
-    result = np.clip(result, 0, 1)
-    return (result * 255).astype("uint8")
-
-
 def match_histograms(im, ref):
     """
     Match the cumulative histogram of one image to another.
@@ -615,3 +591,22 @@ def match_by_regression(im, ref, pif_loc=None):
         matched = _linear_regress(im, ref, pif_loc).astype(im.dtype)
 
     return matched
+
+
+def inv_pca(im, joblib_path):
+    """
+    Restore PCA result.
+
+    Args:
+        im (np.ndarray): The input image after PCA.
+        joblib_path (str): Path of *.joblib about PCA.
+
+    Returns:
+        np.ndarray: The raw input image.
+    """
+    pca = load(joblib_path)
+    H, W, C = im.shape
+    n_im = np.reshape(im, (-1, C))
+    r_im = pca.inverse_transform(n_im)
+    r_im = np.reshape(r_im, (H, W, -1))
+    return r_im

+ 15 - 10
paddlers/transforms/operators.py

@@ -27,11 +27,12 @@ import numpy as np
 import cv2
 import imghdr
 from PIL import Image
+from joblib import load
 
 import paddlers
 from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
-    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, pca, select_bands, \
+    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, select_bands, \
     to_intensity, to_uint8, img_flip, img_simple_rotate
 
 __all__ = [
@@ -242,7 +243,7 @@ class Compose(Transform):
         ValueError: Invalid length of transforms.
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms, to_uint8=True):
         super(Compose, self).__init__()
         if not isinstance(transforms, list):
             raise TypeError(
@@ -253,7 +254,7 @@ class Compose(Transform):
                 'Length of transforms must not be less than 1, but received is {}'
                 .format(len(transforms)))
         self.transforms = transforms
-        self.decode_image = ImgDecoder()
+        self.decode_image = ImgDecoder(to_uint8=to_uint8)
         self.arrange_outputs = None
         self.apply_im_only = False
 
@@ -1552,18 +1553,22 @@ class DimReducing(Transform):
     Use PCA to reduce input image(s) dimension.
 
     Args: 
-        dim (int, optional): Reserved dimensions. Defaults to 3.
-        whiten (bool, optional): PCA whiten or not. Defaults to True.
+        joblib_path (str): Path of *.joblib about PCA.
     """
 
-    def __init__(self, dim=3, whiten=True):
+    def __init__(self, joblib_path):
         super(DimReducing, self).__init__()
-        self.dim = dim
-        self.whiten = whiten
+        ext = joblib_path.split(".")[-1]
+        if ext != "joblib":
+            raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(ext))
+        self.pca = load(joblib_path)
 
     def apply_im(self, image):
-        image = pca(image, self.dim, self.whiten)
-        return image
+        H, W, C = image.shape
+        n_im = np.reshape(image, (-1, C))
+        im_pca = self.pca.transform(n_im)
+        result = np.reshape(im_pca, (H, W, -1))
+        return result
 
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])

+ 2 - 2
tools/geojson2mask.py

@@ -18,7 +18,7 @@ import numpy as np
 import argparse
 import geojson
 from tqdm import tqdm
-from utils import Raster, save_mask_geotiff, Timer
+from utils import Raster, save_geotiff, Timer
 
 
 def _gt_convert(x_geo, y_geo, geotf):
@@ -48,7 +48,7 @@ def convert_data(image_path, geojson_path):
         # TODO: Label category
         cv2.fillPoly(tmp_img, [xy_points], 1)  # 多边形填充
     ext = "." + geojson_path.split(".")[-1]
-    save_mask_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
+    save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
 
 
 parser = argparse.ArgumentParser(description="input parameters")

+ 53 - 0
tools/pca.py

@@ -0,0 +1,53 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import numpy as np
+import argparse
+from sklearn.decomposition import PCA
+from joblib import dump
+from utils import Raster, Timer, save_geotiff
+
+
+@Timer
+def pca_train(img_path, save_dir="output", dim=3):
+    raster = Raster(img_path)
+    im = raster.getArray()
+    n_im = np.reshape(im, (-1, raster.bands))
+    pca = PCA(n_components=dim, whiten=True)
+    pca_model = pca.fit(n_im)
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
+    name = osp.splitext(osp.normpath(img_path).split(os.sep)[-1])[0]
+    model_save_path = osp.join(save_dir, (name + "_pca.joblib"))
+    image_save_path = osp.join(save_dir, (name + "_pca.tif"))
+    dump(pca_model, model_save_path)  # save model
+    output = pca_model.transform(n_im).reshape((raster.height, raster.width, -1))
+    save_geotiff(output, image_save_path, raster.proj, raster.geot)  # save tiff
+    print("The Image and model of PCA saved in {}.".format(save_dir))
+
+
+parser = argparse.ArgumentParser(description="input parameters")
+parser.add_argument("--im_path", type=str, required=True, \
+                    help="The path of HSIs image.")
+parser.add_argument("--save_dir", type=str, default="output", \
+                    help="The params(*.joblib) saved folder, `output` is the default.")
+parser.add_argument("--dim", type=int, default=3, \
+                    help="The dimension after reduced, `3` is the default.")
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    pca_train(args.im_path, args.save_dir, args.dim)

+ 20 - 20
tools/utils/__init__.py

@@ -1,20 +1,20 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import sys
-import os.path as osp
-sys.path.insert(0, osp.abspath(".."))  # add workspace
-
-from .raster import Raster, save_mask_geotiff, raster2uint8
-from .timer import Timer
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os.path as osp
+sys.path.insert(0, osp.abspath(".."))  # add workspace
+
+from .raster import Raster, raster2uint8, save_geotiff
+from .timer import Timer

+ 227 - 207
tools/utils/raster.py

@@ -1,207 +1,227 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os.path as osp
-from typing import List, Tuple, Union
-
-import numpy as np
-
-from paddlers.transforms.functions import to_uint8 as raster2uint8
-
-try:
-    from osgeo import gdal
-except:
-    import gdal
-
-
-class Raster:
-    def __init__(self,
-                 path: str,
-                 band_list: Union[List[int], Tuple[int], None]=None,
-                 to_uint8: bool=False) -> None:
-        """ Class of read raster.
-
-        Args:
-            path (str): The path of raster.
-            band_list (Union[List[int], Tuple[int], None], optional): 
-                band list (start with 1) or None (all of bands). Defaults to None.
-            to_uint8 (bool, optional): 
-                Convert uint8 or return raw data. Defaults to False.
-        """
-        super(Raster, self).__init__()
-        if osp.exists(path):
-            self.path = path
-            self.ext_type = path.split(".")[-1]
-            if self.ext_type.lower() in ["npy", "npz"]:
-                self._src_data = None
-            else:
-                try:
-                    # raster format support in GDAL: 
-                    # https://www.osgeo.cn/gdal/drivers/raster/index.html
-                    self._src_data = gdal.Open(path)
-                except:
-                    raise TypeError("Unsupported data format: `{}`".format(
-                        self.ext_type))
-            self.to_uint8 = to_uint8
-            self.setBands(band_list)
-            self._getInfo()
-        else:
-            raise ValueError("The path {0} not exists.".format(path))
-
-    def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
-        """ Set band of data.
-
-        Args:
-            band_list (Union[List[int], Tuple[int], None]): 
-                band list (start with 1) or None (all of bands).
-        """
-        self.bands = self._src_data.RasterCount
-        if band_list is not None:
-            if len(band_list) > self.bands:
-                raise ValueError(
-                    "The lenght of band_list must be less than {0}.".format(
-                        str(self.bands)))
-            if max(band_list) > self.bands or min(band_list) < 1:
-                raise ValueError("The range of band_list must within [1, {0}].".
-                                 format(str(self.bands)))
-        self.band_list = band_list
-
-    def getArray(
-            self,
-            start_loc: Union[List[int], Tuple[int], None]=None,
-            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
-        """ Get ndarray data 
-
-        Args:
-            start_loc (Union[List[int], Tuple[int], None], optional): 
-                Coordinates of the upper left corner of the block, if None means return full image.
-            block_size (Union[List[int], Tuple[int]], optional): 
-                Block size. Defaults to [512, 512].
-
-        Returns:
-            np.ndarray: data's ndarray.
-        """
-        if self._src_data is not None:
-            if start_loc is None:
-                return self._getArray()
-            else:
-                return self._getBlock(start_loc, block_size)
-        else:
-            print("Numpy doesn't support blocking temporarily.")
-            return self._getNumpy()
-
-    def _getInfo(self) -> None:
-        if self._src_data is not None:
-            self.width = self._src_data.RasterXSize
-            self.height = self._src_data.RasterYSize
-            self.geot = self._src_data.GetGeoTransform()
-            self.proj = self._src_data.GetProjection()
-            d_name = self._getBlock([0, 0], [1, 1]).dtype.name
-        else:
-            d_img = self._getNumpy()
-            d_shape = d_img.shape
-            d_name = d_img.dtype.name
-            if len(d_shape) == 3:
-                self.height, self.width, self.bands = d_shape
-            else:
-                self.height, self.width = d_shape
-                self.bands = 1
-            self.geot = None
-            self.proj = None
-        if "int8" in d_name:
-            self.datatype = gdal.GDT_Byte
-        elif "int16" in d_name:
-            self.datatype = gdal.GDT_UInt16
-        else:
-            self.datatype = gdal.GDT_Float32
-
-    def _getNumpy(self):
-        ima = np.load(self.path)
-        if self.band_list is not None:
-            band_array = []
-            for b in self.band_list:
-                band_i = ima[:, :, b - 1]
-                band_array.append(band_i)
-            ima = np.stack(band_array, axis=0)
-        return ima
-
-    def _getArray(
-            self,
-            window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
-        if window is not None:
-            xoff, yoff, xsize, ysize = window
-        if self.band_list is None:
-            if window is None:
-                ima = self._src_data.ReadAsArray()
-            else:
-                ima = self._src_data.ReadAsArray(xoff, yoff, xsize, ysize)
-        else:
-            band_array = []
-            for b in self.band_list:
-                if window is None:
-                    band_i = self._src_data.GetRasterBand(b).ReadAsArray()
-                else:
-                    band_i = self._src_data.GetRasterBand(b).ReadAsArray(
-                        xoff, yoff, xsize, ysize)
-                band_array.append(band_i)
-            ima = np.stack(band_array, axis=0)
-        if self.bands == 1:
-            if len(ima.shape) == 3:
-                ima = ima.squeeze(0)
-            # the type is complex means this is a SAR data
-            if isinstance(type(ima[0, 0]), complex):
-                ima = abs(ima)
-        else:
-            ima = ima.transpose((1, 2, 0))
-        if self.to_uint8 is True:
-            ima = raster2uint8(ima)
-        return ima
-
-    def _getBlock(
-            self,
-            start_loc: Union[List[int], Tuple[int]],
-            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
-        if len(start_loc) != 2 or len(block_size) != 2:
-            raise ValueError("The length start_loc/block_size must be 2.")
-        xoff, yoff = start_loc
-        xsize, ysize = block_size
-        if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height):
-            raise ValueError("start_loc must be within [0-{0}, 0-{1}].".format(
-                str(self.width), str(self.height)))
-        if xoff + xsize > self.width:
-            xsize = self.width - xoff
-        if yoff + ysize > self.height:
-            ysize = self.height - yoff
-        ima = self._getArray([int(xoff), int(yoff), int(xsize), int(ysize)])
-        h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape
-        if self.bands != 1:
-            tmp = np.zeros(
-                (block_size[0], block_size[1], self.bands), dtype=ima.dtype)
-            tmp[:h, :w, :] = ima
-        else:
-            tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype)
-            tmp[:h, :w] = ima
-        return tmp
-
-
-def save_mask_geotiff(mask: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None:
-    height, width = mask.shape
-    driver = gdal.GetDriverByName("GTiff")
-    dst_ds = driver.Create(save_path, width, height, 1, gdal.GDT_UInt16)
-    dst_ds.SetGeoTransform(geotf)
-    dst_ds.SetProjection(proj)
-    band = dst_ds.GetRasterBand(1)
-    band.WriteArray(mask)
-    dst_ds.FlushCache()
-    dst_ds = None
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+from typing import List, Tuple, Union
+
+import numpy as np
+
+from paddlers.transforms.functions import to_uint8 as raster2uint8
+
+try:
+    from osgeo import gdal
+except:
+    import gdal
+
+
+def _get_type(type_name: str) -> int:
+    if type_name in ["bool", "uint8"]:
+        gdal_type = gdal.GDT_Byte
+    elif type_name in ["int8", "int16"]:
+        gdal_type = gdal.GDT_Int16
+    elif type_name == "uint16":
+        gdal_type = gdal.GDT_UInt16
+    elif type_name == "int32":
+        gdal_type = gdal.GDT_Int32
+    elif type_name == "uint32":
+        gdal_type = gdal.GDT_UInt32
+    elif type_name in ["int64", "uint64", "float16", "float32"]:
+        gdal_type = gdal.GDT_Float32
+    elif type_name == "float64":
+        gdal_type = gdal.GDT_Float64
+    elif type_name == "complex64":
+        gdal_type = gdal.GDT_CFloat64
+    else:
+        raise TypeError("Non-suported data type `{}`.".format(type_name))
+    return gdal_type
+
+
+class Raster:
+    def __init__(self,
+                 path: str,
+                 band_list: Union[List[int], Tuple[int], None]=None,
+                 to_uint8: bool=False) -> None:
+        """ Class of read raster.
+        Args:
+            path (str): The path of raster.
+            band_list (Union[List[int], Tuple[int], None], optional): 
+                band list (start with 1) or None (all of bands). Defaults to None.
+            to_uint8 (bool, optional): 
+                Convert uint8 or return raw data. Defaults to False.
+        """
+        super(Raster, self).__init__()
+        if osp.exists(path):
+            self.path = path
+            self.ext_type = path.split(".")[-1]
+            if self.ext_type.lower() in ["npy", "npz"]:
+                self._src_data = None
+            else:
+                try:
+                    # raster format support in GDAL: 
+                    # https://www.osgeo.cn/gdal/drivers/raster/index.html
+                    self._src_data = gdal.Open(path)
+                except:
+                    raise TypeError("Unsupported data format: `{}`".format(
+                        self.ext_type))
+            self.to_uint8 = to_uint8
+            self.setBands(band_list)
+            self._getInfo()
+        else:
+            raise ValueError("The path {0} not exists.".format(path))
+
+    def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
+        """ Set band of data.
+        Args:
+            band_list (Union[List[int], Tuple[int], None]): 
+                band list (start with 1) or None (all of bands).
+        """
+        self.bands = self._src_data.RasterCount
+        if band_list is not None:
+            if len(band_list) > self.bands:
+                raise ValueError(
+                    "The lenght of band_list must be less than {0}.".format(
+                        str(self.bands)))
+            if max(band_list) > self.bands or min(band_list) < 1:
+                raise ValueError("The range of band_list must within [1, {0}].".
+                                 format(str(self.bands)))
+        self.band_list = band_list
+
+    def getArray(
+            self,
+            start_loc: Union[List[int], Tuple[int], None]=None,
+            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+        """ Get ndarray data 
+        Args:
+            start_loc (Union[List[int], Tuple[int], None], optional): 
+                Coordinates of the upper left corner of the block, if None means return full image.
+            block_size (Union[List[int], Tuple[int]], optional): 
+                Block size. Defaults to [512, 512].
+        Returns:
+            np.ndarray: data's ndarray.
+        """
+        if self._src_data is not None:
+            if start_loc is None:
+                return self._getArray()
+            else:
+                return self._getBlock(start_loc, block_size)
+        else:
+            print("Numpy doesn't support blocking temporarily.")
+            return self._getNumpy()
+
+    def _getInfo(self) -> None:
+        if self._src_data is not None:
+            self.width = self._src_data.RasterXSize
+            self.height = self._src_data.RasterYSize
+            self.geot = self._src_data.GetGeoTransform()
+            self.proj = self._src_data.GetProjection()
+            d_name = self._getBlock([0, 0], [1, 1]).dtype.name
+        else:
+            d_img = self._getNumpy()
+            d_shape = d_img.shape
+            d_name = d_img.dtype.name
+            if len(d_shape) == 3:
+                self.height, self.width, self.bands = d_shape
+            else:
+                self.height, self.width = d_shape
+                self.bands = 1
+            self.geot = None
+            self.proj = None
+        self.datatype = _get_type(d_name)
+
+    def _getNumpy(self):
+        ima = np.load(self.path)
+        if self.band_list is not None:
+            band_array = []
+            for b in self.band_list:
+                band_i = ima[:, :, b - 1]
+                band_array.append(band_i)
+            ima = np.stack(band_array, axis=0)
+        return ima
+
+    def _getArray(
+            self,
+            window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
+        if window is not None:
+            xoff, yoff, xsize, ysize = window
+        if self.band_list is None:
+            if window is None:
+                ima = self._src_data.ReadAsArray()
+            else:
+                ima = self._src_data.ReadAsArray(xoff, yoff, xsize, ysize)
+        else:
+            band_array = []
+            for b in self.band_list:
+                if window is None:
+                    band_i = self._src_data.GetRasterBand(b).ReadAsArray()
+                else:
+                    band_i = self._src_data.GetRasterBand(b).ReadAsArray(
+                        xoff, yoff, xsize, ysize)
+                band_array.append(band_i)
+            ima = np.stack(band_array, axis=0)
+        if self.bands == 1:
+            if len(ima.shape) == 3:
+                ima = ima.squeeze(0)
+            # the type is complex means this is a SAR data
+            if isinstance(type(ima[0, 0]), complex):
+                ima = abs(ima)
+        else:
+            ima = ima.transpose((1, 2, 0))
+        if self.to_uint8 is True:
+            ima = raster2uint8(ima)
+        return ima
+
+    def _getBlock(
+            self,
+            start_loc: Union[List[int], Tuple[int]],
+            block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+        if len(start_loc) != 2 or len(block_size) != 2:
+            raise ValueError("The length start_loc/block_size must be 2.")
+        xoff, yoff = start_loc
+        xsize, ysize = block_size
+        if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height):
+            raise ValueError("start_loc must be within [0-{0}, 0-{1}].".format(
+                str(self.width), str(self.height)))
+        if xoff + xsize > self.width:
+            xsize = self.width - xoff
+        if yoff + ysize > self.height:
+            ysize = self.height - yoff
+        ima = self._getArray([int(xoff), int(yoff), int(xsize), int(ysize)])
+        h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape
+        if self.bands != 1:
+            tmp = np.zeros(
+                (block_size[0], block_size[1], self.bands), dtype=ima.dtype)
+            tmp[:h, :w, :] = ima
+        else:
+            tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype)
+            tmp[:h, :w] = ima
+        return tmp
+
+
+def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None:
+    height, width, channel = image.shape
+    data_type = _get_type(image.dtype.name)
+    driver = gdal.GetDriverByName("GTiff")
+    dst_ds = driver.Create(save_path, width, height, channel, data_type)
+    dst_ds.SetGeoTransform(geotf)
+    dst_ds.SetProjection(proj)
+    if channel > 1:
+        for i in range(channel):
+            band = dst_ds.GetRasterBand(i + 1)
+            band.WriteArray(image[:, :, i])
+            dst_ds.FlushCache()
+    else:
+        band = dst_ds.GetRasterBand(1)
+        band.WriteArray(image)
+        dst_ds.FlushCache()
+    dst_ds = None