Просмотр исходного кода

[Clean] clean raster to transforms func

geoyee 3 лет назад
Родитель
Сommit
a7194612bf

+ 1 - 2
paddlers/datasets/__init__.py

@@ -15,5 +15,4 @@
 from .voc import VOCDetection
 from .seg_dataset import SegDataset
 from .cd_dataset import CDDataset
-from .clas_dataset import ClasDataset
-from .raster import Raster
+from .clas_dataset import ClasDataset

+ 127 - 8
paddlers/transforms/functions.py

@@ -14,10 +14,11 @@
 
 import cv2
 import numpy as np
-
+import copy
+import operator
 import shapely.ops
 from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
-import copy
+from functools import reduce
 from sklearn.decomposition import PCA
 
 
@@ -194,6 +195,122 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
     return rle
 
 
+def to_uint8(im):
+    """ Convert raster to uint8.
+    
+    Args:
+        im (np.ndarray): The image.
+
+    Returns:
+        np.ndarray: Image on uint8.
+    """
+    # 2% linear stretch
+    def _two_percentLinear(image, max_out=255, min_out=0):
+        def _gray_process(gray, maxout=max_out, minout=min_out):
+            # get the corresponding gray level at 98% histogram
+            high_value = np.percentile(gray, 98)
+            low_value = np.percentile(gray, 2)
+            truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
+            processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
+                             (maxout - minout)
+            return processed_gray
+        if len(image.shape) == 3:
+            processes = []
+            for b in range(image.shape[-1]):
+                processes.append(_gray_process(image[:, :, b]))
+            result = np.stack(processes, axis=2)
+        else:  # if len(image.shape) == 2
+            result = _gray_process(image)
+        return np.uint8(result)
+
+    # simple image standardization
+    def _sample_norm(image, NUMS=65536):
+        stretches = []
+        if len(image.shape) == 3:
+            for b in range(image.shape[-1]):
+                stretched = _stretch(image[:, :, b], NUMS)
+                stretched /= float(NUMS)
+                stretches.append(stretched)
+            stretched_img = np.stack(stretches, axis=2)
+        else:  # if len(image.shape) == 2
+            stretched_img = _stretch(image, NUMS)
+        return np.uint8(stretched_img * 255)
+
+    # histogram equalization
+    def _stretch(ima, NUMS):
+        hist = _histogram(ima, NUMS)
+        lut = []
+        for bt in range(0, len(hist), NUMS):
+            # step size
+            step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
+            # create balanced lookup table
+            n = 0
+            for i in range(NUMS):
+                lut.append(n / step)
+                n += hist[i + bt]
+            np.take(lut, ima, out=ima)
+            return ima
+
+    # calculate histogram
+    def _histogram(ima, NUMS):
+        bins = list(range(0, NUMS))
+        flat = ima.flat
+        n = np.searchsorted(np.sort(flat), bins)
+        n = np.concatenate([n, [len(flat)]])
+        hist = n[1:] - n[:-1]
+        return hist
+
+    dtype = im.dtype.name
+    dtypes = ["uint8", "uint16", "float32"]
+    if dtype not in dtypes:
+        raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
+    if dtype == "uint8":
+        return im
+    else:
+        if dtype == "float32":
+            im = _sample_norm(im)
+        return _two_percentLinear(im)
+
+
+def to_intensity(im):
+    """ calculate SAR data's intensity diagram.
+
+    Args:
+        im (np.ndarray): The SAR image.
+
+    Returns:
+        np.ndarray: Intensity diagram.
+    """
+    if len(im.shape) != 2:
+        raise ValueError("im's shape must be 2.")
+    # the type is complex means this is a SAR data
+    if isinstance(type(im[0, 0]), complex):
+        im = abs(im)
+    return im
+
+
+def select_bands(im, band_list=[1, 2, 3]):
+    """ Select bands.
+
+    Args:
+        im (np.ndarray): The image.
+        band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
+
+    Returns:
+        np.ndarray: The image after band selected.
+    """
+    total_band = im.shape[-1]
+    result = []
+    for band in band_list:
+        band = int(band - 1)
+        if band < 0 or band >= total_band:
+            raise ValueError(
+                "The element in band_list must > 1 and <= {}.".format(str(total_band)))
+        result.append()
+    ima = np.stack(result, axis=0)
+    return ima
+
+
 def matching(im1, im2):
     """ Match two images, used change detection. (Just RGB)
 
@@ -214,8 +331,10 @@ def matching(im1, im2):
     for m, n in mathces:
         if m.distance < 0.75 * n.distance:
             good_matches.append([m])
-    src_automatic_points = np.float32([kp1[m[0].queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
-    den_automatic_points = np.float32([kp2[m[0].trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
+    src_automatic_points = np.float32([kp1[m[0].queryIdx].pt \
+                                      for m in good_matches]).reshape(-1, 1, 2)
+    den_automatic_points = np.float32([kp2[m[0].trainIdx].pt \
+                                      for m in good_matches]).reshape(-1, 1, 2)
     H, _ = cv2.findHomography(src_automatic_points, den_automatic_points, cv2.RANSAC, 5.0)
     im1_t = cv2.warpPerspective(im1, H, (im2.shape[1], im2.shape[0]))
     return im1_t, im2
@@ -231,7 +350,7 @@ def de_haze(im, gamma=False):
     Returns:
         np.ndarray: The image after defogged.
     """
-    def guided_filter(I, p, r, eps):
+    def _guided_filter(I, p, r, eps):
         m_I = cv2.boxFilter(I, -1, (r, r))
         m_p = cv2.boxFilter(p, -1, (r, r))
         m_Ip = cv2.boxFilter(I * p, -1, (r, r))
@@ -244,11 +363,11 @@ def de_haze(im, gamma=False):
         m_b = cv2.boxFilter(b, -1, (r, r))
         return m_a * I + m_b
 
-    def de_fog(im, r, w, maxatmo_mask, eps):
+    def _de_fog(im, r, w, maxatmo_mask, eps):
         # im is RGB and range[0, 1]
         atmo_mask = np.min(im, 2)
         dark_channel = cv2.erode(atmo_mask, np.ones((15, 15)))
-        atmo_mask = guided_filter(atmo_mask, dark_channel, r, eps)
+        atmo_mask = _guided_filter(atmo_mask, dark_channel, r, eps)
         bins = 2000
         ht = np.histogram(atmo_mask, bins)
         d = np.cumsum(ht[0]) / float(atmo_mask.size)
@@ -262,7 +381,7 @@ def de_haze(im, gamma=False):
     if np.max(im) > 1:
         im = im / 255.
     result = np.zeros(im.shape)
-    mask_img, atmo_illum = de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
+    mask_img, atmo_illum = _de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
     for k in range(3):
         result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
     result = np.clip(result, 0, 1)

+ 1 - 2
paddlers/utils/__init__.py

@@ -21,5 +21,4 @@ from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkp
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress
 from .stats import SmoothedValue, TrainingStats
-from .shm import _get_shared_memory_size_in_M
-from .convert import raster2uint8
+from .shm import _get_shared_memory_size_in_M

+ 0 - 95
paddlers/utils/convert.py

@@ -1,95 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import numpy as np
-import operator
-from functools import reduce
-
-
-def raster2uint8(image: np.ndarray) -> np.ndarray:
-    """ Convert raster to uint8.
-    Args:
-        image (np.ndarray): image.
-    Returns:
-        np.ndarray: image on uint8.
-    """
-    dtype = image.dtype.name
-    dtypes = ["uint8", "uint16", "float32"]
-    if dtype not in dtypes:
-        raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
-    if dtype == "uint8":
-        return image
-    else:
-        if dtype == "float32":
-            image = _sample_norm(image)
-        return _two_percentLinear(image)
-
-
-# 2% linear stretch
-def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray:
-    def _gray_process(gray, maxout=max_out, minout=min_out):
-        # get the corresponding gray level at 98% histogram
-        high_value = np.percentile(gray, 98)
-        low_value = np.percentile(gray, 2)
-        truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
-        processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout)
-        return processed_gray
-    if len(image.shape) == 3:
-        processes = []
-        for b in range(image.shape[-1]):
-            processes.append(_gray_process(image[:, :, b]))
-        result = np.stack(processes, axis=2)
-    else:  # if len(image.shape) == 2
-        result = _gray_process(image)
-    return np.uint8(result)
-
-
-# simple image standardization
-def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray:
-    stretches = []
-    if len(image.shape) == 3:
-        for b in range(image.shape[-1]):
-            stretched = _stretch(image[:, :, b], NUMS)
-            stretched /= float(NUMS)
-            stretches.append(stretched)
-        stretched_img = np.stack(stretches, axis=2)
-    else:  # if len(image.shape) == 2
-        stretched_img = _stretch(image, NUMS)
-    return np.uint8(stretched_img * 255)
-
-
-# histogram equalization
-def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray:
-    hist = _histogram(ima, NUMS)
-    lut = []
-    for bt in range(0, len(hist), NUMS):
-        # step size
-        step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
-        # create balanced lookup table
-        n = 0
-        for i in range(NUMS):
-            lut.append(n / step)
-            n += hist[i + bt]
-        np.take(lut, ima, out=ima)
-        return ima
-
-
-# calculate histogram
-def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray:
-    bins = list(range(0, NUMS))
-    flat = ima.flat
-    n = np.searchsorted(np.sort(flat), bins)
-    n = np.concatenate([n, [len(flat)]])
-    hist = n[1:] - n[:-1]
-    return hist

+ 2 - 5
tools/mask2shp.py

@@ -12,15 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
-import os.path as osp
-sys.path.insert(0, osp.abspath(".."))  # add workspace
-
 import os
+import os.path as osp
 import numpy as np
 import argparse
 from PIL import Image
-from paddlers.datasets.raster import Raster
+from utils import Raster
 
 try:
     from osgeo import gdal, ogr, osr

+ 2 - 5
tools/spliter.py

@@ -12,15 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
-import os.path as osp
-sys.path.insert(0, osp.abspath(".."))  # add workspace
-
 import os
+import os.path as osp
 import argparse
 from math import ceil
 from PIL import Image
-from paddlers.datasets.raster import Raster
+from utils import Raster
 
 
 def split_data(image_path, block_size, save_folder):

+ 19 - 0
tools/utils/__init__.py

@@ -0,0 +1,19 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os.path as osp
+sys.path.insert(0, osp.abspath(".."))  # add workspace
+
+from .raster import Raster

+ 1 - 1
paddlers/datasets/raster.py → tools/utils/raster.py

@@ -15,7 +15,7 @@
 import os.path as osp
 import numpy as np
 from typing import List, Tuple, Union
-from paddlers.utils import raster2uint8
+from paddlers.transforms.functions import to_uint8 as raster2uint8
 
 try:
     from osgeo import gdal