test_model.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import numpy as np
  16. from test_utils import CommonTest
  17. class TestModel(CommonTest):
  18. MODEL_CLASS = None
  19. DEFAULT_HW = (256, 256)
  20. DEFAULT_BATCH_SIZE = 2
  21. def setUp(self):
  22. self.set_specs()
  23. self.set_inputs()
  24. self.set_targets()
  25. self.set_models()
  26. def test_forward(self):
  27. for i, (input, model, target
  28. ) in enumerate(zip(self.inputs, self.models, self.targets)):
  29. with self.subTest(i=i):
  30. output = model(input)
  31. self.check_output(output, target)
  32. def check_output(self, output, target):
  33. pass
  34. def set_specs(self):
  35. self.specs = []
  36. def set_models(self):
  37. self.models = (self.build_model(spec) for spec in self.specs)
  38. def set_inputs(self):
  39. self.inputs = []
  40. def set_targets(self):
  41. self.targets = []
  42. def build_model(self, spec):
  43. if '_phase' in spec:
  44. phase = spec.pop('_phase')
  45. else:
  46. phase = 'train'
  47. if '_stop_grad' in spec:
  48. stop_grad = spec.pop('_stop_grad')
  49. else:
  50. stop_grad = False
  51. model = self.MODEL_CLASS(**spec)
  52. if phase == 'train':
  53. model.train()
  54. elif phase == 'eval':
  55. model.eval()
  56. if stop_grad:
  57. for p in model.parameters():
  58. p.stop_gradient = True
  59. return model
  60. def get_shape(self, c, b=None, h=None, w=None):
  61. if h is None or w is None:
  62. h, w = self.DEFAULT_HW
  63. if b is None:
  64. b = self.DEFAULT_BATCH_SIZE
  65. return (b, c, h, w)
  66. def get_zeros_array(self, c, b=None, h=None, w=None):
  67. shape = self.get_shape(c, b, h, w)
  68. return np.zeros(shape)
  69. def get_randn_tensor(self, c, b=None, h=None, w=None):
  70. shape = self.get_shape(c, b, h, w)
  71. return paddle.randn(shape)