custom_trainer.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import paddle
  2. import paddlers
  3. from paddlers.tasks.change_detector import BaseChangeDetector
  4. from attach_tools import Attach
  5. attach = Attach.to(paddlers.tasks.change_detector)
  6. def make_trainer(net_type, *args, **kwargs):
  7. def _init_func(self,
  8. num_classes=2,
  9. use_mixed_loss=False,
  10. losses=None,
  11. **params):
  12. super().__init__(
  13. model_name=net_type.__name__,
  14. num_classes=num_classes,
  15. use_mixed_loss=use_mixed_loss,
  16. losses=losses,
  17. **params)
  18. if not issubclass(net_type, paddle.nn.Layer):
  19. raise TypeError("net must be a subclass of paddle.nn.Layer")
  20. trainer_name = net_type.__name__
  21. trainer_type = type(trainer_name, (BaseChangeDetector, ),
  22. {'__init__': _init_func})
  23. return trainer_type(*args, **kwargs)
  24. @attach
  25. class CustomTrainer(BaseChangeDetector):
  26. def __init__(self,
  27. num_classes=2,
  28. use_mixed_loss=False,
  29. losses=None,
  30. in_channels=3,
  31. att_types='ct',
  32. use_dropout=False,
  33. **params):
  34. params.update({
  35. 'in_channels': in_channels,
  36. 'att_types': att_types,
  37. 'use_dropout': use_dropout
  38. })
  39. super().__init__(
  40. model_name='CustomModel',
  41. num_classes=num_classes,
  42. use_mixed_loss=use_mixed_loss,
  43. losses=losses,
  44. **params)