export_model.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. from ast import literal_eval
  17. from paddlers.tasks import load_model
  18. def get_parser():
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument('--model_dir', '-m', type=str, default=None, help='model directory path')
  21. parser.add_argument('--save_dir', '-s', type=str, default=None, help='path to save inference model')
  22. parser.add_argument('--fixed_input_shape', '-fs', type=str, default=None,
  23. help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
  24. return parser
  25. if __name__ == '__main__':
  26. parser = get_parser()
  27. args = parser.parse_args()
  28. # Get input shape
  29. fixed_input_shape = None
  30. if args.fixed_input_shape is not None:
  31. # Try to interpret the string as a list.
  32. fixed_input_shape = literal_eval(args.fixed_input_shape)
  33. # Check validaty
  34. if not isinstance(fixed_input_shape, list):
  35. raise ValueError("fixed_input_shape should be of None or list type.")
  36. if len(fixed_input_shape) not in (2, 4):
  37. raise ValueError("fixed_input_shape contains an incorrect number of elements.")
  38. if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
  39. raise ValueError("the input width and height must be positive integers.")
  40. if len(fixed_input_shape)==4 and fixed_input_shape[1] <= 0:
  41. raise ValueError("the number of input channels must be a positive integer.")
  42. # Set environment variables
  43. os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
  44. os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
  45. # Load model from directory
  46. model = load_model(args.model_dir)
  47. # Do dynamic-to-static cast
  48. # XXX: Invoke a protected (single underscore) method outside of subclasses.
  49. model._export_inference_model(args.save_dir, fixed_input_shape)