Răsfoiți Sursa

[Re-request][Feature] Add ChangeStar model (#23)

Lin Manhui 3 ani în urmă
părinte
comite
099d5ac8c5

+ 1 - 0
paddlers/custom_models/cd/models/__init__.py

@@ -18,6 +18,7 @@ from .dsifn import DSIFN
 from .stanet import STANet
 from .snunet import SNUNet
 from .dsamnet import DSAMNet
+from .changestar import ChangeStar
 from .unet_ef import UNetEarlyFusion
 from .unet_siamconc import UNetSiamConc
 from .unet_siamdiff import UNetSiamDiff

+ 158 - 0
paddlers/custom_models/cd/models/changestar.py

@@ -0,0 +1,158 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddlers.datasets.cd_dataset import MaskType
+from paddlers.custom_models.seg.models import FarSeg
+from .layers import Conv3x3, Identity
+
+
+class _ChangeStarBase(nn.Layer):
+
+    USE_MULTITASK_DECODER = True
+    OUT_TYPES = (
+        MaskType.CD, 
+        MaskType.CD, 
+        MaskType.SEG_T1, 
+        MaskType.SEG_T2
+    )
+
+    def __init__(
+        self,
+        seg_model,
+        num_classes,
+        mid_channels,
+        inner_channels,
+        num_convs,
+        scale_factor
+    ):
+        super().__init__()
+
+        self.extract = seg_model
+        self.detect = ChangeMixin(
+            in_ch=mid_channels*2,
+            out_ch=num_classes,
+            mid_ch=inner_channels,
+            num_convs=num_convs,
+            scale_factor=scale_factor
+        )
+        self.segment = nn.Sequential(
+            Conv3x3(mid_channels, 2),
+            nn.UpsamplingBilinear2D(scale_factor=scale_factor)
+        )
+
+        self.init_weight()
+
+    def forward(self, t1, t2):
+        x1 = self.extract(t1)[0]
+        x2 = self.extract(t2)[0]
+        logit12, logit21 = self.detect(x1, x2)
+
+        if not self.training:
+            logit_list = [logit12]
+        else:
+            logit1 = self.segment(x1)
+            logit2 = self.segment(x2)
+            logit_list = [logit12, logit21, logit1, logit2]
+
+        return logit_list
+
+    def init_weight(self):
+        pass
+
+
+class ChangeMixin(nn.Layer):
+    def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor):
+        super().__init__()
+        convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)]
+        convs += [
+            Conv3x3(mid_ch, mid_ch, norm=True, act=True) 
+            for _ in range(num_convs-1)
+        ]
+        self.detect = nn.Sequential(
+            *convs,
+            Conv3x3(mid_ch, out_ch),
+            nn.UpsamplingBilinear2D(scale_factor=scale_factor)
+        )
+
+    def forward(self, x1, x2):
+        pred12 = self.detect(paddle.concat([x1, x2], axis=1))
+        pred21 = self.detect(paddle.concat([x2, x1], axis=1))
+        return pred12, pred21
+
+
+class ChangeStar_FarSeg(_ChangeStarBase):
+    """
+    The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle.
+
+    The original article refers to
+        Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery"
+        (https://arxiv.org/abs/2108.07002).
+    
+    Note that this implementation differs from the original code in two aspects:
+    1. The encoder of the FarSeg model is ResNet50.
+    2. We use conv-bn-relu instead of conv-relu-bn.
+
+    Args:
+        num_classes (int): The number of target classes.
+        mid_channels (int, optional): The number of channels required by the ChangeMixin module. Default: 256.
+        inner_channels (int, optional): The number of filters used in the convolutional layers in the ChangeMixin module. 
+            Default: 16.
+        num_convs (int, optional): The number of convolutional layers used in the ChangeMixin module. Default: 4.
+        scale_factor (float, optional): The scaling factor of the output upsampling layer. Default: 4.0.
+    """
+
+    def __init__(
+        self, 
+        num_classes,
+        mid_channels=256,
+        inner_channels=16,
+        num_convs=4,
+        scale_factor=4.0,
+    ):
+        # TODO: Configurable FarSeg model
+        class _FarSegWrapper(nn.Layer):
+            def __init__(self, seg_model):
+                super().__init__()
+                self._seg_model = seg_model
+                self._seg_model.cls_pred_conv = Identity()
+
+            def forward(self, x):
+                feat_list = self._seg_model.en(x)
+                fpn_feat_list = self._seg_model.fpn(feat_list)
+                if self._seg_model.scene_relation:
+                    c5 = feat_list[-1]
+                    c6 = self._seg_model.gap(c5)
+                    refined_fpn_feat_list = self._seg_model.sr(c6, fpn_feat_list)
+                else:
+                    refined_fpn_feat_list = fpn_feat_list
+                final_feat = self._seg_model.decoder(refined_fpn_feat_list)
+                return [final_feat]
+
+        seg_model = FarSeg(out_ch=mid_channels)
+
+        super().__init__(
+            seg_model=_FarSegWrapper(seg_model),
+            num_classes=num_classes,
+            mid_channels=mid_channels,
+            inner_channels=inner_channels,
+            num_convs=num_convs,
+            scale_factor=scale_factor
+        )
+
+# NOTE: Currently, ChangeStar = FarSeg + ChangeMixin + SegHead
+ChangeStar = ChangeStar_FarSeg

+ 35 - 1
paddlers/tasks/changedetector.py

@@ -44,6 +44,8 @@ __all__ = [
     "BIT", 
     "SNUNet", 
     "DSIFN", 
+    "DSAMNet", 
+    "ChangeStar"
     "DSAMNet"
 ]
 
@@ -873,4 +875,36 @@ class DSAMNet(BaseChangeDetector):
                 'coef': [1.0, 0.05, 0.05]
             }
         else:
-            return super().default_loss()
+            return super().default_loss()
+
+
+class ChangeStar(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 mid_channels=256,
+                 inner_channels=16,
+                 num_convs=4,
+                 scale_factor=4.0,
+                 **params):
+        params.update({
+            'mid_channels': mid_channels,
+            'inner_channels': inner_channels,
+            'num_convs': num_convs,
+            'scale_factor': scale_factor
+        })
+        super(ChangeStar, self).__init__(
+            model_name='ChangeStar',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            **params)
+
+    def default_loss(self):
+        if self.use_mixed_loss is False:
+            return {
+                # XXX: make sure the shallow copy works correctly here.
+                'types': [paddleseg.models.CrossEntropyLoss()]*4,
+                'coef': [1.0]*4
+            }
+        else:
+            return super().default_loss()