|  | @@ -0,0 +1,478 @@
 | 
											
												
													
														|  | 
 |  | +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 | 
											
												
													
														|  | 
 |  | +#
 | 
											
												
													
														|  | 
 |  | +# Licensed under the Apache License, Version 2.0 (the "License");
 | 
											
												
													
														|  | 
 |  | +# you may not use this file except in compliance with the License.
 | 
											
												
													
														|  | 
 |  | +# You may obtain a copy of the License at
 | 
											
												
													
														|  | 
 |  | +#
 | 
											
												
													
														|  | 
 |  | +#     http://www.apache.org/licenses/LICENSE-2.0
 | 
											
												
													
														|  | 
 |  | +#
 | 
											
												
													
														|  | 
 |  | +# Unless required by applicable law or agreed to in writing, software
 | 
											
												
													
														|  | 
 |  | +# distributed under the License is distributed on an "AS IS" BASIS,
 | 
											
												
													
														|  | 
 |  | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
											
												
													
														|  | 
 |  | +# See the License for the specific language governing permissions and
 | 
											
												
													
														|  | 
 |  | +# limitations under the License.
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +import paddle
 | 
											
												
													
														|  | 
 |  | +import paddle.nn as nn
 | 
											
												
													
														|  | 
 |  | +import paddle.nn.functional as F
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +bn_mom = 1 - 0.0003
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class NLBlock(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_channels):
 | 
											
												
													
														|  | 
 |  | +        super(NLBlock, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv_v = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_ch=in_channels,
 | 
											
												
													
														|  | 
 |  | +            out_ch=in_channels,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_channels, momentum=0.9))
 | 
											
												
													
														|  | 
 |  | +        self.W = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_ch=in_channels,
 | 
											
												
													
														|  | 
 |  | +            out_ch=in_channels,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_channels, momentum=0.9),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
 | 
											
												
													
														|  | 
 |  | +        value = self.conv_v(x)
 | 
											
												
													
														|  | 
 |  | +        value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]])
 | 
											
												
													
														|  | 
 |  | +        value = value.transpose([0, 2, 1])  # B * (H*W) * value_channels
 | 
											
												
													
														|  | 
 |  | +        key = x.reshape([batch_size, c, h * w])  # B * key_channels * (H*W)
 | 
											
												
													
														|  | 
 |  | +        query = x.reshape([batch_size, c, h * w])
 | 
											
												
													
														|  | 
 |  | +        query = query.transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        sim_map = paddle.matmul(query, key)  # B * (H*W) * (H*W)
 | 
											
												
													
														|  | 
 |  | +        sim_map = (c**-.5) * sim_map  # B * (H*W) * (H*W)
 | 
											
												
													
														|  | 
 |  | +        sim_map = nn.functional.softmax(sim_map, axis=-1)  # B * (H*W) * (H*W)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        context = paddle.matmul(sim_map, value)
 | 
											
												
													
														|  | 
 |  | +        context = context.transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        context = context.reshape([batch_size, c, *x.shape[2:]])
 | 
											
												
													
														|  | 
 |  | +        context = self.W(context)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return context
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class NLFPN(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """ Non-local feature parymid network"""
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_dim, reduction=True):
 | 
											
												
													
														|  | 
 |  | +        super(NLFPN, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        if reduction:
 | 
											
												
													
														|  | 
 |  | +            self.reduction = BasicConv(
 | 
											
												
													
														|  | 
 |  | +                in_ch=in_dim,
 | 
											
												
													
														|  | 
 |  | +                out_ch=in_dim // 4,
 | 
											
												
													
														|  | 
 |  | +                kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +                norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                    in_dim // 4, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +                act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +            self.re_reduction = BasicConv(
 | 
											
												
													
														|  | 
 |  | +                in_ch=in_dim // 4,
 | 
											
												
													
														|  | 
 |  | +                out_ch=in_dim,
 | 
											
												
													
														|  | 
 |  | +                kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +                norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                    in_dim, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +                act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +            in_dim = in_dim // 4
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            self.reduction = None
 | 
											
												
													
														|  | 
 |  | +            self.re_reduction = None
 | 
											
												
													
														|  | 
 |  | +        self.conv_e1 = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_dim,
 | 
											
												
													
														|  | 
 |  | +            in_dim,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_dim, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_e2 = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_dim,
 | 
											
												
													
														|  | 
 |  | +            in_dim * 2,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_dim * 2, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_e3 = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_dim * 2,
 | 
											
												
													
														|  | 
 |  | +            in_dim * 4,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_dim * 4, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_d1 = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_dim,
 | 
											
												
													
														|  | 
 |  | +            in_dim,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_dim, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_d2 = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_dim * 2,
 | 
											
												
													
														|  | 
 |  | +            in_dim,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_dim, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_d3 = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_dim * 4,
 | 
											
												
													
														|  | 
 |  | +            in_dim * 2,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                in_dim * 2, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.nl3 = NLBlock(in_dim * 2)
 | 
											
												
													
														|  | 
 |  | +        self.nl2 = NLBlock(in_dim)
 | 
											
												
													
														|  | 
 |  | +        self.nl1 = NLBlock(in_dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2)
 | 
											
												
													
														|  | 
 |  | +        self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        if self.reduction is not None:
 | 
											
												
													
														|  | 
 |  | +            x = self.reduction(x)
 | 
											
												
													
														|  | 
 |  | +        e1 = self.conv_e1(x)  # C,H,W
 | 
											
												
													
														|  | 
 |  | +        e2 = self.conv_e2(self.downsample_x2(e1))  # 2C,H/2,W/2
 | 
											
												
													
														|  | 
 |  | +        e3 = self.conv_e3(self.downsample_x2(e2))  # 4C,H/4,W/4
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        d3 = self.conv_d3(e3)  # 2C,H/4,W/4
 | 
											
												
													
														|  | 
 |  | +        nl = self.nl3(d3)
 | 
											
												
													
														|  | 
 |  | +        d3 = self.upsample_x2(paddle.multiply(d3, nl))  ##2C,H/2,W/2
 | 
											
												
													
														|  | 
 |  | +        d2 = self.conv_d2(e2 + d3)  # C,H/2,W/2
 | 
											
												
													
														|  | 
 |  | +        nl = self.nl2(d2)
 | 
											
												
													
														|  | 
 |  | +        d2 = self.upsample_x2(paddle.multiply(d2, nl))  # C,H,W
 | 
											
												
													
														|  | 
 |  | +        d1 = self.conv_d1(e1 + d2)
 | 
											
												
													
														|  | 
 |  | +        nl = self.nl1(d1)
 | 
											
												
													
														|  | 
 |  | +        d1 = paddle.multiply(d1, nl)  # C,H,W
 | 
											
												
													
														|  | 
 |  | +        if self.re_reduction is not None:
 | 
											
												
													
														|  | 
 |  | +            d1 = self.re_reduction(d1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return d1
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class Cat(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False):
 | 
											
												
													
														|  | 
 |  | +        super(Cat, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.do_upsample = upsample
 | 
											
												
													
														|  | 
 |  | +        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
 | 
											
												
													
														|  | 
 |  | +        self.conv2d = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_chn_high + in_chn_low,
 | 
											
												
													
														|  | 
 |  | +            out_chn,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                out_chn, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, y):
 | 
											
												
													
														|  | 
 |  | +        if self.do_upsample:
 | 
											
												
													
														|  | 
 |  | +            x = self.upsample(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        x = paddle.concat((x, y), 1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return self.conv2d(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DoubleConv(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_chn, out_chn, stride=1, dilation=1):
 | 
											
												
													
														|  | 
 |  | +        super(DoubleConv, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv = nn.Sequential(
 | 
											
												
													
														|  | 
 |  | +            nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +                in_chn,
 | 
											
												
													
														|  | 
 |  | +                out_chn,
 | 
											
												
													
														|  | 
 |  | +                kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +                stride=stride,
 | 
											
												
													
														|  | 
 |  | +                dilation=dilation,
 | 
											
												
													
														|  | 
 |  | +                padding=dilation),
 | 
											
												
													
														|  | 
 |  | +            nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                out_chn, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            nn.ReLU(),
 | 
											
												
													
														|  | 
 |  | +            nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +                out_chn, out_chn, kernel_size=3, stride=1, padding=1),
 | 
											
												
													
														|  | 
 |  | +            nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                out_chn, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        x = self.conv(x)
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class SEModule(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, channels, reduction_channels):
 | 
											
												
													
														|  | 
 |  | +        super(SEModule, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.fc1 = nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            channels,
 | 
											
												
													
														|  | 
 |  | +            reduction_channels,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +            padding=0,
 | 
											
												
													
														|  | 
 |  | +            bias_attr=True)
 | 
											
												
													
														|  | 
 |  | +        self.ReLU = nn.ReLU()
 | 
											
												
													
														|  | 
 |  | +        self.fc2 = nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            reduction_channels,
 | 
											
												
													
														|  | 
 |  | +            channels,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +            padding=0,
 | 
											
												
													
														|  | 
 |  | +            bias_attr=True)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        x_se = x.reshape(
 | 
											
												
													
														|  | 
 |  | +            [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape(
 | 
											
												
													
														|  | 
 |  | +                [x.shape[0], x.shape[1], 1, 1])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        x_se = self.fc1(x_se)
 | 
											
												
													
														|  | 
 |  | +        x_se = self.ReLU(x_se)
 | 
											
												
													
														|  | 
 |  | +        x_se = self.fc2(x_se)
 | 
											
												
													
														|  | 
 |  | +        return x * F.sigmoid(x_se)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class BasicBlock(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    expansion = 1
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 inplanes,
 | 
											
												
													
														|  | 
 |  | +                 planes,
 | 
											
												
													
														|  | 
 |  | +                 downsample=None,
 | 
											
												
													
														|  | 
 |  | +                 use_se=False,
 | 
											
												
													
														|  | 
 |  | +                 stride=1,
 | 
											
												
													
														|  | 
 |  | +                 dilation=1):
 | 
											
												
													
														|  | 
 |  | +        super(BasicBlock, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        first_planes = planes
 | 
											
												
													
														|  | 
 |  | +        outplanes = planes * self.expansion
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.conv1 = DoubleConv(inplanes, first_planes)
 | 
											
												
													
														|  | 
 |  | +        self.conv2 = DoubleConv(
 | 
											
												
													
														|  | 
 |  | +            first_planes, outplanes, stride=stride, dilation=dilation)
 | 
											
												
													
														|  | 
 |  | +        self.se = SEModule(outplanes, planes // 4) if use_se else None
 | 
											
												
													
														|  | 
 |  | +        self.downsample = MaxPool2x2() if downsample else None
 | 
											
												
													
														|  | 
 |  | +        self.ReLU = nn.ReLU()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        out = self.conv1(x)
 | 
											
												
													
														|  | 
 |  | +        residual = out
 | 
											
												
													
														|  | 
 |  | +        out = self.conv2(out)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.se is not None:
 | 
											
												
													
														|  | 
 |  | +            out = self.se(out)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.downsample is not None:
 | 
											
												
													
														|  | 
 |  | +            residual = self.downsample(residual)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        out = out + residual
 | 
											
												
													
														|  | 
 |  | +        out = self.ReLU(out)
 | 
											
												
													
														|  | 
 |  | +        return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DenseCatAdd(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_chn, out_chn):
 | 
											
												
													
														|  | 
 |  | +        super(DenseCatAdd, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_out = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_chn,
 | 
											
												
													
														|  | 
 |  | +            out_chn,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                out_chn, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, y):
 | 
											
												
													
														|  | 
 |  | +        x1 = self.conv1(x)
 | 
											
												
													
														|  | 
 |  | +        x2 = self.conv2(x1)
 | 
											
												
													
														|  | 
 |  | +        x3 = self.conv3(x2 + x1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        y1 = self.conv1(y)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.conv2(y1)
 | 
											
												
													
														|  | 
 |  | +        y3 = self.conv3(y2 + y1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DenseCatDiff(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_chn, out_chn):
 | 
											
												
													
														|  | 
 |  | +        super(DenseCatDiff, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +        self.conv_out = BasicConv(
 | 
											
												
													
														|  | 
 |  | +            in_ch=in_chn,
 | 
											
												
													
														|  | 
 |  | +            out_ch=out_chn,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=1,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                out_chn, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, y):
 | 
											
												
													
														|  | 
 |  | +        x1 = self.conv1(x)
 | 
											
												
													
														|  | 
 |  | +        x2 = self.conv2(x1)
 | 
											
												
													
														|  | 
 |  | +        x3 = self.conv3(x2 + x1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        y1 = self.conv1(y)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.conv2(y1)
 | 
											
												
													
														|  | 
 |  | +        y3 = self.conv3(y2 + y1)
 | 
											
												
													
														|  | 
 |  | +        out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3))
 | 
											
												
													
														|  | 
 |  | +        return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DFModule(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """Dense connection-based feature fusion module"""
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, dim_in, dim_out, reduction=True):
 | 
											
												
													
														|  | 
 |  | +        super(DFModule, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        if reduction:
 | 
											
												
													
														|  | 
 |  | +            self.reduction = Conv1x1(
 | 
											
												
													
														|  | 
 |  | +                dim_in,
 | 
											
												
													
														|  | 
 |  | +                dim_in // 2,
 | 
											
												
													
														|  | 
 |  | +                norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                    dim_in // 2, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +                act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +            dim_in = dim_in // 2
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            self.reduction = None
 | 
											
												
													
														|  | 
 |  | +        self.cat1 = DenseCatAdd(dim_in, dim_out)
 | 
											
												
													
														|  | 
 |  | +        self.cat2 = DenseCatDiff(dim_in, dim_out)
 | 
											
												
													
														|  | 
 |  | +        self.conv1 = Conv3x3(
 | 
											
												
													
														|  | 
 |  | +            dim_out,
 | 
											
												
													
														|  | 
 |  | +            dim_out,
 | 
											
												
													
														|  | 
 |  | +            norm=nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                dim_out, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            act=nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x1, x2):
 | 
											
												
													
														|  | 
 |  | +        if self.reduction is not None:
 | 
											
												
													
														|  | 
 |  | +            x1 = self.reduction(x1)
 | 
											
												
													
														|  | 
 |  | +            x2 = self.reduction(x2)
 | 
											
												
													
														|  | 
 |  | +        x_add = self.cat1(x1, x2)
 | 
											
												
													
														|  | 
 |  | +        x_diff = self.cat2(x1, x2)
 | 
											
												
													
														|  | 
 |  | +        y = self.conv1(x_diff) + x_add
 | 
											
												
													
														|  | 
 |  | +        return y
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class FCCDN(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +    The FCCDN implementation based on PaddlePaddle.
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    The original article refers to
 | 
											
												
													
														|  | 
 |  | +        Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
 | 
											
												
													
														|  | 
 |  | +        (https://arxiv.org/pdf/2105.10860.pdf).
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    Args:
 | 
											
												
													
														|  | 
 |  | +        in_channels (int): Number of input channels. Default: 3.
 | 
											
												
													
														|  | 
 |  | +        num_classes (int): Number of target classes. Default: 2.
 | 
											
												
													
														|  | 
 |  | +        os (int): Number of output stride. Default: 16.
 | 
											
												
													
														|  | 
 |  | +        use_se (bool): Whether to use SEModule. Default: True.
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True):
 | 
											
												
													
														|  | 
 |  | +        super(FCCDN, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        if os >= 16:
 | 
											
												
													
														|  | 
 |  | +            dilation_list = [1, 1, 1, 1]
 | 
											
												
													
														|  | 
 |  | +            stride_list = [2, 2, 2, 2]
 | 
											
												
													
														|  | 
 |  | +            pool_list = [True, True, True, True]
 | 
											
												
													
														|  | 
 |  | +        elif os == 8:
 | 
											
												
													
														|  | 
 |  | +            dilation_list = [2, 1, 1, 1]
 | 
											
												
													
														|  | 
 |  | +            stride_list = [1, 2, 2, 2]
 | 
											
												
													
														|  | 
 |  | +            pool_list = [False, True, True, True]
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            dilation_list = [2, 2, 1, 1]
 | 
											
												
													
														|  | 
 |  | +            stride_list = [1, 1, 2, 2]
 | 
											
												
													
														|  | 
 |  | +            pool_list = [False, False, True, True]
 | 
											
												
													
														|  | 
 |  | +        se_list = [use_se, use_se, use_se, use_se]
 | 
											
												
													
														|  | 
 |  | +        channel_list = [256, 128, 64, 32]
 | 
											
												
													
														|  | 
 |  | +        # Encoder
 | 
											
												
													
														|  | 
 |  | +        self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3],
 | 
											
												
													
														|  | 
 |  | +                                 se_list[3], stride_list[3], dilation_list[3])
 | 
											
												
													
														|  | 
 |  | +        self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2],
 | 
											
												
													
														|  | 
 |  | +                                 se_list[2], stride_list[2], dilation_list[2])
 | 
											
												
													
														|  | 
 |  | +        self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1],
 | 
											
												
													
														|  | 
 |  | +                                 se_list[1], stride_list[1], dilation_list[1])
 | 
											
												
													
														|  | 
 |  | +        self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0],
 | 
											
												
													
														|  | 
 |  | +                                 se_list[0], stride_list[0], dilation_list[0])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Center
 | 
											
												
													
														|  | 
 |  | +        self.center = NLFPN(channel_list[0], True)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Decoder
 | 
											
												
													
														|  | 
 |  | +        self.decoder3 = Cat(channel_list[0],
 | 
											
												
													
														|  | 
 |  | +                            channel_list[1],
 | 
											
												
													
														|  | 
 |  | +                            channel_list[1],
 | 
											
												
													
														|  | 
 |  | +                            upsample=pool_list[0])
 | 
											
												
													
														|  | 
 |  | +        self.decoder2 = Cat(channel_list[1],
 | 
											
												
													
														|  | 
 |  | +                            channel_list[2],
 | 
											
												
													
														|  | 
 |  | +                            channel_list[2],
 | 
											
												
													
														|  | 
 |  | +                            upsample=pool_list[1])
 | 
											
												
													
														|  | 
 |  | +        self.decoder1 = Cat(channel_list[2],
 | 
											
												
													
														|  | 
 |  | +                            channel_list[3],
 | 
											
												
													
														|  | 
 |  | +                            channel_list[3],
 | 
											
												
													
														|  | 
 |  | +                            upsample=pool_list[2])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.df1 = DFModule(channel_list[3], channel_list[3], True)
 | 
											
												
													
														|  | 
 |  | +        self.df2 = DFModule(channel_list[2], channel_list[2], True)
 | 
											
												
													
														|  | 
 |  | +        self.df3 = DFModule(channel_list[1], channel_list[1], True)
 | 
											
												
													
														|  | 
 |  | +        self.df4 = DFModule(channel_list[0], channel_list[0], True)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.catc3 = Cat(channel_list[0],
 | 
											
												
													
														|  | 
 |  | +                         channel_list[1],
 | 
											
												
													
														|  | 
 |  | +                         channel_list[1],
 | 
											
												
													
														|  | 
 |  | +                         upsample=pool_list[0])
 | 
											
												
													
														|  | 
 |  | +        self.catc2 = Cat(channel_list[1],
 | 
											
												
													
														|  | 
 |  | +                         channel_list[2],
 | 
											
												
													
														|  | 
 |  | +                         channel_list[2],
 | 
											
												
													
														|  | 
 |  | +                         upsample=pool_list[1])
 | 
											
												
													
														|  | 
 |  | +        self.catc1 = Cat(channel_list[2],
 | 
											
												
													
														|  | 
 |  | +                         channel_list[3],
 | 
											
												
													
														|  | 
 |  | +                         channel_list[3],
 | 
											
												
													
														|  | 
 |  | +                         upsample=pool_list[2])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.upsample_x2 = nn.Sequential(
 | 
											
												
													
														|  | 
 |  | +            nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +                channel_list[3], 8, kernel_size=3, stride=1, padding=1),
 | 
											
												
													
														|  | 
 |  | +            nn.BatchNorm2D(
 | 
											
												
													
														|  | 
 |  | +                8, momentum=bn_mom),
 | 
											
												
													
														|  | 
 |  | +            nn.ReLU(),
 | 
											
												
													
														|  | 
 |  | +            nn.UpsamplingBilinear2D(scale_factor=2))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.conv_out = nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            8, num_classes, kernel_size=3, stride=1, padding=1)
 | 
											
												
													
														|  | 
 |  | +        self.conv_out_class = nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            channel_list[3], 1, kernel_size=1, stride=1, padding=0)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, t1, t2):
 | 
											
												
													
														|  | 
 |  | +        e1_1 = self.block1(t1)
 | 
											
												
													
														|  | 
 |  | +        e2_1 = self.block2(e1_1)
 | 
											
												
													
														|  | 
 |  | +        e3_1 = self.block3(e2_1)
 | 
											
												
													
														|  | 
 |  | +        y1 = self.block4(e3_1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        e1_2 = self.block1(t2)
 | 
											
												
													
														|  | 
 |  | +        e2_2 = self.block2(e1_2)
 | 
											
												
													
														|  | 
 |  | +        e3_2 = self.block3(e2_2)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.block4(e3_2)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        y1 = self.center(y1)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.center(y2)
 | 
											
												
													
														|  | 
 |  | +        c = self.df4(y1, y2)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        y1 = self.decoder3(y1, e3_1)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.decoder3(y2, e3_2)
 | 
											
												
													
														|  | 
 |  | +        c = self.catc3(c, self.df3(y1, y2))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        y1 = self.decoder2(y1, e2_1)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.decoder2(y2, e2_2)
 | 
											
												
													
														|  | 
 |  | +        c = self.catc2(c, self.df2(y1, y2))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        y1 = self.decoder1(y1, e1_1)
 | 
											
												
													
														|  | 
 |  | +        y2 = self.decoder1(y2, e1_2)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        c = self.catc1(c, self.df1(y1, y2))
 | 
											
												
													
														|  | 
 |  | +        y = self.conv_out(self.upsample_x2(c))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.training:
 | 
											
												
													
														|  | 
 |  | +            y1 = self.conv_out_class(y1)
 | 
											
												
													
														|  | 
 |  | +            y2 = self.conv_out_class(y2)
 | 
											
												
													
														|  | 
 |  | +            return [y, [y1, y2]]
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            return [y]
 |