|
@@ -14,13 +14,14 @@
|
|
|
|
|
|
import os
|
|
|
import os.path as osp
|
|
|
-from functools import partial
|
|
|
import time
|
|
|
import copy
|
|
|
import math
|
|
|
-import yaml
|
|
|
import json
|
|
|
+from functools import partial, wraps
|
|
|
+from inspect import signature
|
|
|
|
|
|
+import yaml
|
|
|
import paddle
|
|
|
from paddle.io import DataLoader, DistributedBatchSampler
|
|
|
from paddleslim import QAT
|
|
@@ -28,17 +29,40 @@ from paddleslim.analysis import flops
|
|
|
from paddleslim import L1NormFilterPruner, FPGMFilterPruner
|
|
|
|
|
|
import paddlers
|
|
|
+import paddlers.utils.logging as logging
|
|
|
from paddlers.transforms import arrange_transforms
|
|
|
from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
|
|
|
get_pretrain_weights, load_pretrain_weights,
|
|
|
load_checkpoint, SmoothedValue, TrainingStats,
|
|
|
_get_shared_memory_size_in_M, EarlyStop)
|
|
|
-import paddlers.utils.logging as logging
|
|
|
from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
|
|
|
from .utils.infer_nets import InferNet, InferCDNet
|
|
|
|
|
|
|
|
|
-class BaseModel:
|
|
|
+class ModelMeta(type):
|
|
|
+ def __new__(cls, name, bases, attrs):
|
|
|
+ def _deco(init_func):
|
|
|
+ @wraps(init_func)
|
|
|
+ def _wrapper(self, *args, **kwargs):
|
|
|
+ if hasattr(self, '_raw_params'):
|
|
|
+ ret = init_func(self, *args, **kwargs)
|
|
|
+ else:
|
|
|
+ sig = signature(init_func)
|
|
|
+ bnd_args = sig.bind(self, *args, **kwargs)
|
|
|
+ raw_params = bnd_args.arguments
|
|
|
+ raw_params.pop('self')
|
|
|
+ self._raw_params = raw_params
|
|
|
+ ret = init_func(self, *args, **kwargs)
|
|
|
+ return ret
|
|
|
+
|
|
|
+ return _wrapper
|
|
|
+
|
|
|
+ old_init_func = attrs['__init__']
|
|
|
+ attrs['__init__'] = _deco(old_init_func)
|
|
|
+ return type.__new__(cls, name, bases, attrs)
|
|
|
+
|
|
|
+
|
|
|
+class BaseModel(metaclass=ModelMeta):
|
|
|
def __init__(self, model_type):
|
|
|
self.model_type = model_type
|
|
|
self.in_channels = None
|
|
@@ -128,7 +152,11 @@ class BaseModel:
|
|
|
model_name=self.model_name,
|
|
|
checkpoint=resume_checkpoint)
|
|
|
|
|
|
- def get_model_info(self):
|
|
|
+ def get_model_info(self, get_raw_params=False, inplace=True):
|
|
|
+ if inplace:
|
|
|
+ init_params = self.init_params
|
|
|
+ else:
|
|
|
+ init_params = copy.deepcopy(self.init_params)
|
|
|
info = dict()
|
|
|
info['version'] = paddlers.__version__
|
|
|
info['Model'] = self.__class__.__name__
|
|
@@ -138,16 +166,19 @@ class BaseModel:
|
|
|
('fixed_input_shape', self.fixed_input_shape),
|
|
|
('best_accuracy', self.best_accuracy),
|
|
|
('best_model_epoch', self.best_model_epoch)])
|
|
|
- if 'self' in self.init_params:
|
|
|
- del self.init_params['self']
|
|
|
- if '__class__' in self.init_params:
|
|
|
- del self.init_params['__class__']
|
|
|
- if 'model_name' in self.init_params:
|
|
|
- del self.init_params['model_name']
|
|
|
- if 'params' in self.init_params:
|
|
|
- del self.init_params['params']
|
|
|
+ if 'self' in init_params:
|
|
|
+ del init_params['self']
|
|
|
+ if '__class__' in init_params:
|
|
|
+ del init_params['__class__']
|
|
|
+ if 'model_name' in init_params:
|
|
|
+ del init_params['model_name']
|
|
|
+ if 'params' in init_params:
|
|
|
+ del init_params['params']
|
|
|
+
|
|
|
+ info['_init_params'] = init_params
|
|
|
|
|
|
- info['_init_params'] = self.init_params
|
|
|
+ if get_raw_params:
|
|
|
+ info['raw_params'] = self._raw_params
|
|
|
|
|
|
try:
|
|
|
primary_metric_key = list(self.eval_metrics.keys())[0]
|
|
@@ -191,7 +222,7 @@ class BaseModel:
|
|
|
if osp.exists(save_dir):
|
|
|
os.remove(save_dir)
|
|
|
os.makedirs(save_dir)
|
|
|
- model_info = self.get_model_info()
|
|
|
+ model_info = self.get_model_info(get_raw_params=True)
|
|
|
model_info['status'] = self.status
|
|
|
|
|
|
paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
|