1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import os
- import argparse
- from ast import literal_eval
- from paddlers.tasks import load_model
- def get_parser():
- parser = argparse.ArgumentParser()
- parser.add_argument('--model_dir', '-m', type=str, default=None, help='model directory path')
- parser.add_argument('--save_dir', '-s', type=str, default=None, help='path to save inference model')
- parser.add_argument('--fixed_input_shape', '-fs', type=str, default=None,
- help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
- return parser
- if __name__ == '__main__':
- parser = get_parser()
- args = parser.parse_args()
-
- fixed_input_shape = None
- if args.fixed_input_shape is not None:
-
- fixed_input_shape = literal_eval(args.fixed_input_shape)
-
- if not isinstance(fixed_input_shape, list):
- raise ValueError("fixed_input_shape should be of None or list type.")
- if len(fixed_input_shape) not in (2, 4):
- raise ValueError("fixed_input_shape contains an incorrect number of elements.")
- if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
- raise ValueError("the input width and height must be positive integers.")
- if len(fixed_input_shape)==4 and fixed_input_shape[1] <= 0:
- raise ValueError("the number of input channels must be a positive integer.")
-
- os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
- os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
-
- model = load_model(args.model_dir)
-
-
- model._export_inference_model(args.save_dir, fixed_input_shape)
|