123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from .backbones import resnet
- from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
- from .param_init import KaimingInitMixin
- class STANet(nn.Layer):
- """
- The STANet implementation based on PaddlePaddle.
- The original article refers to
- H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
- (https://www.mdpi.com/2072-4292/12/10/1662).
- Note that this implementation differs from the original work in two aspects:
- 1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
- 2. A classification head is used in place of the original metric learning-based head to stablize the training process.
- Args:
- in_channels (int): The number of bands of the input images.
- num_classes (int): The number of target classes.
- att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
- ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values
- greater than 1, the input features will first be processed by an average pooling layer with the kernel size of
- `ds_factor`, before being used to calculate the attention scores. Default: 1.
- Raises:
- ValueError: When `att_type` has an illeagal value (unsupported attention type).
- """
- def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):
- super(STANet, self).__init__()
- WIDTH = 64
- self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
- self.attend = build_sta_module(
- in_ch=WIDTH, att_type=att_type, ds=ds_factor)
- self.conv_out = nn.Sequential(
- Conv3x3(
- WIDTH, WIDTH, norm=True, act=True),
- Conv3x3(WIDTH, num_classes))
- self.init_weight()
- def forward(self, t1, t2):
- f1 = self.extract(t1)
- f2 = self.extract(t2)
- f1, f2 = self.attend(f1, f2)
- y = paddle.abs(f1 - f2)
- y = F.interpolate(
- y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
- pred = self.conv_out(y)
- return [pred]
- def init_weight(self):
- # Do nothing here as the encoder and decoder weights have already been initialized.
- # Note however that currently self.attend and self.conv_out use the default initilization method.
- pass
- def build_feat_extractor(in_ch, width):
- return nn.Sequential(Backbone(in_ch, 'resnet18'), Decoder(width))
- def build_sta_module(in_ch, att_type, ds):
- if att_type == 'BAM':
- return Attention(BAM(in_ch, ds))
- elif att_type == 'PAM':
- return Attention(PAM(in_ch, ds))
- else:
- raise ValueError
- class Backbone(nn.Layer, KaimingInitMixin):
- def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)):
- super(Backbone, self).__init__()
- if arch == 'resnet18':
- self.resnet = resnet.resnet18(
- pretrained=pretrained,
- strides=strides,
- norm_layer=get_norm_layer())
- elif arch == 'resnet34':
- self.resnet = resnet.resnet34(
- pretrained=pretrained,
- strides=strides,
- norm_layer=get_norm_layer())
- elif arch == 'resnet50':
- self.resnet = resnet.resnet50(
- pretrained=pretrained,
- strides=strides,
- norm_layer=get_norm_layer())
- else:
- raise ValueError
- self._trim_resnet()
- if in_ch != 3:
- self.resnet.conv1 = nn.Conv2D(
- in_ch,
- 64,
- kernel_size=7,
- stride=strides[0],
- padding=3,
- bias_attr=False)
- if not pretrained:
- self.init_weight()
- def forward(self, x):
- x = self.resnet.conv1(x)
- x = self.resnet.bn1(x)
- x = self.resnet.relu(x)
- x = self.resnet.maxpool(x)
- x1 = self.resnet.layer1(x)
- x2 = self.resnet.layer2(x1)
- x3 = self.resnet.layer3(x2)
- x4 = self.resnet.layer4(x3)
- return x1, x2, x3, x4
- def _trim_resnet(self):
- self.resnet.avgpool = Identity()
- self.resnet.fc = Identity()
- class Decoder(nn.Layer, KaimingInitMixin):
- def __init__(self, f_ch):
- super(Decoder, self).__init__()
- self.dr1 = Conv1x1(64, 96, norm=True, act=True)
- self.dr2 = Conv1x1(128, 96, norm=True, act=True)
- self.dr3 = Conv1x1(256, 96, norm=True, act=True)
- self.dr4 = Conv1x1(512, 96, norm=True, act=True)
- self.conv_out = nn.Sequential(
- Conv3x3(
- 384, 256, norm=True, act=True),
- nn.Dropout(0.5),
- Conv1x1(
- 256, f_ch, norm=True, act=True))
- self.init_weight()
- def forward(self, feats):
- f1 = self.dr1(feats[0])
- f2 = self.dr2(feats[1])
- f3 = self.dr3(feats[2])
- f4 = self.dr4(feats[3])
- f2 = F.interpolate(
- f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
- f3 = F.interpolate(
- f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
- f4 = F.interpolate(
- f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
- x = paddle.concat([f1, f2, f3, f4], axis=1)
- y = self.conv_out(x)
- return y
- class BAM(nn.Layer):
- def __init__(self, in_ch, ds):
- super(BAM, self).__init__()
- self.ds = ds
- self.pool = nn.AvgPool2D(self.ds)
- self.val_ch = in_ch
- self.key_ch = in_ch // 8
- self.conv_q = Conv1x1(in_ch, self.key_ch)
- self.conv_k = Conv1x1(in_ch, self.key_ch)
- self.conv_v = Conv1x1(in_ch, self.val_ch)
- self.softmax = nn.Softmax(axis=-1)
- def forward(self, x):
- x = x.flatten(-2)
- x_rs = self.pool(x)
- b, c, h, w = paddle.shape(x_rs)
- query = self.conv_q(x_rs).reshape((b, -1, h * w)).transpose((0, 2, 1))
- key = self.conv_k(x_rs).reshape((b, -1, h * w))
- energy = paddle.bmm(query, key)
- energy = (self.key_ch**(-0.5)) * energy
- attention = self.softmax(energy)
- value = self.conv_v(x_rs).reshape((b, -1, w * h))
- out = paddle.bmm(value, attention.transpose((0, 2, 1)))
- out = out.reshape((b, c, h, w))
- out = F.interpolate(out, scale_factor=self.ds)
- out = out + x
- return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
- class PAMBlock(nn.Layer):
- def __init__(self, in_ch, scale=1, ds=1):
- super(PAMBlock, self).__init__()
- self.scale = scale
- self.ds = ds
- self.pool = nn.AvgPool2D(self.ds)
- self.val_ch = in_ch
- self.key_ch = in_ch // 8
- self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)
- self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)
- self.conv_v = Conv1x1(in_ch, self.val_ch)
- def forward(self, x):
- x_rs = self.pool(x)
- # Get query, key, and value.
- query = self.conv_q(x_rs)
- key = self.conv_k(x_rs)
- value = self.conv_v(x_rs)
- # Split the whole image into subregions.
- b, c, h, w = x_rs.shape
- query = self._split_subregions(query)
- key = self._split_subregions(key)
- value = self._split_subregions(value)
- # Perform subregion-wise attention.
- out = self._attend(query, key, value)
- # Stack subregions to reconstruct the whole image.
- out = self._recons_whole(out, b, c, h, w)
- out = F.interpolate(out, scale_factor=self.ds)
- return out
- def _attend(self, query, key, value):
- energy = paddle.bmm(query.transpose((0, 2, 1)),
- key) # batch matrix multiplication
- energy = (self.key_ch**(-0.5)) * energy
- attention = F.softmax(energy, axis=-1)
- out = paddle.bmm(value, attention.transpose((0, 2, 1)))
- return out
- def _split_subregions(self, x):
- b, c, h, w = x.shape
- assert h % self.scale == 0 and w % self.scale == 0
- x = x.reshape(
- (b, c, self.scale, h // self.scale, self.scale, w // self.scale))
- x = x.transpose((0, 2, 4, 1, 3, 5))
- x = x.reshape((b * self.scale * self.scale, c, -1))
- return x
- def _recons_whole(self, x, b, c, h, w):
- x = x.reshape(
- (b, self.scale, self.scale, c, h // self.scale, w // self.scale))
- x = x.transpose((0, 3, 1, 4, 2, 5)).reshape((b, c, h, w))
- return x
- class PAM(nn.Layer):
- def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)):
- super(PAM, self).__init__()
- self.stages = nn.LayerList(
- [PAMBlock(
- in_ch, scale=s, ds=ds) for s in scales])
- self.conv_out = Conv1x1(in_ch * len(scales), in_ch, bias=False)
- def forward(self, x):
- x = x.flatten(-2)
- res = [stage(x) for stage in self.stages]
- out = self.conv_out(paddle.concat(res, axis=1))
- return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
- class Attention(nn.Layer):
- def __init__(self, att):
- super(Attention, self).__init__()
- self.att = att
- def forward(self, x1, x2):
- x = paddle.stack([x1, x2], axis=-1)
- y = self.att(x)
- return y[..., 0], y[..., 1]
|