|  | @@ -0,0 +1,1001 @@
 | 
											
												
													
														|  | 
 |  | +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 | 
											
												
													
														|  | 
 |  | +#
 | 
											
												
													
														|  | 
 |  | +# Licensed under the Apache License, Version 2.0 (the "License");
 | 
											
												
													
														|  | 
 |  | +# you may not use this file except in compliance with the License.
 | 
											
												
													
														|  | 
 |  | +# You may obtain a copy of the License at
 | 
											
												
													
														|  | 
 |  | +#
 | 
											
												
													
														|  | 
 |  | +#    http://www.apache.org/licenses/LICENSE-2.0
 | 
											
												
													
														|  | 
 |  | +#
 | 
											
												
													
														|  | 
 |  | +# Unless required by applicable law or agreed to in writing, software
 | 
											
												
													
														|  | 
 |  | +# distributed under the License is distributed on an "AS IS" BASIS,
 | 
											
												
													
														|  | 
 |  | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
											
												
													
														|  | 
 |  | +# See the License for the specific language governing permissions and
 | 
											
												
													
														|  | 
 |  | +# limitations under the License.
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +import warnings
 | 
											
												
													
														|  | 
 |  | +import math
 | 
											
												
													
														|  | 
 |  | +from functools import partial
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +import paddle as pd
 | 
											
												
													
														|  | 
 |  | +import paddle.nn as nn
 | 
											
												
													
														|  | 
 |  | +import paddle.nn.functional as F
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +from .layers.pd_timm import DropPath, to_2tuple
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +def calc_product(*args):
 | 
											
												
													
														|  | 
 |  | +    if len(args) < 1:
 | 
											
												
													
														|  | 
 |  | +        raise ValueError
 | 
											
												
													
														|  | 
 |  | +    ret = args[0]
 | 
											
												
													
														|  | 
 |  | +    for arg in args[1:]:
 | 
											
												
													
														|  | 
 |  | +        ret *= arg
 | 
											
												
													
														|  | 
 |  | +    return ret
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class ConvBlock(pd.nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 input_size,
 | 
											
												
													
														|  | 
 |  | +                 output_size,
 | 
											
												
													
														|  | 
 |  | +                 kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +                 stride=1,
 | 
											
												
													
														|  | 
 |  | +                 padding=1,
 | 
											
												
													
														|  | 
 |  | +                 bias=True,
 | 
											
												
													
														|  | 
 |  | +                 activation='prelu',
 | 
											
												
													
														|  | 
 |  | +                 norm=None):
 | 
											
												
													
														|  | 
 |  | +        super(ConvBlock, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv = pd.nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            input_size,
 | 
											
												
													
														|  | 
 |  | +            output_size,
 | 
											
												
													
														|  | 
 |  | +            kernel_size,
 | 
											
												
													
														|  | 
 |  | +            stride,
 | 
											
												
													
														|  | 
 |  | +            padding,
 | 
											
												
													
														|  | 
 |  | +            bias_attr=bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.norm = norm
 | 
											
												
													
														|  | 
 |  | +        if self.norm == 'batch':
 | 
											
												
													
														|  | 
 |  | +            self.bn = pd.nn.BatchNorm2D(output_size)
 | 
											
												
													
														|  | 
 |  | +        elif self.norm == 'instance':
 | 
											
												
													
														|  | 
 |  | +            self.bn = pd.nn.InstanceNorm2D(output_size)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.activation = activation
 | 
											
												
													
														|  | 
 |  | +        if self.activation == 'relu':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.ReLU(True)
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'prelu':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.PReLU()
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'lrelu':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.LeakyReLU(0.2, True)
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'tanh':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.Tanh()
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'sigmoid':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.Sigmoid()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        if self.norm is not None:
 | 
											
												
													
														|  | 
 |  | +            out = self.bn(self.conv(x))
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            out = self.conv(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.activation != 'no':
 | 
											
												
													
														|  | 
 |  | +            return self.act(out)
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DeconvBlock(pd.nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 input_size,
 | 
											
												
													
														|  | 
 |  | +                 output_size,
 | 
											
												
													
														|  | 
 |  | +                 kernel_size=4,
 | 
											
												
													
														|  | 
 |  | +                 stride=2,
 | 
											
												
													
														|  | 
 |  | +                 padding=1,
 | 
											
												
													
														|  | 
 |  | +                 bias=True,
 | 
											
												
													
														|  | 
 |  | +                 activation='prelu',
 | 
											
												
													
														|  | 
 |  | +                 norm=None):
 | 
											
												
													
														|  | 
 |  | +        super(DeconvBlock, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.deconv = pd.nn.Conv2DTranspose(
 | 
											
												
													
														|  | 
 |  | +            input_size,
 | 
											
												
													
														|  | 
 |  | +            output_size,
 | 
											
												
													
														|  | 
 |  | +            kernel_size,
 | 
											
												
													
														|  | 
 |  | +            stride,
 | 
											
												
													
														|  | 
 |  | +            padding,
 | 
											
												
													
														|  | 
 |  | +            bias_attr=bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.norm = norm
 | 
											
												
													
														|  | 
 |  | +        if self.norm == 'batch':
 | 
											
												
													
														|  | 
 |  | +            self.bn = pd.nn.BatchNorm2D(output_size)
 | 
											
												
													
														|  | 
 |  | +        elif self.norm == 'instance':
 | 
											
												
													
														|  | 
 |  | +            self.bn = pd.nn.InstanceNorm2D(output_size)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.activation = activation
 | 
											
												
													
														|  | 
 |  | +        if self.activation == 'relu':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.ReLU(True)
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'prelu':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.PReLU()
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'lrelu':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.LeakyReLU(0.2, True)
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'tanh':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.Tanh()
 | 
											
												
													
														|  | 
 |  | +        elif self.activation == 'sigmoid':
 | 
											
												
													
														|  | 
 |  | +            self.act = pd.nn.Sigmoid()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        if self.norm is not None:
 | 
											
												
													
														|  | 
 |  | +            out = self.bn(self.deconv(x))
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            out = self.deconv(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.activation is not None:
 | 
											
												
													
														|  | 
 |  | +            return self.act(out)
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class ConvLayer(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
 | 
											
												
													
														|  | 
 |  | +        super(ConvLayer, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size, stride,
 | 
											
												
													
														|  | 
 |  | +                                padding)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        out = self.conv2d(x)
 | 
											
												
													
														|  | 
 |  | +        return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class UpsampleConvLayer(pd.nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, in_channels, out_channels, kernel_size, stride):
 | 
											
												
													
														|  | 
 |  | +        super(UpsampleConvLayer, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv2d = nn.Conv2DTranspose(
 | 
											
												
													
														|  | 
 |  | +            in_channels, out_channels, kernel_size, stride=stride, padding=1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        out = self.conv2d(x)
 | 
											
												
													
														|  | 
 |  | +        return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class ResidualBlock(pd.nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, channels):
 | 
											
												
													
														|  | 
 |  | +        super(ResidualBlock, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.conv1 = ConvLayer(
 | 
											
												
													
														|  | 
 |  | +            channels, channels, kernel_size=3, stride=1, padding=1)
 | 
											
												
													
														|  | 
 |  | +        self.conv2 = ConvLayer(
 | 
											
												
													
														|  | 
 |  | +            channels, channels, kernel_size=3, stride=1, padding=1)
 | 
											
												
													
														|  | 
 |  | +        self.relu = nn.ReLU()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        residual = x
 | 
											
												
													
														|  | 
 |  | +        out = self.relu(self.conv1(x))
 | 
											
												
													
														|  | 
 |  | +        out = self.conv2(out) * 0.1
 | 
											
												
													
														|  | 
 |  | +        out = pd.add(out, residual)
 | 
											
												
													
														|  | 
 |  | +        return out
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class ChangeFormer(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +    The ChangeFormer implementation based on PaddlePaddle.
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    The original article refers to
 | 
											
												
													
														|  | 
 |  | +        Wele Gedara Chaminda Bandara, Vishal M. Patel., "A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION"
 | 
											
												
													
														|  | 
 |  | +        (https://arxiv.org/pdf/2201.01293.pdf).
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    Args:
 | 
											
												
													
														|  | 
 |  | +        in_channels (int): Number of bands of the input images. Default: 3.
 | 
											
												
													
														|  | 
 |  | +        num_classes (int): Number of target classes. Default: 2.
 | 
											
												
													
														|  | 
 |  | +        decoder_softmax (bool, optional): Use softmax after decode or not. Default: False.
 | 
											
												
													
														|  | 
 |  | +        embed_dim (int, optional): Embedding dimension of each decoder head. Default: 256.
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 in_channels=3,
 | 
											
												
													
														|  | 
 |  | +                 num_classes=2,
 | 
											
												
													
														|  | 
 |  | +                 decoder_softmax=False,
 | 
											
												
													
														|  | 
 |  | +                 embed_dim=256):
 | 
											
												
													
														|  | 
 |  | +        super(ChangeFormer, self).__init__()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Transformer Encoder
 | 
											
												
													
														|  | 
 |  | +        self.embed_dims = [64, 128, 320, 512]
 | 
											
												
													
														|  | 
 |  | +        self.depths = [3, 3, 4, 3]
 | 
											
												
													
														|  | 
 |  | +        self.embedding_dim = embed_dim
 | 
											
												
													
														|  | 
 |  | +        self.drop_rate = 0.1
 | 
											
												
													
														|  | 
 |  | +        self.attn_drop = 0.1
 | 
											
												
													
														|  | 
 |  | +        self.drop_path_rate = 0.1
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.Tenc_x2 = EncoderTransformer_v3(
 | 
											
												
													
														|  | 
 |  | +            img_size=256,
 | 
											
												
													
														|  | 
 |  | +            patch_size=7,
 | 
											
												
													
														|  | 
 |  | +            in_chans=in_channels,
 | 
											
												
													
														|  | 
 |  | +            num_classes=num_classes,
 | 
											
												
													
														|  | 
 |  | +            embed_dims=self.embed_dims,
 | 
											
												
													
														|  | 
 |  | +            num_heads=[1, 2, 4, 8],
 | 
											
												
													
														|  | 
 |  | +            mlp_ratios=[4, 4, 4, 4],
 | 
											
												
													
														|  | 
 |  | +            qkv_bias=True,
 | 
											
												
													
														|  | 
 |  | +            qk_scale=None,
 | 
											
												
													
														|  | 
 |  | +            drop_rate=self.drop_rate,
 | 
											
												
													
														|  | 
 |  | +            attn_drop_rate=self.attn_drop,
 | 
											
												
													
														|  | 
 |  | +            drop_path_rate=self.drop_path_rate,
 | 
											
												
													
														|  | 
 |  | +            norm_layer=partial(
 | 
											
												
													
														|  | 
 |  | +                nn.LayerNorm, epsilon=1e-6),
 | 
											
												
													
														|  | 
 |  | +            depths=self.depths,
 | 
											
												
													
														|  | 
 |  | +            sr_ratios=[8, 4, 2, 1])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Transformer Decoder
 | 
											
												
													
														|  | 
 |  | +        self.TDec_x2 = DecoderTransformer_v3(
 | 
											
												
													
														|  | 
 |  | +            input_transform='multiple_select',
 | 
											
												
													
														|  | 
 |  | +            in_index=[0, 1, 2, 3],
 | 
											
												
													
														|  | 
 |  | +            align_corners=False,
 | 
											
												
													
														|  | 
 |  | +            in_channels=self.embed_dims,
 | 
											
												
													
														|  | 
 |  | +            embedding_dim=self.embedding_dim,
 | 
											
												
													
														|  | 
 |  | +            output_nc=num_classes,
 | 
											
												
													
														|  | 
 |  | +            decoder_softmax=decoder_softmax,
 | 
											
												
													
														|  | 
 |  | +            feature_strides=[2, 4, 8, 16])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x1, x2):
 | 
											
												
													
														|  | 
 |  | +        [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        cp = self.TDec_x2(fx1, fx2)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return [cp]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +# Transormer Ecoder with x2, x4, x8, x16 scales
 | 
											
												
													
														|  | 
 |  | +class EncoderTransformer_v3(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 img_size=256,
 | 
											
												
													
														|  | 
 |  | +                 patch_size=3,
 | 
											
												
													
														|  | 
 |  | +                 in_chans=3,
 | 
											
												
													
														|  | 
 |  | +                 num_classes=2,
 | 
											
												
													
														|  | 
 |  | +                 embed_dims=[32, 64, 128, 256],
 | 
											
												
													
														|  | 
 |  | +                 num_heads=[2, 2, 4, 8],
 | 
											
												
													
														|  | 
 |  | +                 mlp_ratios=[4, 4, 4, 4],
 | 
											
												
													
														|  | 
 |  | +                 qkv_bias=True,
 | 
											
												
													
														|  | 
 |  | +                 qk_scale=None,
 | 
											
												
													
														|  | 
 |  | +                 drop_rate=0.,
 | 
											
												
													
														|  | 
 |  | +                 attn_drop_rate=0.,
 | 
											
												
													
														|  | 
 |  | +                 drop_path_rate=0.,
 | 
											
												
													
														|  | 
 |  | +                 norm_layer=nn.LayerNorm,
 | 
											
												
													
														|  | 
 |  | +                 depths=[3, 3, 6, 18],
 | 
											
												
													
														|  | 
 |  | +                 sr_ratios=[8, 4, 2, 1]):
 | 
											
												
													
														|  | 
 |  | +        super().__init__()
 | 
											
												
													
														|  | 
 |  | +        self.num_classes = num_classes
 | 
											
												
													
														|  | 
 |  | +        self.depths = depths
 | 
											
												
													
														|  | 
 |  | +        self.embed_dims = embed_dims
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # patch embedding definitions
 | 
											
												
													
														|  | 
 |  | +        self.patch_embed1 = OverlapPatchEmbed(
 | 
											
												
													
														|  | 
 |  | +            img_size=img_size,
 | 
											
												
													
														|  | 
 |  | +            patch_size=7,
 | 
											
												
													
														|  | 
 |  | +            stride=4,
 | 
											
												
													
														|  | 
 |  | +            in_chans=in_chans,
 | 
											
												
													
														|  | 
 |  | +            embed_dim=embed_dims[0])
 | 
											
												
													
														|  | 
 |  | +        self.patch_embed2 = OverlapPatchEmbed(
 | 
											
												
													
														|  | 
 |  | +            img_size=img_size // 4,
 | 
											
												
													
														|  | 
 |  | +            patch_size=patch_size,
 | 
											
												
													
														|  | 
 |  | +            stride=2,
 | 
											
												
													
														|  | 
 |  | +            in_chans=embed_dims[0],
 | 
											
												
													
														|  | 
 |  | +            embed_dim=embed_dims[1])
 | 
											
												
													
														|  | 
 |  | +        self.patch_embed3 = OverlapPatchEmbed(
 | 
											
												
													
														|  | 
 |  | +            img_size=img_size // 8,
 | 
											
												
													
														|  | 
 |  | +            patch_size=patch_size,
 | 
											
												
													
														|  | 
 |  | +            stride=2,
 | 
											
												
													
														|  | 
 |  | +            in_chans=embed_dims[1],
 | 
											
												
													
														|  | 
 |  | +            embed_dim=embed_dims[2])
 | 
											
												
													
														|  | 
 |  | +        self.patch_embed4 = OverlapPatchEmbed(
 | 
											
												
													
														|  | 
 |  | +            img_size=img_size // 16,
 | 
											
												
													
														|  | 
 |  | +            patch_size=patch_size,
 | 
											
												
													
														|  | 
 |  | +            stride=2,
 | 
											
												
													
														|  | 
 |  | +            in_chans=embed_dims[2],
 | 
											
												
													
														|  | 
 |  | +            embed_dim=embed_dims[3])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage-1 (x1/4 scale)
 | 
											
												
													
														|  | 
 |  | +        dpr = [x.item() for x in pd.linspace(0, drop_path_rate, sum(depths))]
 | 
											
												
													
														|  | 
 |  | +        cur = 0
 | 
											
												
													
														|  | 
 |  | +        self.block1 = nn.LayerList([
 | 
											
												
													
														|  | 
 |  | +            Block(
 | 
											
												
													
														|  | 
 |  | +                dim=embed_dims[0],
 | 
											
												
													
														|  | 
 |  | +                num_heads=num_heads[0],
 | 
											
												
													
														|  | 
 |  | +                mlp_ratio=mlp_ratios[0],
 | 
											
												
													
														|  | 
 |  | +                qkv_bias=qkv_bias,
 | 
											
												
													
														|  | 
 |  | +                qk_scale=qk_scale,
 | 
											
												
													
														|  | 
 |  | +                drop=drop_rate,
 | 
											
												
													
														|  | 
 |  | +                attn_drop=attn_drop_rate,
 | 
											
												
													
														|  | 
 |  | +                drop_path=dpr[cur + i],
 | 
											
												
													
														|  | 
 |  | +                norm_layer=norm_layer,
 | 
											
												
													
														|  | 
 |  | +                sr_ratio=sr_ratios[0]) for i in range(depths[0])
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        self.norm1 = norm_layer(embed_dims[0])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage-2 (x1/8 scale)
 | 
											
												
													
														|  | 
 |  | +        cur += depths[0]
 | 
											
												
													
														|  | 
 |  | +        self.block2 = nn.LayerList([
 | 
											
												
													
														|  | 
 |  | +            Block(
 | 
											
												
													
														|  | 
 |  | +                dim=embed_dims[1],
 | 
											
												
													
														|  | 
 |  | +                num_heads=num_heads[1],
 | 
											
												
													
														|  | 
 |  | +                mlp_ratio=mlp_ratios[1],
 | 
											
												
													
														|  | 
 |  | +                qkv_bias=qkv_bias,
 | 
											
												
													
														|  | 
 |  | +                qk_scale=qk_scale,
 | 
											
												
													
														|  | 
 |  | +                drop=drop_rate,
 | 
											
												
													
														|  | 
 |  | +                attn_drop=attn_drop_rate,
 | 
											
												
													
														|  | 
 |  | +                drop_path=dpr[cur + i],
 | 
											
												
													
														|  | 
 |  | +                norm_layer=norm_layer,
 | 
											
												
													
														|  | 
 |  | +                sr_ratio=sr_ratios[1]) for i in range(depths[1])
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        self.norm2 = norm_layer(embed_dims[1])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage-3 (x1/16 scale)
 | 
											
												
													
														|  | 
 |  | +        cur += depths[1]
 | 
											
												
													
														|  | 
 |  | +        self.block3 = nn.LayerList([
 | 
											
												
													
														|  | 
 |  | +            Block(
 | 
											
												
													
														|  | 
 |  | +                dim=embed_dims[2],
 | 
											
												
													
														|  | 
 |  | +                num_heads=num_heads[2],
 | 
											
												
													
														|  | 
 |  | +                mlp_ratio=mlp_ratios[2],
 | 
											
												
													
														|  | 
 |  | +                qkv_bias=qkv_bias,
 | 
											
												
													
														|  | 
 |  | +                qk_scale=qk_scale,
 | 
											
												
													
														|  | 
 |  | +                drop=drop_rate,
 | 
											
												
													
														|  | 
 |  | +                attn_drop=attn_drop_rate,
 | 
											
												
													
														|  | 
 |  | +                drop_path=dpr[cur + i],
 | 
											
												
													
														|  | 
 |  | +                norm_layer=norm_layer,
 | 
											
												
													
														|  | 
 |  | +                sr_ratio=sr_ratios[2]) for i in range(depths[2])
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        self.norm3 = norm_layer(embed_dims[2])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage-4 (x1/32 scale)
 | 
											
												
													
														|  | 
 |  | +        cur += depths[2]
 | 
											
												
													
														|  | 
 |  | +        self.block4 = nn.LayerList([
 | 
											
												
													
														|  | 
 |  | +            Block(
 | 
											
												
													
														|  | 
 |  | +                dim=embed_dims[3],
 | 
											
												
													
														|  | 
 |  | +                num_heads=num_heads[3],
 | 
											
												
													
														|  | 
 |  | +                mlp_ratio=mlp_ratios[3],
 | 
											
												
													
														|  | 
 |  | +                qkv_bias=qkv_bias,
 | 
											
												
													
														|  | 
 |  | +                qk_scale=qk_scale,
 | 
											
												
													
														|  | 
 |  | +                drop=drop_rate,
 | 
											
												
													
														|  | 
 |  | +                attn_drop=attn_drop_rate,
 | 
											
												
													
														|  | 
 |  | +                drop_path=dpr[cur + i],
 | 
											
												
													
														|  | 
 |  | +                norm_layer=norm_layer,
 | 
											
												
													
														|  | 
 |  | +                sr_ratio=sr_ratios[3]) for i in range(depths[3])
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        self.norm4 = norm_layer(embed_dims[3])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.apply(self._init_weights)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _init_weights(self, m):
 | 
											
												
													
														|  | 
 |  | +        if isinstance(m, nn.Linear):
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op(m.weight)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            if isinstance(m, nn.Linear) and m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.LayerNorm):
 | 
											
												
													
														|  | 
 |  | +            init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +            init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Constant(1.0)
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.Conv2D):
 | 
											
												
													
														|  | 
 |  | +            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
 | 
											
												
													
														|  | 
 |  | +            fan_out //= m._groups
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            if m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def reset_drop_path(self, drop_path_rate):
 | 
											
												
													
														|  | 
 |  | +        dpr = [
 | 
											
												
													
														|  | 
 |  | +            x.item() for x in pd.linspace(0, drop_path_rate, sum(self.depths))
 | 
											
												
													
														|  | 
 |  | +        ]
 | 
											
												
													
														|  | 
 |  | +        cur = 0
 | 
											
												
													
														|  | 
 |  | +        for i in range(self.depths[0]):
 | 
											
												
													
														|  | 
 |  | +            self.block1[i].drop_path.drop_prob = dpr[cur + i]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        cur += self.depths[0]
 | 
											
												
													
														|  | 
 |  | +        for i in range(self.depths[1]):
 | 
											
												
													
														|  | 
 |  | +            self.block2[i].drop_path.drop_prob = dpr[cur + i]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        cur += self.depths[1]
 | 
											
												
													
														|  | 
 |  | +        for i in range(self.depths[2]):
 | 
											
												
													
														|  | 
 |  | +            self.block3[i].drop_path.drop_prob = dpr[cur + i]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        cur += self.depths[2]
 | 
											
												
													
														|  | 
 |  | +        for i in range(self.depths[3]):
 | 
											
												
													
														|  | 
 |  | +            self.block4[i].drop_path.drop_prob = dpr[cur + i]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward_features(self, x):
 | 
											
												
													
														|  | 
 |  | +        B = x.shape[0]
 | 
											
												
													
														|  | 
 |  | +        outs = []
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # stage 1
 | 
											
												
													
														|  | 
 |  | +        x1, H1, W1 = self.patch_embed1(x)
 | 
											
												
													
														|  | 
 |  | +        for i, blk in enumerate(self.block1):
 | 
											
												
													
														|  | 
 |  | +            x1 = blk(x1, H1, W1)
 | 
											
												
													
														|  | 
 |  | +        x1 = self.norm1(x1)
 | 
											
												
													
														|  | 
 |  | +        x1 = x1.reshape(
 | 
											
												
													
														|  | 
 |  | +            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
 | 
											
												
													
														|  | 
 |  | +                [0, 3, 1, 2])
 | 
											
												
													
														|  | 
 |  | +        outs.append(x1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # stage 2
 | 
											
												
													
														|  | 
 |  | +        x1, H1, W1 = self.patch_embed2(x1)
 | 
											
												
													
														|  | 
 |  | +        for i, blk in enumerate(self.block2):
 | 
											
												
													
														|  | 
 |  | +            x1 = blk(x1, H1, W1)
 | 
											
												
													
														|  | 
 |  | +        x1 = self.norm2(x1)
 | 
											
												
													
														|  | 
 |  | +        x1 = x1.reshape(
 | 
											
												
													
														|  | 
 |  | +            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
 | 
											
												
													
														|  | 
 |  | +                [0, 3, 1, 2])
 | 
											
												
													
														|  | 
 |  | +        outs.append(x1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # stage 3
 | 
											
												
													
														|  | 
 |  | +        x1, H1, W1 = self.patch_embed3(x1)
 | 
											
												
													
														|  | 
 |  | +        for i, blk in enumerate(self.block3):
 | 
											
												
													
														|  | 
 |  | +            x1 = blk(x1, H1, W1)
 | 
											
												
													
														|  | 
 |  | +        x1 = self.norm3(x1)
 | 
											
												
													
														|  | 
 |  | +        x1 = x1.reshape(
 | 
											
												
													
														|  | 
 |  | +            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
 | 
											
												
													
														|  | 
 |  | +                [0, 3, 1, 2])
 | 
											
												
													
														|  | 
 |  | +        outs.append(x1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # stage 4
 | 
											
												
													
														|  | 
 |  | +        x1, H1, W1 = self.patch_embed4(x1)
 | 
											
												
													
														|  | 
 |  | +        for i, blk in enumerate(self.block4):
 | 
											
												
													
														|  | 
 |  | +            x1 = blk(x1, H1, W1)
 | 
											
												
													
														|  | 
 |  | +        x1 = self.norm4(x1)
 | 
											
												
													
														|  | 
 |  | +        x1 = x1.reshape(
 | 
											
												
													
														|  | 
 |  | +            [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
 | 
											
												
													
														|  | 
 |  | +                [0, 3, 1, 2])
 | 
											
												
													
														|  | 
 |  | +        outs.append(x1)
 | 
											
												
													
														|  | 
 |  | +        return outs
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        x = self.forward_features(x)
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DecoderTransformer_v3(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +    Transformer Decoder
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 input_transform='multiple_select',
 | 
											
												
													
														|  | 
 |  | +                 in_index=[0, 1, 2, 3],
 | 
											
												
													
														|  | 
 |  | +                 align_corners=True,
 | 
											
												
													
														|  | 
 |  | +                 in_channels=[32, 64, 128, 256],
 | 
											
												
													
														|  | 
 |  | +                 embedding_dim=64,
 | 
											
												
													
														|  | 
 |  | +                 output_nc=2,
 | 
											
												
													
														|  | 
 |  | +                 decoder_softmax=False,
 | 
											
												
													
														|  | 
 |  | +                 feature_strides=[2, 4, 8, 16]):
 | 
											
												
													
														|  | 
 |  | +        super(DecoderTransformer_v3, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        # assert
 | 
											
												
													
														|  | 
 |  | +        assert len(feature_strides) == len(in_channels)
 | 
											
												
													
														|  | 
 |  | +        assert min(feature_strides) == feature_strides[0]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # settings
 | 
											
												
													
														|  | 
 |  | +        self.feature_strides = feature_strides
 | 
											
												
													
														|  | 
 |  | +        self.input_transform = input_transform
 | 
											
												
													
														|  | 
 |  | +        self.in_index = in_index
 | 
											
												
													
														|  | 
 |  | +        self.align_corners = align_corners
 | 
											
												
													
														|  | 
 |  | +        self.in_channels = in_channels
 | 
											
												
													
														|  | 
 |  | +        self.embedding_dim = embedding_dim
 | 
											
												
													
														|  | 
 |  | +        self.output_nc = output_nc
 | 
											
												
													
														|  | 
 |  | +        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # MLP decoder heads
 | 
											
												
													
														|  | 
 |  | +        self.linear_c4 = MLP(input_dim=c4_in_channels,
 | 
											
												
													
														|  | 
 |  | +                             embed_dim=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +        self.linear_c3 = MLP(input_dim=c3_in_channels,
 | 
											
												
													
														|  | 
 |  | +                             embed_dim=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +        self.linear_c2 = MLP(input_dim=c2_in_channels,
 | 
											
												
													
														|  | 
 |  | +                             embed_dim=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +        self.linear_c1 = MLP(input_dim=c1_in_channels,
 | 
											
												
													
														|  | 
 |  | +                             embed_dim=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # convolutional Difference Layers
 | 
											
												
													
														|  | 
 |  | +        self.diff_c4 = conv_diff(
 | 
											
												
													
														|  | 
 |  | +            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +        self.diff_c3 = conv_diff(
 | 
											
												
													
														|  | 
 |  | +            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +        self.diff_c2 = conv_diff(
 | 
											
												
													
														|  | 
 |  | +            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +        self.diff_c1 = conv_diff(
 | 
											
												
													
														|  | 
 |  | +            in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # taking outputs from middle of the encoder
 | 
											
												
													
														|  | 
 |  | +        self.make_pred_c4 = make_prediction(
 | 
											
												
													
														|  | 
 |  | +            in_channels=self.embedding_dim, out_channels=self.output_nc)
 | 
											
												
													
														|  | 
 |  | +        self.make_pred_c3 = make_prediction(
 | 
											
												
													
														|  | 
 |  | +            in_channels=self.embedding_dim, out_channels=self.output_nc)
 | 
											
												
													
														|  | 
 |  | +        self.make_pred_c2 = make_prediction(
 | 
											
												
													
														|  | 
 |  | +            in_channels=self.embedding_dim, out_channels=self.output_nc)
 | 
											
												
													
														|  | 
 |  | +        self.make_pred_c1 = make_prediction(
 | 
											
												
													
														|  | 
 |  | +            in_channels=self.embedding_dim, out_channels=self.output_nc)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Final linear fusion layer
 | 
											
												
													
														|  | 
 |  | +        self.linear_fuse = nn.Sequential(
 | 
											
												
													
														|  | 
 |  | +            nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +                in_channels=self.embedding_dim * len(in_channels),
 | 
											
												
													
														|  | 
 |  | +                out_channels=self.embedding_dim,
 | 
											
												
													
														|  | 
 |  | +                kernel_size=1),
 | 
											
												
													
														|  | 
 |  | +            nn.BatchNorm2D(self.embedding_dim))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Final predction head
 | 
											
												
													
														|  | 
 |  | +        self.convd2x = UpsampleConvLayer(
 | 
											
												
													
														|  | 
 |  | +            self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
 | 
											
												
													
														|  | 
 |  | +        self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
 | 
											
												
													
														|  | 
 |  | +        self.convd1x = UpsampleConvLayer(
 | 
											
												
													
														|  | 
 |  | +            self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
 | 
											
												
													
														|  | 
 |  | +        self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
 | 
											
												
													
														|  | 
 |  | +        self.change_probability = ConvLayer(
 | 
											
												
													
														|  | 
 |  | +            self.embedding_dim,
 | 
											
												
													
														|  | 
 |  | +            self.output_nc,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=3,
 | 
											
												
													
														|  | 
 |  | +            stride=1,
 | 
											
												
													
														|  | 
 |  | +            padding=1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Final activation
 | 
											
												
													
														|  | 
 |  | +        self.output_softmax = decoder_softmax
 | 
											
												
													
														|  | 
 |  | +        self.active = nn.Sigmoid()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _transform_inputs(self, inputs):
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Transform inputs for decoder.
 | 
											
												
													
														|  | 
 |  | +        Args:
 | 
											
												
													
														|  | 
 |  | +            inputs (list[Tensor]): List of multi-level img features.
 | 
											
												
													
														|  | 
 |  | +        Returns:
 | 
											
												
													
														|  | 
 |  | +            Tensor: The transformed inputs
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.input_transform == 'resize_concat':
 | 
											
												
													
														|  | 
 |  | +            inputs = [inputs[i] for i in self.in_index]
 | 
											
												
													
														|  | 
 |  | +            upsampled_inputs = [
 | 
											
												
													
														|  | 
 |  | +                resize(
 | 
											
												
													
														|  | 
 |  | +                    input=x,
 | 
											
												
													
														|  | 
 |  | +                    size=inputs[0].shape[2:],
 | 
											
												
													
														|  | 
 |  | +                    mode='bilinear',
 | 
											
												
													
														|  | 
 |  | +                    align_corners=self.align_corners) for x in inputs
 | 
											
												
													
														|  | 
 |  | +            ]
 | 
											
												
													
														|  | 
 |  | +            inputs = pd.concat(upsampled_inputs, dim=1)
 | 
											
												
													
														|  | 
 |  | +        elif self.input_transform == 'multiple_select':
 | 
											
												
													
														|  | 
 |  | +            inputs = [inputs[i] for i in self.in_index]
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            inputs = inputs[self.in_index]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return inputs
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, inputs1, inputs2):
 | 
											
												
													
														|  | 
 |  | +        # Transforming encoder features (select layers)
 | 
											
												
													
														|  | 
 |  | +        x_1 = self._transform_inputs(inputs1)  # len=4, 1/2, 1/4, 1/8, 1/16
 | 
											
												
													
														|  | 
 |  | +        x_2 = self._transform_inputs(inputs2)  # len=4, 1/2, 1/4, 1/8, 1/16
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # img1 and img2 features
 | 
											
												
													
														|  | 
 |  | +        c1_1, c2_1, c3_1, c4_1 = x_1
 | 
											
												
													
														|  | 
 |  | +        c1_2, c2_2, c3_2, c4_2 = x_2
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        ############## MLP decoder on C1-C4 ###########
 | 
											
												
													
														|  | 
 |  | +        n, _, h, w = c4_1.shape
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        outputs = []
 | 
											
												
													
														|  | 
 |  | +        # Stage 4: x1/32 scale
 | 
											
												
													
														|  | 
 |  | +        _c4_1 = self.linear_c4(c4_1).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c4_1 = _c4_1.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c4_1.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c4_1.shape[2] * c4_1.shape[3]), c4_1.shape[2], c4_1.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c4_2 = self.linear_c4(c4_2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c4_2 = _c4_2.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c4_2.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c4_2.shape[2] * c4_2.shape[3]), c4_2.shape[2], c4_2.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c4 = self.diff_c4(pd.concat((_c4_1, _c4_2), axis=1))
 | 
											
												
													
														|  | 
 |  | +        p_c4 = self.make_pred_c4(_c4)
 | 
											
												
													
														|  | 
 |  | +        outputs.append(p_c4)
 | 
											
												
													
														|  | 
 |  | +        _c4_up = resize(
 | 
											
												
													
														|  | 
 |  | +            _c4, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage 3: x1/16 scale
 | 
											
												
													
														|  | 
 |  | +        _c3_1 = self.linear_c3(c3_1).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c3_1 = _c3_1.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c3_1.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c3_1.shape[2] * c3_1.shape[3]), c3_1.shape[2], c3_1.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c3_2 = self.linear_c3(c3_2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c3_2 = _c3_2.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c3_2.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c3_2.shape[2] * c3_2.shape[3]), c3_2.shape[2], c3_2.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c3 = self.diff_c3(pd.concat((_c3_1, _c3_2), axis=1)) + \
 | 
											
												
													
														|  | 
 |  | +            F.interpolate(_c4, scale_factor=2, mode="bilinear")
 | 
											
												
													
														|  | 
 |  | +        p_c3 = self.make_pred_c3(_c3)
 | 
											
												
													
														|  | 
 |  | +        outputs.append(p_c3)
 | 
											
												
													
														|  | 
 |  | +        _c3_up = resize(
 | 
											
												
													
														|  | 
 |  | +            _c3, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage 2: x1/8 scale
 | 
											
												
													
														|  | 
 |  | +        _c2_1 = self.linear_c2(c2_1).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c2_1 = _c2_1.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c2_1.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c2_1.shape[2] * c2_1.shape[3]), c2_1.shape[2], c2_1.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c2_2 = self.linear_c2(c2_2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c2_2 = _c2_2.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c2_2.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c2_2.shape[2] * c2_2.shape[3]), c2_2.shape[2], c2_2.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c2 = self.diff_c2(pd.concat((_c2_1, _c2_2), axis=1)) + \
 | 
											
												
													
														|  | 
 |  | +            F.interpolate(_c3, scale_factor=2, mode="bilinear")
 | 
											
												
													
														|  | 
 |  | +        p_c2 = self.make_pred_c2(_c2)
 | 
											
												
													
														|  | 
 |  | +        outputs.append(p_c2)
 | 
											
												
													
														|  | 
 |  | +        _c2_up = resize(
 | 
											
												
													
														|  | 
 |  | +            _c2, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Stage 1: x1/4 scale
 | 
											
												
													
														|  | 
 |  | +        _c1_1 = self.linear_c1(c1_1).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c1_1 = _c1_1.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c1_1.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c1_1.shape[2] * c1_1.shape[3]), c1_1.shape[2], c1_1.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c1_2 = self.linear_c1(c1_2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        _c1_2 = _c1_2.reshape([
 | 
											
												
													
														|  | 
 |  | +            n, calc_product(*_c1_2.shape[1:]) //
 | 
											
												
													
														|  | 
 |  | +            (c1_2.shape[2] * c1_2.shape[3]), c1_2.shape[2], c1_2.shape[3]
 | 
											
												
													
														|  | 
 |  | +        ])
 | 
											
												
													
														|  | 
 |  | +        _c1 = self.diff_c1(pd.concat((_c1_1, _c1_2), axis=1)) + \
 | 
											
												
													
														|  | 
 |  | +            F.interpolate(_c2, scale_factor=2, mode="bilinear")
 | 
											
												
													
														|  | 
 |  | +        p_c1 = self.make_pred_c1(_c1)
 | 
											
												
													
														|  | 
 |  | +        outputs.append(p_c1)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Linear Fusion of difference image from all scales
 | 
											
												
													
														|  | 
 |  | +        _c = self.linear_fuse(pd.concat((_c4_up, _c3_up, _c2_up, _c1), axis=1))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Upsampling x2 (x1/2 scale)
 | 
											
												
													
														|  | 
 |  | +        x = self.convd2x(_c)
 | 
											
												
													
														|  | 
 |  | +        # Residual block
 | 
											
												
													
														|  | 
 |  | +        x = self.dense_2x(x)
 | 
											
												
													
														|  | 
 |  | +        # Upsampling x2 (x1 scale)
 | 
											
												
													
														|  | 
 |  | +        x = self.convd1x(x)
 | 
											
												
													
														|  | 
 |  | +        # Residual block
 | 
											
												
													
														|  | 
 |  | +        x = self.dense_1x(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Final prediction
 | 
											
												
													
														|  | 
 |  | +        cp = self.change_probability(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        outputs.append(cp)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.output_softmax:
 | 
											
												
													
														|  | 
 |  | +            temp = outputs
 | 
											
												
													
														|  | 
 |  | +            outputs = []
 | 
											
												
													
														|  | 
 |  | +            for pred in temp:
 | 
											
												
													
														|  | 
 |  | +                outputs.append(self.active(pred))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return outputs[-1]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class OverlapPatchEmbed(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """ 
 | 
											
												
													
														|  | 
 |  | +    Image to Patch Embedding
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 img_size=224,
 | 
											
												
													
														|  | 
 |  | +                 patch_size=7,
 | 
											
												
													
														|  | 
 |  | +                 stride=4,
 | 
											
												
													
														|  | 
 |  | +                 in_chans=3,
 | 
											
												
													
														|  | 
 |  | +                 embed_dim=768):
 | 
											
												
													
														|  | 
 |  | +        super().__init__()
 | 
											
												
													
														|  | 
 |  | +        img_size = to_2tuple(img_size)
 | 
											
												
													
														|  | 
 |  | +        patch_size = to_2tuple(patch_size)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.img_size = img_size
 | 
											
												
													
														|  | 
 |  | +        self.patch_size = patch_size
 | 
											
												
													
														|  | 
 |  | +        self.H, self.W = img_size[0] // patch_size[0], img_size[
 | 
											
												
													
														|  | 
 |  | +            1] // patch_size[1]
 | 
											
												
													
														|  | 
 |  | +        self.num_patches = self.H * self.W
 | 
											
												
													
														|  | 
 |  | +        self.proj = nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            in_chans,
 | 
											
												
													
														|  | 
 |  | +            embed_dim,
 | 
											
												
													
														|  | 
 |  | +            kernel_size=patch_size,
 | 
											
												
													
														|  | 
 |  | +            stride=stride,
 | 
											
												
													
														|  | 
 |  | +            padding=(patch_size[0] // 2, patch_size[1] // 2))
 | 
											
												
													
														|  | 
 |  | +        self.norm = nn.LayerNorm(embed_dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.apply(self._init_weights)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _init_weights(self, m):
 | 
											
												
													
														|  | 
 |  | +        if isinstance(m, nn.Linear):
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if isinstance(m, nn.Linear) and m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.LayerNorm):
 | 
											
												
													
														|  | 
 |  | +            init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +            init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Constant(1.0)
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.Conv2D):
 | 
											
												
													
														|  | 
 |  | +            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
 | 
											
												
													
														|  | 
 |  | +            fan_out //= m._groups
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        x = self.proj(x)
 | 
											
												
													
														|  | 
 |  | +        _, _, H, W = x.shape
 | 
											
												
													
														|  | 
 |  | +        x = x.flatten(2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        x = self.norm(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return x, H, W
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +def resize(input,
 | 
											
												
													
														|  | 
 |  | +           size=None,
 | 
											
												
													
														|  | 
 |  | +           scale_factor=None,
 | 
											
												
													
														|  | 
 |  | +           mode='nearest',
 | 
											
												
													
														|  | 
 |  | +           align_corners=None,
 | 
											
												
													
														|  | 
 |  | +           warning=True):
 | 
											
												
													
														|  | 
 |  | +    if warning:
 | 
											
												
													
														|  | 
 |  | +        if size is not None and align_corners:
 | 
											
												
													
														|  | 
 |  | +            input_h, input_w = tuple(int(x) for x in input.shape[2:])
 | 
											
												
													
														|  | 
 |  | +            output_h, output_w = tuple(int(x) for x in size)
 | 
											
												
													
														|  | 
 |  | +            if output_h > input_h or output_w > output_h:
 | 
											
												
													
														|  | 
 |  | +                if ((output_h > 1 and output_w > 1 and input_h > 1 and
 | 
											
												
													
														|  | 
 |  | +                     input_w > 1) and (output_h - 1) % (input_h - 1) and
 | 
											
												
													
														|  | 
 |  | +                    (output_w - 1) % (input_w - 1)):
 | 
											
												
													
														|  | 
 |  | +                    warnings.warn(
 | 
											
												
													
														|  | 
 |  | +                        f'When align_corners={align_corners}, '
 | 
											
												
													
														|  | 
 |  | +                        'the output would more aligned if '
 | 
											
												
													
														|  | 
 |  | +                        f'input size {(input_h, input_w)} is `x+1` and '
 | 
											
												
													
														|  | 
 |  | +                        f'out size {(output_h, output_w)} is `nx+1`')
 | 
											
												
													
														|  | 
 |  | +    return F.interpolate(input, size, scale_factor, mode, align_corners)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class Mlp(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 in_features,
 | 
											
												
													
														|  | 
 |  | +                 hidden_features=None,
 | 
											
												
													
														|  | 
 |  | +                 out_features=None,
 | 
											
												
													
														|  | 
 |  | +                 act_layer=nn.GELU,
 | 
											
												
													
														|  | 
 |  | +                 drop=0.):
 | 
											
												
													
														|  | 
 |  | +        super().__init__()
 | 
											
												
													
														|  | 
 |  | +        out_features = out_features or in_features
 | 
											
												
													
														|  | 
 |  | +        hidden_features = hidden_features or in_features
 | 
											
												
													
														|  | 
 |  | +        self.fc1 = nn.Linear(in_features, hidden_features)
 | 
											
												
													
														|  | 
 |  | +        self.dwconv = DWConv(hidden_features)
 | 
											
												
													
														|  | 
 |  | +        self.act = act_layer()
 | 
											
												
													
														|  | 
 |  | +        self.fc2 = nn.Linear(hidden_features, out_features)
 | 
											
												
													
														|  | 
 |  | +        self.drop = nn.Dropout(drop)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.apply(self._init_weights)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _init_weights(self, m):
 | 
											
												
													
														|  | 
 |  | +        if isinstance(m, nn.Linear):
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if isinstance(m, nn.Linear) and m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.LayerNorm):
 | 
											
												
													
														|  | 
 |  | +            init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +            init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Constant(1.0)
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.Conv2D):
 | 
											
												
													
														|  | 
 |  | +            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
 | 
											
												
													
														|  | 
 |  | +            fan_out //= m._groups
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, H, W):
 | 
											
												
													
														|  | 
 |  | +        x = self.fc1(x)
 | 
											
												
													
														|  | 
 |  | +        x = self.dwconv(x, H, W)
 | 
											
												
													
														|  | 
 |  | +        x = self.act(x)
 | 
											
												
													
														|  | 
 |  | +        x = self.drop(x)
 | 
											
												
													
														|  | 
 |  | +        x = self.fc2(x)
 | 
											
												
													
														|  | 
 |  | +        x = self.drop(x)
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class Attention(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 dim,
 | 
											
												
													
														|  | 
 |  | +                 num_heads=8,
 | 
											
												
													
														|  | 
 |  | +                 qkv_bias=False,
 | 
											
												
													
														|  | 
 |  | +                 qk_scale=None,
 | 
											
												
													
														|  | 
 |  | +                 attn_drop=0.,
 | 
											
												
													
														|  | 
 |  | +                 proj_drop=0.,
 | 
											
												
													
														|  | 
 |  | +                 sr_ratio=1):
 | 
											
												
													
														|  | 
 |  | +        super().__init__()
 | 
											
												
													
														|  | 
 |  | +        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.dim = dim
 | 
											
												
													
														|  | 
 |  | +        self.num_heads = num_heads
 | 
											
												
													
														|  | 
 |  | +        head_dim = dim // num_heads
 | 
											
												
													
														|  | 
 |  | +        self.scale = qk_scale or head_dim**-0.5
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
 | 
											
												
													
														|  | 
 |  | +        self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
 | 
											
												
													
														|  | 
 |  | +        self.attn_drop = nn.Dropout(attn_drop)
 | 
											
												
													
														|  | 
 |  | +        self.proj = nn.Linear(dim, dim)
 | 
											
												
													
														|  | 
 |  | +        self.proj_drop = nn.Dropout(proj_drop)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.sr_ratio = sr_ratio
 | 
											
												
													
														|  | 
 |  | +        if sr_ratio > 1:
 | 
											
												
													
														|  | 
 |  | +            self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
 | 
											
												
													
														|  | 
 |  | +            self.norm = nn.LayerNorm(dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.apply(self._init_weights)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _init_weights(self, m):
 | 
											
												
													
														|  | 
 |  | +        if isinstance(m, nn.Linear):
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if isinstance(m, nn.Linear) and m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.LayerNorm):
 | 
											
												
													
														|  | 
 |  | +            init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +            init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Constant(1.0)
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.Conv2D):
 | 
											
												
													
														|  | 
 |  | +            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
 | 
											
												
													
														|  | 
 |  | +            fan_out //= m._groups
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, H, W):
 | 
											
												
													
														|  | 
 |  | +        B, N, C = x.shape
 | 
											
												
													
														|  | 
 |  | +        q = self.q(x).reshape([B, N, self.num_heads,
 | 
											
												
													
														|  | 
 |  | +                               C // self.num_heads]).transpose([0, 2, 1, 3])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if self.sr_ratio > 1:
 | 
											
												
													
														|  | 
 |  | +            x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
 | 
											
												
													
														|  | 
 |  | +            x_ = self.sr(x_)
 | 
											
												
													
														|  | 
 |  | +            x_ = x_.reshape([B, C, calc_product(*x_.shape[2:])]).transpose(
 | 
											
												
													
														|  | 
 |  | +                [0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +            x_ = self.norm(x_)
 | 
											
												
													
														|  | 
 |  | +            kv = self.kv(x_)
 | 
											
												
													
														|  | 
 |  | +            kv = kv.reshape([
 | 
											
												
													
														|  | 
 |  | +                B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
 | 
											
												
													
														|  | 
 |  | +                C // self.num_heads
 | 
											
												
													
														|  | 
 |  | +            ]).transpose([2, 0, 3, 1, 4])
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            kv = self.kv(x)
 | 
											
												
													
														|  | 
 |  | +            kv = kv.reshape([
 | 
											
												
													
														|  | 
 |  | +                B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
 | 
											
												
													
														|  | 
 |  | +                C // self.num_heads
 | 
											
												
													
														|  | 
 |  | +            ]).transpose([2, 0, 3, 1, 4])
 | 
											
												
													
														|  | 
 |  | +        k, v = kv[0], kv[1]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
 | 
											
												
													
														|  | 
 |  | +        attn = F.softmax(attn, axis=-1)
 | 
											
												
													
														|  | 
 |  | +        attn = self.attn_drop(attn)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
 | 
											
												
													
														|  | 
 |  | +        x = self.proj(x)
 | 
											
												
													
														|  | 
 |  | +        x = self.proj_drop(x)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class Block(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self,
 | 
											
												
													
														|  | 
 |  | +                 dim,
 | 
											
												
													
														|  | 
 |  | +                 num_heads,
 | 
											
												
													
														|  | 
 |  | +                 mlp_ratio=4.,
 | 
											
												
													
														|  | 
 |  | +                 qkv_bias=False,
 | 
											
												
													
														|  | 
 |  | +                 qk_scale=None,
 | 
											
												
													
														|  | 
 |  | +                 drop=0.,
 | 
											
												
													
														|  | 
 |  | +                 attn_drop=0.,
 | 
											
												
													
														|  | 
 |  | +                 drop_path=0.,
 | 
											
												
													
														|  | 
 |  | +                 act_layer=nn.GELU,
 | 
											
												
													
														|  | 
 |  | +                 norm_layer=nn.LayerNorm,
 | 
											
												
													
														|  | 
 |  | +                 sr_ratio=1):
 | 
											
												
													
														|  | 
 |  | +        super().__init__()
 | 
											
												
													
														|  | 
 |  | +        self.norm1 = norm_layer(dim)
 | 
											
												
													
														|  | 
 |  | +        self.attn = Attention(
 | 
											
												
													
														|  | 
 |  | +            dim,
 | 
											
												
													
														|  | 
 |  | +            num_heads=num_heads,
 | 
											
												
													
														|  | 
 |  | +            qkv_bias=qkv_bias,
 | 
											
												
													
														|  | 
 |  | +            qk_scale=qk_scale,
 | 
											
												
													
														|  | 
 |  | +            attn_drop=attn_drop,
 | 
											
												
													
														|  | 
 |  | +            proj_drop=drop,
 | 
											
												
													
														|  | 
 |  | +            sr_ratio=sr_ratio)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity(
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +        self.norm2 = norm_layer(dim)
 | 
											
												
													
														|  | 
 |  | +        mlp_hidden_dim = int(dim * mlp_ratio)
 | 
											
												
													
														|  | 
 |  | +        self.mlp = Mlp(in_features=dim,
 | 
											
												
													
														|  | 
 |  | +                       hidden_features=mlp_hidden_dim,
 | 
											
												
													
														|  | 
 |  | +                       act_layer=act_layer,
 | 
											
												
													
														|  | 
 |  | +                       drop=drop)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _init_weights(self, m):
 | 
											
												
													
														|  | 
 |  | +        if isinstance(m, nn.Linear):
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
 | 
											
												
													
														|  | 
 |  | +            trunc_normal_op(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if isinstance(m, nn.Linear) and m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.LayerNorm):
 | 
											
												
													
														|  | 
 |  | +            init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +            init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Constant(1.0)
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +        elif isinstance(m, nn.Conv2D):
 | 
											
												
													
														|  | 
 |  | +            fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
 | 
											
												
													
														|  | 
 |  | +            fan_out //= m._groups
 | 
											
												
													
														|  | 
 |  | +            init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
 | 
											
												
													
														|  | 
 |  | +            init_weight(m.weight)
 | 
											
												
													
														|  | 
 |  | +            if m.bias is not None:
 | 
											
												
													
														|  | 
 |  | +                init_bias = nn.initializer.Constant(0)
 | 
											
												
													
														|  | 
 |  | +                init_bias(m.bias)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, H, W):
 | 
											
												
													
														|  | 
 |  | +        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
 | 
											
												
													
														|  | 
 |  | +        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class DWConv(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, dim=768):
 | 
											
												
													
														|  | 
 |  | +        super(DWConv, self).__init__()
 | 
											
												
													
														|  | 
 |  | +        self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x, H, W):
 | 
											
												
													
														|  | 
 |  | +        B, N, C = x.shape
 | 
											
												
													
														|  | 
 |  | +        x = x.transpose([0, 2, 1]).reshape([B, C, H, W])
 | 
											
												
													
														|  | 
 |  | +        x = self.dwconv(x)
 | 
											
												
													
														|  | 
 |  | +        x = x.flatten(2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +# Transformer Decoder
 | 
											
												
													
														|  | 
 |  | +class MLP(nn.Layer):
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +    Linear Embedding
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, input_dim=2048, embed_dim=768):
 | 
											
												
													
														|  | 
 |  | +        super().__init__()
 | 
											
												
													
														|  | 
 |  | +        self.proj = nn.Linear(input_dim, embed_dim)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def forward(self, x):
 | 
											
												
													
														|  | 
 |  | +        x = x.flatten(2).transpose([0, 2, 1])
 | 
											
												
													
														|  | 
 |  | +        x = self.proj(x)
 | 
											
												
													
														|  | 
 |  | +        return x
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +# Difference Layer
 | 
											
												
													
														|  | 
 |  | +def conv_diff(in_channels, out_channels):
 | 
											
												
													
														|  | 
 |  | +    return nn.Sequential(
 | 
											
												
													
														|  | 
 |  | +        nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            in_channels, out_channels, kernel_size=3, padding=1),
 | 
											
												
													
														|  | 
 |  | +        nn.ReLU(),
 | 
											
												
													
														|  | 
 |  | +        nn.BatchNorm2D(out_channels),
 | 
											
												
													
														|  | 
 |  | +        nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            out_channels, out_channels, kernel_size=3, padding=1),
 | 
											
												
													
														|  | 
 |  | +        nn.ReLU())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +# Intermediate prediction Layer
 | 
											
												
													
														|  | 
 |  | +def make_prediction(in_channels, out_channels):
 | 
											
												
													
														|  | 
 |  | +    return nn.Sequential(
 | 
											
												
													
														|  | 
 |  | +        nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            in_channels, out_channels, kernel_size=3, padding=1),
 | 
											
												
													
														|  | 
 |  | +        nn.ReLU(),
 | 
											
												
													
														|  | 
 |  | +        nn.BatchNorm2D(out_channels),
 | 
											
												
													
														|  | 
 |  | +        nn.Conv2D(
 | 
											
												
													
														|  | 
 |  | +            out_channels, out_channels, kernel_size=3, padding=1))
 |