load_model.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import yaml
  16. import numpy as np
  17. import paddle
  18. import paddleslim
  19. import paddlers
  20. import paddlers.utils.logging as logging
  21. from paddlers.transforms import build_transforms
  22. def load_rcnn_inference_model(model_dir):
  23. paddle.enable_static()
  24. exe = paddle.static.Executor(paddle.CPUPlace())
  25. path_prefix = osp.join(model_dir, "model")
  26. prog, _, _ = paddle.static.load_inference_model(path_prefix, exe)
  27. paddle.disable_static()
  28. extra_var_info = paddle.load(osp.join(model_dir, "model.pdiparams.info"))
  29. net_state_dict = dict()
  30. static_state_dict = dict()
  31. for name, var in prog.state_dict().items():
  32. static_state_dict[name] = np.array(var)
  33. for var_name in static_state_dict:
  34. if var_name not in extra_var_info:
  35. continue
  36. structured_name = extra_var_info[var_name].get('structured_name', None)
  37. if structured_name is None:
  38. continue
  39. net_state_dict[structured_name] = static_state_dict[var_name]
  40. return net_state_dict
  41. def load_model(model_dir, **params):
  42. """
  43. Load saved model from a given directory.
  44. Args:
  45. model_dir(str): The directory where the model is saved.
  46. Returns:
  47. The model loaded from the directory.
  48. """
  49. if not osp.exists(model_dir):
  50. logging.error("Directory '{}' does not exist!".format(model_dir))
  51. if not osp.exists(osp.join(model_dir, "model.yml")):
  52. raise Exception("There is no file named model.yml in {}.".format(
  53. model_dir))
  54. with open(osp.join(model_dir, "model.yml")) as f:
  55. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  56. status = model_info['status']
  57. with_net = params.get('with_net', True)
  58. if not with_net:
  59. assert status == 'Infer', \
  60. "Only exported models can be deployed for inference, but current model status is {}.".format(status)
  61. model_type = model_info['_Attributes']['model_type']
  62. mod = getattr(paddlers.tasks, model_type)
  63. if not hasattr(mod, model_info['Model']):
  64. raise Exception("There is no {} attribute in {}.".format(model_info[
  65. 'Model'], mod))
  66. if 'model_name' in model_info['_init_params']:
  67. del model_info['_init_params']['model_name']
  68. model_info['_init_params'].update({'with_net': with_net})
  69. with paddle.utils.unique_name.guard():
  70. if 'raw_params' not in model_info:
  71. logging.warning(
  72. "Cannot find raw_params. Default arguments will be used to construct the model."
  73. )
  74. params = model_info.pop('raw_params', {})
  75. params.update(model_info['_init_params'])
  76. model = getattr(mod, model_info['Model'])(**params)
  77. if with_net:
  78. if status == 'Pruned' or osp.exists(
  79. osp.join(model_dir, "prune.yml")):
  80. with open(osp.join(model_dir, "prune.yml")) as f:
  81. pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
  82. inputs = pruning_info['pruner_inputs']
  83. if model.model_type == 'detector':
  84. inputs = [{
  85. k: paddle.to_tensor(v)
  86. for k, v in inputs.items()
  87. }]
  88. model.net.eval()
  89. model.pruner = getattr(paddleslim, pruning_info['pruner'])(
  90. model.net, inputs=inputs)
  91. model.pruning_ratios = pruning_info['pruning_ratios']
  92. model.pruner.prune_vars(
  93. ratios=model.pruning_ratios,
  94. axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
  95. if status == 'Quantized' or osp.exists(
  96. osp.join(model_dir, "quant.yml")):
  97. with open(osp.join(model_dir, "quant.yml")) as f:
  98. quant_info = yaml.load(f.read(), Loader=yaml.Loader)
  99. model.quant_config = quant_info['quant_config']
  100. model.quantizer = paddleslim.QAT(model.quant_config)
  101. model.quantizer.quantize(model.net)
  102. if status == 'Infer':
  103. if osp.exists(osp.join(model_dir, "quant.yml")):
  104. logging.error(
  105. "Exported quantized model can not be loaded, because quant.yml is not found.",
  106. exit=True)
  107. model.net = model._build_inference_net()
  108. if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
  109. net_state_dict = load_rcnn_inference_model(model_dir)
  110. else:
  111. net_state_dict = paddle.load(osp.join(model_dir, 'model'))
  112. if model.model_type in [
  113. 'classifier', 'segmenter', 'change_detector'
  114. ]:
  115. # When exporting a classifier, segmenter, or change_detector,
  116. # InferNet (or InferCDNet) is defined to append softmax and argmax operators to the model,
  117. # so the parameter names all start with 'net.'
  118. new_net_state_dict = {}
  119. for k, v in net_state_dict.items():
  120. new_net_state_dict['net.' + k] = v
  121. net_state_dict = new_net_state_dict
  122. else:
  123. net_state_dict = paddle.load(
  124. osp.join(model_dir, 'model.pdparams'))
  125. model.net.set_state_dict(net_state_dict)
  126. if 'Transforms' in model_info:
  127. model.test_transforms = build_transforms(model_info['Transforms'])
  128. if '_Attributes' in model_info:
  129. for k, v in model_info['_Attributes'].items():
  130. if k in model.__dict__:
  131. model.__dict__[k] = v
  132. logging.info("Model[{}] loaded.".format(model_info['Model']))
  133. model.status = status
  134. return model