Bladeren bron

[Fix] Fix decode_image array-str inconsistency bug

Bobholamovic 2 jaren geleden
bovenliggende
commit
87ecbafa94
4 gewijzigde bestanden met toevoegingen van 12 en 11 verwijderingen
  1. 2 2
      docs/apis/data.md
  2. 2 0
      paddlers/tasks/change_detector.py
  3. 6 6
      paddlers/transforms/__init__.py
  4. 2 3
      paddlers/transforms/operators.py

+ 2 - 2
docs/apis/data.md

@@ -133,13 +133,13 @@
 |参数名称|类型|参数说明|默认值|
 |-------|----|--------|-----|
 |`im_path`|`str`|输入图像路径。||
-|`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
+|`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。该参数已废弃,在将来可能被移除,请尽可能避免使用。|`True`|
 |`to_uint8`|`bool`|若为`True`,则将读取的影像数据量化并转换为uint8类型。|`True`|
 |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
 |`decode_sar`|`bool`|若为`True`,则自动将单通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
 |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
 |`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`|
-|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb``to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
+|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`为`True`而`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
 
 返回格式如下:
 

+ 2 - 0
paddlers/tasks/change_detector.py

@@ -629,6 +629,8 @@ class BaseChangeDetector(BaseModel):
             if isinstance(im1, str) or isinstance(im2, str):
                 im1 = decode_image(im1, read_raw=True)
                 im2 = decode_image(im2, read_raw=True)
+                np.save('im1_whole.npy', im1)
+                np.save('im2_whole.npy', im2)
             ori_shape = im1.shape[:2]
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             sample = {'image': im1, 'image2': im2}

+ 6 - 6
paddlers/transforms/__init__.py

@@ -33,8 +33,8 @@ def decode_image(im_path,
     
     Args:
         im_path (str): Path of the image to decode.
-        to_rgb (bool, optional): If True, convert input image(s) from BGR format to 
-            RGB format. Defaults to True.
+        to_rgb (bool, optional): (Deprecated) If True, convert input image(s) from 
+            BGR format to RGB format. Defaults to True.
         to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to 
             uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo 
@@ -46,9 +46,9 @@ def decode_image(im_path,
             the image. Deafults to False.
         use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if 
             `to_uint8` is True. Defaults to False.
-        read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8`
-            to False. Setting `read_raw` takes precedence over setting `to_rgb` and 
-            `to_uint8`. Defaults to False.
+        read_raw (bool, optional): If True, equivalent to setting `to_rgb` to `True` and
+            `to_uint8` to False. Setting `read_raw` takes precedence over setting `to_rgb` 
+            and `to_uint8`. Defaults to False.
     
     Returns:
         np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. 
@@ -61,7 +61,7 @@ def decode_image(im_path,
     if not osp.exists(im_path):
         raise ValueError(f"{im_path} does not exist!")
     if read_raw:
-        to_rgb = False
+        to_rgb = True
         to_uint8 = False
     decoder = T.DecodeImg(
         to_rgb=to_rgb,

+ 2 - 3
paddlers/transforms/operators.py

@@ -257,6 +257,8 @@ class DecodeImg(Transform):
             else:
                 im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                                      cv2.IMREAD_ANYCOLOR)
+            if self.to_rgb and im_data.shape[-1] == 3:
+                im_data = cv2.cvtColor(im_data, cv2.COLOR_BGR2RGB)
         elif ext == '.npy':
             im_data = np.load(img_path)
         else:
@@ -282,9 +284,6 @@ class DecodeImg(Transform):
         else:
             image = im_path
 
-        if self.to_rgb and image.shape[-1] == 3:
-            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-
         if self.to_uint8:
             image = F.to_uint8(image, stretch=self.use_stretch)