test_model.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import paddle
  16. import numpy as np
  17. from paddle.static import InputSpec
  18. from testing_utils import CommonTest
  19. class _TestModelNamespace:
  20. class TestModel(CommonTest):
  21. MODEL_CLASS = None
  22. DEFAULT_HW = (256, 256)
  23. DEFAULT_BATCH_SIZE = 2
  24. def setUp(self):
  25. self.set_specs()
  26. self.set_inputs()
  27. self.set_targets()
  28. self.set_models()
  29. def test_forward(self):
  30. for i, (
  31. input, model, target
  32. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  33. with self.subTest(i=i):
  34. output = model(input)
  35. self.check_output(output, target)
  36. def test_to_static(self):
  37. for i, (
  38. input, model, target
  39. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  40. with self.subTest(i=i):
  41. static_model = self.convert_to_static(model, input)
  42. output = static_model(output)
  43. self.check_output(output, target)
  44. def check_output(self, output, target):
  45. pass
  46. def set_specs(self):
  47. self.specs = []
  48. def set_models(self):
  49. self.models = (self.build_model(spec) for spec in self.specs)
  50. def set_inputs(self):
  51. self.inputs = []
  52. def set_targets(self):
  53. self.targets = []
  54. def build_model(self, spec):
  55. if '_phase' in spec:
  56. phase = spec.pop('_phase')
  57. else:
  58. phase = 'train'
  59. if '_stop_grad' in spec:
  60. stop_grad = spec.pop('_stop_grad')
  61. else:
  62. stop_grad = False
  63. model = self.MODEL_CLASS(**spec)
  64. if phase == 'train':
  65. model.train()
  66. elif phase == 'eval':
  67. model.eval()
  68. if stop_grad:
  69. for p in model.parameters():
  70. p.stop_gradient = True
  71. return model
  72. def get_shape(self, c, b=None, h=None, w=None):
  73. if h is None or w is None:
  74. h, w = self.DEFAULT_HW
  75. if b is None:
  76. b = self.DEFAULT_BATCH_SIZE
  77. return (b, c, h, w)
  78. def get_zeros_array(self, c, b=None, h=None, w=None):
  79. shape = self.get_shape(c, b, h, w)
  80. return np.zeros(shape)
  81. def get_randn_tensor(self, c, b=None, h=None, w=None):
  82. shape = self.get_shape(c, b, h, w)
  83. return paddle.randn(shape)
  84. def get_input_spec(self, model, input):
  85. if not isinstance(input, list):
  86. input = [input]
  87. input_spec = []
  88. for param_name, tensor in zip(
  89. inspect.signature(model.forward).parameters, input):
  90. # XXX: Hard-code dtype
  91. input_spec.append(
  92. InputSpec(
  93. shape=tensor.shape, name=param_name, dtype='float32'))
  94. return input_spec
  95. def convert_to_static(self, model, input):
  96. return paddle.jit.to_static(
  97. model, input_spec=self.get_input_spec(model, input))
  98. TestModel = _TestModelNamespace.TestModel