123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Based on https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/common_test.py
- import unittest
- import warnings
- import numpy as np
- import paddle
- __all__ = ['CommonTest', 'CpuCommonTest']
- # Assume all elements has same data type
- def get_container_type(container):
- container_t = type(container)
- if container_t in [list, tuple]:
- if len(container) == 0:
- return container_t
- return get_container_type(container[0])
- return container_t
- class _CommonTestNamespace:
- # Wrap the subclasses of unittest.TestCase that are expected to be inherited from.
- class CommonTest(unittest.TestCase):
- CATCH_WARNINGS = False
- def __init__(self, methodName='runTest'):
- super(CommonTest, self).__init__(methodName=methodName)
- self.config = {}
- self.places = ['cpu']
- if paddle.is_compiled_with_cuda():
- self.places.append('gpu')
- @classmethod
- def setUpClass(cls):
- '''
- Set the decorators for all test function
- '''
- for key, value in cls.__dict__.items():
- if key.startswith('test'):
- decorator_func_list = ["_test_places"]
- if cls.CATCH_WARNINGS:
- decorator_func_list.append("_catch_warnings")
- for decorator_func in decorator_func_list:
- decorator_func = getattr(CommonTest, decorator_func)
- value = decorator_func(value)
- setattr(cls, key, value)
- def _catch_warnings(func):
- '''
- Catch the warnings and treat them as errors for each test.
- '''
- def wrapper(self, *args, **kwargs):
- with warnings.catch_warnings(record=True) as w:
- warnings.resetwarnings()
- # ignore specified warnings
- warning_white_list = [UserWarning]
- for warning in warning_white_list:
- warnings.simplefilter("ignore", warning)
- func(self, *args, **kwargs)
- msg = None if len(w) == 0 else w[0].message
- self.assertFalse(len(w) > 0, msg)
- return wrapper
- def _test_places(func):
- '''
- Setting the running place for each test.
- '''
- def wrapper(self, *args, **kwargs):
- places = self.places
- for place in places:
- paddle.set_device(place)
- func(self, *args, **kwargs)
- return wrapper
- def _check_output_impl(self,
- result,
- expected_result,
- rtol,
- atol,
- equal=True):
- assertForNormalType = self.assertNotEqual
- assertForFloat = self.assertFalse
- if equal:
- assertForNormalType = self.assertEqual
- assertForFloat = self.assertTrue
- result_t = type(result)
- error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}'
- if result_t in [list, tuple]:
- result_t = get_container_type(result)
- if result_t in [
- str, int, bool, set, np.bool, np.int32, np.int64, np.str
- ]:
- assertForNormalType(
- result,
- expected_result,
- msg=error_msg.format(paddle.get_device(), expected_result,
- result, self.__class__.__name__))
- elif result_t in [float, np.ndarray, np.float32, np.float64]:
- assertForFloat(
- np.allclose(
- result, expected_result, rtol=rtol, atol=atol),
- msg=error_msg.format(paddle.get_device(), expected_result,
- result, self.__class__.__name__))
- if result_t == np.ndarray:
- assertForNormalType(
- result.shape,
- expected_result.shape,
- msg=error_msg.format(
- paddle.get_device(), expected_result.shape,
- result.shape, self.__class__.__name__))
- else:
- raise ValueError(
- 'result type must be str, int, bool, set, np.bool, np.int32, '
- 'np.int64, np.str, float, np.ndarray, np.float32, np.float64'
- )
- def check_output_equal(self,
- result,
- expected_result,
- rtol=1.e-5,
- atol=1.e-8):
- '''
- Check whether result and expected result are equal, including shape.
- Args:
- result: str, int, bool, set, np.ndarray.
- The result needs to be checked.
- expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
- Use the expected result to check result.
- rtol: float
- relative tolerance, default 1.e-5.
- atol: float
- absolute tolerance, default 1.e-8
- '''
- self._check_output_impl(result, expected_result, rtol, atol)
- def check_output_not_equal(self,
- result,
- expected_result,
- rtol=1.e-5,
- atol=1.e-8):
- '''
- Check whether result and expected result are not equal, including shape.
- Args:
- result: str, int, bool, set, np.ndarray.
- The result needs to be checked.
- expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
- Use the expected result to check result.
- rtol: float
- relative tolerance, default 1.e-5.
- atol: float
- absolute tolerance, default 1.e-8
- '''
- self._check_output_impl(
- result, expected_result, rtol, atol, equal=False)
- class CpuCommonTest(CommonTest):
- def __init__(self, methodName='runTest'):
- super(CpuCommonTest, self).__init__(methodName=methodName)
- self.places = ['cpu']
- CommonTest = _CommonTestNamespace.CommonTest
- CpuCommonTest = _CommonTestNamespace.CpuCommonTest
|