|
@@ -73,9 +73,11 @@ def load_model(model_dir, **params):
|
|
|
assert status == 'Infer', \
|
|
|
"Only exported models can be deployed for inference, but current model status is {}.".format(status)
|
|
|
|
|
|
- if not hasattr(paddlers.tasks, model_info['Model']):
|
|
|
- raise Exception("There is no {} attribute in paddlers.tasks.".format(
|
|
|
- model_info['Model']))
|
|
|
+ model_type = model_info['_Attributes']['model_type']
|
|
|
+ mod = getattr(paddlers.tasks, model_type)
|
|
|
+ if not hasattr(mod, model_info['Model']):
|
|
|
+ raise Exception("There is no {} attribute in {}.".format(model_info[
|
|
|
+ 'Model'], mod))
|
|
|
if 'model_name' in model_info['_init_params']:
|
|
|
del model_info['_init_params']['model_name']
|
|
|
|
|
@@ -88,7 +90,7 @@ def load_model(model_dir, **params):
|
|
|
)
|
|
|
params = model_info.pop('raw_params', {})
|
|
|
params.update(model_info['_init_params'])
|
|
|
- model = getattr(paddlers.tasks, model_info['Model'])(**params)
|
|
|
+ model = getattr(mod, model_info['Model'])(**params)
|
|
|
if with_net:
|
|
|
if status == 'Pruned' or osp.exists(
|
|
|
osp.join(model_dir, "prune.yml")):
|
|
@@ -127,9 +129,9 @@ def load_model(model_dir, **params):
|
|
|
else:
|
|
|
net_state_dict = paddle.load(osp.join(model_dir, 'model'))
|
|
|
if model.model_type in [
|
|
|
- 'classifier', 'segmenter', 'changedetector'
|
|
|
+ 'classifier', 'segmenter', 'change_detector'
|
|
|
]:
|
|
|
- # When exporting a classifier, segmenter, or changedetector,
|
|
|
+ # When exporting a classifier, segmenter, or change_detector,
|
|
|
# InferNet (or InferCDNet) is defined to append softmax and argmax operators to the model,
|
|
|
# so the parameter names all start with 'net.'
|
|
|
new_net_state_dict = {}
|