Bladeren bron

fix conflicts

michaelowenliu 2 jaren geleden
bovenliggende
commit
104b6150f8

+ 46 - 15
paddlers/tasks/base.py

@@ -14,13 +14,14 @@
 
 import os
 import os.path as osp
-from functools import partial
 import time
 import copy
 import math
-import yaml
 import json
+from functools import partial, wraps
+from inspect import signature
 
+import yaml
 import paddle
 from paddle.io import DataLoader, DistributedBatchSampler
 from paddleslim import QAT
@@ -28,17 +29,40 @@ from paddleslim.analysis import flops
 from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 
 import paddlers
+import paddlers.utils.logging as logging
 from paddlers.transforms import arrange_transforms
 from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
                             get_pretrain_weights, load_pretrain_weights,
                             load_checkpoint, SmoothedValue, TrainingStats,
                             _get_shared_memory_size_in_M, EarlyStop)
-import paddlers.utils.logging as logging
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
 from .utils.infer_nets import InferNet, InferCDNet
 
 
-class BaseModel:
+class ModelMeta(type):
+    def __new__(cls, name, bases, attrs):
+        def _deco(init_func):
+            @wraps(init_func)
+            def _wrapper(self, *args, **kwargs):
+                if hasattr(self, '_raw_params'):
+                    ret = init_func(self, *args, **kwargs)
+                else:
+                    sig = signature(init_func)
+                    bnd_args = sig.bind(self, *args, **kwargs)
+                    raw_params = bnd_args.arguments
+                    raw_params.pop('self')
+                    self._raw_params = raw_params
+                    ret = init_func(self, *args, **kwargs)
+                return ret
+
+            return _wrapper
+
+        old_init_func = attrs['__init__']
+        attrs['__init__'] = _deco(old_init_func)
+        return type.__new__(cls, name, bases, attrs)
+
+
+class BaseModel(metaclass=ModelMeta):
     def __init__(self, model_type):
         self.model_type = model_type
         self.in_channels = None
@@ -128,7 +152,11 @@ class BaseModel:
                 model_name=self.model_name,
                 checkpoint=resume_checkpoint)
 
-    def get_model_info(self):
+    def get_model_info(self, get_raw_params=False, inplace=True):
+        if inplace:
+            init_params = self.init_params
+        else:
+            init_params = copy.deepcopy(self.init_params)
         info = dict()
         info['version'] = paddlers.__version__
         info['Model'] = self.__class__.__name__
@@ -138,16 +166,19 @@ class BaseModel:
              ('fixed_input_shape', self.fixed_input_shape),
              ('best_accuracy', self.best_accuracy),
              ('best_model_epoch', self.best_model_epoch)])
-        if 'self' in self.init_params:
-            del self.init_params['self']
-        if '__class__' in self.init_params:
-            del self.init_params['__class__']
-        if 'model_name' in self.init_params:
-            del self.init_params['model_name']
-        if 'params' in self.init_params:
-            del self.init_params['params']
+        if 'self' in init_params:
+            del init_params['self']
+        if '__class__' in init_params:
+            del init_params['__class__']
+        if 'model_name' in init_params:
+            del init_params['model_name']
+        if 'params' in init_params:
+            del init_params['params']
+
+        info['_init_params'] = init_params
 
-        info['_init_params'] = self.init_params
+        if get_raw_params:
+            info['raw_params'] = self._raw_params
 
         try:
             primary_metric_key = list(self.eval_metrics.keys())[0]
@@ -191,7 +222,7 @@ class BaseModel:
             if osp.exists(save_dir):
                 os.remove(save_dir)
             os.makedirs(save_dir)
-        model_info = self.get_model_info()
+        model_info = self.get_model_info(get_raw_params=True)
         model_info['status'] = self.status
 
         paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))

+ 21 - 13
paddlers/tasks/load_model.py

@@ -50,27 +50,31 @@ def load_rcnn_inference_model(model_dir):
 def load_model(model_dir, **params):
     """
     Load saved model from a given directory.
+
     Args:
         model_dir(str): The directory where the model is saved.
+
     Returns:
         The model loaded from the directory.
     """
+
     if not osp.exists(model_dir):
-        logging.error("model_dir '{}' does not exists!".format(model_dir))
+        logging.error("Directory '{}' does not exist!".format(model_dir))
     if not osp.exists(osp.join(model_dir, "model.yml")):
-        raise Exception("There's no model.yml in {}".format(model_dir))
+        raise Exception("There is no file named model.yml in {}.".format(
+            model_dir))
+
     with open(osp.join(model_dir, "model.yml")) as f:
         model_info = yaml.load(f.read(), Loader=yaml.Loader)
-    f.close()
 
     status = model_info['status']
     with_net = params.get('with_net', True)
     if not with_net:
         assert status == 'Infer', \
-            "Only exported inference models can be deployed, current model status is {}".format(status)
+            "Only exported models can be deployed for inference, but current model status is {}.".format(status)
 
     if not hasattr(paddlers.tasks, model_info['Model']):
-        raise Exception("There's no attribute {} in paddlers.tasks".format(
+        raise Exception("There is no {} attribute in paddlers.tasks.".format(
             model_info['Model']))
     if 'model_name' in model_info['_init_params']:
         del model_info['_init_params']['model_name']
@@ -78,8 +82,9 @@ def load_model(model_dir, **params):
     model_info['_init_params'].update({'with_net': with_net})
 
     with paddle.utils.unique_name.guard():
-        model = getattr(paddlers.tasks, model_info['Model'])(
-            **model_info['_init_params'])
+        params = model_info.pop('raw_params', {})
+        params.update(model_info['_init_params'])
+        model = getattr(paddlers.tasks, model_info['Model'])(**params)
         if with_net:
             if status == 'Pruned' or osp.exists(
                     osp.join(model_dir, "prune.yml")):
@@ -110,18 +115,19 @@ def load_model(model_dir, **params):
             if status == 'Infer':
                 if osp.exists(osp.join(model_dir, "quant.yml")):
                     logging.error(
-                        "Exported quantized model can not be loaded, only deployment is supported.",
+                        "Exported quantized model can not be loaded, because quant.yml is not found.",
                         exit=True)
                 model.net = model._build_inference_net()
                 if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
                     net_state_dict = load_rcnn_inference_model(model_dir)
                 else:
                     net_state_dict = paddle.load(osp.join(model_dir, 'model'))
-                    if model.model_type in ['classifier', 'segmenter'
-                                            ] and 'rc' in version:
-                        # When exporting a classifier and segmenter,
-                        # InferNet is defined to append softmax and argmax operators to the model,
-                        # so parameter name starts with 'net.'
+                    if model.model_type in [
+                            'classifier', 'segmenter', 'changedetector'
+                    ]:
+                        # When exporting a classifier, segmenter, or changedetector,
+                        # InferNet (or InferCDNet) is defined to append softmax and argmax operators to the model,
+                        # so the parameter names all start with 'net.'
                         new_net_state_dict = {}
                         for k, v in net_state_dict.items():
                             new_net_state_dict['net.' + k] = v
@@ -139,6 +145,8 @@ def load_model(model_dir, **params):
         for k, v in model_info['_Attributes'].items():
             if k in model.__dict__:
                 model.__dict__[k] = v
+
     logging.info("Model[{}] loaded.".format(model_info['Model']))
     model.status = status
+
     return model

+ 2 - 2
tutorials/train/semantic_segmentation/deeplabv3p.py

@@ -79,10 +79,10 @@ model.train(
     eval_dataset=eval_dataset,
     save_interval_epochs=5,
     # 每多少次迭代记录一次日志
-    log_interval_steps=50,
+    log_interval_steps=4,
     save_dir=EXP_DIR,
     # 初始学习率大小
-    learning_rate=0.01,
+    learning_rate=0.001,
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     early_stop=False,
     # 是否启用VisualDL日志功能

+ 60 - 0
tutorials/train/semantic_segmentation/run_with_clean_log.py

@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+
+import sys
+import subprocess
+from io import RawIOBase
+
+
+class StreamFilter(RawIOBase):
+    def __init__(self, conds, stream):
+        super().__init__()
+        self.conds = conds
+        self.stream = stream
+
+    def readinto(self, _):
+        pass
+
+    def write(self, msg):
+        if all(cond(msg) for cond in self.conds):
+            self.stream.write(msg)
+        else:
+            pass
+
+
+class CleanLog(object):
+    def __init__(self, filter_, stream_name):
+        self.filter = filter_
+        self.stream_name = stream_name
+        self.old_stream = getattr(sys, stream_name)
+
+    def __enter__(self):
+        setattr(sys, self.stream_name, self.filter)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        setattr(sys, self.stream_name, self.old_stream)
+
+
+if __name__ == '__main__':
+    if len(sys.argv) < 2:
+        raise TypeError("请指定需要运行的脚本!")
+
+    tar_file = sys.argv[1]
+    gdal_filter = StreamFilter([
+        lambda msg: "Sum of Photometric type-related color channels and ExtraSamples doesn't match SamplesPerPixel." not in msg
+    ], sys.stdout)
+    with CleanLog(gdal_filter, 'stdout'):
+        proc = subprocess.Popen(
+            ["python", tar_file],
+            stderr=subprocess.STDOUT,
+            stdout=subprocess.PIPE,
+            text=True)
+        while True:
+            try:
+                out_line = proc.stdout.readline()
+                if out_line == '' and proc.poll() is not None:
+                    break
+                if out_line:
+                    print(out_line, end='')
+            except KeyboardInterrupt:
+                import signal
+                proc.send_signal(signal.SIGINT)

+ 2 - 2
tutorials/train/semantic_segmentation/unet.py

@@ -77,10 +77,10 @@ model.train(
     eval_dataset=eval_dataset,
     save_interval_epochs=5,
     # 每多少次迭代记录一次日志
-    log_interval_steps=50,
+    log_interval_steps=4,
     save_dir=EXP_DIR,
     # 初始学习率大小
-    learning_rate=0.01,
+    learning_rate=0.001,
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     early_stop=False,
     # 是否启用VisualDL日志功能