Browse Source

Change custom model

Bobholamovic 2 years ago
parent
commit
09d3dd1202

+ 73 - 73
examples/rs_research/README.md

@@ -43,64 +43,73 @@ python ../../tools/prepare_dataset/prepare_svcd.py \
 1. 巨大的参数量意味着巨大的存储开销。在许多实际场景中,硬件资源往往是有限的,过多的模型参数将给部署造成困难。
 2. 在数据有限的情况下,大模型更易遭受过拟合,其在实验数据集上看起来良好的结果也难以泛化到真实场景。
 
-本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力?基于这个观点,本案例的基本思路是设计一种基于网络迭代优化思想的深度学习变化检测算法。首先,构造一个轻量级的变化检测模型,并以其作为基础迭代单元。在每次迭代开始时,由上一次迭代输出的概率图以及原始的输入影像对构造新的输入,如此逐级实现coarse-to-fine优化。考虑到增加迭代单元的数量将使模型参数量成倍增加,在迭代过程中应始终复用同一迭代单元的参数以充分挖掘变化检测网络的拟合能力,迫使其学习到更加有效的特征。这一做法类似[循环神经网络](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490)。根据此思路可以绘制框图如下:
+本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,也即存在一部分“无用”的特征,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力,获取更多更加有效的特征?基于这个观点,本案例的基本思路是为现有的变化检测模型添加一个“插件式”的特征优化模块,在仅引入较少额外的参数数量的情况下,实现变化特征增强。本案例计划以变化检测领域经典的FC-Siam-diff[4]为baseline网络,利用时间、空间、通道注意力模块对网络的中间层特征进行优化,从而减小特征冗余,提升检测效果。在具体的模块设计方面,对于时间与通道维度,选用论文[5]中提出的通道注意力模块;对于空间维度,选用论文[5]中提出的空间注意力模块。
 
-![draft](draft.png)
+### 3.2 模型定义
 
-### 3.2 确定baseline模型
+#### 3.2.1 自定义模型组网
 
-科研工作往往需要“站在巨人的肩膀上”,在前人工作的基础上做“增量创新”。因此,对模型设计类工作而言,选用一个合适的baseline模型至关重要。考虑到本案例的出发点是解决现有模型参数量过大、冗余特征过多的问题,并且在拟定的解决方案中使用到了循环结构,用作baseline的网络结构必须足够轻量和高效(因为最直接的思路是使用baseline作为基础迭代单元)。为此,本案例选用Bitemporal Image Transformer(BIT)作为baseline。BIT是一个轻量级的深度学习变化检测模型,其基本结构如图所示:
-
-![bit](bit.png)
-
-BIT的核心思想在于,
-
-### 3.3 定义新模型
-
-确定了基本思路和baseline模型之后,可以绘制如下的算法整体框图:
-
-![framework](framework.png)
+在`custom_model.py`中定义模型的宏观(macro)结构以及组成模型的各个微观(micro)模块。例如,本案例中,`custom_model.py`中定义了改进后的FC-EF结构,其核心部分实现如下:
+```python
+...
+# PaddleRS提供了许多开箱即用的模块,其中有对底层基础模块的封装(如conv-bn-relu结构等),也有注意力模块等较高层级的结构
+from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
+from paddlers.rs_models.cd.layers import ChannelAttention, SpatialAttention
 
-依据此框图,即可在。
+from attach_tools import Attach
 
-#### 3.3.1 自定义模型组网
+attach = Attach.to(paddlers.rs_models.cd)
 
-在`custom_model.py`中定义模型的宏观(macro)结构以及组成模型的各个微观(micro)模块。例如,当前`custom_model.py`中定义了迭代版本的BIT模型`IterativeBIT`:
-```python
 @attach
-class IterativeBIT(nn.Layer):
-    def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
+class CustomModel(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 num_classes,
+                 att_types='cst',
+                 use_dropout=False):
         super().__init__()
-
-        if num_iters <= 0:
-            raise ValueError(
-                f"`num_iters` should have positive value, but got {num_iters}.")
-
-        self.num_iters = num_iters
-        self.gamma = gamma
-
-        if bit_kwargs is None:
-            bit_kwargs = dict()
-
-        if 'num_classes' in bit_kwargs:
-            raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
-        bit_kwargs['num_classes'] = num_classes
-
-        self.bit = BIT(**bit_kwargs)
+        ...
+
+        # 从`att_types`参数中获取要使用的注意力类型
+        # 每个注意力模块都是可选的
+        if 'c' in att_types:
+            self.att_c = ChannelAttention(C4)
+        else:
+            self.att_c = Identity()
+        if 's' in att_types:
+            self.att_s = SpatialAttention()
+        else:
+            self.att_s = Identity()
+        # 时间注意力模块部分复用通道注意力的逻辑,在`forward()`中将具体解释
+        if 't' in att_types:
+            self.att_t = ChannelAttention(2, ratio=1)
+        else:
+            self.att_t = Identity()
+
+        self.init_weight()
 
     def forward(self, t1, t2):
-        rate_map = self._init_rate_map(t1.shape)
-
-        for it in range(self.num_iters):
-            # Construct inputs
-            x1 = self._constr_iter_input(t1, rate_map)
-            x2 = self._constr_iter_input(t2, rate_map)
-            # Get logits
-            logits_list = self.bit(x1, x2)
-            # Construct rate map
-            rate_map = self._constr_rate_map(logits_list[0])
-
-        return logits_list
+        ...
+        # 以下是本案例在FC-EF基础上新增的部分
+        # x43_1和x43_2分别是FC-EF的两路编码器提取的特征
+        # 首先使用通道和空间注意力模块对特征进行优化
+        x43_1 = self.att_c(x43_1) * x43_1
+        x43_1 = self.att_s(x43_1) * x43_1
+        x43_2 = self.att_c(x43_2) * x43_2
+        x43_2 = self.att_s(x43_2) * x43_2
+        # 为了复用通道注意力模块执行时间维度的注意力操作,首先将两个时相的特征堆叠
+        x43 = paddle.stack([x43_1, x43_2], axis=1)
+        # 堆叠后的x43形状为[b, t, c, h, w],其中b表示batch size,t为2(时相数目),c为通道数,h和w分别为特征图高宽
+        # 将t和c维度交换,输出tensor形状为[b, c, t, h, w]
+        x43 = paddle.transpose(x43, [0, 2, 1, 3, 4])
+        # 将b和c两个维度合并,输出tensor形状为[b*c, t, h, w]
+        x43 = paddle.flatten(x43, stop_axis=1)
+        # 此时,时间维度已经替代了原先的通道维度,将四维tensor输入ChannelAttention模块进行处理
+        x43 = self.att_t(x43) * x43
+        # 从处理结果中分离两个时相的信息
+        x43 = x43.reshape((x43_1.shape[0], -1, 2, *x43.shape[2:]))
+        x43_1, x43_2 = x43[:,:,0], x43[:,:,1]
+        ...
     ...
 ```
 
@@ -112,27 +121,27 @@ class IterativeBIT(nn.Layer):
 
 关于模型定义的更多细节请参考[文档](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/dev/dev_guide.md)。
 
-#### 3.3.2 自定义训练器
+#### 3.2.2 自定义训练器
 
-在`custom_trainer.py`中定义训练器。例如,当前`custom_trainer.py`中定义了与`IterativeBIT`模型对应的训练器:
+在`custom_trainer.py`中定义训练器。例如,本案例中,`custom_trainer.py`中定义了与`CustomModel`模型对应的训练器:
 ```python
 @attach
-class IterativeBIT(BaseChangeDetector):
+class CustomTrainer(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
                  losses=None,
-                 num_iters=1,
-                 gamma=0.1,
-                 bit_kwargs=None,
+                 in_channels=3,
+                 att_types='cst',
+                 use_dropout=False,
                  **params):
         params.update({
-            'num_iters': num_iters,
-            'gamma': gamma,
-            'bit_kwargs': bit_kwargs
+            'in_channels': in_channels,
+            'att_types': att_types,
+            'use_dropout': use_dropout
         })
         super().__init__(
-            model_name='IterativeBIT',
+            model_name='CustomModel',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             losses=losses,
@@ -149,27 +158,17 @@ class IterativeBIT(BaseChangeDetector):
 
 关于训练器的更多细节请参考[API文档](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/train.md)。
 
-### 3.4 进行参数分析与消融实验
+### 3.3 消融实验
 
-#### 3.4.1 实验设置
+#### 3.3.1 实验设置
 
-#### 3.4.2 编写配置文件
+#### 3.3.2 编写配置文件
 
-#### 3.4.3 实验结果
+#### 3.3.3 实验结果
 
 VisualDL、定量指标
 
-### 3.5 \*Magic Behind
-
-本小节涉及技术细节,对于本案例来说属于进阶内容,您可以选择性了解。
-
-#### 3.5.1 延迟属性绑定
-
-PaddleRS提供了,只需要。`attach_tools.Attach`对象自动。
-
-#### 3.5.2 非侵入式轻量级配置系统
-
-### 3.5 开展特征可视化实验
+### 3.4 特征可视化实验
 
 ## 4 对比实验
 
@@ -206,3 +205,4 @@ PaddleRS提供了,只需要。`attach_tools.Attach`对象自动。
 [2] Lebedev, M. A., et al. "CHANGE DETECTION IN REMOTE SENSING IMAGES USING CONDITIONAL ADVERSARIAL NETWORKS." *International Archives of the Photogrammetry, Remote Sensing & Spatial Information Sciences* 42.2 (2018).  
 [3] Chen, Hao, Zipeng Qi, and Zhenwei Shi. "Remote sensing image change detection with transformers." *IEEE Transactions on Geoscience and Remote Sensing* 60 (2021): 1-14.  
 [4] Daudt, Rodrigo Caye, Bertr Le Saux, and Alexandre Boulch. "Fully convolutional siamese networks for change detection." *2018 25th IEEE International Conference on Image Processing (ICIP)*. IEEE, 2018.  
+[5] Woo, Sanghyun, et al. "Cbam: Convolutional block attention module." *Proceedings of the European conference on computer vision (ECCV)*. 2018.

+ 192 - 74
examples/rs_research/custom_model.py

@@ -2,88 +2,206 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 import paddlers
-from paddlers.rs_models.cd import BIT
+from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
+from paddlers.rs_models.cd.layers import ChannelAttention, SpatialAttention
+
 from attach_tools import Attach
 
 attach = Attach.to(paddlers.rs_models.cd)
 
 
 @attach
-class IterativeBIT(BIT):
+class CustomModel(nn.Layer):
     def __init__(self,
-                 num_iters=1,
-                 feat_channels=32,
-                 num_classes=2,
-                 bit_kwargs=None):
-        if num_iters <= 0:
-            raise ValueError(
-                f"`num_iters` should have positive value, but got {num_iters}.")
-
-        self.num_iters = num_iters
-
-        if bit_kwargs is None:
-            bit_kwargs = dict()
-
-        if 'num_classes' in bit_kwargs:
-            raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
-        bit_kwargs['num_classes'] = num_classes
-
-        super().__init__(**bit_kwargs)
+                 in_channels,
+                 num_classes,
+                 att_types='cst',
+                 use_dropout=False):
+        super(CustomModel, self).__init__()
+
+        C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
+
+        self.use_dropout = use_dropout
+
+        self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
+        self.do11 = self._make_dropout()
+        self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
+        self.do12 = self._make_dropout()
+        self.pool1 = MaxPool2x2()
+
+        self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
+        self.do21 = self._make_dropout()
+        self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
+        self.do22 = self._make_dropout()
+        self.pool2 = MaxPool2x2()
+
+        self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
+        self.do31 = self._make_dropout()
+        self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
+        self.do32 = self._make_dropout()
+        self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
+        self.do33 = self._make_dropout()
+        self.pool3 = MaxPool2x2()
+
+        self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
+        self.do41 = self._make_dropout()
+        self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
+        self.do42 = self._make_dropout()
+        self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
+        self.do43 = self._make_dropout()
+        self.pool4 = MaxPool2x2()
+
+        self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
+
+        self.conv43d = Conv3x3(C5, C4, norm=True, act=True)
+        self.do43d = self._make_dropout()
+        self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
+        self.do42d = self._make_dropout()
+        self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
+        self.do41d = self._make_dropout()
+
+        self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
+
+        self.conv33d = Conv3x3(C4, C3, norm=True, act=True)
+        self.do33d = self._make_dropout()
+        self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
+        self.do32d = self._make_dropout()
+        self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
+        self.do31d = self._make_dropout()
+
+        self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
+
+        self.conv22d = Conv3x3(C3, C2, norm=True, act=True)
+        self.do22d = self._make_dropout()
+        self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
+        self.do21d = self._make_dropout()
+
+        self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
+
+        self.conv12d = Conv3x3(C2, C1, norm=True, act=True)
+        self.do12d = self._make_dropout()
+        self.conv11d = Conv3x3(C1, num_classes)
+
+        if 'c' in att_types:
+            self.att_c = ChannelAttention(C4)
+        else:
+            self.att_c = Identity()
+        if 's' in att_types:
+            self.att_s = SpatialAttention()
+        else:
+            self.att_s = Identity()
+        if 't' in att_types:
+            self.att_t = ChannelAttention(2, ratio=1)
+        else:
+            self.att_t = Identity()
 
-        self.conv_fuse = nn.Sequential(
-            nn.Conv2D(feat_channels + 1, feat_channels, 1), nn.Sigmoid())
+        self.init_weight()
 
     def forward(self, t1, t2):
-        # Extract features via shared backbone.
-        x1 = self.backbone(t1)
-        x2 = self.backbone(t2)
-
-        # Tokenization
-        if self.use_tokenizer:
-            token1 = self._get_semantic_tokens(x1)
-            token2 = self._get_semantic_tokens(x2)
+        # Encode t1
+        # Stage 1
+        x11 = self.do11(self.conv11(t1))
+        x12_1 = self.do12(self.conv12(x11))
+        x1p = self.pool1(x12_1)
+
+        # Stage 2
+        x21 = self.do21(self.conv21(x1p))
+        x22_1 = self.do22(self.conv22(x21))
+        x2p = self.pool2(x22_1)
+
+        # Stage 3
+        x31 = self.do31(self.conv31(x2p))
+        x32 = self.do32(self.conv32(x31))
+        x33_1 = self.do33(self.conv33(x32))
+        x3p = self.pool3(x33_1)
+
+        # Stage 4
+        x41 = self.do41(self.conv41(x3p))
+        x42 = self.do42(self.conv42(x41))
+        x43_1 = self.do43(self.conv43(x42))
+        x4p = self.pool4(x43_1)
+
+        # Encode t2
+        # Stage 1
+        x11 = self.do11(self.conv11(t2))
+        x12_2 = self.do12(self.conv12(x11))
+        x1p = self.pool1(x12_2)
+
+        # Stage 2
+        x21 = self.do21(self.conv21(x1p))
+        x22_2 = self.do22(self.conv22(x21))
+        x2p = self.pool2(x22_2)
+
+        # Stage 3
+        x31 = self.do31(self.conv31(x2p))
+        x32 = self.do32(self.conv32(x31))
+        x33_2 = self.do33(self.conv33(x32))
+        x3p = self.pool3(x33_2)
+
+        # Stage 4
+        x41 = self.do41(self.conv41(x3p))
+        x42 = self.do42(self.conv42(x41))
+        x43_2 = self.do43(self.conv43(x42))
+        x4p = self.pool4(x43_2)
+
+        # Attend
+        x43_1 = self.att_c(x43_1) * x43_1
+        x43_1 = self.att_s(x43_1) * x43_1
+        x43_2 = self.att_c(x43_2) * x43_2
+        x43_2 = self.att_s(x43_2) * x43_2
+        x43 = paddle.stack([x43_1, x43_2], axis=1)
+        x43 = paddle.transpose(x43, [0, 2, 1, 3, 4])
+        x43 = paddle.flatten(x43, stop_axis=1)
+        x43 = self.att_t(x43) * x43
+        x43 = x43.reshape((x43_1.shape[0], -1, 2, *x43.shape[2:]))
+        x43_1, x43_2 = x43[:, :, 0], x43[:, :, 1]
+
+        # Decode
+        # Stage 4d
+        x4d = self.upconv4(x4p)
+        pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
+                x43_1.shape[2] - x4d.shape[2])
+        x4d = F.pad(x4d, pad=pad4, mode='replicate')
+        x4d = paddle.concat([x4d, paddle.abs(x43_1 - x43_2)], 1)
+        x43d = self.do43d(self.conv43d(x4d))
+        x42d = self.do42d(self.conv42d(x43d))
+        x41d = self.do41d(self.conv41d(x42d))
+
+        # Stage 3d
+        x3d = self.upconv3(x41d)
+        pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0,
+                x33_1.shape[2] - x3d.shape[2])
+        x3d = F.pad(x3d, pad=pad3, mode='replicate')
+        x3d = paddle.concat([x3d, paddle.abs(x33_1 - x33_2)], 1)
+        x33d = self.do33d(self.conv33d(x3d))
+        x32d = self.do32d(self.conv32d(x33d))
+        x31d = self.do31d(self.conv31d(x32d))
+
+        # Stage 2d
+        x2d = self.upconv2(x31d)
+        pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0,
+                x22_1.shape[2] - x2d.shape[2])
+        x2d = F.pad(x2d, pad=pad2, mode='replicate')
+        x2d = paddle.concat([x2d, paddle.abs(x22_1 - x22_2)], 1)
+        x22d = self.do22d(self.conv22d(x2d))
+        x21d = self.do21d(self.conv21d(x22d))
+
+        # Stage 1d
+        x1d = self.upconv1(x21d)
+        pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0,
+                x12_1.shape[2] - x1d.shape[2])
+        x1d = F.pad(x1d, pad=pad1, mode='replicate')
+        x1d = paddle.concat([x1d, paddle.abs(x12_1 - x12_2)], 1)
+        x12d = self.do12d(self.conv12d(x1d))
+        x11d = self.conv11d(x12d)
+
+        return [x11d]
+
+    def init_weight(self):
+        pass
+
+    def _make_dropout(self):
+        if self.use_dropout:
+            return nn.Dropout2D(p=0.2)
         else:
-            token1 = self._get_reshaped_tokens(x1)
-            token2 = self._get_reshaped_tokens(x2)
-
-        # Transformer encoder forward
-        token = paddle.concat([token1, token2], axis=1)
-        token = self.encode(token)
-        token1, token2 = paddle.chunk(token, 2, axis=1)
-
-        # Get initial rate map
-        rate_map = self._init_rate_map(x1.shape)
-
-        for it in range(self.num_iters):
-            # Construct inputs
-            x1_iter = self._constr_iter_input(x1, rate_map)
-            x2_iter = self._constr_iter_input(x2, rate_map)
-
-            # Transformer decoder forward
-            y1 = self.decode(x1_iter, token1)
-            y2 = self.decode(x2_iter, token2)
-
-            # Feature differencing
-            y = paddle.abs(y1 - y2)
-
-            # Construct rate map
-            rate_map = self._constr_rate_map(y)
-
-        y = self.upsample(y)
-        pred = self.conv_out(y)
-
-        return [pred]
-
-    def _init_rate_map(self, im_shape):
-        b, _, h, w = im_shape
-        return paddle.full((b, 1, h, w), 0.5)
-
-    def _constr_iter_input(self, x, rate_map):
-        return self.conv_fuse(paddle.concat([x, rate_map], axis=1))
-
-    def _constr_rate_map(self, x):
-        rate_map = x.mean(1, keepdim=True).detach()  # Cut off gradient workflow
-        # min-max normalization
-        rate_map -= rate_map.min()
-        rate_map /= rate_map.max()
-        return rate_map
+            return Identity()

+ 8 - 8
examples/rs_research/custom_trainer.py

@@ -7,22 +7,22 @@ attach = Attach.to(paddlers.tasks.change_detector)
 
 
 @attach
-class IterativeBIT(BaseChangeDetector):
+class CustomTrainer(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
                  losses=None,
-                 num_iters=1,
-                 feat_channels=32,
-                 bit_kwargs=None,
+                 in_channels=3,
+                 att_types='cst',
+                 use_dropout=False,
                  **params):
         params.update({
-            'num_iters': num_iters,
-            'feat_channels': feat_channels,
-            'bit_kwargs': bit_kwargs
+            'in_channels': in_channels,
+            'att_types': att_types,
+            'use_dropout': use_dropout
         })
         super().__init__(
-            model_name='IterativeBIT',
+            model_name='CustomModel',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             losses=losses,