Bobholamovic 2 ani în urmă
părinte
comite
116fb9be61

+ 15 - 3
paddlers/transforms/__init__.py

@@ -23,15 +23,27 @@ from paddlers import transforms as T
 def decode_image(im_path,
                  to_rgb=True,
                  to_uint8=True,
-                 decode_rgb=True,
-                 decode_sar=False):
+                 decode_bgr=True,
+                 decode_sar=True):
+    """
+    Decode an image.
+    
+    Args:
+        to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
+        to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
+        decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. 
+            Defaults to True.
+        decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a 
+            SAR image, set this argument to True. Defaults to True.
+    """
+
     # Do a presence check. `osp.exists` assumes `im_path` is a path-like object.
     if not osp.exists(im_path):
         raise ValueError(f"{im_path} does not exist!")
     decoder = T.DecodeImg(
         to_rgb=to_rgb,
         to_uint8=to_uint8,
-        decode_rgb=decode_rgb,
+        decode_bgr=decode_bgr,
         decode_sar=decode_sar)
     # Deepcopy to avoid inplace modification
     sample = {'image': copy.deepcopy(im_path)}

+ 9 - 11
paddlers/transforms/operators.py

@@ -126,19 +126,21 @@ class DecodeImg(Transform):
     Args:
         to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
         to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
-        decode_rgb (bool, optional): If the image to decode is a non-geo RGB image (e.g., jpeg images), set this argument to True. Defaults to True.
-        decode_sar (bool, optional): If the image to decode is a SAR image, set this argument to True. Defaults to False.
+        decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image. 
+            Defaults to True.
+        decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a 
+            SAR image, set this argument to True. Defaults to True.
     """
 
     def __init__(self,
                  to_rgb=True,
                  to_uint8=True,
-                 decode_rgb=True,
-                 decode_sar=False):
+                 decode_bgr=True,
+                 decode_sar=True):
         super(DecodeImg, self).__init__()
         self.to_rgb = to_rgb
         self.to_uint8 = to_uint8
-        self.decode_rgb = decode_rgb
+        self.decode_bgr = decode_bgr
         self.decode_sar = decode_sar
 
     def read_img(self, img_path):
@@ -159,11 +161,7 @@ class DecodeImg(Transform):
             if dataset == None:
                 raise IOError('Can not open', img_path)
             im_data = dataset.ReadAsArray()
-            if self.decode_sar:
-                if im_data.ndim != 2:
-                    raise ValueError(
-                        f"SAR images should have exactly 2 channels, but the image has {im_data.ndim} channels."
-                    )
+            if im_data.ndim == 2 and self.decode_sar:
                 im_data = to_intensity(im_data)  # is read SAR
                 im_data = im_data[:, :, np.newaxis]
             else:
@@ -171,7 +169,7 @@ class DecodeImg(Transform):
                     im_data = im_data.transpose((1, 2, 0))
             return im_data
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
-            if self.decode_rgb:
+            if self.decode_bgr:
                 return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                                   cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
             else:

+ 3 - 7
tests/deploy/test_predictor.py

@@ -100,13 +100,9 @@ class TestPredictor(CommonTest):
             for key in dict_.keys():
                 if key in ignore_keys:
                     continue
-                if isinstance(dict_[key], (list, dict)):
-                    self.check_dict_equal(
-                        dict_[key], expected_dict[key], ignore_keys=ignore_keys)
-                else:
-                    # Use higher tolerance
-                    self.check_output_equal(
-                        dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
+                # Use higher tolerance
+                self.check_output_equal(
+                    dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
 
 
 @TestPredictor.add_tests

+ 42 - 0
tests/transforms/test_operators.py

@@ -116,6 +116,18 @@ def _is_mt(sample):
     return 'image2' in sample
 
 
+def _is_seg(sample):
+    return 'mask' in sample and 'image2' not in sample
+
+
+def _is_det(sample):
+    return 'gt_bbox' in sample or 'gt_poly' in sample
+
+
+def _is_clas(sample):
+    return 'label' in sample
+
+
 _filter_only_optical = _InputFilter([_is_optical])
 _filter_only_sar = _InputFilter([_is_sar])
 _filter_only_multispectral = _InputFilter([_is_multispectral])
@@ -123,6 +135,7 @@ _filter_no_multispectral = _filter_only_optical | _filter_only_sar
 _filter_no_sar = _filter_only_optical | _filter_only_multispectral
 _filter_no_optical = _filter_only_sar | _filter_only_multispectral
 _filter_only_mt = _InputFilter([_is_mt])
+_filter_no_det = _InputFilter([_is_seg, _is_clas, _is_mt])
 
 OP2FILTER = {
     'RandomSwap': _filter_only_mt,
@@ -262,6 +275,35 @@ class TestTransform(CpuCommonTest):
             keep_ratio=True)
         test_func_keep_ratio(self)
 
+    def test_RandomFlipOrRotate(self):
+        def _in_hook(sample):
+            if 'image2' in sample:
+                self.im_diff = (
+                    sample['image'] - sample['image2']).astype('float64')
+            elif 'mask' in sample:
+                self.im_diff = (
+                    sample['image'][..., 0] - sample['mask']).astype('float64')
+            return sample
+
+        def _out_hook(sample):
+            im_diff = None
+            if 'image2' in sample:
+                im_diff = (sample['image'] - sample['image2']).astype('float64')
+            elif 'mask' in sample:
+                im_diff = (
+                    sample['image'][..., 0] - sample['mask']).astype('float64')
+            if im_diff is not None:
+                self.check_output_equal(im_diff.max(), self.im_diff.max())
+                self.check_output_equal(im_diff.min(), self.im_diff.min())
+            return sample
+
+        test_func = make_test_func(
+            T.RandomFlipOrRotate,
+            in_hook=_in_hook,
+            out_hook=_out_hook,
+            filter_=_filter_no_det)
+        test_func(self)
+
 
 class TestCompose(CpuCommonTest):
     pass