Explorar el Código

Remove unused classes in fccdn

Bobholamovic hace 2 años
padre
commit
92a5086c79

+ 8 - 55
paddlers/rs_models/cd/losses/fccdn_loss.py

@@ -43,42 +43,13 @@ class DiceLoss(nn.Layer):
         return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true)
 
 
-class MultiClassDiceLoss(nn.Layer):
-    def __init__(
-            self,
-            weight,
-            batch=True,
-            ignore_index=-1,
-            do_softmax=False,
-            **kwargs, ):
-        super(MultiClassDiceLoss, self).__init__()
-        self.ignore_index = ignore_index
-        self.weight = weight
-        self.do_softmax = do_softmax
-        self.binary_diceloss = DiceLoss(batch)
-
-    def forward(self, y_pred, y_true):
-        if self.do_softmax:
-            y_pred = paddle.nn.functional.softmax(y_pred, axis=1)
-        y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2)
-        total_loss = 0.0
-        tmp_i = 0.0
-        for i in range(y_pred.shape[1]):
-            if i != self.ignore_index:
-                diceloss = self.binary_diceloss(y_pred[:, i, :, :],
-                                                y_true[:, i, :, :])
-                total_loss += paddle.multiply(diceloss, self.weight[i])
-                tmp_i += 1.0
-        return total_loss / tmp_i
-
-
 class DiceBCELoss(nn.Layer):
     """Binary change detection task loss"""
 
     def __init__(self):
         super(DiceBCELoss, self).__init__()
         self.bce_loss = nn.BCELoss()
-        self.binnary_dice = DiceLoss()
+        self.binary_dice = DiceLoss()
 
     def forward(self, scores, labels, do_sigmoid=True):
         if len(scores.shape) > 3:
@@ -87,29 +58,11 @@ class DiceBCELoss(nn.Layer):
             labels = labels.squeeze(1)
         if do_sigmoid:
             scores = paddle.nn.functional.sigmoid(scores.clone())
-        diceloss = self.binnary_dice(scores, labels)
+        diceloss = self.binary_dice(scores, labels)
         bceloss = self.bce_loss(scores, labels)
         return diceloss + bceloss
 
 
-class McDiceBCELoss(nn.Layer):
-    """Multi-class change detection task loss"""
-
-    def __init__(self, weight, do_sigmoid=True):
-        super(McDiceBCELoss, self).__init__()
-        self.ce_loss = nn.CrossEntropyLoss(weight)
-        self.dice = MultiClassDiceLoss(weight, do_sigmoid)
-
-    def forward(self, scores, labels):
-        if len(scores.shape) < 4:
-            scores = scores.unsqueeze(1)
-        if len(labels.shape) < 4:
-            labels = labels.unsqueeze(1)
-        diceloss = self.dice(scores, labels)
-        bceloss = self.ce_loss(scores, labels)
-        return diceloss + bceloss
-
-
 def fccdn_ssl_loss(logits_list, labels):
     """
     Self-supervised learning loss for change detection.
@@ -160,11 +113,11 @@ def fccdn_ssl_loss(logits_list, labels):
 
     # Seg loss
     labels_downsample = labels_downsample.astype(paddle.float32)
-    loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False)
-    loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False)
-    loss_aux += 0.2 * criterion_ssl(
-        out3, labels_downsample - pred_seg_post_tmp2, False)
-    loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
-                                    False)
+    loss_aux = criterion_ssl(out1, pred_seg_post_tmp1, False)
+    loss_aux += criterion_ssl(out2, pred_seg_pre_tmp1, False)
+    loss_aux += criterion_ssl(out3, labels_downsample - pred_seg_post_tmp2,
+                              False)
+    loss_aux += criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
+                              False)
 
     return loss_aux

+ 1 - 1
paddlers/tasks/change_detector.py

@@ -1067,7 +1067,7 @@ class FCCDN(BaseChangeDetector):
             return {
                 'types':
                 [seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss],
-                'coef': [1.0, 1.0]
+                'coef': [1.0, 0.2]
             }
         else:
             raise ValueError(

+ 1 - 1
tutorials/train/change_detection/fccdn.py

@@ -78,7 +78,7 @@ model = pdrs.tasks.cd.FCCDN()
 
 # 执行模型训练
 model.train(
-    num_epochs=10,
+    num_epochs=15,
     train_dataset=train_dataset,
     train_batch_size=4,
     eval_dataset=eval_dataset,