Browse Source

Fix bugs in unittests

Bobholamovic 2 years ago
parent
commit
587b4451ac
2 changed files with 22 additions and 13 deletions
  1. 17 9
      tests/transforms/test_functions.py
  2. 5 4
      tests/transforms/test_operators.py

+ 17 - 9
tests/transforms/test_functions.py

@@ -14,6 +14,8 @@
 
 import copy
 
+import numpy as np
+
 import paddlers.transforms as T
 from testing_utils import CpuCommonTest
 from data import build_input_from_file
@@ -22,7 +24,9 @@ __all__ = ['TestMatchHistograms', 'TestMatchByRegression']
 
 
 def calc_err(a, b):
-    return (a - b).abs().mean()
+    a = a.astype('float64')
+    b = b.astype('float64')
+    return np.abs(a - b).mean()
 
 
 class TestMatchHistograms(CpuCommonTest):
@@ -37,12 +41,14 @@ class TestMatchHistograms(CpuCommonTest):
         for input in copy.deepcopy(self.inputs):
             for sample in input:
                 sample = decoder(sample)
-                im_out = T.functions.match_histograms(sample['image'],
-                                                      sample['image2'])
-                self.check_output_equal(im_out.shape, sample['image2'].shape)
-                self.assertEqual(im_out.dtype, sample['image2'].dtype)
+
                 im_out = T.functions.match_histograms(sample['image2'],
                                                       sample['image'])
+                self.check_output_equal(im_out.shape, sample['image2'].shape)
+                self.assertEqual(im_out.dtype, sample['image2'].dtype)
+
+                im_out = T.functions.match_histograms(sample['image'],
+                                                      sample['image2'])
                 self.check_output_equal(im_out.shape, sample['image'].shape)
                 self.assertEqual(im_out.dtype, sample['image'].dtype)
 
@@ -59,15 +65,17 @@ class TestMatchByRegression(CpuCommonTest):
         for input in copy.deepcopy(self.inputs):
             for sample in input:
                 sample = decoder(sample)
-                im_out = T.functions.match_by_regression(sample['image'],
-                                                         sample['image2'])
+
+                im_out = T.functions.match_by_regression(sample['image2'],
+                                                         sample['image'])
                 self.check_output_equal(im_out.shape, sample['image2'].shape)
                 self.assertEqual(im_out.dtype, sample['image2'].dtype)
                 err1 = calc_err(sample['image'], sample['image2'])
                 err2 = calc_err(sample['image'], im_out)
+
                 self.assertLessEqual(err2, err1)
-                im_out = T.functions.match_by_regression(sample['image2'],
-                                                         sample['image'])
+                im_out = T.functions.match_by_regression(sample['image'],
+                                                         sample['image2'])
                 self.check_output_equal(im_out.shape, sample['image'].shape)
                 self.assertEqual(im_out.dtype, sample['image'].dtype)
                 err1 = calc_err(sample['image'], sample['image2'])

+ 5 - 4
tests/transforms/test_operators.py

@@ -145,8 +145,7 @@ OP2FILTER = {
     'SelectBand': _filter_no_sar,
     'Dehaze': _filter_only_optical,
     'Normalize': _filter_only_optical,
-    'RandomDistort': _filter_only_optical,
-    'MatchRadiance': _filter_only_mt 
+    'RandomDistort': _filter_only_optical
 }
 
 
@@ -352,9 +351,11 @@ class TestTransform(CpuCommonTest):
         test_func(self)
 
     def test_MatchRadiance(self):
-        test_hist = make_test_func(T.MatchRadiance, 'hist')
+        test_hist = make_test_func(
+            T.MatchRadiance, 'hist', _filter=_filter_only_mt)
         test_hist(self)
-        test_lsr = make_test_func(T.MatchRadiance, 'lsr')
+        test_lsr = make_test_func(
+            T.MatchRadiance, 'lsr', _filter=_filter_only_mt)
         test_lsr(self)