|
@@ -23,12 +23,12 @@ from attach_tools import Attach
|
|
|
attach = Attach.to(paddlers.tasks.change_detector)
|
|
|
|
|
|
|
|
|
-def make_trainer(net_type, *args, **kwargs):
|
|
|
+def make_trainer(net_type, attach_trainer=True):
|
|
|
def _init_func(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
losses=None,
|
|
|
- **params):
|
|
|
+ **_params_):
|
|
|
sig = inspect.signature(net_type.__init__)
|
|
|
net_params = {
|
|
|
k: p.default
|
|
@@ -36,7 +36,13 @@ def make_trainer(net_type, *args, **kwargs):
|
|
|
}
|
|
|
net_params.pop('self', None)
|
|
|
net_params.pop('num_classes', None)
|
|
|
- net_params.update(params)
|
|
|
+ # Special rule to parse arguments from `_params_`.
|
|
|
+ # When using pdrs.tasks.load_model, `_params_`` is a dict with the key '_params_'.
|
|
|
+ # This bypasses the dynamic modification/creation of function signature.
|
|
|
+ if '_params_' not in _params_:
|
|
|
+ net_params.update(_params_)
|
|
|
+ else:
|
|
|
+ net_params.update(_params_['_params_'])
|
|
|
|
|
|
super(trainer_type, self).__init__(
|
|
|
model_name=net_type.__name__,
|
|
@@ -52,7 +58,13 @@ def make_trainer(net_type, *args, **kwargs):
|
|
|
|
|
|
trainer_type = type(trainer_name, (BaseChangeDetector, ),
|
|
|
{'__init__': _init_func})
|
|
|
+ if attach_trainer:
|
|
|
+ trainer_type = attach(trainer_type)
|
|
|
+ return trainer_type
|
|
|
|
|
|
+
|
|
|
+def make_trainer_and_build(net_type, *args, **kwargs):
|
|
|
+ trainer_type = make_trainer(net_type, attach_trainer=True)
|
|
|
return trainer_type(*args, **kwargs)
|
|
|
|
|
|
|