Browse Source

[论文复现赛] FarSeg (#45)

* Update FarSeg

* Update FarSeg

* Compatible initialization
ucsk 2 years ago
parent
commit
bc5b536ed8
3 changed files with 179 additions and 178 deletions
  1. 4 16
      paddlers/rs_models/cd/changestar.py
  2. 167 159
      paddlers/rs_models/seg/farseg.py
  3. 8 3
      tests/rs_models/test_seg_models.py

+ 4 - 16
paddlers/rs_models/cd/changestar.py

@@ -14,7 +14,6 @@
 
 import paddle
 import paddle.nn as nn
-import paddle.nn.functional as F
 
 from paddlers.datasets.cd_dataset import MaskType
 from paddlers.rs_models.seg import FarSeg
@@ -22,7 +21,6 @@ from .layers import Conv3x3, Identity
 
 
 class _ChangeStarBase(nn.Layer):
-
     USE_MULTITASK_DECODER = True
     OUT_TYPES = (MaskType.CD, MaskType.CD, MaskType.SEG_T1, MaskType.SEG_T2)
 
@@ -118,22 +116,12 @@ class ChangeStar_FarSeg(_ChangeStarBase):
             def __init__(self, seg_model):
                 super(_FarSegWrapper, self).__init__()
                 self._seg_model = seg_model
-                self._seg_model.cls_pred_conv = Identity()
+                self._seg_model.cls_head = Identity()
 
             def forward(self, x):
-                feat_list = self._seg_model.en(x)
-                fpn_feat_list = self._seg_model.fpn(feat_list)
-                if self._seg_model.scene_relation:
-                    c5 = feat_list[-1]
-                    c6 = self._seg_model.gap(c5)
-                    refined_fpn_feat_list = self._seg_model.sr(c6,
-                                                               fpn_feat_list)
-                else:
-                    refined_fpn_feat_list = fpn_feat_list
-                final_feat = self._seg_model.decoder(refined_fpn_feat_list)
-                return [final_feat]
-
-        seg_model = FarSeg(out_ch=mid_channels)
+                return self._seg_model(x)
+
+        seg_model = FarSeg(decoder_out_channels=mid_channels)
 
         super(ChangeStar_FarSeg, self).__init__(
             seg_model=_FarSegWrapper(seg_model),

+ 167 - 159
paddlers/rs_models/seg/farseg.py

@@ -20,25 +20,79 @@ import math
 
 import paddle.nn as nn
 import paddle.nn.functional as F
-from paddle.vision.models import resnet50
-from paddle import nn
-import paddle.nn.functional as F
+from paddle.vision.models import resnet
 
-from .layers import (Identity, ConvReLU, kaiming_normal_init, constant_init)
+from paddlers.models.ppdet.modeling import initializer as init
 
 
-class FPN(nn.Layer):
-    """
-    Module that adds FPN on top of a list of feature maps.
-    The feature maps are currently supposed to be in increasing depth
-        order, and must be consecutive.
-    """
+class FPNConvBlock(nn.Conv2D):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 dilation=1):
+        super(FPNConvBlock, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=dilation * (kernel_size - 1) // 2,
+            dilation=dilation)
+        init.kaiming_uniform_(self.weight, a=1)
+        init.constant_(self.bias, value=0)
+
 
+class DefaultConvBlock(nn.Conv2D):
     def __init__(self,
-                 in_channels_list,
+                 in_channels,
                  out_channels,
-                 conv_block=ConvReLU,
-                 top_blocks=None):
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 bias_attr=None):
+        super(DefaultConvBlock, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            bias_attr=bias_attr)
+        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+        if self.bias is not None:
+            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+            bound = 1 / math.sqrt(fan_in)
+            init.uniform_(self.bias, -bound, bound)
+
+
+class ResNetEncoder(nn.Layer):
+    def __init__(self, backbone='resnet50', in_channels=3, pretrained=True):
+        super(ResNetEncoder, self).__init__()
+        self.resnet = getattr(resnet, backbone)(pretrained=pretrained)
+        if in_channels != 3:
+            self.resnet.conv1 = nn.Conv2D(
+                in_channels, 64, 7, stride=2, padding=3, bias_attr=False)
+
+        for layer in self.resnet.sublayers():
+            if isinstance(layer, (nn.BatchNorm2D, nn.SyncBatchNorm)):
+                layer._momentum = 0.1
+
+    def forward(self, x):
+        x = self.resnet.conv1(x)
+        x = self.resnet.bn1(x)
+        x = self.resnet.relu(x)
+        x = self.resnet.maxpool(x)
+
+        c2 = self.resnet.layer1(x)
+        c3 = self.resnet.layer2(c2)
+        c4 = self.resnet.layer3(c3)
+        c5 = self.resnet.layer4(c4)
+
+        return [c2, c3, c4, c5]
+
+
+class FPN(nn.Layer):
+    def __init__(self, in_channels_list, out_channels, conv_block=FPNConvBlock):
         super(FPN, self).__init__()
 
         inner_blocks = []
@@ -46,17 +100,10 @@ class FPN(nn.Layer):
         for idx, in_channels in enumerate(in_channels_list, 1):
             if in_channels == 0:
                 continue
-            inner_block_module = conv_block(in_channels, out_channels, 1)
-            layer_block_module = conv_block(out_channels, out_channels, 3, 1)
-            for module in [inner_block_module, layer_block_module]:
-                for m in module.sublayers():
-                    if isinstance(m, nn.Conv2D):
-                        kaiming_normal_init(m.weight)
-            inner_blocks.append(inner_block_module)
-            layer_blocks.append(layer_block_module)
+            inner_blocks.append(conv_block(in_channels, out_channels, 1))
+            layer_blocks.append(conv_block(out_channels, out_channels, 3, 1))
         self.inner_blocks = nn.LayerList(inner_blocks)
         self.layer_blocks = nn.LayerList(layer_blocks)
-        self.top_blocks = top_blocks
 
     def forward(self, x):
         last_inner = self.inner_blocks[-1](x[-1])
@@ -69,80 +116,55 @@ class FPN(nn.Layer):
             inner_lateral = inner_block(feature)
             last_inner = inner_lateral + inner_top_down
             results.insert(0, layer_block(last_inner))
-        if isinstance(self.top_blocks, LastLevelP6P7):
-            last_results = self.top_blocks(x[-1], results[-1])
-            results.extend(last_results)
-        elif isinstance(self.top_blocks, LastLevelMaxPool):
-            last_results = self.top_blocks(results[-1])
-            results.extend(last_results)
         return tuple(results)
 
 
-class LastLevelMaxPool(nn.Layer):
-    def forward(self, x):
-        return [F.max_pool2d(x, 1, 2, 0)]
-
-
-class LastLevelP6P7(nn.Layer):
-    """
-    This module is used in RetinaNet to generate extra layers, P6 and P7.
-    """
-
-    def __init__(self, in_channels, out_channels):
-        super(LastLevelP6P7, self).__init__()
-        self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1)
-        self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1)
-        for module in [self.p6, self.p7]:
-            for m in module.sublayers():
-                kaiming_normal_init(m.weight)
-                constant_init(m.bias, value=0)
-        self.use_P5 = in_channels == out_channels
-
-    def forward(self, c5, p5):
-        x = p5 if self.use_P5 else c5
-        p6 = self.p6(x)
-        p7 = self.p7(F.relu(p6))
-        return [p6, p7]
-
-
-class SceneRelation(nn.Layer):
+class FSRelation(nn.Layer):
     def __init__(self,
                  in_channels,
-                 channel_list,
+                 channels_list,
                  out_channels,
-                 scale_aware_proj=True):
-        super(SceneRelation, self).__init__()
+                 scale_aware_proj=True,
+                 conv_block=DefaultConvBlock):
+        super(FSRelation, self).__init__()
+
         self.scale_aware_proj = scale_aware_proj
-        if scale_aware_proj:
+        if self.scale_aware_proj:
             self.scene_encoder = nn.LayerList([
                 nn.Sequential(
-                    nn.Conv2D(in_channels, out_channels, 1),
-                    nn.ReLU(), nn.Conv2D(out_channels, out_channels, 1))
-                for _ in range(len(channel_list))
+                    conv_block(in_channels, out_channels, 1),
+                    nn.ReLU(), conv_block(out_channels, out_channels, 1))
+                for _ in range(len(channels_list))
             ])
         else:
-            # 2mlp
             self.scene_encoder = nn.Sequential(
-                nn.Conv2D(in_channels, out_channels, 1),
-                nn.ReLU(),
-                nn.Conv2D(out_channels, out_channels, 1), )
+                conv_block(in_channels, out_channels, 1),
+                nn.ReLU(), conv_block(out_channels, out_channels, 1))
+
         self.content_encoders = nn.LayerList()
         self.feature_reencoders = nn.LayerList()
-        for c in channel_list:
+        for channel in channels_list:
             self.content_encoders.append(
                 nn.Sequential(
-                    nn.Conv2D(c, out_channels, 1),
-                    nn.BatchNorm2D(out_channels), nn.ReLU()))
+                    conv_block(
+                        channel, out_channels, 1, bias_attr=True),
+                    nn.BatchNorm2D(
+                        out_channels, momentum=0.1),
+                    nn.ReLU()))
             self.feature_reencoders.append(
                 nn.Sequential(
-                    nn.Conv2D(c, out_channels, 1),
-                    nn.BatchNorm2D(out_channels), nn.ReLU()))
+                    conv_block(
+                        channel, out_channels, 1, bias_attr=True),
+                    nn.BatchNorm2D(
+                        out_channels, momentum=0.1),
+                    nn.ReLU()))
+
         self.normalizer = nn.Sigmoid()
 
-    def forward(self, scene_feature, features: list):
+    def forward(self, scene_feature, feature_list):
         content_feats = [
             c_en(p_feat)
-            for c_en, p_feat in zip(self.content_encoders, features)
+            for c_en, p_feat in zip(self.content_encoders, feature_list)
         ]
         if self.scale_aware_proj:
             scene_feats = [op(scene_feature) for op in self.scene_encoder]
@@ -157,7 +179,8 @@ class SceneRelation(nn.Layer):
                 for cf in content_feats
             ]
         p_feats = [
-            op(p_feat) for op, p_feat in zip(self.feature_reencoders, features)
+            op(p_feat)
+            for op, p_feat in zip(self.feature_reencoders, feature_list)
         ]
         refined_feats = [r * p for r, p in zip(relations, p_feats)]
         return refined_feats
@@ -167,71 +190,40 @@ class AsymmetricDecoder(nn.Layer):
     def __init__(self,
                  in_channels,
                  out_channels,
-                 in_feat_output_strides=(4, 8, 16, 32),
-                 out_feat_output_stride=4,
-                 norm_fn=nn.BatchNorm2D,
-                 num_groups_gn=None):
+                 in_feature_output_strides=(4, 8, 16, 32),
+                 out_feature_output_stride=4,
+                 conv_block=DefaultConvBlock):
         super(AsymmetricDecoder, self).__init__()
-        if norm_fn == nn.BatchNorm2D:
-            norm_fn_args = dict(num_features=out_channels)
-        elif norm_fn == nn.GroupNorm:
-            if num_groups_gn is None:
-                raise ValueError(
-                    'When norm_fn is nn.GroupNorm, num_groups_gn is needed.')
-            norm_fn_args = dict(
-                num_groups=num_groups_gn, num_channels=out_channels)
-        else:
-            raise ValueError('Type of {} is not support.'.format(type(norm_fn)))
+
         self.blocks = nn.LayerList()
-        for in_feat_os in in_feat_output_strides:
-            num_upsample = int(math.log2(int(in_feat_os))) - int(
-                math.log2(int(out_feat_output_stride)))
+        for in_feature_output_stride in in_feature_output_strides:
+            num_upsample = int(math.log2(int(in_feature_output_stride))) - int(
+                math.log2(int(out_feature_output_stride)))
             num_layers = num_upsample if num_upsample != 0 else 1
             self.blocks.append(
                 nn.Sequential(*[
                     nn.Sequential(
-                        nn.Conv2D(
+                        conv_block(
                             in_channels if idx == 0 else out_channels,
                             out_channels,
                             3,
                             1,
                             1,
                             bias_attr=False),
-                        norm_fn(**norm_fn_args)
-                        if norm_fn is not None else Identity(),
+                        nn.BatchNorm2D(
+                            out_channels, momentum=0.1),
                         nn.ReLU(),
                         nn.UpsamplingBilinear2D(scale_factor=2) if num_upsample
-                        != 0 else Identity(), ) for idx in range(num_layers)
+                        != 0 else nn.Identity(), ) for idx in range(num_layers)
                 ]))
 
-    def forward(self, feat_list: list):
-        inner_feat_list = []
+    def forward(self, feature_list):
+        inner_feature_list = []
         for idx, block in enumerate(self.blocks):
-            decoder_feat = block(feat_list[idx])
-            inner_feat_list.append(decoder_feat)
-        out_feat = sum(inner_feat_list) / 4.
-        return out_feat
-
-
-class ResNet50Encoder(nn.Layer):
-    def __init__(self, in_ch=3, pretrained=True):
-        super(ResNet50Encoder, self).__init__()
-        self.resnet = resnet50(pretrained=pretrained)
-        if in_ch != 3:
-            self.resnet.conv1 = nn.Conv2D(
-                in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)
-
-    def forward(self, inputs):
-        x = inputs
-        x = self.resnet.conv1(x)
-        x = self.resnet.bn1(x)
-        x = self.resnet.relu(x)
-        x = self.resnet.maxpool(x)
-        c2 = self.resnet.layer1(x)
-        c3 = self.resnet.layer2(c2)
-        c4 = self.resnet.layer3(c3)
-        c5 = self.resnet.layer4(c4)
-        return [c2, c3, c4, c5]
+            decoder_feature = block(feature_list[idx])
+            inner_feature_list.append(decoder_feature)
+        out_feature = sum(inner_feature_list) / len(inner_feature_list)
+        return out_feature
 
 
 class FarSeg(nn.Layer):
@@ -239,50 +231,66 @@ class FarSeg(nn.Layer):
     The FarSeg implementation based on PaddlePaddle.
 
     The original article refers to
-    Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution 
-        Remote Sensing Imagery"
-    (https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf)
+    Zheng Z, Zhong Y, Wang J, et al. Foreground-aware relation network for geospatial object segmentation in
+    high spatial resolution remote sensing imagery[C]//Proceedings of the IEEE/CVF conference on computer vision
+    and pattern recognition. 2020: 4096-4105.
 
     Args:
-        in_channels (int, optional): Number of bands of the input images. Default: 3.
-        num_classes (int, optional): Number of target classes. Default: 16.
-        fpn_ch_list (list[int]|tuple[int], optional): Channel list of the FPN. Default: (256, 512, 1024, 2048).
-        mid_ch (int, optional): Output channels of the FPN. Default: 256.
-        out_ch (int, optional): Output channels of the decoder. Default: 128.
-        sr_ch_list (list[int]|tuple[int], optional): Channel list of the foreground-scene relation module. Default: (256, 256, 256, 256).
-        pretrained_encoder (bool, optional): Whether to use a pretrained encoder. Default: True.
+        in_channels (int): The number of image channels for the input model. Default: 3.
+        num_classes (int): The unique number of target classes. Default: 16.
+        backbone (str): A backbone network, models available in `paddle.vision.models.resnet`. Default: resnet50.
+        backbone_pretrained (bool): Whether the backbone network uses IMAGENET pretrained weights. Default: True.
+        fpn_out_channels (int): The number of channels output by the feature pyramid network. Default: 256.
+        fsr_out_channels (int): The number of channels output by the F-S relation module. Default: 256.
+        scale_aware_proj (bool): Whether to use scale awareness in F-S relation module. Default: True.
+        decoder_out_channels (int): The number of channels output by the decoder. Default: 128.
     """
 
     def __init__(self,
                  in_channels=3,
                  num_classes=16,
-                 fpn_ch_list=(256, 512, 1024, 2048),
-                 mid_ch=256,
-                 out_ch=128,
-                 sr_ch_list=(256, 256, 256, 256),
-                 pretrained_encoder=True):
+                 backbone='resnet50',
+                 backbone_pretrained=True,
+                 fpn_out_channels=256,
+                 fsr_out_channels=256,
+                 scale_aware_proj=True,
+                 decoder_out_channels=128):
         super(FarSeg, self).__init__()
-        self.en = ResNet50Encoder(in_channels, pretrained_encoder)
-        self.fpn = FPN(in_channels_list=fpn_ch_list, out_channels=mid_ch)
+
+        backbone = backbone.lower()
+        self.encoder = ResNetEncoder(
+            backbone=backbone,
+            in_channels=in_channels,
+            pretrained=backbone_pretrained)
+
+        fpn_max_in_channels = 2048
+        if backbone in ['resnet18', 'resnet34']:
+            fpn_max_in_channels = 512
+        self.fpn = FPN(in_channels_list=[
+            fpn_max_in_channels // (2**(3 - i)) for i in range(4)
+        ],
+                       out_channels=fpn_out_channels)
+        self.gap = nn.AdaptiveAvgPool2D(1)
+        self.fsr = FSRelation(
+            in_channels=fpn_max_in_channels,
+            channels_list=[fpn_out_channels] * 4,
+            out_channels=fsr_out_channels,
+            scale_aware_proj=scale_aware_proj)
+
         self.decoder = AsymmetricDecoder(
-            in_channels=mid_ch, out_channels=out_ch)
-        self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1)
-        self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4)
-        self.scene_relation = True if sr_ch_list is not None else False
-        if self.scene_relation:
-            self.gap = nn.AdaptiveAvgPool2D(1)
-            self.sr = SceneRelation(fpn_ch_list[-1], sr_ch_list, mid_ch)
+            in_channels=fsr_out_channels, out_channels=decoder_out_channels)
+
+        self.cls_head = nn.Sequential(
+            DefaultConvBlock(decoder_out_channels, num_classes, 1),
+            nn.UpsamplingBilinear2D(scale_factor=4))
 
     def forward(self, x):
-        feat_list = self.en(x)
-        fpn_feat_list = self.fpn(feat_list)
-        if self.scene_relation:
-            c5 = feat_list[-1]
-            c6 = self.gap(c5)
-            refined_fpn_feat_list = self.sr(c6, fpn_feat_list)
-        else:
-            refined_fpn_feat_list = fpn_feat_list
-        final_feat = self.decoder(refined_fpn_feat_list)
-        cls_pred = self.cls_pred_conv(final_feat)
-        cls_pred = self.upsample4x_op(cls_pred)
-        return [cls_pred]
+        feature_list = self.encoder(x)
+
+        fpn_feature_list = self.fpn(feature_list)
+        scene_feature = self.gap(feature_list[-1])
+        refined_feature_list = self.fsr(scene_feature, fpn_feature_list)
+
+        feature = self.decoder(refined_feature_list)
+        logit = self.cls_head(feature)
+        return [logit]

+ 8 - 3
tests/rs_models/test_seg_models.py

@@ -53,10 +53,15 @@ class TestFarSegModel(TestSegModel):
 
     def set_specs(self):
         self.specs = [
-            dict(), dict(num_classes=20), dict(pretrained_encoder=False),
-            dict(in_channels=10)
+            dict(), dict(
+                in_channels=6, num_classes=10), dict(
+                    backbone='resnet18', backbone_pretrained=False), dict(
+                        fpn_out_channels=128,
+                        fsr_out_channels=64,
+                        decoder_out_channels=32), dict(scale_aware_proj=False)
         ]
 
     def set_targets(self):
-        self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(20)],
+        self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)],
+                        [self.get_zeros_array(16)], [self.get_zeros_array(16)],
                         [self.get_zeros_array(16)]]