ソースを参照

Fix to_uint8 bug

Bobholamovic 2 年 前
コミット
2cc44b7a8f
1 ファイル変更17 行追加7 行削除
  1. 17 7
      paddlers/transforms/functions.py

+ 17 - 7
paddlers/transforms/functions.py

@@ -396,6 +396,14 @@ def to_uint8(im, norm=True, stretch=False):
         np.ndarray: Image data with unit8 type.
     """
 
+    EPS = 1e-32
+
+    def _minmax_norm(image):
+        image = image.astype(np.float32)
+        min_val = image.min()
+        max_val = image.max()
+        return (image - min_val) / (max_val - min_val + EPS)
+
     # 2% linear stretch
     def _two_percent_linear(image, max_out=1., min_out=0.):
         def _gray_process(gray, maxout=max_out, minout=min_out):
@@ -403,7 +411,7 @@ def to_uint8(im, norm=True, stretch=False):
             high_value = np.percentile(gray, 98)
             low_value = np.percentile(gray, 2)
             truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
-            processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
+            processed_gray = ((truncated_gray - low_value) / (high_value - low_value + EPS)) * \
                              (maxout - minout)
             return processed_gray
 
@@ -416,26 +424,28 @@ def to_uint8(im, norm=True, stretch=False):
             result = _gray_process(image)
         return result
 
-    # Simple image standardization
-    def _sample_norm(image):
+    def _equalize_hist(image):
         stretches = []
         if len(image.shape) == 3:
             for b in range(image.shape[-1]):
                 stretched = exposure.equalize_hist(image[:, :, b])
-                stretched /= float(np.max(stretched))
+                stretched /= float(np.max(stretched)) + EPS
                 stretches.append(stretched)
             stretched_img = np.stack(stretches, axis=2)
         else:  # if len(image.shape) == 2
             stretched_img = exposure.equalize_hist(image)
+            stretched_img /= float(np.max(stretched_img)) + EPS
         return stretched_img
 
     dtype = im.dtype.name
-    if dtype == "uint8" and !stretch:
+    if dtype == 'uint8' and not stretch:
         return im
-    if dtype != "uint8" and norm:
-        im = _sample_norm(im)
     if stretch:
         im = _two_percent_linear(im)
+    else:
+        im = _minmax_norm(im)
+    if norm:
+        im = _equalize_hist(im)
     im = np.uint8(im * 255)
     return im