|
@@ -14,6 +14,8 @@
|
|
|
|
|
|
import copy
|
|
|
|
|
|
+import numpy as np
|
|
|
+
|
|
|
import paddlers.transforms as T
|
|
|
from testing_utils import CpuCommonTest
|
|
|
from data import build_input_from_file
|
|
@@ -22,7 +24,9 @@ __all__ = ['TestMatchHistograms', 'TestMatchByRegression']
|
|
|
|
|
|
|
|
|
def calc_err(a, b):
|
|
|
- return (a - b).abs().mean()
|
|
|
+ a = a.astype('float64')
|
|
|
+ b = b.astype('float64')
|
|
|
+ return np.abs(a - b).mean()
|
|
|
|
|
|
|
|
|
class TestMatchHistograms(CpuCommonTest):
|
|
@@ -37,12 +41,14 @@ class TestMatchHistograms(CpuCommonTest):
|
|
|
for input in copy.deepcopy(self.inputs):
|
|
|
for sample in input:
|
|
|
sample = decoder(sample)
|
|
|
- im_out = T.functions.match_histograms(sample['image'],
|
|
|
- sample['image2'])
|
|
|
- self.check_output_equal(im_out.shape, sample['image2'].shape)
|
|
|
- self.assertEqual(im_out.dtype, sample['image2'].dtype)
|
|
|
+
|
|
|
im_out = T.functions.match_histograms(sample['image2'],
|
|
|
sample['image'])
|
|
|
+ self.check_output_equal(im_out.shape, sample['image2'].shape)
|
|
|
+ self.assertEqual(im_out.dtype, sample['image2'].dtype)
|
|
|
+
|
|
|
+ im_out = T.functions.match_histograms(sample['image'],
|
|
|
+ sample['image2'])
|
|
|
self.check_output_equal(im_out.shape, sample['image'].shape)
|
|
|
self.assertEqual(im_out.dtype, sample['image'].dtype)
|
|
|
|
|
@@ -59,15 +65,17 @@ class TestMatchByRegression(CpuCommonTest):
|
|
|
for input in copy.deepcopy(self.inputs):
|
|
|
for sample in input:
|
|
|
sample = decoder(sample)
|
|
|
- im_out = T.functions.match_by_regression(sample['image'],
|
|
|
- sample['image2'])
|
|
|
+
|
|
|
+ im_out = T.functions.match_by_regression(sample['image2'],
|
|
|
+ sample['image'])
|
|
|
self.check_output_equal(im_out.shape, sample['image2'].shape)
|
|
|
self.assertEqual(im_out.dtype, sample['image2'].dtype)
|
|
|
err1 = calc_err(sample['image'], sample['image2'])
|
|
|
err2 = calc_err(sample['image'], im_out)
|
|
|
+
|
|
|
self.assertLessEqual(err2, err1)
|
|
|
- im_out = T.functions.match_by_regression(sample['image2'],
|
|
|
- sample['image'])
|
|
|
+ im_out = T.functions.match_by_regression(sample['image'],
|
|
|
+ sample['image2'])
|
|
|
self.check_output_equal(im_out.shape, sample['image'].shape)
|
|
|
self.assertEqual(im_out.dtype, sample['image'].dtype)
|
|
|
err1 = calc_err(sample['image'], sample['image2'])
|