Просмотр исходного кода

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleRS into update_docs

Bobholamovic 2 лет назад
Родитель
Сommit
42d795b9a2

+ 1 - 0
paddlers/rs_models/cd/__init__.py

@@ -22,3 +22,4 @@ from .changestar import ChangeStar
 from .fc_ef import FCEarlyFusion
 from .fc_siam_conc import FCSiamConc
 from .fc_siam_diff import FCSiamDiff
+from .changeformer import ChangeFormer

+ 1001 - 0
paddlers/rs_models/cd/changeformer.py

@@ -0,0 +1,1001 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+import math
+from functools import partial
+
+import paddle as pd
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .layers.pd_timm import DropPath, to_2tuple
+
+
+def calc_product(*args):
+    if len(args) < 1:
+        raise ValueError
+    ret = args[0]
+    for arg in args[1:]:
+        ret *= arg
+    return ret
+
+
+class ConvBlock(pd.nn.Layer):
+    def __init__(self,
+                 input_size,
+                 output_size,
+                 kernel_size=3,
+                 stride=1,
+                 padding=1,
+                 bias=True,
+                 activation='prelu',
+                 norm=None):
+        super(ConvBlock, self).__init__()
+        self.conv = pd.nn.Conv2D(
+            input_size,
+            output_size,
+            kernel_size,
+            stride,
+            padding,
+            bias_attr=bias)
+
+        self.norm = norm
+        if self.norm == 'batch':
+            self.bn = pd.nn.BatchNorm2D(output_size)
+        elif self.norm == 'instance':
+            self.bn = pd.nn.InstanceNorm2D(output_size)
+
+        self.activation = activation
+        if self.activation == 'relu':
+            self.act = pd.nn.ReLU(True)
+        elif self.activation == 'prelu':
+            self.act = pd.nn.PReLU()
+        elif self.activation == 'lrelu':
+            self.act = pd.nn.LeakyReLU(0.2, True)
+        elif self.activation == 'tanh':
+            self.act = pd.nn.Tanh()
+        elif self.activation == 'sigmoid':
+            self.act = pd.nn.Sigmoid()
+
+    def forward(self, x):
+        if self.norm is not None:
+            out = self.bn(self.conv(x))
+        else:
+            out = self.conv(x)
+
+        if self.activation != 'no':
+            return self.act(out)
+        else:
+            return out
+
+
+class DeconvBlock(pd.nn.Layer):
+    def __init__(self,
+                 input_size,
+                 output_size,
+                 kernel_size=4,
+                 stride=2,
+                 padding=1,
+                 bias=True,
+                 activation='prelu',
+                 norm=None):
+        super(DeconvBlock, self).__init__()
+        self.deconv = pd.nn.Conv2DTranspose(
+            input_size,
+            output_size,
+            kernel_size,
+            stride,
+            padding,
+            bias_attr=bias)
+
+        self.norm = norm
+        if self.norm == 'batch':
+            self.bn = pd.nn.BatchNorm2D(output_size)
+        elif self.norm == 'instance':
+            self.bn = pd.nn.InstanceNorm2D(output_size)
+
+        self.activation = activation
+        if self.activation == 'relu':
+            self.act = pd.nn.ReLU(True)
+        elif self.activation == 'prelu':
+            self.act = pd.nn.PReLU()
+        elif self.activation == 'lrelu':
+            self.act = pd.nn.LeakyReLU(0.2, True)
+        elif self.activation == 'tanh':
+            self.act = pd.nn.Tanh()
+        elif self.activation == 'sigmoid':
+            self.act = pd.nn.Sigmoid()
+
+    def forward(self, x):
+        if self.norm is not None:
+            out = self.bn(self.deconv(x))
+        else:
+            out = self.deconv(x)
+
+        if self.activation is not None:
+            return self.act(out)
+        else:
+            return out
+
+
+class ConvLayer(nn.Layer):
+    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
+        super(ConvLayer, self).__init__()
+        self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size, stride,
+                                padding)
+
+    def forward(self, x):
+        out = self.conv2d(x)
+        return out
+
+
+class UpsampleConvLayer(pd.nn.Layer):
+    def __init__(self, in_channels, out_channels, kernel_size, stride):
+        super(UpsampleConvLayer, self).__init__()
+        self.conv2d = nn.Conv2DTranspose(
+            in_channels, out_channels, kernel_size, stride=stride, padding=1)
+
+    def forward(self, x):
+        out = self.conv2d(x)
+        return out
+
+
+class ResidualBlock(pd.nn.Layer):
+    def __init__(self, channels):
+        super(ResidualBlock, self).__init__()
+        self.conv1 = ConvLayer(
+            channels, channels, kernel_size=3, stride=1, padding=1)
+        self.conv2 = ConvLayer(
+            channels, channels, kernel_size=3, stride=1, padding=1)
+        self.relu = nn.ReLU()
+
+    def forward(self, x):
+        residual = x
+        out = self.relu(self.conv1(x))
+        out = self.conv2(out) * 0.1
+        out = pd.add(out, residual)
+        return out
+
+
+class ChangeFormer(nn.Layer):
+    """
+    The ChangeFormer implementation based on PaddlePaddle.
+
+    The original article refers to
+        Wele Gedara Chaminda Bandara, Vishal M. Patel., "A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION"
+        (https://arxiv.org/pdf/2201.01293.pdf).
+
+    Args:
+        in_channels (int): Number of bands of the input images. Default: 3.
+        num_classes (int): Number of target classes. Default: 2.
+        decoder_softmax (bool, optional): Use softmax after decode or not. Default: False.
+        embed_dim (int, optional): Embedding dimension of each decoder head. Default: 256.
+    """
+
+    def __init__(self,
+                 in_channels=3,
+                 num_classes=2,
+                 decoder_softmax=False,
+                 embed_dim=256):
+        super(ChangeFormer, self).__init__()
+
+        # Transformer Encoder
+        self.embed_dims = [64, 128, 320, 512]
+        self.depths = [3, 3, 4, 3]
+        self.embedding_dim = embed_dim
+        self.drop_rate = 0.1
+        self.attn_drop = 0.1
+        self.drop_path_rate = 0.1
+
+        self.Tenc_x2 = EncoderTransformer_v3(
+            img_size=256,
+            patch_size=7,
+            in_chans=in_channels,
+            num_classes=num_classes,
+            embed_dims=self.embed_dims,
+            num_heads=[1, 2, 4, 8],
+            mlp_ratios=[4, 4, 4, 4],
+            qkv_bias=True,
+            qk_scale=None,
+            drop_rate=self.drop_rate,
+            attn_drop_rate=self.attn_drop,
+            drop_path_rate=self.drop_path_rate,
+            norm_layer=partial(
+                nn.LayerNorm, epsilon=1e-6),
+            depths=self.depths,
+            sr_ratios=[8, 4, 2, 1])
+
+        # Transformer Decoder
+        self.TDec_x2 = DecoderTransformer_v3(
+            input_transform='multiple_select',
+            in_index=[0, 1, 2, 3],
+            align_corners=False,
+            in_channels=self.embed_dims,
+            embedding_dim=self.embedding_dim,
+            output_nc=num_classes,
+            decoder_softmax=decoder_softmax,
+            feature_strides=[2, 4, 8, 16])
+
+    def forward(self, x1, x2):
+        [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
+
+        cp = self.TDec_x2(fx1, fx2)
+
+        return [cp]
+
+
+# Transormer Ecoder with x2, x4, x8, x16 scales
+class EncoderTransformer_v3(nn.Layer):
+    def __init__(self,
+                 img_size=256,
+                 patch_size=3,
+                 in_chans=3,
+                 num_classes=2,
+                 embed_dims=[32, 64, 128, 256],
+                 num_heads=[2, 2, 4, 8],
+                 mlp_ratios=[4, 4, 4, 4],
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 norm_layer=nn.LayerNorm,
+                 depths=[3, 3, 6, 18],
+                 sr_ratios=[8, 4, 2, 1]):
+        super().__init__()
+        self.num_classes = num_classes
+        self.depths = depths
+        self.embed_dims = embed_dims
+
+        # patch embedding definitions
+        self.patch_embed1 = OverlapPatchEmbed(
+            img_size=img_size,
+            patch_size=7,
+            stride=4,
+            in_chans=in_chans,
+            embed_dim=embed_dims[0])
+        self.patch_embed2 = OverlapPatchEmbed(
+            img_size=img_size // 4,
+            patch_size=patch_size,
+            stride=2,
+            in_chans=embed_dims[0],
+            embed_dim=embed_dims[1])
+        self.patch_embed3 = OverlapPatchEmbed(
+            img_size=img_size // 8,
+            patch_size=patch_size,
+            stride=2,
+            in_chans=embed_dims[1],
+            embed_dim=embed_dims[2])
+        self.patch_embed4 = OverlapPatchEmbed(
+            img_size=img_size // 16,
+            patch_size=patch_size,
+            stride=2,
+            in_chans=embed_dims[2],
+            embed_dim=embed_dims[3])
+
+        # Stage-1 (x1/4 scale)
+        dpr = [x.item() for x in pd.linspace(0, drop_path_rate, sum(depths))]
+        cur = 0
+        self.block1 = nn.LayerList([
+            Block(
+                dim=embed_dims[0],
+                num_heads=num_heads[0],
+                mlp_ratio=mlp_ratios[0],
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[cur + i],
+                norm_layer=norm_layer,
+                sr_ratio=sr_ratios[0]) for i in range(depths[0])
+        ])
+        self.norm1 = norm_layer(embed_dims[0])
+
+        # Stage-2 (x1/8 scale)
+        cur += depths[0]
+        self.block2 = nn.LayerList([
+            Block(
+                dim=embed_dims[1],
+                num_heads=num_heads[1],
+                mlp_ratio=mlp_ratios[1],
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[cur + i],
+                norm_layer=norm_layer,
+                sr_ratio=sr_ratios[1]) for i in range(depths[1])
+        ])
+        self.norm2 = norm_layer(embed_dims[1])
+
+        # Stage-3 (x1/16 scale)
+        cur += depths[1]
+        self.block3 = nn.LayerList([
+            Block(
+                dim=embed_dims[2],
+                num_heads=num_heads[2],
+                mlp_ratio=mlp_ratios[2],
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[cur + i],
+                norm_layer=norm_layer,
+                sr_ratio=sr_ratios[2]) for i in range(depths[2])
+        ])
+        self.norm3 = norm_layer(embed_dims[2])
+
+        # Stage-4 (x1/32 scale)
+        cur += depths[2]
+        self.block4 = nn.LayerList([
+            Block(
+                dim=embed_dims[3],
+                num_heads=num_heads[3],
+                mlp_ratio=mlp_ratios[3],
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[cur + i],
+                norm_layer=norm_layer,
+                sr_ratio=sr_ratios[3]) for i in range(depths[3])
+        ])
+        self.norm4 = norm_layer(embed_dims[3])
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
+            trunc_normal_op(m.weight)
+
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+
+        elif isinstance(m, nn.LayerNorm):
+            init_bias = nn.initializer.Constant(0)
+            init_bias(m.bias)
+
+            init_weight = nn.initializer.Constant(1.0)
+            init_weight(m.weight)
+
+        elif isinstance(m, nn.Conv2D):
+            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+            fan_out //= m._groups
+            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
+            init_weight(m.weight)
+
+            if m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+
+    def reset_drop_path(self, drop_path_rate):
+        dpr = [
+            x.item() for x in pd.linspace(0, drop_path_rate, sum(self.depths))
+        ]
+        cur = 0
+        for i in range(self.depths[0]):
+            self.block1[i].drop_path.drop_prob = dpr[cur + i]
+
+        cur += self.depths[0]
+        for i in range(self.depths[1]):
+            self.block2[i].drop_path.drop_prob = dpr[cur + i]
+
+        cur += self.depths[1]
+        for i in range(self.depths[2]):
+            self.block3[i].drop_path.drop_prob = dpr[cur + i]
+
+        cur += self.depths[2]
+        for i in range(self.depths[3]):
+            self.block4[i].drop_path.drop_prob = dpr[cur + i]
+
+    def forward_features(self, x):
+        B = x.shape[0]
+        outs = []
+
+        # stage 1
+        x1, H1, W1 = self.patch_embed1(x)
+        for i, blk in enumerate(self.block1):
+            x1 = blk(x1, H1, W1)
+        x1 = self.norm1(x1)
+        x1 = x1.reshape(
+            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
+                [0, 3, 1, 2])
+        outs.append(x1)
+
+        # stage 2
+        x1, H1, W1 = self.patch_embed2(x1)
+        for i, blk in enumerate(self.block2):
+            x1 = blk(x1, H1, W1)
+        x1 = self.norm2(x1)
+        x1 = x1.reshape(
+            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
+                [0, 3, 1, 2])
+        outs.append(x1)
+
+        # stage 3
+        x1, H1, W1 = self.patch_embed3(x1)
+        for i, blk in enumerate(self.block3):
+            x1 = blk(x1, H1, W1)
+        x1 = self.norm3(x1)
+        x1 = x1.reshape(
+            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
+                [0, 3, 1, 2])
+        outs.append(x1)
+
+        # stage 4
+        x1, H1, W1 = self.patch_embed4(x1)
+        for i, blk in enumerate(self.block4):
+            x1 = blk(x1, H1, W1)
+        x1 = self.norm4(x1)
+        x1 = x1.reshape(
+            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
+                [0, 3, 1, 2])
+        outs.append(x1)
+        return outs
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+class DecoderTransformer_v3(nn.Layer):
+    """
+    Transformer Decoder
+    """
+
+    def __init__(self,
+                 input_transform='multiple_select',
+                 in_index=[0, 1, 2, 3],
+                 align_corners=True,
+                 in_channels=[32, 64, 128, 256],
+                 embedding_dim=64,
+                 output_nc=2,
+                 decoder_softmax=False,
+                 feature_strides=[2, 4, 8, 16]):
+        super(DecoderTransformer_v3, self).__init__()
+        # assert
+        assert len(feature_strides) == len(in_channels)
+        assert min(feature_strides) == feature_strides[0]
+
+        # settings
+        self.feature_strides = feature_strides
+        self.input_transform = input_transform
+        self.in_index = in_index
+        self.align_corners = align_corners
+        self.in_channels = in_channels
+        self.embedding_dim = embedding_dim
+        self.output_nc = output_nc
+        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
+
+        # MLP decoder heads
+        self.linear_c4 = MLP(input_dim=c4_in_channels,
+                             embed_dim=self.embedding_dim)
+        self.linear_c3 = MLP(input_dim=c3_in_channels,
+                             embed_dim=self.embedding_dim)
+        self.linear_c2 = MLP(input_dim=c2_in_channels,
+                             embed_dim=self.embedding_dim)
+        self.linear_c1 = MLP(input_dim=c1_in_channels,
+                             embed_dim=self.embedding_dim)
+
+        # convolutional Difference Layers
+        self.diff_c4 = conv_diff(
+            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
+        self.diff_c3 = conv_diff(
+            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
+        self.diff_c2 = conv_diff(
+            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
+        self.diff_c1 = conv_diff(
+            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
+
+        # taking outputs from middle of the encoder
+        self.make_pred_c4 = make_prediction(
+            in_channels=self.embedding_dim, out_channels=self.output_nc)
+        self.make_pred_c3 = make_prediction(
+            in_channels=self.embedding_dim, out_channels=self.output_nc)
+        self.make_pred_c2 = make_prediction(
+            in_channels=self.embedding_dim, out_channels=self.output_nc)
+        self.make_pred_c1 = make_prediction(
+            in_channels=self.embedding_dim, out_channels=self.output_nc)
+
+        # Final linear fusion layer
+        self.linear_fuse = nn.Sequential(
+            nn.Conv2D(
+                in_channels=self.embedding_dim * len(in_channels),
+                out_channels=self.embedding_dim,
+                kernel_size=1),
+            nn.BatchNorm2D(self.embedding_dim))
+
+        # Final predction head
+        self.convd2x = UpsampleConvLayer(
+            self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
+        self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
+        self.convd1x = UpsampleConvLayer(
+            self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
+        self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
+        self.change_probability = ConvLayer(
+            self.embedding_dim,
+            self.output_nc,
+            kernel_size=3,
+            stride=1,
+            padding=1)
+
+        # Final activation
+        self.output_softmax = decoder_softmax
+        self.active = nn.Sigmoid()
+
+    def _transform_inputs(self, inputs):
+        """
+        Transform inputs for decoder.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+        Returns:
+            Tensor: The transformed inputs
+        """
+
+        if self.input_transform == 'resize_concat':
+            inputs = [inputs[i] for i in self.in_index]
+            upsampled_inputs = [
+                resize(
+                    input=x,
+                    size=inputs[0].shape[2:],
+                    mode='bilinear',
+                    align_corners=self.align_corners) for x in inputs
+            ]
+            inputs = pd.concat(upsampled_inputs, dim=1)
+        elif self.input_transform == 'multiple_select':
+            inputs = [inputs[i] for i in self.in_index]
+        else:
+            inputs = inputs[self.in_index]
+
+        return inputs
+
+    def forward(self, inputs1, inputs2):
+        # Transforming encoder features (select layers)
+        x_1 = self._transform_inputs(inputs1)  # len=4, 1/2, 1/4, 1/8, 1/16
+        x_2 = self._transform_inputs(inputs2)  # len=4, 1/2, 1/4, 1/8, 1/16
+
+        # img1 and img2 features
+        c1_1, c2_1, c3_1, c4_1 = x_1
+        c1_2, c2_2, c3_2, c4_2 = x_2
+
+        ############## MLP decoder on C1-C4 ###########
+        n, _, h, w = c4_1.shape
+
+        outputs = []
+        # Stage 4: x1/32 scale
+        _c4_1 = self.linear_c4(c4_1).transpose([0, 2, 1])
+        _c4_1 = _c4_1.reshape([
+            n, calc_product(*_c4_1.shape[1:]) //
+            (c4_1.shape[2] * c4_1.shape[3]), c4_1.shape[2], c4_1.shape[3]
+        ])
+        _c4_2 = self.linear_c4(c4_2).transpose([0, 2, 1])
+        _c4_2 = _c4_2.reshape([
+            n, calc_product(*_c4_2.shape[1:]) //
+            (c4_2.shape[2] * c4_2.shape[3]), c4_2.shape[2], c4_2.shape[3]
+        ])
+        _c4 = self.diff_c4(pd.concat((_c4_1, _c4_2), axis=1))
+        p_c4 = self.make_pred_c4(_c4)
+        outputs.append(p_c4)
+        _c4_up = resize(
+            _c4, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
+
+        # Stage 3: x1/16 scale
+        _c3_1 = self.linear_c3(c3_1).transpose([0, 2, 1])
+        _c3_1 = _c3_1.reshape([
+            n, calc_product(*_c3_1.shape[1:]) //
+            (c3_1.shape[2] * c3_1.shape[3]), c3_1.shape[2], c3_1.shape[3]
+        ])
+        _c3_2 = self.linear_c3(c3_2).transpose([0, 2, 1])
+        _c3_2 = _c3_2.reshape([
+            n, calc_product(*_c3_2.shape[1:]) //
+            (c3_2.shape[2] * c3_2.shape[3]), c3_2.shape[2], c3_2.shape[3]
+        ])
+        _c3 = self.diff_c3(pd.concat((_c3_1, _c3_2), axis=1)) + \
+            F.interpolate(_c4, scale_factor=2, mode="bilinear")
+        p_c3 = self.make_pred_c3(_c3)
+        outputs.append(p_c3)
+        _c3_up = resize(
+            _c3, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
+
+        # Stage 2: x1/8 scale
+        _c2_1 = self.linear_c2(c2_1).transpose([0, 2, 1])
+        _c2_1 = _c2_1.reshape([
+            n, calc_product(*_c2_1.shape[1:]) //
+            (c2_1.shape[2] * c2_1.shape[3]), c2_1.shape[2], c2_1.shape[3]
+        ])
+        _c2_2 = self.linear_c2(c2_2).transpose([0, 2, 1])
+        _c2_2 = _c2_2.reshape([
+            n, calc_product(*_c2_2.shape[1:]) //
+            (c2_2.shape[2] * c2_2.shape[3]), c2_2.shape[2], c2_2.shape[3]
+        ])
+        _c2 = self.diff_c2(pd.concat((_c2_1, _c2_2), axis=1)) + \
+            F.interpolate(_c3, scale_factor=2, mode="bilinear")
+        p_c2 = self.make_pred_c2(_c2)
+        outputs.append(p_c2)
+        _c2_up = resize(
+            _c2, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
+
+        # Stage 1: x1/4 scale
+        _c1_1 = self.linear_c1(c1_1).transpose([0, 2, 1])
+        _c1_1 = _c1_1.reshape([
+            n, calc_product(*_c1_1.shape[1:]) //
+            (c1_1.shape[2] * c1_1.shape[3]), c1_1.shape[2], c1_1.shape[3]
+        ])
+        _c1_2 = self.linear_c1(c1_2).transpose([0, 2, 1])
+        _c1_2 = _c1_2.reshape([
+            n, calc_product(*_c1_2.shape[1:]) //
+            (c1_2.shape[2] * c1_2.shape[3]), c1_2.shape[2], c1_2.shape[3]
+        ])
+        _c1 = self.diff_c1(pd.concat((_c1_1, _c1_2), axis=1)) + \
+            F.interpolate(_c2, scale_factor=2, mode="bilinear")
+        p_c1 = self.make_pred_c1(_c1)
+        outputs.append(p_c1)
+
+        # Linear Fusion of difference image from all scales
+        _c = self.linear_fuse(pd.concat((_c4_up, _c3_up, _c2_up, _c1), axis=1))
+
+        # Upsampling x2 (x1/2 scale)
+        x = self.convd2x(_c)
+        # Residual block
+        x = self.dense_2x(x)
+        # Upsampling x2 (x1 scale)
+        x = self.convd1x(x)
+        # Residual block
+        x = self.dense_1x(x)
+
+        # Final prediction
+        cp = self.change_probability(x)
+
+        outputs.append(cp)
+
+        if self.output_softmax:
+            temp = outputs
+            outputs = []
+            for pred in temp:
+                outputs.append(self.active(pred))
+
+        return outputs[-1]
+
+
+class OverlapPatchEmbed(nn.Layer):
+    """ 
+    Image to Patch Embedding
+    """
+
+    def __init__(self,
+                 img_size=224,
+                 patch_size=7,
+                 stride=4,
+                 in_chans=3,
+                 embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.H, self.W = img_size[0] // patch_size[0], img_size[
+            1] // patch_size[1]
+        self.num_patches = self.H * self.W
+        self.proj = nn.Conv2D(
+            in_chans,
+            embed_dim,
+            kernel_size=patch_size,
+            stride=stride,
+            padding=(patch_size[0] // 2, patch_size[1] // 2))
+        self.norm = nn.LayerNorm(embed_dim)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
+            trunc_normal_op(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            init_bias = nn.initializer.Constant(0)
+            init_bias(m.bias)
+            init_weight = nn.initializer.Constant(1.0)
+            init_weight(m.weight)
+        elif isinstance(m, nn.Conv2D):
+            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+            fan_out //= m._groups
+            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
+            init_weight(m.weight)
+            if m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+
+    def forward(self, x):
+        x = self.proj(x)
+        _, _, H, W = x.shape
+        x = x.flatten(2).transpose([0, 2, 1])
+        x = self.norm(x)
+
+        return x, H, W
+
+
+def resize(input,
+           size=None,
+           scale_factor=None,
+           mode='nearest',
+           align_corners=None,
+           warning=True):
+    if warning:
+        if size is not None and align_corners:
+            input_h, input_w = tuple(int(x) for x in input.shape[2:])
+            output_h, output_w = tuple(int(x) for x in size)
+            if output_h > input_h or output_w > output_h:
+                if ((output_h > 1 and output_w > 1 and input_h > 1 and
+                     input_w > 1) and (output_h - 1) % (input_h - 1) and
+                    (output_w - 1) % (input_w - 1)):
+                    warnings.warn(
+                        f'When align_corners={align_corners}, '
+                        'the output would more aligned if '
+                        f'input size {(input_h, input_w)} is `x+1` and '
+                        f'out size {(output_h, output_w)} is `nx+1`')
+    return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class Mlp(nn.Layer):
+    def __init__(self,
+                 in_features,
+                 hidden_features=None,
+                 out_features=None,
+                 act_layer=nn.GELU,
+                 drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.dwconv = DWConv(hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
+            trunc_normal_op(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            init_bias = nn.initializer.Constant(0)
+            init_bias(m.bias)
+            init_weight = nn.initializer.Constant(1.0)
+            init_weight(m.weight)
+        elif isinstance(m, nn.Conv2D):
+            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+            fan_out //= m._groups
+            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
+            init_weight(m.weight)
+            if m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+
+    def forward(self, x, H, W):
+        x = self.fc1(x)
+        x = self.dwconv(x, H, W)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Layer):
+    def __init__(self,
+                 dim,
+                 num_heads=8,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 sr_ratio=1):
+        super().__init__()
+        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+        self.dim = dim
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
+        self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        self.sr_ratio = sr_ratio
+        if sr_ratio > 1:
+            self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+            self.norm = nn.LayerNorm(dim)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
+            trunc_normal_op(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            init_bias = nn.initializer.Constant(0)
+            init_bias(m.bias)
+            init_weight = nn.initializer.Constant(1.0)
+            init_weight(m.weight)
+        elif isinstance(m, nn.Conv2D):
+            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+            fan_out //= m._groups
+            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
+            init_weight(m.weight)
+            if m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+
+    def forward(self, x, H, W):
+        B, N, C = x.shape
+        q = self.q(x).reshape([B, N, self.num_heads,
+                               C // self.num_heads]).transpose([0, 2, 1, 3])
+
+        if self.sr_ratio > 1:
+            x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
+            x_ = self.sr(x_)
+            x_ = x_.reshape([B, C, calc_product(*x_.shape[2:])]).transpose(
+                [0, 2, 1])
+            x_ = self.norm(x_)
+            kv = self.kv(x_)
+            kv = kv.reshape([
+                B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
+                C // self.num_heads
+            ]).transpose([2, 0, 3, 1, 4])
+        else:
+            kv = self.kv(x)
+            kv = kv.reshape([
+                B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
+                C // self.num_heads
+            ]).transpose([2, 0, 3, 1, 4])
+        k, v = kv[0], kv[1]
+
+        attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
+        attn = F.softmax(attn, axis=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        return x
+
+
+class Block(nn.Layer):
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm,
+                 sr_ratio=1):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+            sr_ratio=sr_ratio)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity(
+        )
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim,
+                       hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer,
+                       drop=drop)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
+            trunc_normal_op(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            init_bias = nn.initializer.Constant(0)
+            init_bias(m.bias)
+            init_weight = nn.initializer.Constant(1.0)
+            init_weight(m.weight)
+        elif isinstance(m, nn.Conv2D):
+            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+            fan_out //= m._groups
+            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
+            init_weight(m.weight)
+            if m.bias is not None:
+                init_bias = nn.initializer.Constant(0)
+                init_bias(m.bias)
+
+    def forward(self, x, H, W):
+        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+        return x
+
+
+class DWConv(nn.Layer):
+    def __init__(self, dim=768):
+        super(DWConv, self).__init__()
+        self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)
+
+    def forward(self, x, H, W):
+        B, N, C = x.shape
+        x = x.transpose([0, 2, 1]).reshape([B, C, H, W])
+        x = self.dwconv(x)
+        x = x.flatten(2).transpose([0, 2, 1])
+
+        return x
+
+
+# Transformer Decoder
+class MLP(nn.Layer):
+    """
+    Linear Embedding
+    """
+
+    def __init__(self, input_dim=2048, embed_dim=768):
+        super().__init__()
+        self.proj = nn.Linear(input_dim, embed_dim)
+
+    def forward(self, x):
+        x = x.flatten(2).transpose([0, 2, 1])
+        x = self.proj(x)
+        return x
+
+
+# Difference Layer
+def conv_diff(in_channels, out_channels):
+    return nn.Sequential(
+        nn.Conv2D(
+            in_channels, out_channels, kernel_size=3, padding=1),
+        nn.ReLU(),
+        nn.BatchNorm2D(out_channels),
+        nn.Conv2D(
+            out_channels, out_channels, kernel_size=3, padding=1),
+        nn.ReLU())
+
+
+# Intermediate prediction Layer
+def make_prediction(in_channels, out_channels):
+    return nn.Sequential(
+        nn.Conv2D(
+            in_channels, out_channels, kernel_size=3, padding=1),
+        nn.ReLU(),
+        nn.BatchNorm2D(out_channels),
+        nn.Conv2D(
+            out_channels, out_channels, kernel_size=3, padding=1))

+ 71 - 0
paddlers/rs_models/cd/layers/pd_timm.py

@@ -0,0 +1,71 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License..
+
+from itertools import repeat
+import collections.abc
+
+import paddle
+import paddle.nn as nn
+"""
+Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth
+"""
+
+
+class DropPath(nn.Layer):
+    """DropPath class"""
+
+    def __init__(self, drop_prob=None):
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def drop_path(self, inputs):
+        """drop path op
+        Args:
+            input: tensor with arbitrary shape
+            drop_prob: float number of drop path probability, default: 0.0
+            training: bool, if current mode is training, default: False
+        Returns:
+            output: output tensor after drop path
+        """
+        # if prob is 0 or eval mode, return original input
+        if self.drop_prob == 0. or not self.training:
+            return inputs
+        keep_prob = 1 - self.drop_prob
+        keep_prob = paddle.to_tensor(keep_prob, dtype='float32')
+        shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1
+                                               )  # shape=(N, 1, 1, 1)
+        random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype)
+        random_tensor = random_tensor.floor()  # mask
+        output = inputs.divide(
+            keep_prob) * random_tensor  # divide to keep same output expectation
+        return output
+
+    def forward(self, inputs):
+        return self.drop_path(inputs)
+
+
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple

+ 21 - 1
paddlers/tasks/change_detector.py

@@ -36,7 +36,7 @@ from .utils import seg_metrics as metrics
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
-    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar"
+    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer"
 ]
 
 
@@ -1041,3 +1041,23 @@ class ChangeStar(BaseChangeDetector):
             raise ValueError(
                 f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
             )
+
+
+class ChangeFormer(BaseChangeDetector):
+    def __init__(self,
+                 in_channels=3,
+                 num_classes=2,
+                 decoder_softmax=False,
+                 embed_dim=256,
+                 use_mixed_loss=False,
+                 **params):
+        params.update({
+            'in_channels': in_channels,
+            'embed_dim': embed_dim,
+            'decoder_softmax': decoder_softmax
+        })
+        super(ChangeFormer, self).__init__(
+            model_name='ChangeFormer',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            **params)

+ 8 - 0
test_tipc/configs/cd/changeformer/changeformer.yaml

@@ -0,0 +1,8 @@
+# Basic configurations of ChangeFormer
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/changeformer/
+
+model: !Node
+       type: ChangeFormer

+ 53 - 0
test_tipc/configs/cd/changeformer/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:cd:changeformer
+python:python
+gpu_list:0|0,1
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/changeformer/changeformer.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[1,3,256,256]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:changeformer
+null:null

+ 14 - 1
tests/rs_models/test_cd_models.py

@@ -21,7 +21,8 @@ from rs_models.test_model import TestModel
 __all__ = [
     'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
     'TestDSIFNModel', 'TestFCEarlyFusionModel', 'TestFCSiamConcModel',
-    'TestFCSiamDiffModel', 'TestSNUNetModel', 'TestSTANetModel'
+    'TestFCSiamDiffModel', 'TestSNUNetModel', 'TestSTANetModel',
+    'TestChangeFormerModel'
 ]
 
 
@@ -215,6 +216,18 @@ class TestSTANetModel(TestCDModel):
         ]   # yapf: disable
 
 
+class TestChangeFormerModel(TestCDModel):
+    MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer
+
+    def set_specs(self):
+        base_spec = dict(in_channels=3, num_classes=2)
+        self.specs = [
+            base_spec,
+            dict(**base_spec, decoder_softmax=True),
+            dict(**base_spec, embed_dim=56)
+        ]   # yapf: disable
+
+
 # HACK:FIXME: We observe an OOM error when running TestSTANetModel.test_forward() on a Windows machine.
 # Currently, we do not perform this test.
 if platform.system() == 'Windows':

+ 94 - 0
tutorials/train/change_detection/changeformer.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+
+# 变化检测模型ChangeFormer训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/airchange/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/airchange/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/airchange/eval.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/changeformer/'
+
+# 下载和解压AirChange数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 随机裁剪
+    T.RandomCrop(
+        # 裁剪区域将被缩放到256x256
+        crop_size=256,
+        # 裁剪区域的横纵比在0.5-2之间变动
+        aspect_ratio=[0.5, 2.0],
+        # 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5
+        scaling=[0.2, 1.0]),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    label_list=None,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+eval_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    label_list=None,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+# 使用默认参数构建ChangeFormer模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
+model = pdrs.tasks.cd.ChangeFormer()
+
+# 执行模型训练
+model.train(
+    num_epochs=5,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=3,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)