export_model.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. from ast import literal_eval
  17. from paddlers.tasks import load_model
  18. def get_parser():
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. '--model_dir',
  22. '-m',
  23. type=str,
  24. default=None,
  25. help='model directory path')
  26. parser.add_argument(
  27. '--save_dir',
  28. '-s',
  29. type=str,
  30. default=None,
  31. help='path to save inference model')
  32. parser.add_argument(
  33. '--fixed_input_shape',
  34. '-fs',
  35. type=str,
  36. default=None,
  37. help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
  38. return parser
  39. if __name__ == '__main__':
  40. parser = get_parser()
  41. args = parser.parse_args()
  42. # Get input shape
  43. fixed_input_shape = None
  44. if args.fixed_input_shape is not None:
  45. # Try to interpret the string as a list.
  46. fixed_input_shape = literal_eval(args.fixed_input_shape)
  47. # Check validaty
  48. if not isinstance(fixed_input_shape, list):
  49. raise ValueError(
  50. "fixed_input_shape should be of None or list type.")
  51. if len(fixed_input_shape) not in (2, 4):
  52. raise ValueError(
  53. "fixed_input_shape contains an incorrect number of elements.")
  54. if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
  55. raise ValueError(
  56. "Input width and height must be positive integers.")
  57. if len(fixed_input_shape) == 4 and fixed_input_shape[1] <= 0:
  58. raise ValueError(
  59. "The number of input channels must be a positive integer.")
  60. # Set environment variables
  61. os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
  62. os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
  63. # Load model from directory
  64. model = load_model(args.model_dir)
  65. # Do dynamic-to-static cast
  66. model.export_inference_model(args.save_dir, fixed_input_shape)