Explorar el Código

[Feature] Add raster class

geoyee hace 3 años
padre
commit
0c35d0e108

+ 1 - 0
paddlers/datasets/__init__.py

@@ -1,2 +1,3 @@
 from .voc import VOCDetection
 from .seg_dataset import SegDataset
+from .raster import Raster

+ 140 - 0
paddlers/datasets/raster.py

@@ -0,0 +1,140 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import numpy as np
+from typing import List, Tuple, Union
+from paddlers.utils import raster2uint8
+
+try:
+    from osgeo import gdal
+except:
+    import gdal
+
+
+class Raster:
+    def __init__(self, 
+                 path: str,
+                 band_list: Union[List[int], Tuple[int], None]=None, 
+                 is_sar: bool=False,  # TODO: Remove this param
+                 is_src: bool=False) -> None:
+        """ Class of read raster.
+
+        Args:
+            path (str): The path of raster.
+            band_list (Union[List[int], Tuple[int], None], optional): 
+                band list (start with 1) or None (all of bands). Defaults to None.
+            is_sar (bool, optional): The raster is SAR or not. Defaults to False.
+            is_src (bool, optional): 
+                Return raw data or not (convert uint8/float32). Defaults to False.
+        """
+        super(Raster, self).__init__()
+        if osp.exists(path):
+            self.path = path
+            self.__src_data = gdal.Open(path)
+            self.__getInfo()
+            self.is_sar = is_sar
+            self.is_src = is_src
+            self.setBands(band_list)
+        else:
+            raise ValueError("The path {0} not exists.".format(path))
+
+    def setBands(self,
+                 band_list: Union[List[int], Tuple[int], None]) -> None:
+        """ Set band of data.
+
+        Args:
+            band_list (Union[List[int], Tuple[int], None]): 
+                band list (start with 1) or None (all of bands).
+        """
+        if band_list is not None:
+            if len(band_list) > self.bands:
+                raise ValueError("The lenght of band_list must be less than {0}.".format(str(self.bands)))
+            if max(band_list) > self.bands or min(band_list) < 1:
+                raise ValueError("The range of band_list must within [1, {0}].".format(str(self.bands)))
+        self.band_list = band_list
+
+    def getArray(self,
+                 start_loc: Union[List[int], Tuple[int], None]=None, 
+                 block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+        """ Get ndarray data 
+
+        Args:
+            start_loc (Union[List[int], Tuple[int], None], optional): 
+                Coordinates of the upper left corner of the block, if None means return full image.
+            block_size (Union[List[int], Tuple[int]], optional): 
+                Block size. Defaults to [512, 512].
+
+        Returns:
+            np.ndarray: data's ndarray.
+        """
+        if start_loc is None:
+            return self.__getAarray()
+        else:
+            return self.__getBlock(start_loc, block_size)
+
+    def __getInfo(self) -> None:
+        self.bands = self.__src_data.RasterCount
+        self.width = self.__src_data.RasterXSize
+        self.height = self.__src_data.RasterYSize
+
+    def __getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
+        if window is not None:
+            xoff, yoff, xsize, ysize = window
+        if self.band_list is None:
+            if window is None:
+                ima = self.__src_data.ReadAsArray()
+            else:
+                ima = self.__src_data.ReadAsArray(xoff, yoff, xsize, ysize)
+        else:
+            band_array = []
+            for b in self.band_list:
+                if window is None:
+                    band_i = self.__src_data.GetRasterBand(b).ReadAsArray()
+                else:
+                    band_i = self.__src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize)
+                band_array.append(band_i)
+            ima = np.stack(band_array, axis=0)
+        if self.bands == 1:
+            if self.is_sar:
+                ima = abs(ima)
+        else:
+            ima = ima.transpose((1, 2, 0))
+        if self.is_src is False:
+            ima = raster2uint8(ima)
+        return ima
+
+    def __getBlock(self,
+                   start_loc: Union[List[int], Tuple[int]], 
+                   block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+        if len(start_loc) != 2 or len(block_size) != 2:
+            raise ValueError("The length start_loc/block_size must be 2.")
+        xoff, yoff = start_loc
+        xsize, ysize = block_size
+        if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height):
+            raise ValueError(
+                "start_loc must be within [0-{0}, 0-{1}].".format(str(self.width), str(self.height)))
+        if xoff + xsize > self.width:
+            xsize = self.width - xoff
+        if yoff + ysize > self.height:
+            ysize = self.height - yoff
+        ima = self.__getAarray([int(xoff), int(yoff), int(xsize), int(ysize)])
+        h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape
+        if self.bands != 1:
+            tmp = np.zeros((block_size[0], block_size[1], self.bands), dtype=ima.dtype)
+            tmp[:h, :w, :] = ima
+        else:
+            tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype)
+            tmp[:h, :w] = ima
+        return tmp

+ 1 - 0
paddlers/utils/__init__.py

@@ -22,3 +22,4 @@ from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress
 from .stats import SmoothedValue, TrainingStats
 from .shm import _get_shared_memory_size_in_M
+from .convert import raster2uint8

+ 95 - 0
paddlers/utils/convert.py

@@ -0,0 +1,95 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import operator
+from functools import reduce
+
+
+def raster2uint8(image: np.ndarray) -> np.ndarray:
+    """ Convert raster to uint8.
+    Args:
+        image (np.ndarray): image.
+    Returns:
+        np.ndarray: image on uint8.
+    """
+    dtype = image.dtype.name
+    dtypes = ["uint8", "uint16", "float32"]
+    if dtype not in dtypes:
+        raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
+    if dtype == "uint8":
+        return image
+    else:
+        if dtype == "float32":
+            image = _sample_norm(image)
+        return _two_percentLinear(image)
+
+
+# 2% linear stretch
+def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray:
+    def _gray_process(gray, maxout=max_out, minout=min_out):
+        # Get the corresponding gray level at 98% histogram
+        high_value = np.percentile(gray, 98)
+        low_value = np.percentile(gray, 2)
+        truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
+        processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout)
+        return processed_gray
+    if len(image.shape) == 3:
+        processes = []
+        for b in range(image.shape[-1]):
+            processes.append(_gray_process(image[:, :, b]))
+        result = np.stack(processes, axis=2)
+    else:  # if len(image.shape) == 2
+        result = _gray_process(image)
+    return np.uint8(result)
+
+
+# Simple image standardization
+def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray:
+    stretches = []
+    if len(image.shape) == 3:
+        for b in range(image.shape[-1]):
+            stretched = _stretch(image[:, :, b], NUMS)
+            stretched /= float(NUMS)
+            stretches.append(stretched)
+        stretched_img = np.stack(stretches, axis=2)
+    else:  # if len(image.shape) == 2
+        stretched_img = _stretch(image, NUMS)
+    return np.uint8(stretched_img * 255)
+
+
+# Histogram equalization
+def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray:
+    hist = _histogram(ima, NUMS)
+    lut = []
+    for bt in range(0, len(hist), NUMS):
+        # Step size
+        step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
+        # Create balanced lookup table
+        n = 0
+        for i in range(NUMS):
+            lut.append(n / step)
+            n += hist[i + bt]
+        np.take(lut, ima, out=ima)
+        return ima
+
+
+# Calculate histogram
+def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray:
+    bins = list(range(0, NUMS))
+    flat = ima.flat
+    n = np.searchsorted(np.sort(flat), bins)
+    n = np.concatenate([n, [len(flat)]])
+    hist = n[1:] - n[:-1]
+    return hist