Bladeren bron

[论文复现赛] FCCDN (#23)

精度验收通过,代码符合规范,论文复现成功。

Co-authored-by: liuxtakeoff <763848861.qq.com>
liuxtakeoff 2 jaren geleden
bovenliggende
commit
bbbbd3c7c1

+ 2 - 0
paddlers/rs_models/cd/__init__.py

@@ -23,3 +23,5 @@ from .fc_ef import FCEarlyFusion
 from .fc_siam_conc import FCSiamConc
 from .fc_siam_diff import FCSiamDiff
 from .changeformer import ChangeFormer
+from .fccdn import FCCDN
+from .losses import fccdn_ssl_loss

+ 478 - 0
paddlers/rs_models/cd/fccdn.py

@@ -0,0 +1,478 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3
+
+bn_mom = 1 - 0.0003
+
+
+class NLBlock(nn.Layer):
+    def __init__(self, in_channels):
+        super(NLBlock, self).__init__()
+        self.conv_v = BasicConv(
+            in_ch=in_channels,
+            out_ch=in_channels,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_channels, momentum=0.9))
+        self.W = BasicConv(
+            in_ch=in_channels,
+            out_ch=in_channels,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_channels, momentum=0.9),
+            act=nn.ReLU())
+
+    def forward(self, x):
+        batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
+        value = self.conv_v(x)
+        value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]])
+        value = value.transpose([0, 2, 1])  # B * (H*W) * value_channels
+        key = x.reshape([batch_size, c, h * w])  # B * key_channels * (H*W)
+        query = x.reshape([batch_size, c, h * w])
+        query = query.transpose([0, 2, 1])
+
+        sim_map = paddle.matmul(query, key)  # B * (H*W) * (H*W)
+        sim_map = (c**-.5) * sim_map  # B * (H*W) * (H*W)
+        sim_map = nn.functional.softmax(sim_map, axis=-1)  # B * (H*W) * (H*W)
+
+        context = paddle.matmul(sim_map, value)
+        context = context.transpose([0, 2, 1])
+        context = context.reshape([batch_size, c, *x.shape[2:]])
+        context = self.W(context)
+
+        return context
+
+
+class NLFPN(nn.Layer):
+    """ Non-local feature parymid network"""
+
+    def __init__(self, in_dim, reduction=True):
+        super(NLFPN, self).__init__()
+        if reduction:
+            self.reduction = BasicConv(
+                in_ch=in_dim,
+                out_ch=in_dim // 4,
+                kernel_size=1,
+                norm=nn.BatchNorm2D(
+                    in_dim // 4, momentum=bn_mom),
+                act=nn.ReLU())
+            self.re_reduction = BasicConv(
+                in_ch=in_dim // 4,
+                out_ch=in_dim,
+                kernel_size=1,
+                norm=nn.BatchNorm2D(
+                    in_dim, momentum=bn_mom),
+                act=nn.ReLU())
+            in_dim = in_dim // 4
+        else:
+            self.reduction = None
+            self.re_reduction = None
+        self.conv_e1 = BasicConv(
+            in_dim,
+            in_dim,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_dim, momentum=bn_mom),
+            act=nn.ReLU())
+        self.conv_e2 = BasicConv(
+            in_dim,
+            in_dim * 2,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_dim * 2, momentum=bn_mom),
+            act=nn.ReLU())
+        self.conv_e3 = BasicConv(
+            in_dim * 2,
+            in_dim * 4,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_dim * 4, momentum=bn_mom),
+            act=nn.ReLU())
+        self.conv_d1 = BasicConv(
+            in_dim,
+            in_dim,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_dim, momentum=bn_mom),
+            act=nn.ReLU())
+        self.conv_d2 = BasicConv(
+            in_dim * 2,
+            in_dim,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_dim, momentum=bn_mom),
+            act=nn.ReLU())
+        self.conv_d3 = BasicConv(
+            in_dim * 4,
+            in_dim * 2,
+            kernel_size=3,
+            norm=nn.BatchNorm2D(
+                in_dim * 2, momentum=bn_mom),
+            act=nn.ReLU())
+        self.nl3 = NLBlock(in_dim * 2)
+        self.nl2 = NLBlock(in_dim)
+        self.nl1 = NLBlock(in_dim)
+
+        self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2)
+        self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2)
+
+    def forward(self, x):
+        if self.reduction is not None:
+            x = self.reduction(x)
+        e1 = self.conv_e1(x)  # C,H,W
+        e2 = self.conv_e2(self.downsample_x2(e1))  # 2C,H/2,W/2
+        e3 = self.conv_e3(self.downsample_x2(e2))  # 4C,H/4,W/4
+
+        d3 = self.conv_d3(e3)  # 2C,H/4,W/4
+        nl = self.nl3(d3)
+        d3 = self.upsample_x2(paddle.multiply(d3, nl))  ##2C,H/2,W/2
+        d2 = self.conv_d2(e2 + d3)  # C,H/2,W/2
+        nl = self.nl2(d2)
+        d2 = self.upsample_x2(paddle.multiply(d2, nl))  # C,H,W
+        d1 = self.conv_d1(e1 + d2)
+        nl = self.nl1(d1)
+        d1 = paddle.multiply(d1, nl)  # C,H,W
+        if self.re_reduction is not None:
+            d1 = self.re_reduction(d1)
+
+        return d1
+
+
+class Cat(nn.Layer):
+    def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False):
+        super(Cat, self).__init__()
+        self.do_upsample = upsample
+        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+        self.conv2d = BasicConv(
+            in_chn_high + in_chn_low,
+            out_chn,
+            kernel_size=1,
+            norm=nn.BatchNorm2D(
+                out_chn, momentum=bn_mom),
+            act=nn.ReLU())
+
+    def forward(self, x, y):
+        if self.do_upsample:
+            x = self.upsample(x)
+
+        x = paddle.concat((x, y), 1)
+
+        return self.conv2d(x)
+
+
+class DoubleConv(nn.Layer):
+    def __init__(self, in_chn, out_chn, stride=1, dilation=1):
+        super(DoubleConv, self).__init__()
+        self.conv = nn.Sequential(
+            nn.Conv2D(
+                in_chn,
+                out_chn,
+                kernel_size=3,
+                stride=stride,
+                dilation=dilation,
+                padding=dilation),
+            nn.BatchNorm2D(
+                out_chn, momentum=bn_mom),
+            nn.ReLU(),
+            nn.Conv2D(
+                out_chn, out_chn, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2D(
+                out_chn, momentum=bn_mom),
+            nn.ReLU())
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class SEModule(nn.Layer):
+    def __init__(self, channels, reduction_channels):
+        super(SEModule, self).__init__()
+        self.fc1 = nn.Conv2D(
+            channels,
+            reduction_channels,
+            kernel_size=1,
+            padding=0,
+            bias_attr=True)
+        self.ReLU = nn.ReLU()
+        self.fc2 = nn.Conv2D(
+            reduction_channels,
+            channels,
+            kernel_size=1,
+            padding=0,
+            bias_attr=True)
+
+    def forward(self, x):
+        x_se = x.reshape(
+            [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape(
+                [x.shape[0], x.shape[1], 1, 1])
+
+        x_se = self.fc1(x_se)
+        x_se = self.ReLU(x_se)
+        x_se = self.fc2(x_se)
+        return x * F.sigmoid(x_se)
+
+
+class BasicBlock(nn.Layer):
+    expansion = 1
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 downsample=None,
+                 use_se=False,
+                 stride=1,
+                 dilation=1):
+        super(BasicBlock, self).__init__()
+        first_planes = planes
+        outplanes = planes * self.expansion
+
+        self.conv1 = DoubleConv(inplanes, first_planes)
+        self.conv2 = DoubleConv(
+            first_planes, outplanes, stride=stride, dilation=dilation)
+        self.se = SEModule(outplanes, planes // 4) if use_se else None
+        self.downsample = MaxPool2x2() if downsample else None
+        self.ReLU = nn.ReLU()
+
+    def forward(self, x):
+        out = self.conv1(x)
+        residual = out
+        out = self.conv2(out)
+
+        if self.se is not None:
+            out = self.se(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(residual)
+
+        out = out + residual
+        out = self.ReLU(out)
+        return out
+
+
+class DenseCatAdd(nn.Layer):
+    def __init__(self, in_chn, out_chn):
+        super(DenseCatAdd, self).__init__()
+        self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+        self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+        self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+        self.conv_out = BasicConv(
+            in_chn,
+            out_chn,
+            kernel_size=1,
+            norm=nn.BatchNorm2D(
+                out_chn, momentum=bn_mom),
+            act=nn.ReLU())
+
+    def forward(self, x, y):
+        x1 = self.conv1(x)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x2 + x1)
+
+        y1 = self.conv1(y)
+        y2 = self.conv2(y1)
+        y3 = self.conv3(y2 + y1)
+
+        return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3)
+
+
+class DenseCatDiff(nn.Layer):
+    def __init__(self, in_chn, out_chn):
+        super(DenseCatDiff, self).__init__()
+        self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+        self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+        self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+        self.conv_out = BasicConv(
+            in_ch=in_chn,
+            out_ch=out_chn,
+            kernel_size=1,
+            norm=nn.BatchNorm2D(
+                out_chn, momentum=bn_mom),
+            act=nn.ReLU())
+
+    def forward(self, x, y):
+        x1 = self.conv1(x)
+        x2 = self.conv2(x1)
+        x3 = self.conv3(x2 + x1)
+
+        y1 = self.conv1(y)
+        y2 = self.conv2(y1)
+        y3 = self.conv3(y2 + y1)
+        out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3))
+        return out
+
+
+class DFModule(nn.Layer):
+    """Dense connection-based feature fusion module"""
+
+    def __init__(self, dim_in, dim_out, reduction=True):
+        super(DFModule, self).__init__()
+        if reduction:
+            self.reduction = Conv1x1(
+                dim_in,
+                dim_in // 2,
+                norm=nn.BatchNorm2D(
+                    dim_in // 2, momentum=bn_mom),
+                act=nn.ReLU())
+            dim_in = dim_in // 2
+        else:
+            self.reduction = None
+        self.cat1 = DenseCatAdd(dim_in, dim_out)
+        self.cat2 = DenseCatDiff(dim_in, dim_out)
+        self.conv1 = Conv3x3(
+            dim_out,
+            dim_out,
+            norm=nn.BatchNorm2D(
+                dim_out, momentum=bn_mom),
+            act=nn.ReLU())
+
+    def forward(self, x1, x2):
+        if self.reduction is not None:
+            x1 = self.reduction(x1)
+            x2 = self.reduction(x2)
+        x_add = self.cat1(x1, x2)
+        x_diff = self.cat2(x1, x2)
+        y = self.conv1(x_diff) + x_add
+        return y
+
+
+class FCCDN(nn.Layer):
+    """
+    The FCCDN implementation based on PaddlePaddle.
+
+    The original article refers to
+        Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
+        (https://arxiv.org/pdf/2105.10860.pdf).
+
+    Args:
+        in_channels (int): Number of input channels. Default: 3.
+        num_classes (int): Number of target classes. Default: 2.
+        os (int): Number of output stride. Default: 16.
+        use_se (bool): Whether to use SEModule. Default: True.
+    """
+
+    def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True):
+        super(FCCDN, self).__init__()
+        if os >= 16:
+            dilation_list = [1, 1, 1, 1]
+            stride_list = [2, 2, 2, 2]
+            pool_list = [True, True, True, True]
+        elif os == 8:
+            dilation_list = [2, 1, 1, 1]
+            stride_list = [1, 2, 2, 2]
+            pool_list = [False, True, True, True]
+        else:
+            dilation_list = [2, 2, 1, 1]
+            stride_list = [1, 1, 2, 2]
+            pool_list = [False, False, True, True]
+        se_list = [use_se, use_se, use_se, use_se]
+        channel_list = [256, 128, 64, 32]
+        # Encoder
+        self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3],
+                                 se_list[3], stride_list[3], dilation_list[3])
+        self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2],
+                                 se_list[2], stride_list[2], dilation_list[2])
+        self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1],
+                                 se_list[1], stride_list[1], dilation_list[1])
+        self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0],
+                                 se_list[0], stride_list[0], dilation_list[0])
+
+        # Center
+        self.center = NLFPN(channel_list[0], True)
+
+        # Decoder
+        self.decoder3 = Cat(channel_list[0],
+                            channel_list[1],
+                            channel_list[1],
+                            upsample=pool_list[0])
+        self.decoder2 = Cat(channel_list[1],
+                            channel_list[2],
+                            channel_list[2],
+                            upsample=pool_list[1])
+        self.decoder1 = Cat(channel_list[2],
+                            channel_list[3],
+                            channel_list[3],
+                            upsample=pool_list[2])
+
+        self.df1 = DFModule(channel_list[3], channel_list[3], True)
+        self.df2 = DFModule(channel_list[2], channel_list[2], True)
+        self.df3 = DFModule(channel_list[1], channel_list[1], True)
+        self.df4 = DFModule(channel_list[0], channel_list[0], True)
+
+        self.catc3 = Cat(channel_list[0],
+                         channel_list[1],
+                         channel_list[1],
+                         upsample=pool_list[0])
+        self.catc2 = Cat(channel_list[1],
+                         channel_list[2],
+                         channel_list[2],
+                         upsample=pool_list[1])
+        self.catc1 = Cat(channel_list[2],
+                         channel_list[3],
+                         channel_list[3],
+                         upsample=pool_list[2])
+
+        self.upsample_x2 = nn.Sequential(
+            nn.Conv2D(
+                channel_list[3], 8, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2D(
+                8, momentum=bn_mom),
+            nn.ReLU(),
+            nn.UpsamplingBilinear2D(scale_factor=2))
+
+        self.conv_out = nn.Conv2D(
+            8, num_classes, kernel_size=3, stride=1, padding=1)
+        self.conv_out_class = nn.Conv2D(
+            channel_list[3], 1, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, t1, t2):
+        e1_1 = self.block1(t1)
+        e2_1 = self.block2(e1_1)
+        e3_1 = self.block3(e2_1)
+        y1 = self.block4(e3_1)
+
+        e1_2 = self.block1(t2)
+        e2_2 = self.block2(e1_2)
+        e3_2 = self.block3(e2_2)
+        y2 = self.block4(e3_2)
+
+        y1 = self.center(y1)
+        y2 = self.center(y2)
+        c = self.df4(y1, y2)
+
+        y1 = self.decoder3(y1, e3_1)
+        y2 = self.decoder3(y2, e3_2)
+        c = self.catc3(c, self.df3(y1, y2))
+
+        y1 = self.decoder2(y1, e2_1)
+        y2 = self.decoder2(y2, e2_2)
+        c = self.catc2(c, self.df2(y1, y2))
+
+        y1 = self.decoder1(y1, e1_1)
+        y2 = self.decoder1(y2, e1_2)
+
+        c = self.catc1(c, self.df1(y1, y2))
+        y = self.conv_out(self.upsample_x2(c))
+
+        if self.training:
+            y1 = self.conv_out_class(y1)
+            y2 = self.conv_out_class(y2)
+            return [y, [y1, y2]]
+        else:
+            return [y]

+ 15 - 0
paddlers/rs_models/cd/losses/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .fccdn_loss import fccdn_ssl_loss

+ 170 - 0
paddlers/rs_models/cd/losses/fccdn_loss.py

@@ -0,0 +1,170 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class DiceLoss(nn.Layer):
+    def __init__(self, batch=True):
+        super(DiceLoss, self).__init__()
+        self.batch = batch
+
+    def soft_dice_coeff(self, y_pred, y_true):
+        smooth = 0.00001
+        if self.batch:
+            i = paddle.sum(y_true)
+            j = paddle.sum(y_pred)
+            intersection = paddle.sum(y_true * y_pred)
+        else:
+            i = y_true.sum(1).sum(1).sum(1)
+            j = y_pred.sum(1).sum(1).sum(1)
+            intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
+        score = (2. * intersection + smooth) / (i + j + smooth)
+        return score.mean()
+
+    def soft_dice_loss(self, y_pred, y_true):
+        loss = 1 - self.soft_dice_coeff(y_pred, y_true)
+        return loss
+
+    def forward(self, y_pred, y_true):
+        return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true)
+
+
+class MultiClassDiceLoss(nn.Layer):
+    def __init__(
+            self,
+            weight,
+            batch=True,
+            ignore_index=-1,
+            do_softmax=False,
+            **kwargs, ):
+        super(MultiClassDiceLoss, self).__init__()
+        self.ignore_index = ignore_index
+        self.weight = weight
+        self.do_softmax = do_softmax
+        self.binary_diceloss = DiceLoss(batch)
+
+    def forward(self, y_pred, y_true):
+        if self.do_softmax:
+            y_pred = paddle.nn.functional.softmax(y_pred, axis=1)
+        y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2)
+        total_loss = 0.0
+        tmp_i = 0.0
+        for i in range(y_pred.shape[1]):
+            if i != self.ignore_index:
+                diceloss = self.binary_diceloss(y_pred[:, i, :, :],
+                                                y_true[:, i, :, :])
+                total_loss += paddle.multiply(diceloss, self.weight[i])
+                tmp_i += 1.0
+        return total_loss / tmp_i
+
+
+class DiceBCELoss(nn.Layer):
+    """Binary change detection task loss"""
+
+    def __init__(self):
+        super(DiceBCELoss, self).__init__()
+        self.bce_loss = nn.BCELoss()
+        self.binnary_dice = DiceLoss()
+
+    def forward(self, scores, labels, do_sigmoid=True):
+        if len(scores.shape) > 3:
+            scores = scores.squeeze(1)
+        if len(labels.shape) > 3:
+            labels = labels.squeeze(1)
+        if do_sigmoid:
+            scores = paddle.nn.functional.sigmoid(scores.clone())
+        diceloss = self.binnary_dice(scores, labels)
+        bceloss = self.bce_loss(scores, labels)
+        return diceloss + bceloss
+
+
+class McDiceBCELoss(nn.Layer):
+    """Multi-class change detection task loss"""
+
+    def __init__(self, weight, do_sigmoid=True):
+        super(McDiceBCELoss, self).__init__()
+        self.ce_loss = nn.CrossEntropyLoss(weight)
+        self.dice = MultiClassDiceLoss(weight, do_sigmoid)
+
+    def forward(self, scores, labels):
+        if len(scores.shape) < 4:
+            scores = scores.unsqueeze(1)
+        if len(labels.shape) < 4:
+            labels = labels.unsqueeze(1)
+        diceloss = self.dice(scores, labels)
+        bceloss = self.ce_loss(scores, labels)
+        return diceloss + bceloss
+
+
+def fccdn_ssl_loss(logits_list, labels):
+    """
+    Self-supervised learning loss for change detection.
+
+    The original article refers to
+        Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
+        (https://arxiv.org/pdf/2105.10860.pdf).
+        
+    Args:
+        logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases.
+        labels (paddle.Tensor): Binary change labels.
+    """
+
+    # Create loss
+    criterion_ssl = DiceBCELoss()
+
+    # Get downsampled change map
+    h, w = logits_list[0].shape[-2], logits_list[0].shape[-1]
+    labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w])
+    labels_type = str(labels_downsample.dtype)
+    assert "int" in labels_type or "bool" in labels_type,\
+        f"Expected dtype of labels to be int or bool, but got {labels_type}"
+
+    # Seg map
+    out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone()
+    out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone()
+    out3 = out1.clone()
+    out4 = out2.clone()
+
+    out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1)
+    out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2)
+    out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3)
+    out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4)
+
+    pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5,
+                                     paddle.zeros_like(out1),
+                                     paddle.ones_like(out1))
+    pred_seg_post_tmp1 = paddle.where(out2 <= 0.5,
+                                      paddle.zeros_like(out2),
+                                      paddle.ones_like(out2))
+
+    pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5,
+                                     paddle.zeros_like(out3),
+                                     paddle.ones_like(out3))
+    pred_seg_post_tmp2 = paddle.where(out4 <= 0.5,
+                                      paddle.zeros_like(out4),
+                                      paddle.ones_like(out4))
+
+    # Seg loss
+    labels_downsample = labels_downsample.astype(paddle.float32)
+    loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False)
+    loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False)
+    loss_aux += 0.2 * criterion_ssl(
+        out3, labels_downsample - pred_seg_post_tmp2, False)
+    loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
+                                    False)
+
+    return loss_aux

+ 30 - 2
paddlers/tasks/change_detector.py

@@ -37,7 +37,7 @@ from .utils import seg_metrics as metrics
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
-    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer"
+    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer", "FCCDN"
 ]
 
 
@@ -1055,7 +1055,7 @@ class ChangeStar(BaseChangeDetector):
         if self.use_mixed_loss is False:
             return {
                 # XXX: make sure the shallow copy works correctly here.
-                'types': [seglosses.CrossEntropyLoss()] * 4,
+                'types': [seg_losses.CrossEntropyLoss()] * 4,
                 'coef': [1.0] * 4
             }
         else:
@@ -1082,3 +1082,31 @@ class ChangeFormer(BaseChangeDetector):
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             **params)
+
+
+class FCCDN(BaseChangeDetector):
+    def __init__(self,
+                 in_channels=3,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
+        params.update({'in_channels': in_channels})
+        super(FCCDN, self).__init__(
+            model_name='FCCDN',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)
+
+    def default_loss(self):
+        if self.use_mixed_loss is False:
+            return {
+                'types':
+                [seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss],
+                'coef': [1.0, 1.0]
+            }
+        else:
+            raise ValueError(
+                f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
+            )

+ 13 - 0
test_tipc/configs/cd/fccdn/fccdn.yaml

@@ -0,0 +1,13 @@
+# Basic configurations of FCCDN
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/fccdn/
+
+model: !Node
+       type: FCCDN
+
+learning_rate: 0.07
+lr_decay_power: 0.6
+log_interval_steps: 100
+save_interval_epochs: 3

+ 53 - 0
test_tipc/configs/cd/fccdn/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:cd:fccdn
+python:python
+gpu_list:0
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=15|lite_train_whole_infer=15|whole_train_whole_infer=15
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/fccdn/fccdn.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[1,3,256,256]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:fccdn
+null:null

+ 94 - 0
tutorials/train/change_detection/fccdn.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+
+# 变化检测模型FCCDN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/airchange/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/airchange/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/airchange/eval.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/fccdn/'
+
+# 下载和解压AirChange数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 随机裁剪
+    T.RandomCrop(
+        # 裁剪区域将被缩放到256x256
+        crop_size=256,
+        # 裁剪区域的横纵比在0.5-2之间变动
+        aspect_ratio=[0.5, 2.0],
+        # 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5
+        scaling=[0.2, 1.0]),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    label_list=None,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+eval_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    label_list=None,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+# 使用默认参数构建FCCDN模型
+# 目前已支持的模型及模型输入参数请参考:
+# https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
+model = pdrs.tasks.cd.FCCDN()
+
+# 执行模型训练
+model.train(
+    num_epochs=5,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=2,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)