|  | @@ -14,29 +14,55 @@
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import os
 |  |  import os
 | 
											
												
													
														|  |  import os.path as osp
 |  |  import os.path as osp
 | 
											
												
													
														|  | -from functools import partial
 |  | 
 | 
											
												
													
														|  |  import time
 |  |  import time
 | 
											
												
													
														|  |  import copy
 |  |  import copy
 | 
											
												
													
														|  |  import math
 |  |  import math
 | 
											
												
													
														|  | -import yaml
 |  | 
 | 
											
												
													
														|  |  import json
 |  |  import json
 | 
											
												
													
														|  | 
 |  | +from functools import partial, wraps
 | 
											
												
													
														|  | 
 |  | +from inspect import signature
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +import yaml
 | 
											
												
													
														|  |  import paddle
 |  |  import paddle
 | 
											
												
													
														|  |  from paddle.io import DataLoader, DistributedBatchSampler
 |  |  from paddle.io import DataLoader, DistributedBatchSampler
 | 
											
												
													
														|  |  from paddleslim import QAT
 |  |  from paddleslim import QAT
 | 
											
												
													
														|  |  from paddleslim.analysis import flops
 |  |  from paddleslim.analysis import flops
 | 
											
												
													
														|  |  from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 |  |  from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  import paddlers
 |  |  import paddlers
 | 
											
												
													
														|  | 
 |  | +import paddlers.utils.logging as logging
 | 
											
												
													
														|  |  from paddlers.transforms import arrange_transforms
 |  |  from paddlers.transforms import arrange_transforms
 | 
											
												
													
														|  |  from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
 |  |  from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
 | 
											
												
													
														|  |                              get_pretrain_weights, load_pretrain_weights,
 |  |                              get_pretrain_weights, load_pretrain_weights,
 | 
											
												
													
														|  |                              load_checkpoint, SmoothedValue, TrainingStats,
 |  |                              load_checkpoint, SmoothedValue, TrainingStats,
 | 
											
												
													
														|  |                              _get_shared_memory_size_in_M, EarlyStop)
 |  |                              _get_shared_memory_size_in_M, EarlyStop)
 | 
											
												
													
														|  | -import paddlers.utils.logging as logging
 |  | 
 | 
											
												
													
														|  |  from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 |  |  from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 | 
											
												
													
														|  |  from .utils.infer_nets import InferNet, InferCDNet
 |  |  from .utils.infer_nets import InferNet, InferCDNet
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -class BaseModel:
 |  | 
 | 
											
												
													
														|  | 
 |  | +class ModelMeta(type):
 | 
											
												
													
														|  | 
 |  | +    def __new__(cls, name, bases, attrs):
 | 
											
												
													
														|  | 
 |  | +        def _deco(init_func):
 | 
											
												
													
														|  | 
 |  | +            @wraps(init_func)
 | 
											
												
													
														|  | 
 |  | +            def _wrapper(self, *args, **kwargs):
 | 
											
												
													
														|  | 
 |  | +                if hasattr(self, '_raw_params'):
 | 
											
												
													
														|  | 
 |  | +                    ret = init_func(self, *args, **kwargs)
 | 
											
												
													
														|  | 
 |  | +                else:
 | 
											
												
													
														|  | 
 |  | +                    sig = signature(init_func)
 | 
											
												
													
														|  | 
 |  | +                    bnd_args = sig.bind(self, *args, **kwargs)
 | 
											
												
													
														|  | 
 |  | +                    raw_params = bnd_args.arguments
 | 
											
												
													
														|  | 
 |  | +                    raw_params.pop('self')
 | 
											
												
													
														|  | 
 |  | +                    self._raw_params = raw_params
 | 
											
												
													
														|  | 
 |  | +                    ret = init_func(self, *args, **kwargs)
 | 
											
												
													
														|  | 
 |  | +                return ret
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            return _wrapper
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        old_init_func = attrs['__init__']
 | 
											
												
													
														|  | 
 |  | +        attrs['__init__'] = _deco(old_init_func)
 | 
											
												
													
														|  | 
 |  | +        return type.__new__(cls, name, bases, attrs)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class BaseModel(metaclass=ModelMeta):
 | 
											
												
													
														|  |      def __init__(self, model_type):
 |  |      def __init__(self, model_type):
 | 
											
												
													
														|  |          self.model_type = model_type
 |  |          self.model_type = model_type
 | 
											
												
													
														|  |          self.in_channels = None
 |  |          self.in_channels = None
 | 
											
										
											
												
													
														|  | @@ -126,7 +152,11 @@ class BaseModel:
 | 
											
												
													
														|  |                  model_name=self.model_name,
 |  |                  model_name=self.model_name,
 | 
											
												
													
														|  |                  checkpoint=resume_checkpoint)
 |  |                  checkpoint=resume_checkpoint)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def get_model_info(self):
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def get_model_info(self, get_raw_params=False, inplace=True):
 | 
											
												
													
														|  | 
 |  | +        if inplace:
 | 
											
												
													
														|  | 
 |  | +            init_params = self.init_params
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            init_params = copy.deepcopy(self.init_params)
 | 
											
												
													
														|  |          info = dict()
 |  |          info = dict()
 | 
											
												
													
														|  |          info['version'] = paddlers.__version__
 |  |          info['version'] = paddlers.__version__
 | 
											
												
													
														|  |          info['Model'] = self.__class__.__name__
 |  |          info['Model'] = self.__class__.__name__
 | 
											
										
											
												
													
														|  | @@ -136,16 +166,19 @@ class BaseModel:
 | 
											
												
													
														|  |               ('fixed_input_shape', self.fixed_input_shape),
 |  |               ('fixed_input_shape', self.fixed_input_shape),
 | 
											
												
													
														|  |               ('best_accuracy', self.best_accuracy),
 |  |               ('best_accuracy', self.best_accuracy),
 | 
											
												
													
														|  |               ('best_model_epoch', self.best_model_epoch)])
 |  |               ('best_model_epoch', self.best_model_epoch)])
 | 
											
												
													
														|  | -        if 'self' in self.init_params:
 |  | 
 | 
											
												
													
														|  | -            del self.init_params['self']
 |  | 
 | 
											
												
													
														|  | -        if '__class__' in self.init_params:
 |  | 
 | 
											
												
													
														|  | -            del self.init_params['__class__']
 |  | 
 | 
											
												
													
														|  | -        if 'model_name' in self.init_params:
 |  | 
 | 
											
												
													
														|  | -            del self.init_params['model_name']
 |  | 
 | 
											
												
													
														|  | -        if 'params' in self.init_params:
 |  | 
 | 
											
												
													
														|  | -            del self.init_params['params']
 |  | 
 | 
											
												
													
														|  | 
 |  | +        if 'self' in init_params:
 | 
											
												
													
														|  | 
 |  | +            del init_params['self']
 | 
											
												
													
														|  | 
 |  | +        if '__class__' in init_params:
 | 
											
												
													
														|  | 
 |  | +            del init_params['__class__']
 | 
											
												
													
														|  | 
 |  | +        if 'model_name' in init_params:
 | 
											
												
													
														|  | 
 |  | +            del init_params['model_name']
 | 
											
												
													
														|  | 
 |  | +        if 'params' in init_params:
 | 
											
												
													
														|  | 
 |  | +            del init_params['params']
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        info['_init_params'] = init_params
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        info['_init_params'] = self.init_params
 |  | 
 | 
											
												
													
														|  | 
 |  | +        if get_raw_params:
 | 
											
												
													
														|  | 
 |  | +            info['raw_params'] = self._raw_params
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          try:
 |  |          try:
 | 
											
												
													
														|  |              primary_metric_key = list(self.eval_metrics.keys())[0]
 |  |              primary_metric_key = list(self.eval_metrics.keys())[0]
 | 
											
										
											
												
													
														|  | @@ -189,7 +222,7 @@ class BaseModel:
 | 
											
												
													
														|  |              if osp.exists(save_dir):
 |  |              if osp.exists(save_dir):
 | 
											
												
													
														|  |                  os.remove(save_dir)
 |  |                  os.remove(save_dir)
 | 
											
												
													
														|  |              os.makedirs(save_dir)
 |  |              os.makedirs(save_dir)
 | 
											
												
													
														|  | -        model_info = self.get_model_info()
 |  | 
 | 
											
												
													
														|  | 
 |  | +        model_info = self.get_model_info(get_raw_params=True)
 | 
											
												
													
														|  |          model_info['status'] = self.status
 |  |          model_info['status'] = self.status
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
 |  |          paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
 |