# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import paddle
import numpy as np
from paddle.static import InputSpec

from testing_utils import CommonTest


class _TestModelNamespace:
    class TestModel(CommonTest):
        MODEL_CLASS = None
        DEFAULT_HW = (256, 256)
        DEFAULT_BATCH_SIZE = 2

        def setUp(self):
            self.set_specs()
            self.set_inputs()
            self.set_targets()
            self.set_models()

        def test_forward(self):
            for i, (
                    input, model, target
            ) in enumerate(zip(self.inputs, self.models, self.targets)):
                with self.subTest(i=i):
                    if isinstance(input, list):
                        output = model(*input)
                    else:
                        output = model(input)
                    self.check_output(output, target)

        def test_to_static(self):
            for i, (
                    input, model, target
            ) in enumerate(zip(self.inputs, self.models, self.targets)):
                with self.subTest(i=i):
                    static_model = paddle.jit.to_static(
                        model, input_spec=self.get_input_spec(model, input))

        def check_output(self, output, target):
            pass

        def set_specs(self):
            self.specs = []

        def set_models(self):
            self.models = (self.build_model(spec) for spec in self.specs)

        def set_inputs(self):
            self.inputs = []

        def set_targets(self):
            self.targets = []

        def build_model(self, spec):
            if '_phase' in spec:
                phase = spec.pop('_phase')
            else:
                phase = 'train'
            if '_stop_grad' in spec:
                stop_grad = spec.pop('_stop_grad')
            else:
                stop_grad = False

            model = self.MODEL_CLASS(**spec)

            if phase == 'train':
                model.train()
            elif phase == 'eval':
                model.eval()
                if stop_grad:
                    for p in model.parameters():
                        p.stop_gradient = True

            return model

        def get_shape(self, c, b=None, h=None, w=None):
            if h is None or w is None:
                h, w = self.DEFAULT_HW
            if b is None:
                b = self.DEFAULT_BATCH_SIZE
            return (b, c, h, w)

        def get_zeros_array(self, c, b=None, h=None, w=None):
            shape = self.get_shape(c, b, h, w)
            return np.zeros(shape)

        def get_randn_tensor(self, c, b=None, h=None, w=None):
            shape = self.get_shape(c, b, h, w)
            return paddle.randn(shape)

        def get_input_spec(self, model, input):
            if not isinstance(input, list):
                input = [input]
            input_spec = []
            for param_name, tensor in zip(
                    inspect.signature(model.forward).parameters, input):
                # XXX: Hard-code dtype
                input_spec.append(
                    InputSpec(
                        shape=tensor.shape, name=param_name, dtype='float32'))
            return input_spec


TestModel = _TestModelNamespace.TestModel