test_functions.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import paddlers.transforms as T
  16. from testing_utils import CpuCommonTest
  17. from data import build_input_from_file
  18. __all__ = ['TestMatchHistograms', 'TestMatchByRegression']
  19. def calc_err(a, b):
  20. return (a - b).abs().mean()
  21. class TestMatchHistograms(CpuCommonTest):
  22. def setUp(self):
  23. self.inputs = [
  24. build_input_from_file(
  25. "data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt")
  26. ]
  27. def test_output_shape(self):
  28. decoder = T.DecodeImg()
  29. for input in copy.deepcopy(self.inputs):
  30. for sample in input:
  31. sample = decoder(sample)
  32. im_out = T.functions.match_histograms(sample['image'],
  33. sample['image2'])
  34. self.check_output_equal(im_out.shape, sample['image2'].shape)
  35. self.assertEqual(im_out.dtype, sample['image2'].dtype)
  36. im_out = T.functions.match_histograms(sample['image2'],
  37. sample['image'])
  38. self.check_output_equal(im_out.shape, sample['image'].shape)
  39. self.assertEqual(im_out.dtype, sample['image'].dtype)
  40. class TestMatchByRegression(CpuCommonTest):
  41. def setUp(self):
  42. self.inputs = [
  43. build_input_from_file(
  44. "data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt")
  45. ]
  46. def test_output_shape(self):
  47. decoder = T.DecodeImg()
  48. for input in copy.deepcopy(self.inputs):
  49. for sample in input:
  50. sample = decoder(sample)
  51. im_out = T.functions.match_by_regression(sample['image'],
  52. sample['image2'])
  53. self.check_output_equal(im_out.shape, sample['image2'].shape)
  54. self.assertEqual(im_out.dtype, sample['image2'].dtype)
  55. err1 = calc_err(sample['image'], sample['image2'])
  56. err2 = calc_err(sample['image'], im_out)
  57. self.assertLessEqual(err2, err1)
  58. im_out = T.functions.match_by_regression(sample['image2'],
  59. sample['image'])
  60. self.check_output_equal(im_out.shape, sample['image'].shape)
  61. self.assertEqual(im_out.dtype, sample['image'].dtype)
  62. err1 = calc_err(sample['image'], sample['image2'])
  63. err2 = calc_err(im_out, sample['image2'])
  64. self.assertLessEqual(err2, err1)