export_model.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. from ast import literal_eval
  17. from paddlers.tasks import load_model
  18. def get_parser():
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. '--model_dir',
  22. '-m',
  23. type=str,
  24. default=None,
  25. help='model directory path')
  26. parser.add_argument(
  27. '--save_dir',
  28. '-s',
  29. type=str,
  30. default=None,
  31. help='path to save inference model')
  32. parser.add_argument(
  33. '--fixed_input_shape',
  34. '-fs',
  35. type=str,
  36. default=None,
  37. help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
  38. return parser
  39. if __name__ == '__main__':
  40. parser = get_parser()
  41. args = parser.parse_args()
  42. # Get input shape
  43. fixed_input_shape = None
  44. if args.fixed_input_shape is not None:
  45. # Try to interpret the string as a list.
  46. fixed_input_shape = literal_eval(args.fixed_input_shape)
  47. # Check validaty
  48. if not isinstance(fixed_input_shape, list):
  49. raise ValueError(
  50. "fixed_input_shape should be of None or list type.")
  51. if len(fixed_input_shape) not in (2, 4):
  52. raise ValueError(
  53. "fixed_input_shape contains an incorrect number of elements.")
  54. if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
  55. raise ValueError(
  56. "Input width and height must be positive integers.")
  57. if len(fixed_input_shape) == 4 and fixed_input_shape[1] <= 0:
  58. raise ValueError(
  59. "The number of input channels must be a positive integer.")
  60. # Set environment variables
  61. os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
  62. os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
  63. # Load model from directory
  64. model = load_model(args.model_dir)
  65. # Do dynamic-to-static cast
  66. # XXX: Invoke a protected (single underscore) method outside of subclasses.
  67. model._export_inference_model(args.save_dir, fixed_input_shape)