# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Based on https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/common_test.py

import unittest
import warnings
import subprocess

import numpy as np
import paddle

__all__ = ['CommonTest', 'CpuCommonTest', 'run_script']


def run_script(cmd, silent=True, wd=None, timeout=None, echo=True):
    # XXX: This function is not safe!!!
    cfg = dict(check=True, shell=True, timeout=timeout)
    if silent:
        cfg['stdout'] = subprocess.DEVNULL
    if wd is not None:
        cmd = f"cd {wd} && {cmd}"
    if echo:
        print(cmd)
    return subprocess.run(cmd, **cfg)


# Assume all elements has same data type
def get_container_type(container):
    container_t = type(container)
    if container_t in [list, tuple]:
        if len(container) == 0:
            return container_t
        return get_container_type(container[0])
    return container_t


class _CommonTestNamespace:
    # Wrap the subclasses of unittest.TestCase that are expected to be inherited from.
    class CommonTest(unittest.TestCase):
        CATCH_WARNINGS = False

        def __init__(self, methodName='runTest'):
            super(CommonTest, self).__init__(methodName=methodName)
            self.config = {}
            self.places = ['cpu']
            if paddle.is_compiled_with_cuda():
                self.places.append('gpu')

        @classmethod
        def setUpClass(cls):
            """
            Set the decorators for all test function
            """

            for key, value in cls.__dict__.items():
                if key.startswith('test'):
                    decorator_func_list = ["_test_places"]
                    if cls.CATCH_WARNINGS:
                        decorator_func_list.append("_catch_warnings")
                    for decorator_func in decorator_func_list:
                        decorator_func = getattr(CommonTest, decorator_func)
                        value = decorator_func(value)
                    setattr(cls, key, value)

        def _catch_warnings(func):
            """
            Catch the warnings and treat them as errors for each test.
            """

            def wrapper(self, *args, **kwargs):
                with warnings.catch_warnings(record=True) as w:
                    warnings.resetwarnings()
                    # Ignore specified warnings
                    warning_white_list = [UserWarning]
                    for warning in warning_white_list:
                        warnings.simplefilter("ignore", warning)
                    func(self, *args, **kwargs)
                    msg = None if len(w) == 0 else w[0].message
                    self.assertFalse(len(w) > 0, msg)

            return wrapper

        def _test_places(func):
            """
            Setting the running place for each test.
            """

            def wrapper(self, *args, **kwargs):
                places = self.places
                for place in places:
                    paddle.set_device(place)
                    func(self, *args, **kwargs)

            return wrapper

        def _check_output_impl(self,
                               result,
                               expected_result,
                               rtol,
                               atol,
                               equal=True):
            assertForNormalType = self.assertNotEqual
            assertForFloat = self.assertFalse
            if equal:
                assertForNormalType = self.assertEqual
                assertForFloat = self.assertTrue

            result_t = type(result)
            error_msg = "Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}"
            if result_t in [list, tuple]:
                result_t = get_container_type(result)
            if result_t in [
                    str, int, bool, set, np.bool, np.int32, np.int64, np.str
            ]:
                assertForNormalType(
                    result,
                    expected_result,
                    msg=error_msg.format(paddle.get_device(), expected_result,
                                         result, self.__class__.__name__))
            elif result_t in [float, np.ndarray, np.float32, np.float64]:
                assertForFloat(
                    np.allclose(
                        result, expected_result, rtol=rtol, atol=atol),
                    msg=error_msg.format(paddle.get_device(), expected_result,
                                         result, self.__class__.__name__))
                if result_t == np.ndarray:
                    assertForNormalType(
                        result.shape,
                        expected_result.shape,
                        msg=error_msg.format(
                            paddle.get_device(), expected_result.shape,
                            result.shape, self.__class__.__name__))
            else:
                raise ValueError(
                    "result type must be str, int, bool, set, np.bool, np.int32, "
                    "np.int64, np.str, float, np.ndarray, np.float32, np.float64"
                )

        def check_output_equal(self,
                               result,
                               expected_result,
                               rtol=1.e-5,
                               atol=1.e-8):
            """
            Check whether result and expected result are equal, including shape. 
            
            Args:
                result (str|int|bool|set|np.ndarray):
                    The result needs to be checked.
                expected_result (str|int|bool|set|np.ndarray): The type has to be same as
                    result's. Use the expected result to check result.
                rtol (float, optional):
                    relative tolerance, default 1.e-5.
                atol (float, optional):
                    absolute tolerance, default 1.e-8
            """

            self._check_output_impl(result, expected_result, rtol, atol)

        def check_output_not_equal(self,
                                   result,
                                   expected_result,
                                   rtol=1.e-5,
                                   atol=1.e-8):
            """
            Check whether result and expected result are not equal, including shape. 

            Args:
                result (str|int|bool|set|np.ndarray):
                    The result needs to be checked.
                expected_result (str|int|bool|set|np.ndarray): The type has to be same 
                    as result's. Use the expected result to check result.
                rtol (float, optional):
                    relative tolerance, default 1.e-5.
                atol (float, optional):
                    absolute tolerance, default 1.e-8
            """

            self._check_output_impl(
                result, expected_result, rtol, atol, equal=False)

    class CpuCommonTest(CommonTest):
        def __init__(self, methodName='runTest'):
            super(CpuCommonTest, self).__init__(methodName=methodName)
            self.places = ['cpu']


CommonTest = _CommonTestNamespace.CommonTest
CpuCommonTest = _CommonTestNamespace.CpuCommonTest