test_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import paddle
  16. import numpy as np
  17. from paddle.static import InputSpec
  18. from paddlers.utils import logging
  19. from testing_utils import CommonTest
  20. class _TestModelNamespace:
  21. class TestModel(CommonTest):
  22. MODEL_CLASS = None
  23. DEFAULT_HW = (256, 256)
  24. DEFAULT_BATCH_SIZE = 2
  25. def setUp(self):
  26. self.set_specs()
  27. self.set_inputs()
  28. self.set_targets()
  29. self.set_models()
  30. def test_forward(self):
  31. for i, (
  32. input, model, target
  33. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  34. try:
  35. if isinstance(input, list):
  36. output = model(*input)
  37. else:
  38. output = model(input)
  39. self.check_output(output, target)
  40. except:
  41. logging.warning(f"Model built with spec{i} failed!")
  42. raise
  43. def test_to_static(self):
  44. for i, (
  45. input, model, target
  46. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  47. try:
  48. static_model = paddle.jit.to_static(
  49. model, input_spec=self.get_input_spec(model, input))
  50. except:
  51. logging.warning(f"Model built with spec{i} failed!")
  52. raise
  53. def check_output(self, output, target):
  54. pass
  55. def set_specs(self):
  56. self.specs = []
  57. def set_models(self):
  58. self.models = (self.build_model(spec) for spec in self.specs)
  59. def set_inputs(self):
  60. self.inputs = []
  61. def set_targets(self):
  62. self.targets = []
  63. def build_model(self, spec):
  64. if '_phase' in spec:
  65. phase = spec.pop('_phase')
  66. else:
  67. phase = 'train'
  68. if '_stop_grad' in spec:
  69. stop_grad = spec.pop('_stop_grad')
  70. else:
  71. stop_grad = False
  72. model = self.MODEL_CLASS(**spec)
  73. if phase == 'train':
  74. model.train()
  75. elif phase == 'eval':
  76. model.eval()
  77. if stop_grad:
  78. for p in model.parameters():
  79. p.stop_gradient = True
  80. return model
  81. def get_shape(self, c, b=None, h=None, w=None):
  82. if h is None or w is None:
  83. h, w = self.DEFAULT_HW
  84. if b is None:
  85. b = self.DEFAULT_BATCH_SIZE
  86. return (b, c, h, w)
  87. def get_zeros_array(self, c, b=None, h=None, w=None):
  88. shape = self.get_shape(c, b, h, w)
  89. return np.zeros(shape)
  90. def get_randn_tensor(self, c, b=None, h=None, w=None):
  91. shape = self.get_shape(c, b, h, w)
  92. return paddle.randn(shape)
  93. def get_input_spec(self, model, input):
  94. if not isinstance(input, list):
  95. input = [input]
  96. input_spec = []
  97. for param_name, tensor in zip(
  98. inspect.signature(model.forward).parameters, input):
  99. # XXX: Hard-code dtype
  100. input_spec.append(
  101. InputSpec(
  102. shape=tensor.shape, name=param_name, dtype='float32'))
  103. return input_spec
  104. def allow_oom(cls):
  105. def _deco(func):
  106. def _wrapper(self, *args, **kwargs):
  107. try:
  108. func(self, *args, **kwargs)
  109. except (SystemError, RuntimeError, OSError, MemoryError) as e:
  110. # XXX: This may not cover all OOM cases.
  111. msg = str(e)
  112. if "Out of memory error" in msg \
  113. or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg \
  114. or isinstance(e, MemoryError):
  115. logging.warning("An OOM error has been ignored.")
  116. else:
  117. raise
  118. return _wrapper
  119. for key, value in inspect.getmembers(cls):
  120. if key.startswith('test'):
  121. value = _deco(value)
  122. setattr(cls, key, value)
  123. return cls
  124. TestModel = _TestModelNamespace.TestModel