瀏覽代碼

[Fix] Store params when saving training models (#85)

* [Fix] Store params when saving training models

* [Feat] Suppress GDAL warnings

* [Fix] Use raw arguments
Lin Manhui 3 年之前
父節點
當前提交
24515fb78a

+ 48 - 15
paddlers/tasks/base.py

@@ -14,29 +14,55 @@
 
 
 import os
 import os
 import os.path as osp
 import os.path as osp
-from functools import partial
 import time
 import time
 import copy
 import copy
 import math
 import math
-import yaml
 import json
 import json
+from functools import partial, wraps
+from inspect import signature
+
+import yaml
 import paddle
 import paddle
 from paddle.io import DataLoader, DistributedBatchSampler
 from paddle.io import DataLoader, DistributedBatchSampler
 from paddleslim import QAT
 from paddleslim import QAT
 from paddleslim.analysis import flops
 from paddleslim.analysis import flops
 from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 from paddleslim import L1NormFilterPruner, FPGMFilterPruner
+
 import paddlers
 import paddlers
+import paddlers.utils.logging as logging
 from paddlers.transforms import arrange_transforms
 from paddlers.transforms import arrange_transforms
 from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
 from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
                             get_pretrain_weights, load_pretrain_weights,
                             get_pretrain_weights, load_pretrain_weights,
                             load_checkpoint, SmoothedValue, TrainingStats,
                             load_checkpoint, SmoothedValue, TrainingStats,
                             _get_shared_memory_size_in_M, EarlyStop)
                             _get_shared_memory_size_in_M, EarlyStop)
-import paddlers.utils.logging as logging
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 from .utils.infer_nets import InferNet, InferCDNet
 from .utils.infer_nets import InferNet, InferCDNet
 
 
 
 
-class BaseModel:
+class ModelMeta(type):
+    def __new__(cls, name, bases, attrs):
+        def _deco(init_func):
+            @wraps(init_func)
+            def _wrapper(self, *args, **kwargs):
+                if hasattr(self, '_raw_params'):
+                    ret = init_func(self, *args, **kwargs)
+                else:
+                    sig = signature(init_func)
+                    bnd_args = sig.bind(self, *args, **kwargs)
+                    raw_params = bnd_args.arguments
+                    raw_params.pop('self')
+                    self._raw_params = raw_params
+                    ret = init_func(self, *args, **kwargs)
+                return ret
+
+            return _wrapper
+
+        old_init_func = attrs['__init__']
+        attrs['__init__'] = _deco(old_init_func)
+        return type.__new__(cls, name, bases, attrs)
+
+
+class BaseModel(metaclass=ModelMeta):
     def __init__(self, model_type):
     def __init__(self, model_type):
         self.model_type = model_type
         self.model_type = model_type
         self.in_channels = None
         self.in_channels = None
@@ -126,7 +152,11 @@ class BaseModel:
                 model_name=self.model_name,
                 model_name=self.model_name,
                 checkpoint=resume_checkpoint)
                 checkpoint=resume_checkpoint)
 
 
-    def get_model_info(self):
+    def get_model_info(self, get_raw_params=False, inplace=True):
+        if inplace:
+            init_params = self.init_params
+        else:
+            init_params = copy.deepcopy(self.init_params)
         info = dict()
         info = dict()
         info['version'] = paddlers.__version__
         info['version'] = paddlers.__version__
         info['Model'] = self.__class__.__name__
         info['Model'] = self.__class__.__name__
@@ -136,16 +166,19 @@ class BaseModel:
              ('fixed_input_shape', self.fixed_input_shape),
              ('fixed_input_shape', self.fixed_input_shape),
              ('best_accuracy', self.best_accuracy),
              ('best_accuracy', self.best_accuracy),
              ('best_model_epoch', self.best_model_epoch)])
              ('best_model_epoch', self.best_model_epoch)])
-        if 'self' in self.init_params:
-            del self.init_params['self']
-        if '__class__' in self.init_params:
-            del self.init_params['__class__']
-        if 'model_name' in self.init_params:
-            del self.init_params['model_name']
-        if 'params' in self.init_params:
-            del self.init_params['params']
+        if 'self' in init_params:
+            del init_params['self']
+        if '__class__' in init_params:
+            del init_params['__class__']
+        if 'model_name' in init_params:
+            del init_params['model_name']
+        if 'params' in init_params:
+            del init_params['params']
+
+        info['_init_params'] = init_params
 
 
-        info['_init_params'] = self.init_params
+        if get_raw_params:
+            info['raw_params'] = self._raw_params
 
 
         try:
         try:
             primary_metric_key = list(self.eval_metrics.keys())[0]
             primary_metric_key = list(self.eval_metrics.keys())[0]
@@ -189,7 +222,7 @@ class BaseModel:
             if osp.exists(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)
                 os.remove(save_dir)
             os.makedirs(save_dir)
             os.makedirs(save_dir)
-        model_info = self.get_model_info()
+        model_info = self.get_model_info(get_raw_params=True)
         model_info['status'] = self.status
         model_info['status'] = self.status
 
 
         paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
         paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))

+ 21 - 13
paddlers/tasks/load_model.py

@@ -48,27 +48,31 @@ def load_rcnn_inference_model(model_dir):
 def load_model(model_dir, **params):
 def load_model(model_dir, **params):
     """
     """
     Load saved model from a given directory.
     Load saved model from a given directory.
+
     Args:
     Args:
         model_dir(str): The directory where the model is saved.
         model_dir(str): The directory where the model is saved.
+
     Returns:
     Returns:
         The model loaded from the directory.
         The model loaded from the directory.
     """
     """
+
     if not osp.exists(model_dir):
     if not osp.exists(model_dir):
-        logging.error("model_dir '{}' does not exists!".format(model_dir))
+        logging.error("Directory '{}' does not exist!".format(model_dir))
     if not osp.exists(osp.join(model_dir, "model.yml")):
     if not osp.exists(osp.join(model_dir, "model.yml")):
-        raise Exception("There's no model.yml in {}".format(model_dir))
+        raise Exception("There is no file named model.yml in {}.".format(
+            model_dir))
+
     with open(osp.join(model_dir, "model.yml")) as f:
     with open(osp.join(model_dir, "model.yml")) as f:
         model_info = yaml.load(f.read(), Loader=yaml.Loader)
         model_info = yaml.load(f.read(), Loader=yaml.Loader)
-    f.close()
 
 
     status = model_info['status']
     status = model_info['status']
     with_net = params.get('with_net', True)
     with_net = params.get('with_net', True)
     if not with_net:
     if not with_net:
         assert status == 'Infer', \
         assert status == 'Infer', \
-            "Only exported inference models can be deployed, current model status is {}".format(status)
+            "Only exported models can be deployed for inference, but current model status is {}.".format(status)
 
 
     if not hasattr(paddlers.tasks, model_info['Model']):
     if not hasattr(paddlers.tasks, model_info['Model']):
-        raise Exception("There's no attribute {} in paddlers.tasks".format(
+        raise Exception("There is no {} attribute in paddlers.tasks.".format(
             model_info['Model']))
             model_info['Model']))
     if 'model_name' in model_info['_init_params']:
     if 'model_name' in model_info['_init_params']:
         del model_info['_init_params']['model_name']
         del model_info['_init_params']['model_name']
@@ -76,8 +80,9 @@ def load_model(model_dir, **params):
     model_info['_init_params'].update({'with_net': with_net})
     model_info['_init_params'].update({'with_net': with_net})
 
 
     with paddle.utils.unique_name.guard():
     with paddle.utils.unique_name.guard():
-        model = getattr(paddlers.tasks, model_info['Model'])(
-            **model_info['_init_params'])
+        params = model_info.pop('raw_params', {})
+        params.update(model_info['_init_params'])
+        model = getattr(paddlers.tasks, model_info['Model'])(**params)
         if with_net:
         if with_net:
             if status == 'Pruned' or osp.exists(
             if status == 'Pruned' or osp.exists(
                     osp.join(model_dir, "prune.yml")):
                     osp.join(model_dir, "prune.yml")):
@@ -108,18 +113,19 @@ def load_model(model_dir, **params):
             if status == 'Infer':
             if status == 'Infer':
                 if osp.exists(osp.join(model_dir, "quant.yml")):
                 if osp.exists(osp.join(model_dir, "quant.yml")):
                     logging.error(
                     logging.error(
-                        "Exported quantized model can not be loaded, only deployment is supported.",
+                        "Exported quantized model can not be loaded, because quant.yml is not found.",
                         exit=True)
                         exit=True)
                 model.net = model._build_inference_net()
                 model.net = model._build_inference_net()
                 if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
                 if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
                     net_state_dict = load_rcnn_inference_model(model_dir)
                     net_state_dict = load_rcnn_inference_model(model_dir)
                 else:
                 else:
                     net_state_dict = paddle.load(osp.join(model_dir, 'model'))
                     net_state_dict = paddle.load(osp.join(model_dir, 'model'))
-                    if model.model_type in ['classifier', 'segmenter'
-                                            ] and 'rc' in version:
-                        # When exporting a classifier and segmenter,
-                        # InferNet is defined to append softmax and argmax operators to the model,
-                        # so parameter name starts with 'net.'
+                    if model.model_type in [
+                            'classifier', 'segmenter', 'changedetector'
+                    ]:
+                        # When exporting a classifier, segmenter, or changedetector,
+                        # InferNet (or InferCDNet) is defined to append softmax and argmax operators to the model,
+                        # so the parameter names all start with 'net.'
                         new_net_state_dict = {}
                         new_net_state_dict = {}
                         for k, v in net_state_dict.items():
                         for k, v in net_state_dict.items():
                             new_net_state_dict['net.' + k] = v
                             new_net_state_dict['net.' + k] = v
@@ -137,6 +143,8 @@ def load_model(model_dir, **params):
         for k, v in model_info['_Attributes'].items():
         for k, v in model_info['_Attributes'].items():
             if k in model.__dict__:
             if k in model.__dict__:
                 model.__dict__[k] = v
                 model.__dict__[k] = v
+
     logging.info("Model[{}] loaded.".format(model_info['Model']))
     logging.info("Model[{}] loaded.".format(model_info['Model']))
     model.status = status
     model.status = status
+
     return model
     return model

+ 2 - 2
tutorials/train/semantic_segmentation/deeplabv3p.py

@@ -79,10 +79,10 @@ model.train(
     eval_dataset=eval_dataset,
     eval_dataset=eval_dataset,
     save_interval_epochs=5,
     save_interval_epochs=5,
     # 每多少次迭代记录一次日志
     # 每多少次迭代记录一次日志
-    log_interval_steps=50,
+    log_interval_steps=4,
     save_dir=EXP_DIR,
     save_dir=EXP_DIR,
     # 初始学习率大小
     # 初始学习率大小
-    learning_rate=0.01,
+    learning_rate=0.001,
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     early_stop=False,
     early_stop=False,
     # 是否启用VisualDL日志功能
     # 是否启用VisualDL日志功能

+ 60 - 0
tutorials/train/semantic_segmentation/run_with_clean_log.py

@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+
+import sys
+import subprocess
+from io import RawIOBase
+
+
+class StreamFilter(RawIOBase):
+    def __init__(self, conds, stream):
+        super().__init__()
+        self.conds = conds
+        self.stream = stream
+
+    def readinto(self, _):
+        pass
+
+    def write(self, msg):
+        if all(cond(msg) for cond in self.conds):
+            self.stream.write(msg)
+        else:
+            pass
+
+
+class CleanLog(object):
+    def __init__(self, filter_, stream_name):
+        self.filter = filter_
+        self.stream_name = stream_name
+        self.old_stream = getattr(sys, stream_name)
+
+    def __enter__(self):
+        setattr(sys, self.stream_name, self.filter)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        setattr(sys, self.stream_name, self.old_stream)
+
+
+if __name__ == '__main__':
+    if len(sys.argv) < 2:
+        raise TypeError("请指定需要运行的脚本!")
+
+    tar_file = sys.argv[1]
+    gdal_filter = StreamFilter([
+        lambda msg: "Sum of Photometric type-related color channels and ExtraSamples doesn't match SamplesPerPixel." not in msg
+    ], sys.stdout)
+    with CleanLog(gdal_filter, 'stdout'):
+        proc = subprocess.Popen(
+            ["python", tar_file],
+            stderr=subprocess.STDOUT,
+            stdout=subprocess.PIPE,
+            text=True)
+        while True:
+            try:
+                out_line = proc.stdout.readline()
+                if out_line == '' and proc.poll() is not None:
+                    break
+                if out_line:
+                    print(out_line, end='')
+            except KeyboardInterrupt:
+                import signal
+                proc.send_signal(signal.SIGINT)

+ 2 - 2
tutorials/train/semantic_segmentation/unet.py

@@ -77,10 +77,10 @@ model.train(
     eval_dataset=eval_dataset,
     eval_dataset=eval_dataset,
     save_interval_epochs=5,
     save_interval_epochs=5,
     # 每多少次迭代记录一次日志
     # 每多少次迭代记录一次日志
-    log_interval_steps=50,
+    log_interval_steps=4,
     save_dir=EXP_DIR,
     save_dir=EXP_DIR,
     # 初始学习率大小
     # 初始学习率大小
-    learning_rate=0.01,
+    learning_rate=0.001,
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     early_stop=False,
     early_stop=False,
     # 是否启用VisualDL日志功能
     # 是否启用VisualDL日志功能