|
@@ -43,42 +43,13 @@ class DiceLoss(nn.Layer):
|
|
|
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true)
|
|
|
|
|
|
|
|
|
-class MultiClassDiceLoss(nn.Layer):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- weight,
|
|
|
- batch=True,
|
|
|
- ignore_index=-1,
|
|
|
- do_softmax=False,
|
|
|
- **kwargs, ):
|
|
|
- super(MultiClassDiceLoss, self).__init__()
|
|
|
- self.ignore_index = ignore_index
|
|
|
- self.weight = weight
|
|
|
- self.do_softmax = do_softmax
|
|
|
- self.binary_diceloss = DiceLoss(batch)
|
|
|
-
|
|
|
- def forward(self, y_pred, y_true):
|
|
|
- if self.do_softmax:
|
|
|
- y_pred = paddle.nn.functional.softmax(y_pred, axis=1)
|
|
|
- y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2)
|
|
|
- total_loss = 0.0
|
|
|
- tmp_i = 0.0
|
|
|
- for i in range(y_pred.shape[1]):
|
|
|
- if i != self.ignore_index:
|
|
|
- diceloss = self.binary_diceloss(y_pred[:, i, :, :],
|
|
|
- y_true[:, i, :, :])
|
|
|
- total_loss += paddle.multiply(diceloss, self.weight[i])
|
|
|
- tmp_i += 1.0
|
|
|
- return total_loss / tmp_i
|
|
|
-
|
|
|
-
|
|
|
class DiceBCELoss(nn.Layer):
|
|
|
"""Binary change detection task loss"""
|
|
|
|
|
|
def __init__(self):
|
|
|
super(DiceBCELoss, self).__init__()
|
|
|
self.bce_loss = nn.BCELoss()
|
|
|
- self.binnary_dice = DiceLoss()
|
|
|
+ self.binary_dice = DiceLoss()
|
|
|
|
|
|
def forward(self, scores, labels, do_sigmoid=True):
|
|
|
if len(scores.shape) > 3:
|
|
@@ -87,29 +58,11 @@ class DiceBCELoss(nn.Layer):
|
|
|
labels = labels.squeeze(1)
|
|
|
if do_sigmoid:
|
|
|
scores = paddle.nn.functional.sigmoid(scores.clone())
|
|
|
- diceloss = self.binnary_dice(scores, labels)
|
|
|
+ diceloss = self.binary_dice(scores, labels)
|
|
|
bceloss = self.bce_loss(scores, labels)
|
|
|
return diceloss + bceloss
|
|
|
|
|
|
|
|
|
-class McDiceBCELoss(nn.Layer):
|
|
|
- """Multi-class change detection task loss"""
|
|
|
-
|
|
|
- def __init__(self, weight, do_sigmoid=True):
|
|
|
- super(McDiceBCELoss, self).__init__()
|
|
|
- self.ce_loss = nn.CrossEntropyLoss(weight)
|
|
|
- self.dice = MultiClassDiceLoss(weight, do_sigmoid)
|
|
|
-
|
|
|
- def forward(self, scores, labels):
|
|
|
- if len(scores.shape) < 4:
|
|
|
- scores = scores.unsqueeze(1)
|
|
|
- if len(labels.shape) < 4:
|
|
|
- labels = labels.unsqueeze(1)
|
|
|
- diceloss = self.dice(scores, labels)
|
|
|
- bceloss = self.ce_loss(scores, labels)
|
|
|
- return diceloss + bceloss
|
|
|
-
|
|
|
-
|
|
|
def fccdn_ssl_loss(logits_list, labels):
|
|
|
"""
|
|
|
Self-supervised learning loss for change detection.
|
|
@@ -160,11 +113,11 @@ def fccdn_ssl_loss(logits_list, labels):
|
|
|
|
|
|
# Seg loss
|
|
|
labels_downsample = labels_downsample.astype(paddle.float32)
|
|
|
- loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False)
|
|
|
- loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False)
|
|
|
- loss_aux += 0.2 * criterion_ssl(
|
|
|
- out3, labels_downsample - pred_seg_post_tmp2, False)
|
|
|
- loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
|
|
|
- False)
|
|
|
+ loss_aux = criterion_ssl(out1, pred_seg_post_tmp1, False)
|
|
|
+ loss_aux += criterion_ssl(out2, pred_seg_pre_tmp1, False)
|
|
|
+ loss_aux += criterion_ssl(out3, labels_downsample - pred_seg_post_tmp2,
|
|
|
+ False)
|
|
|
+ loss_aux += criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
|
|
|
+ False)
|
|
|
|
|
|
return loss_aux
|