Bobholamovic 2 anni fa
parent
commit
a3a915dca2
2 ha cambiato i file con 13 aggiunte e 7 eliminazioni
  1. 9 6
      paddlers/transforms/functions.py
  2. 4 1
      paddlers/transforms/operators.py

+ 9 - 6
paddlers/transforms/functions.py

@@ -382,12 +382,14 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
     return rle
 
 
-def to_uint8(im, stretch=False):
+def to_uint8(im, norm=True, stretch=False):
     """
     Convert raster data to uint8 type.
     
     Args:
         im (np.ndarray): Input raster image.
+        norm (bool, optional): Use hist equalization to normalize each band or not. 
+            Default is True.
         stretch (bool, optional): Use 2% linear stretch or not. Default is False.
 
     Returns:
@@ -395,7 +397,7 @@ def to_uint8(im, stretch=False):
     """
 
     # 2% linear stretch
-    def _two_percent_linear(image, max_out=255, min_out=0):
+    def _two_percent_linear(image, max_out=1., min_out=0.):
         def _gray_process(gray, maxout=max_out, minout=min_out):
             # Get the corresponding gray level at 98% in the histogram.
             high_value = np.percentile(gray, 98)
@@ -403,7 +405,7 @@ def to_uint8(im, stretch=False):
             truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
             processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
                              (maxout - minout)
-            return np.uint8(processed_gray)
+            return processed_gray
 
         if len(image.shape) == 3:
             processes = []
@@ -412,7 +414,7 @@ def to_uint8(im, stretch=False):
             result = np.stack(processes, axis=2)
         else:  # if len(image.shape) == 2
             result = _gray_process(image)
-        return np.uint8(result)
+        return result
 
     # Simple image standardization
     def _sample_norm(image):
@@ -425,13 +427,14 @@ def to_uint8(im, stretch=False):
             stretched_img = np.stack(stretches, axis=2)
         else:  # if len(image.shape) == 2
             stretched_img = exposure.equalize_hist(image)
-        return np.uint8(stretched_img * 255)
+        return stretched_img
 
     dtype = im.dtype.name
-    if dtype != "uint8":
+    if dtype != "uint8" and norm:
         im = _sample_norm(im)
     if stretch:
         im = _two_percent_linear(im)
+    im = np.uint8(im * 255)
     return im
 
 

+ 4 - 1
paddlers/transforms/operators.py

@@ -288,7 +288,10 @@ class DecodeImg(Transform):
             image = im_path
 
         if self.to_uint8:
-            image = F.to_uint8(image, stretch=self.use_stretch)
+            if self.use_stretch:
+                image = F.to_uint8(image, norm=False, stretch=True)
+            else:
+                image = F.to_uint8(image, norm=True, stretch=False)
 
         if self.read_geo_info:
             return image, geo_info_dict