custom_trainer.py 810 B

1234567891011121314151617181920212223242526272829
  1. import paddlers
  2. from paddlers.tasks.change_detector import BaseChangeDetector
  3. from attach_tools import Attach
  4. attach = Attach.to(paddlers.tasks.change_detector)
  5. @attach
  6. class CustomTrainer(BaseChangeDetector):
  7. def __init__(self,
  8. num_classes=2,
  9. use_mixed_loss=False,
  10. losses=None,
  11. in_channels=3,
  12. att_types='ct',
  13. use_dropout=False,
  14. **params):
  15. params.update({
  16. 'in_channels': in_channels,
  17. 'att_types': att_types,
  18. 'use_dropout': use_dropout
  19. })
  20. super().__init__(
  21. model_name='CustomModel',
  22. num_classes=num_classes,
  23. use_mixed_loss=use_mixed_loss,
  24. losses=losses,
  25. **params)