test_model.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import paddle
  16. import numpy as np
  17. from paddle.static import InputSpec
  18. from testing_utils import CommonTest
  19. class _TestModelNamespace:
  20. class TestModel(CommonTest):
  21. MODEL_CLASS = None
  22. DEFAULT_HW = (256, 256)
  23. DEFAULT_BATCH_SIZE = 2
  24. def setUp(self):
  25. self.set_specs()
  26. self.set_inputs()
  27. self.set_targets()
  28. self.set_models()
  29. def test_forward(self):
  30. for i, (
  31. input, model, target
  32. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  33. with self.subTest(i=i):
  34. if isinstance(input, list):
  35. output = model(*input)
  36. else:
  37. output = model(input)
  38. self.check_output(output, target)
  39. def test_to_static(self):
  40. for i, (
  41. input, model, target
  42. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  43. with self.subTest(i=i):
  44. static_model = paddle.jit.to_static(
  45. model, input_spec=self.get_input_spec(model, input))
  46. def check_output(self, output, target):
  47. pass
  48. def set_specs(self):
  49. self.specs = []
  50. def set_models(self):
  51. self.models = (self.build_model(spec) for spec in self.specs)
  52. def set_inputs(self):
  53. self.inputs = []
  54. def set_targets(self):
  55. self.targets = []
  56. def build_model(self, spec):
  57. if '_phase' in spec:
  58. phase = spec.pop('_phase')
  59. else:
  60. phase = 'train'
  61. if '_stop_grad' in spec:
  62. stop_grad = spec.pop('_stop_grad')
  63. else:
  64. stop_grad = False
  65. model = self.MODEL_CLASS(**spec)
  66. if phase == 'train':
  67. model.train()
  68. elif phase == 'eval':
  69. model.eval()
  70. if stop_grad:
  71. for p in model.parameters():
  72. p.stop_gradient = True
  73. return model
  74. def get_shape(self, c, b=None, h=None, w=None):
  75. if h is None or w is None:
  76. h, w = self.DEFAULT_HW
  77. if b is None:
  78. b = self.DEFAULT_BATCH_SIZE
  79. return (b, c, h, w)
  80. def get_zeros_array(self, c, b=None, h=None, w=None):
  81. shape = self.get_shape(c, b, h, w)
  82. return np.zeros(shape)
  83. def get_randn_tensor(self, c, b=None, h=None, w=None):
  84. shape = self.get_shape(c, b, h, w)
  85. return paddle.randn(shape)
  86. def get_input_spec(self, model, input):
  87. if not isinstance(input, list):
  88. input = [input]
  89. input_spec = []
  90. for param_name, tensor in zip(
  91. inspect.signature(model.forward).parameters, input):
  92. # XXX: Hard-code dtype
  93. input_spec.append(
  94. InputSpec(
  95. shape=tensor.shape, name=param_name, dtype='float32'))
  96. return input_spec
  97. TestModel = _TestModelNamespace.TestModel