12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import warnings
- import math
- from functools import partial
- import paddle as pd
- import paddle.nn as nn
- import paddle.nn.functional as F
- from .layers.pd_timm import DropPath, to_2tuple
- def calc_product(*args):
- if len(args) < 1:
- raise ValueError
- ret = args[0]
- for arg in args[1:]:
- ret *= arg
- return ret
- class ConvBlock(pd.nn.Layer):
- def __init__(self,
- input_size,
- output_size,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=True,
- activation='prelu',
- norm=None):
- super(ConvBlock, self).__init__()
- self.conv = pd.nn.Conv2D(
- input_size,
- output_size,
- kernel_size,
- stride,
- padding,
- bias_attr=bias)
- self.norm = norm
- if self.norm == 'batch':
- self.bn = pd.nn.BatchNorm2D(output_size)
- elif self.norm == 'instance':
- self.bn = pd.nn.InstanceNorm2D(output_size)
- self.activation = activation
- if self.activation == 'relu':
- self.act = pd.nn.ReLU(True)
- elif self.activation == 'prelu':
- self.act = pd.nn.PReLU()
- elif self.activation == 'lrelu':
- self.act = pd.nn.LeakyReLU(0.2, True)
- elif self.activation == 'tanh':
- self.act = pd.nn.Tanh()
- elif self.activation == 'sigmoid':
- self.act = pd.nn.Sigmoid()
- def forward(self, x):
- if self.norm is not None:
- out = self.bn(self.conv(x))
- else:
- out = self.conv(x)
- if self.activation != 'no':
- return self.act(out)
- else:
- return out
- class DeconvBlock(pd.nn.Layer):
- def __init__(self,
- input_size,
- output_size,
- kernel_size=4,
- stride=2,
- padding=1,
- bias=True,
- activation='prelu',
- norm=None):
- super(DeconvBlock, self).__init__()
- self.deconv = pd.nn.Conv2DTranspose(
- input_size,
- output_size,
- kernel_size,
- stride,
- padding,
- bias_attr=bias)
- self.norm = norm
- if self.norm == 'batch':
- self.bn = pd.nn.BatchNorm2D(output_size)
- elif self.norm == 'instance':
- self.bn = pd.nn.InstanceNorm2D(output_size)
- self.activation = activation
- if self.activation == 'relu':
- self.act = pd.nn.ReLU(True)
- elif self.activation == 'prelu':
- self.act = pd.nn.PReLU()
- elif self.activation == 'lrelu':
- self.act = pd.nn.LeakyReLU(0.2, True)
- elif self.activation == 'tanh':
- self.act = pd.nn.Tanh()
- elif self.activation == 'sigmoid':
- self.act = pd.nn.Sigmoid()
- def forward(self, x):
- if self.norm is not None:
- out = self.bn(self.deconv(x))
- else:
- out = self.deconv(x)
- if self.activation is not None:
- return self.act(out)
- else:
- return out
- class ConvLayer(nn.Layer):
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
- super(ConvLayer, self).__init__()
- self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size, stride,
- padding)
- def forward(self, x):
- out = self.conv2d(x)
- return out
- class UpsampleConvLayer(pd.nn.Layer):
- def __init__(self, in_channels, out_channels, kernel_size, stride):
- super(UpsampleConvLayer, self).__init__()
- self.conv2d = nn.Conv2DTranspose(
- in_channels, out_channels, kernel_size, stride=stride, padding=1)
- def forward(self, x):
- out = self.conv2d(x)
- return out
- class ResidualBlock(pd.nn.Layer):
- def __init__(self, channels):
- super(ResidualBlock, self).__init__()
- self.conv1 = ConvLayer(
- channels, channels, kernel_size=3, stride=1, padding=1)
- self.conv2 = ConvLayer(
- channels, channels, kernel_size=3, stride=1, padding=1)
- self.relu = nn.ReLU()
- def forward(self, x):
- residual = x
- out = self.relu(self.conv1(x))
- out = self.conv2(out) * 0.1
- out = pd.add(out, residual)
- return out
- class ChangeFormer(nn.Layer):
- """
- The ChangeFormer implementation based on PaddlePaddle.
- The original article refers to
- Wele Gedara Chaminda Bandara, Vishal M. Patel., "A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION"
- (https://arxiv.org/pdf/2201.01293.pdf).
- Args:
- in_channels (int): Number of bands of the input images. Default: 3.
- num_classes (int): Number of target classes. Default: 2.
- decoder_softmax (bool, optional): Use softmax after decode or not. Default: False.
- embed_dim (int, optional): Embedding dimension of each decoder head. Default: 256.
- """
- def __init__(self,
- in_channels=3,
- num_classes=2,
- decoder_softmax=False,
- embed_dim=256):
- super(ChangeFormer, self).__init__()
- # Transformer Encoder
- self.embed_dims = [64, 128, 320, 512]
- self.depths = [3, 3, 4, 3]
- self.embedding_dim = embed_dim
- self.drop_rate = 0.1
- self.attn_drop = 0.1
- self.drop_path_rate = 0.1
- self.Tenc_x2 = EncoderTransformer_v3(
- img_size=256,
- patch_size=7,
- in_chans=in_channels,
- num_classes=num_classes,
- embed_dims=self.embed_dims,
- num_heads=[1, 2, 4, 8],
- mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True,
- qk_scale=None,
- drop_rate=self.drop_rate,
- attn_drop_rate=self.attn_drop,
- drop_path_rate=self.drop_path_rate,
- norm_layer=partial(
- nn.LayerNorm, epsilon=1e-6),
- depths=self.depths,
- sr_ratios=[8, 4, 2, 1])
- # Transformer Decoder
- self.TDec_x2 = DecoderTransformer_v3(
- input_transform='multiple_select',
- in_index=[0, 1, 2, 3],
- align_corners=False,
- in_channels=self.embed_dims,
- embedding_dim=self.embedding_dim,
- output_nc=num_classes,
- decoder_softmax=decoder_softmax,
- feature_strides=[2, 4, 8, 16])
- def forward(self, x1, x2):
- [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
- cp = self.TDec_x2(fx1, fx2)
- return [cp]
- # Transormer Ecoder with x2, x4, x8, x16 scales
- class EncoderTransformer_v3(nn.Layer):
- def __init__(self,
- img_size=256,
- patch_size=3,
- in_chans=3,
- num_classes=2,
- embed_dims=[32, 64, 128, 256],
- num_heads=[2, 2, 4, 8],
- mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- norm_layer=nn.LayerNorm,
- depths=[3, 3, 6, 18],
- sr_ratios=[8, 4, 2, 1]):
- super().__init__()
- self.num_classes = num_classes
- self.depths = depths
- self.embed_dims = embed_dims
- # Patch embedding definitions
- self.patch_embed1 = OverlapPatchEmbed(
- img_size=img_size,
- patch_size=7,
- stride=4,
- in_chans=in_chans,
- embed_dim=embed_dims[0])
- self.patch_embed2 = OverlapPatchEmbed(
- img_size=img_size // 4,
- patch_size=patch_size,
- stride=2,
- in_chans=embed_dims[0],
- embed_dim=embed_dims[1])
- self.patch_embed3 = OverlapPatchEmbed(
- img_size=img_size // 8,
- patch_size=patch_size,
- stride=2,
- in_chans=embed_dims[1],
- embed_dim=embed_dims[2])
- self.patch_embed4 = OverlapPatchEmbed(
- img_size=img_size // 16,
- patch_size=patch_size,
- stride=2,
- in_chans=embed_dims[2],
- embed_dim=embed_dims[3])
- # Stage-1 (x1/4 scale)
- dpr = [x.item() for x in pd.linspace(0, drop_path_rate, sum(depths))]
- cur = 0
- self.block1 = nn.LayerList([
- Block(
- dim=embed_dims[0],
- num_heads=num_heads[0],
- mlp_ratio=mlp_ratios[0],
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[cur + i],
- norm_layer=norm_layer,
- sr_ratio=sr_ratios[0]) for i in range(depths[0])
- ])
- self.norm1 = norm_layer(embed_dims[0])
- # Stage-2 (x1/8 scale)
- cur += depths[0]
- self.block2 = nn.LayerList([
- Block(
- dim=embed_dims[1],
- num_heads=num_heads[1],
- mlp_ratio=mlp_ratios[1],
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[cur + i],
- norm_layer=norm_layer,
- sr_ratio=sr_ratios[1]) for i in range(depths[1])
- ])
- self.norm2 = norm_layer(embed_dims[1])
- # Stage-3 (x1/16 scale)
- cur += depths[1]
- self.block3 = nn.LayerList([
- Block(
- dim=embed_dims[2],
- num_heads=num_heads[2],
- mlp_ratio=mlp_ratios[2],
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[cur + i],
- norm_layer=norm_layer,
- sr_ratio=sr_ratios[2]) for i in range(depths[2])
- ])
- self.norm3 = norm_layer(embed_dims[2])
- # Stage-4 (x1/32 scale)
- cur += depths[2]
- self.block4 = nn.LayerList([
- Block(
- dim=embed_dims[3],
- num_heads=num_heads[3],
- mlp_ratio=mlp_ratios[3],
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[cur + i],
- norm_layer=norm_layer,
- sr_ratio=sr_ratios[3]) for i in range(depths[3])
- ])
- self.norm4 = norm_layer(embed_dims[3])
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
- trunc_normal_op(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- elif isinstance(m, nn.LayerNorm):
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- init_weight = nn.initializer.Constant(1.0)
- init_weight(m.weight)
- elif isinstance(m, nn.Conv2D):
- fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- fan_out //= m._groups
- init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
- init_weight(m.weight)
- if m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- def reset_drop_path(self, drop_path_rate):
- dpr = [
- x.item() for x in pd.linspace(0, drop_path_rate, sum(self.depths))
- ]
- cur = 0
- for i in range(self.depths[0]):
- self.block1[i].drop_path.drop_prob = dpr[cur + i]
- cur += self.depths[0]
- for i in range(self.depths[1]):
- self.block2[i].drop_path.drop_prob = dpr[cur + i]
- cur += self.depths[1]
- for i in range(self.depths[2]):
- self.block3[i].drop_path.drop_prob = dpr[cur + i]
- cur += self.depths[2]
- for i in range(self.depths[3]):
- self.block4[i].drop_path.drop_prob = dpr[cur + i]
- def forward_features(self, x):
- B = x.shape[0]
- outs = []
- # Stage 1
- x1, H1, W1 = self.patch_embed1(x)
- for i, blk in enumerate(self.block1):
- x1 = blk(x1, H1, W1)
- x1 = self.norm1(x1)
- x1 = x1.reshape(
- [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
- [0, 3, 1, 2])
- outs.append(x1)
- # Stage 2
- x1, H1, W1 = self.patch_embed2(x1)
- for i, blk in enumerate(self.block2):
- x1 = blk(x1, H1, W1)
- x1 = self.norm2(x1)
- x1 = x1.reshape(
- [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
- [0, 3, 1, 2])
- outs.append(x1)
- # Stage 3
- x1, H1, W1 = self.patch_embed3(x1)
- for i, blk in enumerate(self.block3):
- x1 = blk(x1, H1, W1)
- x1 = self.norm3(x1)
- x1 = x1.reshape(
- [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
- [0, 3, 1, 2])
- outs.append(x1)
- # Stage 4
- x1, H1, W1 = self.patch_embed4(x1)
- for i, blk in enumerate(self.block4):
- x1 = blk(x1, H1, W1)
- x1 = self.norm4(x1)
- x1 = x1.reshape(
- [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
- [0, 3, 1, 2])
- outs.append(x1)
- return outs
- def forward(self, x):
- x = self.forward_features(x)
- return x
- class DecoderTransformer_v3(nn.Layer):
- """
- Transformer Decoder
- """
- def __init__(self,
- input_transform='multiple_select',
- in_index=[0, 1, 2, 3],
- align_corners=True,
- in_channels=[32, 64, 128, 256],
- embedding_dim=64,
- output_nc=2,
- decoder_softmax=False,
- feature_strides=[2, 4, 8, 16]):
- super(DecoderTransformer_v3, self).__init__()
- assert len(feature_strides) == len(in_channels)
- assert min(feature_strides) == feature_strides[0]
- # Settings
- self.feature_strides = feature_strides
- self.input_transform = input_transform
- self.in_index = in_index
- self.align_corners = align_corners
- self.in_channels = in_channels
- self.embedding_dim = embedding_dim
- self.output_nc = output_nc
- c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
- # MLP decoder heads
- self.linear_c4 = MLP(input_dim=c4_in_channels,
- embed_dim=self.embedding_dim)
- self.linear_c3 = MLP(input_dim=c3_in_channels,
- embed_dim=self.embedding_dim)
- self.linear_c2 = MLP(input_dim=c2_in_channels,
- embed_dim=self.embedding_dim)
- self.linear_c1 = MLP(input_dim=c1_in_channels,
- embed_dim=self.embedding_dim)
- # Convolutional Difference Layers
- self.diff_c4 = conv_diff(
- in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
- self.diff_c3 = conv_diff(
- in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
- self.diff_c2 = conv_diff(
- in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
- self.diff_c1 = conv_diff(
- in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
- # Take outputs from middle of the encoder
- self.make_pred_c4 = make_prediction(
- in_channels=self.embedding_dim, out_channels=self.output_nc)
- self.make_pred_c3 = make_prediction(
- in_channels=self.embedding_dim, out_channels=self.output_nc)
- self.make_pred_c2 = make_prediction(
- in_channels=self.embedding_dim, out_channels=self.output_nc)
- self.make_pred_c1 = make_prediction(
- in_channels=self.embedding_dim, out_channels=self.output_nc)
- # Final linear fusion layer
- self.linear_fuse = nn.Sequential(
- nn.Conv2D(
- in_channels=self.embedding_dim * len(in_channels),
- out_channels=self.embedding_dim,
- kernel_size=1),
- nn.BatchNorm2D(self.embedding_dim))
- # Final predction head
- self.convd2x = UpsampleConvLayer(
- self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
- self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
- self.convd1x = UpsampleConvLayer(
- self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
- self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
- self.change_probability = ConvLayer(
- self.embedding_dim,
- self.output_nc,
- kernel_size=3,
- stride=1,
- padding=1)
- # Final activation
- self.output_softmax = decoder_softmax
- self.active = nn.Sigmoid()
- def _transform_inputs(self, inputs):
- """
- Transform inputs for decoder.
- Args:
- inputs (list[Tensor]): List of multi-level img features.
- Returns:
- Tensor: The transformed inputs
- """
- if self.input_transform == 'resize_concat':
- inputs = [inputs[i] for i in self.in_index]
- upsampled_inputs = [
- resize(
- input=x,
- size=inputs[0].shape[2:],
- mode='bilinear',
- align_corners=self.align_corners) for x in inputs
- ]
- inputs = pd.concat(upsampled_inputs, dim=1)
- elif self.input_transform == 'multiple_select':
- inputs = [inputs[i] for i in self.in_index]
- else:
- inputs = inputs[self.in_index]
- return inputs
- def forward(self, inputs1, inputs2):
- # Transforming encoder features (select layers)
- x_1 = self._transform_inputs(inputs1) # len=4, 1/2, 1/4, 1/8, 1/16
- x_2 = self._transform_inputs(inputs2) # len=4, 1/2, 1/4, 1/8, 1/16
- # img1 and img2 features
- c1_1, c2_1, c3_1, c4_1 = x_1
- c1_2, c2_2, c3_2, c4_2 = x_2
- ############## MLP decoder on C1-C4 ###########
- n, _, h, w = c4_1.shape
- outputs = []
- # Stage 4: x1/32 scale
- _c4_1 = self.linear_c4(c4_1).transpose([0, 2, 1])
- _c4_1 = _c4_1.reshape([
- n, calc_product(*_c4_1.shape[1:]) //
- (c4_1.shape[2] * c4_1.shape[3]), c4_1.shape[2], c4_1.shape[3]
- ])
- _c4_2 = self.linear_c4(c4_2).transpose([0, 2, 1])
- _c4_2 = _c4_2.reshape([
- n, calc_product(*_c4_2.shape[1:]) //
- (c4_2.shape[2] * c4_2.shape[3]), c4_2.shape[2], c4_2.shape[3]
- ])
- _c4 = self.diff_c4(pd.concat((_c4_1, _c4_2), axis=1))
- p_c4 = self.make_pred_c4(_c4)
- outputs.append(p_c4)
- _c4_up = resize(
- _c4, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
- # Stage 3: x1/16 scale
- _c3_1 = self.linear_c3(c3_1).transpose([0, 2, 1])
- _c3_1 = _c3_1.reshape([
- n, calc_product(*_c3_1.shape[1:]) //
- (c3_1.shape[2] * c3_1.shape[3]), c3_1.shape[2], c3_1.shape[3]
- ])
- _c3_2 = self.linear_c3(c3_2).transpose([0, 2, 1])
- _c3_2 = _c3_2.reshape([
- n, calc_product(*_c3_2.shape[1:]) //
- (c3_2.shape[2] * c3_2.shape[3]), c3_2.shape[2], c3_2.shape[3]
- ])
- _c3 = self.diff_c3(pd.concat((_c3_1, _c3_2), axis=1)) + \
- F.interpolate(_c4, scale_factor=2, mode="bilinear")
- p_c3 = self.make_pred_c3(_c3)
- outputs.append(p_c3)
- _c3_up = resize(
- _c3, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
- # Stage 2: x1/8 scale
- _c2_1 = self.linear_c2(c2_1).transpose([0, 2, 1])
- _c2_1 = _c2_1.reshape([
- n, calc_product(*_c2_1.shape[1:]) //
- (c2_1.shape[2] * c2_1.shape[3]), c2_1.shape[2], c2_1.shape[3]
- ])
- _c2_2 = self.linear_c2(c2_2).transpose([0, 2, 1])
- _c2_2 = _c2_2.reshape([
- n, calc_product(*_c2_2.shape[1:]) //
- (c2_2.shape[2] * c2_2.shape[3]), c2_2.shape[2], c2_2.shape[3]
- ])
- _c2 = self.diff_c2(pd.concat((_c2_1, _c2_2), axis=1)) + \
- F.interpolate(_c3, scale_factor=2, mode="bilinear")
- p_c2 = self.make_pred_c2(_c2)
- outputs.append(p_c2)
- _c2_up = resize(
- _c2, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
- # Stage 1: x1/4 scale
- _c1_1 = self.linear_c1(c1_1).transpose([0, 2, 1])
- _c1_1 = _c1_1.reshape([
- n, calc_product(*_c1_1.shape[1:]) //
- (c1_1.shape[2] * c1_1.shape[3]), c1_1.shape[2], c1_1.shape[3]
- ])
- _c1_2 = self.linear_c1(c1_2).transpose([0, 2, 1])
- _c1_2 = _c1_2.reshape([
- n, calc_product(*_c1_2.shape[1:]) //
- (c1_2.shape[2] * c1_2.shape[3]), c1_2.shape[2], c1_2.shape[3]
- ])
- _c1 = self.diff_c1(pd.concat((_c1_1, _c1_2), axis=1)) + \
- F.interpolate(_c2, scale_factor=2, mode="bilinear")
- p_c1 = self.make_pred_c1(_c1)
- outputs.append(p_c1)
- # Linear Fusion of difference image from all scales
- _c = self.linear_fuse(pd.concat((_c4_up, _c3_up, _c2_up, _c1), axis=1))
- # Upsampling x2 (x1/2 scale)
- x = self.convd2x(_c)
- # Residual block
- x = self.dense_2x(x)
- # Upsampling x2 (x1 scale)
- x = self.convd1x(x)
- # Residual block
- x = self.dense_1x(x)
- # Final prediction
- cp = self.change_probability(x)
- outputs.append(cp)
- if self.output_softmax:
- temp = outputs
- outputs = []
- for pred in temp:
- outputs.append(self.active(pred))
- return outputs[-1]
- class OverlapPatchEmbed(nn.Layer):
- """
- Image to Patch Embedding
- """
- def __init__(self,
- img_size=224,
- patch_size=7,
- stride=4,
- in_chans=3,
- embed_dim=768):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- self.img_size = img_size
- self.patch_size = patch_size
- self.H, self.W = img_size[0] // patch_size[0], img_size[
- 1] // patch_size[1]
- self.num_patches = self.H * self.W
- self.proj = nn.Conv2D(
- in_chans,
- embed_dim,
- kernel_size=patch_size,
- stride=stride,
- padding=(patch_size[0] // 2, patch_size[1] // 2))
- self.norm = nn.LayerNorm(embed_dim)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
- trunc_normal_op(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- elif isinstance(m, nn.LayerNorm):
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- init_weight = nn.initializer.Constant(1.0)
- init_weight(m.weight)
- elif isinstance(m, nn.Conv2D):
- fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- fan_out //= m._groups
- init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
- init_weight(m.weight)
- if m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- def forward(self, x):
- x = self.proj(x)
- _, _, H, W = x.shape
- x = x.flatten(2).transpose([0, 2, 1])
- x = self.norm(x)
- return x, H, W
- def resize(input,
- size=None,
- scale_factor=None,
- mode='nearest',
- align_corners=None,
- warning=True):
- if warning:
- if size is not None and align_corners:
- input_h, input_w = tuple(int(x) for x in input.shape[2:])
- output_h, output_w = tuple(int(x) for x in size)
- if output_h > input_h or output_w > output_h:
- if ((output_h > 1 and output_w > 1 and input_h > 1 and
- input_w > 1) and (output_h - 1) % (input_h - 1) and
- (output_w - 1) % (input_w - 1)):
- warnings.warn(
- f'When align_corners={align_corners}, '
- 'the output would more aligned if '
- f'input size {(input_h, input_w)} is `x+1` and '
- f'out size {(output_h, output_w)} is `nx+1`')
- return F.interpolate(input, size, scale_factor, mode, align_corners)
- class Mlp(nn.Layer):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.dwconv = DWConv(hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
- trunc_normal_op(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- elif isinstance(m, nn.LayerNorm):
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- init_weight = nn.initializer.Constant(1.0)
- init_weight(m.weight)
- elif isinstance(m, nn.Conv2D):
- fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- fan_out //= m._groups
- init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
- init_weight(m.weight)
- if m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- def forward(self, x, H, W):
- x = self.fc1(x)
- x = self.dwconv(x, H, W)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class Attention(nn.Layer):
- def __init__(self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.,
- sr_ratio=1):
- super().__init__()
- assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
- self.dim = dim
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
- self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
- self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.sr_ratio = sr_ratio
- if sr_ratio > 1:
- self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
- self.norm = nn.LayerNorm(dim)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
- trunc_normal_op(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- elif isinstance(m, nn.LayerNorm):
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- init_weight = nn.initializer.Constant(1.0)
- init_weight(m.weight)
- elif isinstance(m, nn.Conv2D):
- fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- fan_out //= m._groups
- init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
- init_weight(m.weight)
- if m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- def forward(self, x, H, W):
- B, N, C = x.shape
- q = self.q(x).reshape([B, N, self.num_heads,
- C // self.num_heads]).transpose([0, 2, 1, 3])
- if self.sr_ratio > 1:
- x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
- x_ = self.sr(x_)
- x_ = x_.reshape([B, C, calc_product(*x_.shape[2:])]).transpose(
- [0, 2, 1])
- x_ = self.norm(x_)
- kv = self.kv(x_)
- kv = kv.reshape([
- B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
- C // self.num_heads
- ]).transpose([2, 0, 3, 1, 4])
- else:
- kv = self.kv(x)
- kv = kv.reshape([
- B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
- C // self.num_heads
- ]).transpose([2, 0, 3, 1, 4])
- k, v = kv[0], kv[1]
- attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
- attn = F.softmax(attn, axis=-1)
- attn = self.attn_drop(attn)
- x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- class Block(nn.Layer):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- sr_ratio=1):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- sr_ratio=sr_ratio)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity(
- )
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
- trunc_normal_op(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- elif isinstance(m, nn.LayerNorm):
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- init_weight = nn.initializer.Constant(1.0)
- init_weight(m.weight)
- elif isinstance(m, nn.Conv2D):
- fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- fan_out //= m._groups
- init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
- init_weight(m.weight)
- if m.bias is not None:
- init_bias = nn.initializer.Constant(0)
- init_bias(m.bias)
- def forward(self, x, H, W):
- x = x + self.drop_path(self.attn(self.norm1(x), H, W))
- x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
- return x
- class DWConv(nn.Layer):
- def __init__(self, dim=768):
- super(DWConv, self).__init__()
- self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)
- def forward(self, x, H, W):
- B, N, C = x.shape
- x = x.transpose([0, 2, 1]).reshape([B, C, H, W])
- x = self.dwconv(x)
- x = x.flatten(2).transpose([0, 2, 1])
- return x
- # Transformer Decoder
- class MLP(nn.Layer):
- """
- Linear Embedding
- """
- def __init__(self, input_dim=2048, embed_dim=768):
- super().__init__()
- self.proj = nn.Linear(input_dim, embed_dim)
- def forward(self, x):
- x = x.flatten(2).transpose([0, 2, 1])
- x = self.proj(x)
- return x
- # Difference Layer
- def conv_diff(in_channels, out_channels):
- return nn.Sequential(
- nn.Conv2D(
- in_channels, out_channels, kernel_size=3, padding=1),
- nn.ReLU(),
- nn.BatchNorm2D(out_channels),
- nn.Conv2D(
- out_channels, out_channels, kernel_size=3, padding=1),
- nn.ReLU())
- # Intermediate prediction Layer
- def make_prediction(in_channels, out_channels):
- return nn.Sequential(
- nn.Conv2D(
- in_channels, out_channels, kernel_size=3, padding=1),
- nn.ReLU(),
- nn.BatchNorm2D(out_channels),
- nn.Conv2D(
- out_channels, out_channels, kernel_size=3, padding=1))
|