Преглед изворни кода

[Feature] [BugFix] Update/Fix to_uint8 (#59)

Yizhou Chen пре 3 година
родитељ
комит
0728d38e21
1 измењених фајлова са 16 додато и 41 уклоњено
  1. 16 41
      paddlers/transforms/functions.py

+ 16 - 41
paddlers/transforms/functions.py

@@ -15,10 +15,8 @@
 import cv2
 import numpy as np
 import copy
-import operator
 import shapely.ops
 from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
-from functools import reduce
 from sklearn.decomposition import PCA
 from sklearn.linear_model import LinearRegression
 from skimage import exposure
@@ -383,18 +381,19 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
     return rle
 
 
-def to_uint8(im):
+def to_uint8(im, is_linear=False):
     """ Convert raster to uint8.
     
     Args:
         im (np.ndarray): The image.
+        is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
 
     Returns:
         np.ndarray: Image on uint8.
     """
 
     # 2% linear stretch
-    def _two_percentLinear(image, max_out=255, min_out=0):
+    def _two_percent_linear(image, max_out=255, min_out=0):
         def _gray_process(gray, maxout=max_out, minout=min_out):
             # get the corresponding gray level at 98% histogram
             high_value = np.percentile(gray, 98)
@@ -402,7 +401,7 @@ def to_uint8(im):
             truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
             processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
                              (maxout - minout)
-            return processed_gray
+            return np.uint8(processed_gray)
 
         if len(image.shape) == 3:
             processes = []
@@ -414,52 +413,28 @@ def to_uint8(im):
         return np.uint8(result)
 
     # simple image standardization
-    def _sample_norm(image, NUMS=65536):
+    def _sample_norm(image):
         stretches = []
         if len(image.shape) == 3:
             for b in range(image.shape[-1]):
-                stretched = _stretch(image[:, :, b], NUMS)
-                stretched /= float(NUMS)
+                stretched = exposure.equalize_hist(image[:, :, b])
+                stretched /= float(np.max(stretched))
                 stretches.append(stretched)
             stretched_img = np.stack(stretches, axis=2)
         else:  # if len(image.shape) == 2
-            stretched_img = _stretch(image, NUMS)
+            stretched_img = exposure.equalize_hist(image)
         return np.uint8(stretched_img * 255)
 
-    # histogram equalization
-    def _stretch(ima, NUMS):
-        hist = _histogram(ima, NUMS)
-        lut = []
-        for bt in range(0, len(hist), NUMS):
-            # step size
-            step = reduce(operator.add, hist[bt:bt + NUMS]) / (NUMS - 1)
-            # create balanced lookup table
-            n = 0
-            for i in range(NUMS):
-                lut.append(n / step)
-                n += hist[i + bt]
-            np.take(lut, ima, out=ima)
-            return ima
-
-    # calculate histogram
-    def _histogram(ima, NUMS):
-        bins = list(range(0, NUMS))
-        flat = ima.flat
-        n = np.searchsorted(np.sort(flat), bins)
-        n = np.concatenate([n, [len(flat)]])
-        hist = n[1:] - n[:-1]
-        return hist
-
     dtype = im.dtype.name
-    dtypes = ["uint8", "uint16", "float32"]
+    dtypes = ["uint8", "uint16", "uint32", "float32"]
     if dtype not in dtypes:
-        raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
-    if dtype == "uint8":
-        return im
-    else:
-        if dtype == "float32":
-            im = _sample_norm(im)
-        return _two_percentLinear(im)
+        raise ValueError(
+            f"'dtype' must be uint8/uint16/uint32/float32, not {dtype}.")
+    if dtype != "uint8":
+        im = _sample_norm(im)
+    if is_linear:
+        im = _two_percent_linear(im)
+    return im
 
 
 def to_intensity(im):