|
@@ -13,7 +13,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
import os.path as osp
|
|
|
-from typing import List, Tuple, Union
|
|
|
+from typing import List, Tuple, Union, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
@@ -49,36 +49,45 @@ def _get_type(type_name: str) -> int:
|
|
|
|
|
|
class Raster:
|
|
|
def __init__(self,
|
|
|
- path: str,
|
|
|
+ path: Optional[str],
|
|
|
+ gdal_obj: Optional[gdal.Dataset]=None,
|
|
|
band_list: Union[List[int], Tuple[int], None]=None,
|
|
|
to_uint8: bool=False) -> None:
|
|
|
""" Class of read raster.
|
|
|
Args:
|
|
|
- path (str): The path of raster.
|
|
|
+ path (Optional[str]): The path of raster.
|
|
|
+ gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
|
|
|
band_list (Union[List[int], Tuple[int], None], optional):
|
|
|
band list (start with 1) or None (all of bands). Defaults to None.
|
|
|
to_uint8 (bool, optional):
|
|
|
Convert uint8 or return raw data. Defaults to False.
|
|
|
"""
|
|
|
super(Raster, self).__init__()
|
|
|
- if osp.exists(path):
|
|
|
- self.path = path
|
|
|
- self.ext_type = path.split(".")[-1]
|
|
|
- if self.ext_type.lower() in ["npy", "npz"]:
|
|
|
- self._src_data = None
|
|
|
+ if path is not None:
|
|
|
+ if osp.exists(path):
|
|
|
+ self.path = path
|
|
|
+ self.ext_type = path.split(".")[-1]
|
|
|
+ if self.ext_type.lower() in ["npy", "npz"]:
|
|
|
+ self._src_data = None
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ # raster format support in GDAL:
|
|
|
+ # https://www.osgeo.cn/gdal/drivers/raster/index.html
|
|
|
+ self._src_data = gdal.Open(path)
|
|
|
+ except:
|
|
|
+ raise TypeError(
|
|
|
+ "Unsupported data format: `{}`".format(self.ext_type))
|
|
|
else:
|
|
|
- try:
|
|
|
- # raster format support in GDAL:
|
|
|
- # https://www.osgeo.cn/gdal/drivers/raster/index.html
|
|
|
- self._src_data = gdal.Open(path)
|
|
|
- except:
|
|
|
- raise TypeError("Unsupported data format: `{}`".format(
|
|
|
- self.ext_type))
|
|
|
- self.to_uint8 = to_uint8
|
|
|
- self.setBands(band_list)
|
|
|
- self._getInfo()
|
|
|
+ raise ValueError("The path {0} not exists.".format(path))
|
|
|
else:
|
|
|
- raise ValueError("The path {0} not exists.".format(path))
|
|
|
+ if gdal_obj is not None:
|
|
|
+ self._src_data = gdal_obj
|
|
|
+ else:
|
|
|
+ raise ValueError("At least one of `path` and `gdal_obj` is not None.")
|
|
|
+ self.to_uint8 = to_uint8
|
|
|
+ self._getInfo()
|
|
|
+ self.setBands(band_list)
|
|
|
+ self._getType()
|
|
|
|
|
|
def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
|
|
|
""" Set band of data.
|
|
@@ -86,7 +95,6 @@ class Raster:
|
|
|
band_list (Union[List[int], Tuple[int], None]):
|
|
|
band list (start with 1) or None (all of bands).
|
|
|
"""
|
|
|
- self.bands = self._src_data.RasterCount
|
|
|
if band_list is not None:
|
|
|
if len(band_list) > self.bands:
|
|
|
raise ValueError(
|
|
@@ -99,8 +107,8 @@ class Raster:
|
|
|
|
|
|
def getArray(
|
|
|
self,
|
|
|
- start_loc: Union[List[int], Tuple[int], None]=None,
|
|
|
- block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
|
|
|
+ start_loc: Union[List[int], Tuple[int, int], None]=None,
|
|
|
+ block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
|
|
|
""" Get ndarray data
|
|
|
Args:
|
|
|
start_loc (Union[List[int], Tuple[int], None], optional):
|
|
@@ -123,13 +131,12 @@ class Raster:
|
|
|
if self._src_data is not None:
|
|
|
self.width = self._src_data.RasterXSize
|
|
|
self.height = self._src_data.RasterYSize
|
|
|
+ self.bands = self._src_data.RasterCount
|
|
|
self.geot = self._src_data.GetGeoTransform()
|
|
|
self.proj = self._src_data.GetProjection()
|
|
|
- d_name = self._getBlock([0, 0], [1, 1]).dtype.name
|
|
|
else:
|
|
|
d_img = self._getNumpy()
|
|
|
d_shape = d_img.shape
|
|
|
- d_name = d_img.dtype.name
|
|
|
if len(d_shape) == 3:
|
|
|
self.height, self.width, self.bands = d_shape
|
|
|
else:
|
|
@@ -137,6 +144,9 @@ class Raster:
|
|
|
self.bands = 1
|
|
|
self.geot = None
|
|
|
self.proj = None
|
|
|
+
|
|
|
+ def _getType(self) -> None:
|
|
|
+ d_name = self.getArray([0, 0], [1, 1]).dtype.name
|
|
|
self.datatype = _get_type(d_name)
|
|
|
|
|
|
def _getNumpy(self):
|
|
@@ -151,7 +161,9 @@ class Raster:
|
|
|
|
|
|
def _getArray(
|
|
|
self,
|
|
|
- window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
|
|
|
+ window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray:
|
|
|
+ if self._src_data is None:
|
|
|
+ raise ValueError("The raster is None.")
|
|
|
if window is not None:
|
|
|
xoff, yoff, xsize, ysize = window
|
|
|
if self.band_list is None:
|
|
@@ -183,8 +195,8 @@ class Raster:
|
|
|
|
|
|
def _getBlock(
|
|
|
self,
|
|
|
- start_loc: Union[List[int], Tuple[int]],
|
|
|
- block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
|
|
|
+ start_loc: Union[List[int], Tuple[int, int]],
|
|
|
+ block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
|
|
|
if len(start_loc) != 2 or len(block_size) != 2:
|
|
|
raise ValueError("The length start_loc/block_size must be 2.")
|
|
|
xoff, yoff = start_loc
|
|
@@ -208,9 +220,21 @@ class Raster:
|
|
|
return tmp
|
|
|
|
|
|
|
|
|
-def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None:
|
|
|
- height, width, channel = image.shape
|
|
|
- data_type = _get_type(image.dtype.name)
|
|
|
+def save_geotiff(image: np.ndarray,
|
|
|
+ save_path: str,
|
|
|
+ proj: str,
|
|
|
+ geotf: Tuple,
|
|
|
+ use_type: Optional[int]=None,
|
|
|
+ clear_ds: bool=True) -> None:
|
|
|
+ if len(image.shape) == 2:
|
|
|
+ height, width = image.shape
|
|
|
+ channel = 1
|
|
|
+ else:
|
|
|
+ height, width, channel = image.shape
|
|
|
+ if use_type is not None:
|
|
|
+ data_type = use_type
|
|
|
+ else:
|
|
|
+ data_type = _get_type(image.dtype.name)
|
|
|
driver = gdal.GetDriverByName("GTiff")
|
|
|
dst_ds = driver.Create(save_path, width, height, channel, data_type)
|
|
|
dst_ds.SetGeoTransform(geotf)
|
|
@@ -224,4 +248,6 @@ def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) ->
|
|
|
band = dst_ds.GetRasterBand(1)
|
|
|
band.WriteArray(image)
|
|
|
dst_ds.FlushCache()
|
|
|
- dst_ds = None
|
|
|
+ if clear_ds:
|
|
|
+ dst_ds = None
|
|
|
+ return dst_ds
|