Bläddra i källkod

[Feat] Add Interface to Set Losses (#18)

* Add interfaces to set losses

* Fix typo

* Fix import bugs

* import custom_models->rs_models
Lin Manhui 2 år sedan
förälder
incheckning
3f8ce38fb0
4 ändrade filer med 101 tillägg och 28 borttagningar
  1. 3 0
      paddlers/models/__init__.py
  2. 38 11
      paddlers/tasks/change_detector.py
  3. 28 7
      paddlers/tasks/classifier.py
  4. 32 10
      paddlers/tasks/segmenter.py

+ 3 - 0
paddlers/models/__init__.py

@@ -13,3 +13,6 @@
 # limitations under the License.
 
 from . import ppcls, ppdet, ppseg, ppgan
+import paddlers.models.ppseg.models.losses as seg_losses
+import paddlers.models.ppdet.modeling.losses as det_losses
+import paddlers.models.ppcls.loss as clas_losses

+ 38 - 11
paddlers/tasks/change_detector.py

@@ -28,6 +28,7 @@ import paddlers
 import paddlers.models.ppseg as ppseg
 import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
+from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
@@ -45,6 +46,7 @@ class BaseChangeDetector(BaseModel):
                  model_name,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  **params):
         self.init_params = locals()
         if 'with_net' in self.init_params:
@@ -56,7 +58,7 @@ class BaseChangeDetector(BaseModel):
         self.model_name = model_name
         self.num_classes = num_classes
         self.use_mixed_loss = use_mixed_loss
-        self.losses = None
+        self.losses = losses
         self.labels = None
         if params.get('with_net', True):
             params.pop('with_net', None)
@@ -178,13 +180,13 @@ class BaseChangeDetector(BaseModel):
         if isinstance(self.use_mixed_loss, bool):
             if self.use_mixed_loss:
                 losses = [
-                    ppseg.models.CrossEntropyLoss(),
-                    ppseg.models.LovaszSoftmaxLoss()
+                    seg_losses.CrossEntropyLoss(),
+                    seg_losses.LovaszSoftmaxLoss()
                 ]
                 coef = [.8, .2]
-                loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ]
+                loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
             else:
-                loss_type = [ppseg.models.CrossEntropyLoss()]
+                loss_type = [seg_losses.CrossEntropyLoss()]
         else:
             losses, coef = list(zip(*self.use_mixed_loss))
             if not set(losses).issubset(
@@ -192,8 +194,8 @@ class BaseChangeDetector(BaseModel):
                 raise ValueError(
                     "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
                 )
-            losses = [getattr(ppseg.models, loss)() for loss in losses]
-            loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))]
+            losses = [getattr(seg_losses, loss)() for loss in losses]
+            loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
         loss_coef = [1.0]
         losses = {'types': loss_type, 'coef': loss_coef}
         return losses
@@ -810,11 +812,17 @@ class BaseChangeDetector(BaseModel):
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeChangeDetector object.")
 
+    def set_losses(self, losses, weights=None):
+        if weights is None:
+            weights = [1. for _ in range(len(losses))]
+        self.losses = {'types': losses, 'coef': weights}
+
 
 class CDNet(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=6,
                  **params):
         params.update({'in_channels': in_channels})
@@ -822,6 +830,7 @@ class CDNet(BaseChangeDetector):
             model_name='CDNet',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -829,6 +838,7 @@ class FCEarlyFusion(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=6,
                  use_dropout=False,
                  **params):
@@ -837,6 +847,7 @@ class FCEarlyFusion(BaseChangeDetector):
             model_name='FCEarlyFusion',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -844,6 +855,7 @@ class FCSiamConc(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=3,
                  use_dropout=False,
                  **params):
@@ -852,6 +864,7 @@ class FCSiamConc(BaseChangeDetector):
             model_name='FCSiamConc',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -859,6 +872,7 @@ class FCSiamDiff(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=3,
                  use_dropout=False,
                  **params):
@@ -867,6 +881,7 @@ class FCSiamDiff(BaseChangeDetector):
             model_name='FCSiamDiff',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -874,6 +889,7 @@ class STANet(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=3,
                  att_type='BAM',
                  ds_factor=1,
@@ -887,6 +903,7 @@ class STANet(BaseChangeDetector):
             model_name='STANet',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -894,6 +911,7 @@ class BIT(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=3,
                  backbone='resnet18',
                  n_stages=4,
@@ -925,6 +943,7 @@ class BIT(BaseChangeDetector):
             model_name='BIT',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -932,6 +951,7 @@ class SNUNet(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=3,
                  width=32,
                  **params):
@@ -940,6 +960,7 @@ class SNUNet(BaseChangeDetector):
             model_name='SNUNet',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -947,6 +968,7 @@ class DSIFN(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  use_dropout=False,
                  **params):
         params.update({'use_dropout': use_dropout})
@@ -954,13 +976,14 @@ class DSIFN(BaseChangeDetector):
             model_name='DSIFN',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
     def default_loss(self):
         if self.use_mixed_loss is False:
             return {
                 # XXX: make sure the shallow copy works correctly here.
-                'types': [ppseg.models.CrossEntropyLoss()] * 5,
+                'types': [seg_losses.CrossEntropyLoss()] * 5,
                 'coef': [1.0] * 5
             }
         else:
@@ -973,6 +996,7 @@ class DSAMNet(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  in_channels=3,
                  ca_ratio=8,
                  sa_kernel=7,
@@ -986,14 +1010,15 @@ class DSAMNet(BaseChangeDetector):
             model_name='DSAMNet',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
     def default_loss(self):
         if self.use_mixed_loss is False:
             return {
                 'types': [
-                    ppseg.models.CrossEntropyLoss(), ppseg.models.DiceLoss(),
-                    ppseg.models.DiceLoss()
+                    seg_losses.CrossEntropyLoss(), seg_losses.DiceLoss(),
+                    seg_losses.DiceLoss()
                 ],
                 'coef': [1.0, 0.05, 0.05]
             }
@@ -1007,6 +1032,7 @@ class ChangeStar(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  mid_channels=256,
                  inner_channels=16,
                  num_convs=4,
@@ -1022,13 +1048,14 @@ class ChangeStar(BaseChangeDetector):
             model_name='ChangeStar',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
     def default_loss(self):
         if self.use_mixed_loss is False:
             return {
                 # XXX: make sure the shallow copy works correctly here.
-                'types': [ppseg.models.CrossEntropyLoss()] * 4,
+                'types': [seglosses.CrossEntropyLoss()] * 4,
                 'coef': [1.0] * 4
             }
         else:

+ 28 - 7
paddlers/tasks/classifier.py

@@ -28,7 +28,7 @@ import paddlers.rs_models.clas as cmcls
 import paddlers.utils.logging as logging
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.models.ppcls.metric import build_metrics
-from paddlers.models.ppcls.loss import build_loss
+from paddlers.models import clas_losses
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
 from paddlers.transforms import Resize, decode_image
@@ -45,6 +45,7 @@ class BaseClassifier(BaseModel):
                  in_channels=3,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  **params):
         self.init_params = locals()
         if 'with_net' in self.init_params:
@@ -59,7 +60,7 @@ class BaseClassifier(BaseModel):
         self.num_classes = num_classes
         self.use_mixed_loss = use_mixed_loss
         self.metrics = None
-        self.losses = None
+        self.losses = losses
         self.labels = None
         self._postprocess = None
         if params.get('with_net', True):
@@ -145,7 +146,7 @@ class BaseClassifier(BaseModel):
     def default_loss(self):
         # TODO: use mixed loss and other loss
         default_config = [{"CELoss": {"weight": 1.0}}]
-        return build_loss(default_config)
+        return clas_losses.build_loss(default_config)
 
     def default_optimizer(self,
                           parameters,
@@ -556,36 +557,56 @@ class BaseClassifier(BaseModel):
 
 
 class ResNet50_vd(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
         super(ResNet50_vd, self).__init__(
             model_name='ResNet50_vd',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
 class MobileNetV3_small_x1_0(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
         super(MobileNetV3_small_x1_0, self).__init__(
             model_name='MobileNetV3_small_x1_0',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
 class HRNet_W18_C(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
         super(HRNet_W18_C, self).__init__(
             model_name='HRNet_W18_C',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
 class CondenseNetV2_b(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
         super(CondenseNetV2_b, self).__init__(
             model_name='CondenseNetV2_b',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)

+ 32 - 10
paddlers/tasks/segmenter.py

@@ -26,10 +26,11 @@ from paddle.static import InputSpec
 import paddlers
 import paddlers.models.ppseg as ppseg
 import paddlers.rs_models.seg as cmseg
-from paddlers.utils import get_single_card_bs, DisablePrint
 import paddlers.utils.logging as logging
-from paddlers.utils.checkpoint import seg_pretrain_weights_dict
+from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image
+from paddlers.utils import get_single_card_bs, DisablePrint
+from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 
@@ -41,6 +42,7 @@ class BaseSegmenter(BaseModel):
                  model_name,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  **params):
         self.init_params = locals()
         if 'with_net' in self.init_params:
@@ -53,7 +55,7 @@ class BaseSegmenter(BaseModel):
         self.model_name = model_name
         self.num_classes = num_classes
         self.use_mixed_loss = use_mixed_loss
-        self.losses = None
+        self.losses = losses
         self.labels = None
         if params.get('with_net', True):
             params.pop('with_net', None)
@@ -160,13 +162,13 @@ class BaseSegmenter(BaseModel):
         if isinstance(self.use_mixed_loss, bool):
             if self.use_mixed_loss:
                 losses = [
-                    ppseg.models.CrossEntropyLoss(),
-                    ppseg.models.LovaszSoftmaxLoss()
+                    seg_losses.CrossEntropyLoss(),
+                    seg_losses.LovaszSoftmaxLoss()
                 ]
                 coef = [.8, .2]
-                loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ]
+                loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
             else:
-                loss_type = [ppseg.models.CrossEntropyLoss()]
+                loss_type = [seg_losses.CrossEntropyLoss()]
         else:
             losses, coef = list(zip(*self.use_mixed_loss))
             if not set(losses).issubset(
@@ -174,8 +176,8 @@ class BaseSegmenter(BaseModel):
                 raise ValueError(
                     "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
                 )
-            losses = [getattr(ppseg.models, loss)() for loss in losses]
-            loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))]
+            losses = [getattr(seg_losses, loss)() for loss in losses]
+            loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
         if self.model_name == 'FastSCNN':
             loss_type *= 2
             loss_coef = [1.0, 0.4]
@@ -771,12 +773,18 @@ class BaseSegmenter(BaseModel):
             raise TypeError(
                 "`transforms.arrange` must be an ArrangeSegmenter object.")
 
+    def set_losses(self, losses, weights=None):
+        if weights is None:
+            weights = [1. for _ in range(len(losses))]
+        self.losses = {'types': losses, 'coef': weights}
+
 
 class UNet(BaseSegmenter):
     def __init__(self,
                  input_channel=3,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  use_deconv=False,
                  align_corners=False,
                  **params):
@@ -789,6 +797,7 @@ class UNet(BaseSegmenter):
             input_channel=input_channel,
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -798,6 +807,7 @@ class DeepLabV3P(BaseSegmenter):
                  num_classes=2,
                  backbone='ResNet50_vd',
                  use_mixed_loss=False,
+                 losses=None,
                  output_stride=8,
                  backbone_indices=(0, 3),
                  aspp_ratios=(1, 12, 24, 36),
@@ -826,6 +836,7 @@ class DeepLabV3P(BaseSegmenter):
             model_name='DeepLabV3P',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -833,6 +844,7 @@ class FastSCNN(BaseSegmenter):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  align_corners=False,
                  **params):
         params.update({'align_corners': align_corners})
@@ -840,6 +852,7 @@ class FastSCNN(BaseSegmenter):
             model_name='FastSCNN',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
@@ -848,6 +861,7 @@ class HRNet(BaseSegmenter):
                  num_classes=2,
                  width=48,
                  use_mixed_loss=False,
+                 losses=None,
                  align_corners=False,
                  **params):
         if width not in (18, 48):
@@ -867,6 +881,7 @@ class HRNet(BaseSegmenter):
             model_name='FCN',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
         self.model_name = 'HRNet'
 
@@ -875,6 +890,7 @@ class BiSeNetV2(BaseSegmenter):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
+                 losses=None,
                  align_corners=False,
                  **params):
         params.update({'align_corners': align_corners})
@@ -882,13 +898,19 @@ class BiSeNetV2(BaseSegmenter):
             model_name='BiSeNetV2',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
 
 
 class FarSeg(BaseSegmenter):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
         super(FarSeg, self).__init__(
             model_name='FarSeg',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)