浏览代码

Merge pull request #33 from Bobholamovic/recalib_as_op

[Feat] Add MatchRadiance Op
cc 2 年之前
父节点
当前提交
4203fcde83
共有 4 个文件被更改,包括 92 次插入26 次删除
  1. 1 0
      docs/intro/transforms.md
  2. 38 3
      paddlers/transforms/operators.py
  3. 28 6
      tests/transforms/test_functions.py
  4. 25 17
      tests/transforms/test_operators.py

+ 1 - 0
docs/intro/transforms.md

@@ -8,6 +8,7 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为
 | -------------------- | ------------------------------------------------- | -------- | ---- |
 | CenterCrop           | 对输入影像进行中心裁剪。 | 所有任务 | ... |
 | Dehaze               | 对输入图像进行去雾。 | 所有任务 | ... |
+| MatchRadiance        | 对两个时相的输入影像进行相对辐射校正。 | 变化检测 | ... |
 | MixupImage           | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本。 | 目标检测 | ... |
 | Normalize            | 对输入影像应用标准化。 | 所有任务 | ... |
 | Pad                  | 将输入影像填充到指定的大小。 | 所有任务 | ... |

+ 38 - 3
paddlers/transforms/operators.py

@@ -35,7 +35,8 @@ from .functions import (
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
     vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
     resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
-    img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape)
+    img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape,
+    match_by_regression, match_histograms)
 
 __all__ = [
     "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
@@ -43,8 +44,9 @@ __all__ = [
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
     "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
-    "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
-    "ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask"
+    "RandomFlipOrRotate", "ReloadMask", "MatchRadiance", "ArrangeSegmenter",
+    "ArrangeChangeDetector", "ArrangeClassifier", "ArrangeDetector",
+    "ArrangeRestorer"
 ]
 
 interp_dict = {
@@ -1928,6 +1930,39 @@ class ReloadMask(Transform):
         return sample
 
 
+class MatchRadiance(Transform):
+    """
+    Perform relative radiometric correction between bi-temporal images.
+
+    Args:
+        method (str, optional): Method used to match the radiance of the
+            bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands
+            for histogram matching and 'lsr' stands for least-squares 
+            regression. Default: 'hist'.
+    """
+
+    def __init__(self, method='hist'):
+        super(MatchRadiance, self).__init__()
+
+        if method == 'hist':
+            self._match_func = match_histograms
+        elif method == 'lsr':
+            self._match_func = match_by_regression
+        else:
+            raise ValueError(
+                "{} is not a supported radiometric correction method.".format(
+                    method))
+
+        self.method = method
+
+    def apply(self, sample):
+        if 'image2' not in sample:
+            raise ValueError("'image2' is not found in the sample.")
+
+        sample['image2'] = self._match_func(sample['image2'], sample['image'])
+        return sample
+
+
 class Arrange(Transform):
     def __init__(self, mode):
         super().__init__()

+ 28 - 6
tests/transforms/test_functions.py

@@ -14,6 +14,8 @@
 
 import copy
 
+import numpy as np
+
 import paddlers.transforms as T
 from testing_utils import CpuCommonTest
 from data import build_input_from_file
@@ -21,6 +23,12 @@ from data import build_input_from_file
 __all__ = ['TestMatchHistograms', 'TestMatchByRegression']
 
 
+def calc_err(a, b):
+    a = a.astype('float64')
+    b = b.astype('float64')
+    return np.abs(a - b).mean()
+
+
 class TestMatchHistograms(CpuCommonTest):
     def setUp(self):
         self.inputs = [
@@ -33,12 +41,16 @@ class TestMatchHistograms(CpuCommonTest):
         for input in copy.deepcopy(self.inputs):
             for sample in input:
                 sample = decoder(sample)
-                im_out = T.functions.match_histograms(sample['image'],
-                                                      sample['image2'])
-                self.check_output_equal(im_out.shape, sample['image2'].shape)
+
                 im_out = T.functions.match_histograms(sample['image2'],
                                                       sample['image'])
+                self.check_output_equal(im_out.shape, sample['image2'].shape)
+                self.assertEqual(im_out.dtype, sample['image2'].dtype)
+
+                im_out = T.functions.match_histograms(sample['image'],
+                                                      sample['image2'])
                 self.check_output_equal(im_out.shape, sample['image'].shape)
+                self.assertEqual(im_out.dtype, sample['image'].dtype)
 
 
 class TestMatchByRegression(CpuCommonTest):
@@ -53,9 +65,19 @@ class TestMatchByRegression(CpuCommonTest):
         for input in copy.deepcopy(self.inputs):
             for sample in input:
                 sample = decoder(sample)
-                im_out = T.functions.match_by_regression(sample['image'],
-                                                         sample['image2'])
-                self.check_output_equal(im_out.shape, sample['image2'].shape)
+
                 im_out = T.functions.match_by_regression(sample['image2'],
                                                          sample['image'])
+                self.check_output_equal(im_out.shape, sample['image2'].shape)
+                self.assertEqual(im_out.dtype, sample['image2'].dtype)
+                err1 = calc_err(sample['image'], sample['image2'])
+                err2 = calc_err(sample['image'], im_out)
+
+                self.assertLessEqual(err2, err1)
+                im_out = T.functions.match_by_regression(sample['image'],
+                                                         sample['image2'])
                 self.check_output_equal(im_out.shape, sample['image'].shape)
+                self.assertEqual(im_out.dtype, sample['image'].dtype)
+                err1 = calc_err(sample['image'], sample['image2'])
+                err2 = calc_err(im_out, sample['image2'])
+                self.assertLessEqual(err2, err1)

+ 25 - 17
tests/transforms/test_operators.py

@@ -54,30 +54,30 @@ def _add_op_tests(cls):
                 filter_ = OP2FILTER.get(op_name, None)
                 setattr(
                     cls, attr_name, make_test_func(
-                        op_class, filter_=filter_))
+                        op_class, _filter=filter_))
     return cls
 
 
 def make_test_func(op_class,
                    *args,
-                   in_hook=None,
-                   out_hook=None,
-                   filter_=None,
+                   _in_hook=None,
+                   _out_hook=None,
+                   _filter=None,
                    **kwargs):
     def _test_func(self):
         op = op_class(*args, **kwargs)
         decoder = T.DecodeImg()
         inputs = map(decoder, copy.deepcopy(self.inputs))
         for i, input_ in enumerate(inputs):
-            if filter_ is not None:
-                input_ = filter_(input_)
+            if _filter is not None:
+                input_ = _filter(input_)
             with self.subTest(i=i):
                 for sample in input_:
-                    if in_hook:
-                        sample = in_hook(sample)
+                    if _in_hook:
+                        sample = _in_hook(sample)
                     sample = op(sample)
-                    if out_hook:
-                        sample = out_hook(sample)
+                    if _out_hook:
+                        sample = _out_hook(sample)
 
     return _test_func
 
@@ -308,15 +308,15 @@ class TestTransform(CpuCommonTest):
 
         test_func_not_keep_ratio = make_test_func(
             T.Resize,
-            in_hook=_in_hook,
-            out_hook=_out_hook_not_keep_ratio,
+            _in_hook=_in_hook,
+            _out_hook=_out_hook_not_keep_ratio,
             target_size=TARGET_SIZE,
             keep_ratio=False)
         test_func_not_keep_ratio(self)
         test_func_keep_ratio = make_test_func(
             T.Resize,
-            in_hook=_in_hook,
-            out_hook=_out_hook_keep_ratio,
+            _in_hook=_in_hook,
+            _out_hook=_out_hook_keep_ratio,
             target_size=TARGET_SIZE,
             keep_ratio=True)
         test_func_keep_ratio(self)
@@ -345,11 +345,19 @@ class TestTransform(CpuCommonTest):
 
         test_func = make_test_func(
             T.RandomFlipOrRotate,
-            in_hook=_in_hook,
-            out_hook=_out_hook,
-            filter_=_filter_no_det)
+            _in_hook=_in_hook,
+            _out_hook=_out_hook,
+            _filter=_filter_no_det)
         test_func(self)
 
+    def test_MatchRadiance(self):
+        test_hist = make_test_func(
+            T.MatchRadiance, 'hist', _filter=_filter_only_mt)
+        test_hist(self)
+        test_lsr = make_test_func(
+            T.MatchRadiance, 'lsr', _filter=_filter_only_mt)
+        test_lsr(self)
+
 
 class TestCompose(CpuCommonTest):
     pass