Browse Source

[Feature] Add STANet

Bobholamovic 3 years ago
parent
commit
91c64d923a

+ 2 - 1
paddlers/models/cd/models/__init__.py

@@ -15,4 +15,5 @@
 from .cdnet import CDNet
 from .unet_ef import UNetEarlyFusion
 from .unet_siamconc import UNetSiamConc
-from .unet_siamdiff import UNetSiamDiff
+from .unet_siamdiff import UNetSiamDiff
+from .stanet import STANet

+ 13 - 0
paddlers/models/cd/models/backbones/__init__.py

@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 358 - 0
paddlers/models/cd/models/backbones/resnet.py

@@ -0,0 +1,358 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py
+## Original head information
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+from paddle.utils.download import get_weights_path_from_url
+
+__all__ = []
+
+model_urls = {
+    'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
+                 'cf548f46534aa3560945be4b95cd11c4'),
+    'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
+                 '8d2275cf8706028345f78ac0e1d31969'),
+    'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
+                 'ca6f485ee1ab0492d38f323885b0ad80'),
+    'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
+                  '02f35f034ca3858e1e54d4036443c92d'),
+    'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
+                  '7ad16a2f1e7333859ff986138630fd7a'),
+}
+
+
+class BasicBlock(nn.Layer):
+    expansion = 1
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 downsample=None,
+                 groups=1,
+                 base_width=64,
+                 dilation=1,
+                 norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2D
+
+        if dilation > 1:
+            raise NotImplementedError(
+                "Dilation > 1 not supported in BasicBlock")
+
+        self.conv1 = nn.Conv2D(
+            inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU()
+        self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class BottleneckBlock(nn.Layer):
+
+    expansion = 4
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 downsample=None,
+                 groups=1,
+                 base_width=64,
+                 dilation=1,
+                 norm_layer=None):
+        super(BottleneckBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2D
+        width = int(planes * (base_width / 64.)) * groups
+
+        self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
+        self.bn1 = norm_layer(width)
+
+        self.conv2 = nn.Conv2D(
+            width,
+            width,
+            3,
+            padding=dilation,
+            stride=stride,
+            groups=groups,
+            dilation=dilation,
+            bias_attr=False)
+        self.bn2 = norm_layer(width)
+
+        self.conv3 = nn.Conv2D(
+            width, planes * self.expansion, 1, bias_attr=False)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU()
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Layer):
+    """ResNet model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+    Args:
+        Block (BasicBlock|BottleneckBlock): block module of model.
+        depth (int): layers of resnet, default: 50.
+        num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer 
+                            will not be defined. Default: 1000.
+        with_pool (bool): use pool before the last fc layer or not. Default: True.
+    Examples:
+        .. code-block:: python
+            from paddle.vision.models import ResNet
+            from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
+            resnet50 = ResNet(BottleneckBlock, 50)
+            resnet18 = ResNet(BasicBlock, 18)
+    """
+
+    def __init__(self, block, depth, num_classes=1000, with_pool=True, strides=(1,1,2,2,2), norm_layer=None):
+        super(ResNet, self).__init__()
+        layer_cfg = {
+            18: [2, 2, 2, 2],
+            34: [3, 4, 6, 3],
+            50: [3, 4, 6, 3],
+            101: [3, 4, 23, 3],
+            152: [3, 8, 36, 3]
+        }
+        layers = layer_cfg[depth]
+        self.num_classes = num_classes
+        self.with_pool = with_pool
+        self._norm_layer = nn.BatchNorm2D if norm_layer is None else norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+
+        self.conv1 = nn.Conv2D(
+            3,
+            self.inplanes,
+            kernel_size=7,
+            stride=strides[0],
+            padding=3,
+            bias_attr=False)
+        self.bn1 = self._norm_layer(self.inplanes)
+        self.relu = nn.ReLU()
+        self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[1])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4])
+        if with_pool:
+            self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
+
+        if num_classes > 0:
+            self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2D(
+                    self.inplanes,
+                    planes * block.expansion,
+                    1,
+                    stride=stride,
+                    bias_attr=False),
+                norm_layer(planes * block.expansion), )
+
+        layers = []
+        layers.append(
+            block(self.inplanes, planes, stride, downsample, 1, 64,
+                  previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        if self.with_pool:
+            x = self.avgpool(x)
+
+        if self.num_classes > 0:
+            x = paddle.flatten(x, 1)
+            x = self.fc(x)
+
+        return x
+
+
+def _resnet(arch, Block, depth, pretrained, **kwargs):
+    model = ResNet(Block, depth, **kwargs)
+    if pretrained:
+        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
+            arch)
+        weight_path = get_weights_path_from_url(model_urls[arch][0],
+                                                model_urls[arch][1])
+
+        param = paddle.load(weight_path)
+        model.set_dict(param)
+
+    return model
+
+
+def resnet18(pretrained=False, **kwargs):
+    """ResNet 18-layer model
+    
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    Examples:
+        .. code-block:: python
+            from paddle.vision.models import resnet18
+            # build model
+            model = resnet18()
+            # build model and load imagenet pretrained weight
+            # model = resnet18(pretrained=True)
+    """
+    return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
+
+
+def resnet34(pretrained=False, **kwargs):
+    """ResNet 34-layer model
+    
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    
+    Examples:
+        .. code-block:: python
+            from paddle.vision.models import resnet34
+            # build model
+            model = resnet34()
+            # build model and load imagenet pretrained weight
+            # model = resnet34(pretrained=True)
+    """
+    return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
+
+
+def resnet50(pretrained=False, **kwargs):
+    """ResNet 50-layer model
+    
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    Examples:
+        .. code-block:: python
+            from paddle.vision.models import resnet50
+            # build model
+            model = resnet50()
+            # build model and load imagenet pretrained weight
+            # model = resnet50(pretrained=True)
+    """
+    return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
+
+
+def resnet101(pretrained=False, **kwargs):
+    """ResNet 101-layer model
+    
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    Examples:
+        .. code-block:: python
+            from paddle.vision.models import resnet101
+            # build model
+            model = resnet101()
+            # build model and load imagenet pretrained weight
+            # model = resnet101(pretrained=True)
+    """
+    return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
+
+
+def resnet152(pretrained=False, **kwargs):
+    """ResNet 152-layer model
+    
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    Examples:
+        .. code-block:: python
+            from paddle.vision.models import resnet152
+            # build model
+            model = resnet152()
+            # build model and load imagenet pretrained weight
+            # model = resnet152(pretrained=True)
+    """
+    return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)

+ 2 - 1
paddlers/models/cd/models/layers/blocks.py

@@ -19,7 +19,8 @@ __all__ = [
     'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7', 
     'MaxPool2x2', 'MaxUnPool2x2', 
     'ConvTransposed3x3',
-    'Identity'
+    'Identity',
+    'get_norm_layer', 'get_act_layer'
 ]
 
 

+ 297 - 0
paddlers/models/cd/models/stanet.py

@@ -0,0 +1,297 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+from .backbones import resnet
+from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
+from .param_init import KaimingInitMixin
+
+
+class STANet(nn.Layer):
+    """
+    The STANet implementation based on PaddlePaddle.
+
+    The original article refers to
+    H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
+    (https://www.mdpi.com/2072-4292/12/10/1662)
+
+    Note that this implementation differs from the original work in two aspects:
+    1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
+    2. A classification head is used in place of the original metric learning-based head to stablize the training process.
+
+    Args:
+        in_channels (int): The number of bands of the input images.
+        num_classes (int): The number of target classes.
+        att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
+        ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values 
+            greater than 1, the input features will first be processed by an average pooling layer with the kernel size of 
+            `ds_factor`, before being used to calculate the attention scores. Default: 1.
+
+    Raises:
+        ValueError: When `att_type` has an illeagal value (unsupported attention type).
+    """
+    def __init__(
+        self, 
+        in_channels, 
+        num_classes, 
+        att_type='BAM', 
+        ds_factor=1
+    ):
+        super().__init__()
+
+        WIDTH = 64
+
+        self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
+        self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor)
+        self.conv_out = nn.Sequential(
+            Conv3x3(WIDTH, WIDTH, norm=True, act=True),
+            Conv3x3(WIDTH, num_classes)
+        )
+
+        self.init_weight()
+
+    def forward(self, t1, t2):
+        f1 = self.extract(t1)
+        f2 = self.extract(t2)
+
+        f1, f2 = self.attend(f1, f2)
+
+        y = paddle.abs(f1- f2)
+        y = F.interpolate(y, size=t1.shape[2:], mode='bilinear', align_corners=True)
+
+        pred = self.conv_out(y)
+        return pred,
+
+    def init_weight(self):
+        # Do nothing here as the encoder and decoder weights have already been initialized.
+        # Note however that currently self.attend and self.conv_out use the default initilization method.
+        pass
+
+
+def build_feat_extractor(in_ch, width):
+    return nn.Sequential(
+        Backbone(in_ch, 'resnet18'),
+        Decoder(width)
+    )
+
+
+def build_sta_module(in_ch, att_type, ds):
+    if att_type == 'BAM':
+        return Attention(BAM(in_ch, ds))
+    elif att_type == 'PAM':
+        return Attention(PAM(in_ch, ds))
+    else:
+        raise ValueError
+
+
+class Backbone(nn.Layer, KaimingInitMixin):
+    def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)):
+        super().__init__()
+
+        if arch == 'resnet18':
+            self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
+        elif arch == 'resnet34':
+            self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
+        elif arch == 'resnet50':
+            self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
+        else:
+            raise ValueError
+
+        self._trim_resnet()
+
+        if in_ch != 3:
+            self.resnet.conv1 = nn.Conv2D(
+                in_ch, 
+                64,
+                kernel_size=7,
+                stride=strides[0],
+                padding=3,
+                bias_attr=False
+            )
+
+        if not pretrained:
+            self.init_weight()
+
+    def forward(self, x):
+        x = self.resnet.conv1(x)
+        x = self.resnet.bn1(x)
+        x = self.resnet.relu(x)
+        x = self.resnet.maxpool(x)
+
+        x1 = self.resnet.layer1(x)
+        x2 = self.resnet.layer2(x1)
+        x3 = self.resnet.layer3(x2)
+        x4 = self.resnet.layer4(x3)
+
+        return x1, x2, x3, x4
+
+    def _trim_resnet(self):
+        self.resnet.avgpool = Identity()
+        self.resnet.fc = Identity()
+
+
+class Decoder(nn.Layer, KaimingInitMixin):
+    def __init__(self, f_ch):
+        super().__init__()
+        self.dr1 = Conv1x1(64, 96, norm=True, act=True)
+        self.dr2 = Conv1x1(128, 96, norm=True, act=True)
+        self.dr3 = Conv1x1(256, 96, norm=True, act=True)
+        self.dr4 = Conv1x1(512, 96, norm=True, act=True)
+        self.conv_out = nn.Sequential(
+            Conv3x3(384, 256, norm=True, act=True),
+            nn.Dropout(0.5),
+            Conv1x1(256, f_ch, norm=True, act=True)
+        )
+
+        self.init_weight()
+
+    def forward(self, feats):
+        f1 = self.dr1(feats[0])
+        f2 = self.dr2(feats[1])
+        f3 = self.dr3(feats[2])
+        f4 = self.dr4(feats[3])
+
+        f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=True)
+        f3 = F.interpolate(f3, size=f1.shape[2:], mode='bilinear', align_corners=True)
+        f4 = F.interpolate(f4, size=f1.shape[2:], mode='bilinear', align_corners=True)
+
+        x = paddle.concat([f1, f2, f3, f4], axis=1)
+        y = self.conv_out(x)
+
+        return y
+
+
+class BAM(nn.Layer):
+    def __init__(self, in_ch, ds):
+        super().__init__()
+
+        self.ds = ds
+        self.pool = nn.AvgPool2D(self.ds)
+
+        self.val_ch = in_ch
+        self.key_ch = in_ch // 8
+        self.conv_q = Conv1x1(in_ch, self.key_ch)
+        self.conv_k = Conv1x1(in_ch, self.key_ch)
+        self.conv_v = Conv1x1(in_ch, self.val_ch)
+
+        self.softmax = nn.Softmax(axis=-1)
+
+    def forward(self, x):
+        x = x.flatten(-2)
+        x_rs = self.pool(x)
+        
+        b, c, h, w = x_rs.shape
+        query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
+        key = self.conv_k(x_rs).reshape((b,-1,h*w))
+        energy = paddle.bmm(query, key)
+        energy = (self.key_ch**(-0.5)) * energy
+
+        attention = self.softmax(energy)
+
+        value = self.conv_v(x_rs).reshape((b,-1,w*h))
+
+        out = paddle.bmm(value, attention.transpose((0,2,1)))
+        out = out.reshape((b,c,h,w))
+
+        out = F.interpolate(out, scale_factor=self.ds)
+        out = out + x
+        return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
+
+
+class PAMBlock(nn.Layer):
+    def __init__(self, in_ch, scale=1, ds=1):
+        super().__init__()
+
+        self.scale = scale
+        self.ds = ds
+        self.pool = nn.AvgPool2D(self.ds)
+
+        self.val_ch = in_ch
+        self.key_ch = in_ch // 8
+        self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)
+        self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)
+        self.conv_v = Conv1x1(in_ch, self.val_ch)
+
+    def forward(self, x):
+        x_rs = self.pool(x)
+
+        # Get query, key, and value.
+        query = self.conv_q(x_rs)
+        key = self.conv_k(x_rs)
+        value = self.conv_v(x_rs)
+        
+        # Split the whole image into subregions.
+        b, c, h, w = x_rs.shape
+        query = self._split_subregions(query)
+        key = self._split_subregions(key)
+        value = self._split_subregions(value)
+        
+        # Perform subregion-wise attention.
+        out = self._attend(query, key, value)
+
+        # Stack subregions to reconstruct the whole image.
+        out = self._recons_whole(out, b, c, h, w)
+        out = F.interpolate(out, scale_factor=self.ds)
+        return out
+
+    def _attend(self, query, key, value):
+        energy = paddle.bmm(query.transpose((0,2,1)), key)  # batch matrix multiplication
+        energy = (self.key_ch**(-0.5)) * energy
+        attention = F.softmax(energy, axis=-1)
+        out = paddle.bmm(value, attention.transpose((0,2,1)))
+        return out
+
+    def _split_subregions(self, x):
+        b, c, h, w = x.shape
+        assert h % self.scale == 0 and w % self.scale == 0
+        x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
+        x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))
+        return x
+
+    def _recons_whole(self, x, b, c, h, w):
+        x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale))
+        x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w))
+        return x
+
+
+class PAM(nn.Layer):
+    def __init__(self, in_ch, ds, scales=(1,2,4,8)):
+        super().__init__()
+
+        self.stages = nn.LayerList([
+            PAMBlock(in_ch, scale=s, ds=ds)
+            for s in scales
+        ])
+        self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False)
+
+    def forward(self, x):
+        x = x.flatten(-2)
+        res = [stage(x) for stage in self.stages]
+        out = self.conv_out(paddle.concat(res, axis=1))
+        return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
+
+
+class Attention(nn.Layer):
+    def __init__(self, att):
+        super().__init__()
+        self.att = att
+
+    def forward(self, x1, x2):
+        x = paddle.stack([x1, x2], axis=-1)
+        y = self.att(x)
+        return y[...,0], y[...,1]

+ 21 - 1
paddlers/tasks/changedetector.py

@@ -31,7 +31,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.transforms import ImgDecoder, Resize
 import paddlers.models.cd as cd
 
-__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff"]
+__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet"]
 
 
 class BaseChangeDetector(BaseModel):
@@ -716,4 +716,24 @@ class UNetSiamDiff(BaseChangeDetector):
             model_name='UNetSiamDiff',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
+            **params)
+
+
+class STANet(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 in_channels=3,
+                 att_type='BAM',
+                 ds_factor=1,
+                 **params):
+        params.update({
+            'in_channels': in_channels,
+            'att_type': att_type,
+            'ds_factor': ds_factor
+        })
+        super(STANet, self).__init__(
+            model_name='STANet',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
             **params)