|
@@ -12,6 +12,8 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import inspect
|
|
|
+
|
|
|
import paddle
|
|
|
import paddlers
|
|
|
from paddlers.tasks.change_detector import BaseChangeDetector
|
|
@@ -27,12 +29,21 @@ def make_trainer(net_type, *args, **kwargs):
|
|
|
use_mixed_loss=False,
|
|
|
losses=None,
|
|
|
**params):
|
|
|
- super().__init__(
|
|
|
+ sig = inspect.signature(net_type.__init__)
|
|
|
+ net_params = {
|
|
|
+ k: p.default
|
|
|
+ for k, p in sig.parameters.items() if not p.default is p.empty
|
|
|
+ }
|
|
|
+ net_params.pop('self', None)
|
|
|
+ net_params.pop('num_classes', None)
|
|
|
+ net_params.update(params)
|
|
|
+
|
|
|
+ super(trainer_type, self).__init__(
|
|
|
model_name=net_type.__name__,
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
losses=losses,
|
|
|
- **params)
|
|
|
+ **net_params)
|
|
|
|
|
|
if not issubclass(net_type, paddle.nn.Layer):
|
|
|
raise TypeError("net must be a subclass of paddle.nn.Layer")
|