123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddle.nn.initializer import Normal
- from .backbones import resnet
- from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
- from .param_init import KaimingInitMixin
- def calc_product(*args):
- if len(args) < 1:
- raise ValueError
- ret = args[0]
- for arg in args[1:]:
- ret *= arg
- return ret
- class BIT(nn.Layer):
- """
- The BIT implementation based on PaddlePaddle.
- The original article refers to
- H. Chen, et al., "Remote Sensing Image Change Detection With Transformers"
- (https://arxiv.org/abs/2103.00208).
- This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized.
- Args:
- in_channels (int): The number of bands of the input images.
- num_classes (int): The number of target classes.
- backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and
- 'resnet34' are supported. Default: 'resnet18'.
- n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}.
- Default: 4.
- use_tokenizer (bool, optional): Use a tokenizer or not. Default: True.
- token_len (int, optional): The length of input tokens. Default: 4.
- pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max'
- for global max pooling and 'avg' for global average pooling. Default: 'max'.
- pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False.
- Default: 2.
- enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the
- encoder. Default: True.
- enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1
- enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64.
- dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8.
- dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8.
- Raises:
- ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5.
- """
- def __init__(self,
- in_channels,
- num_classes,
- backbone='resnet18',
- n_stages=4,
- use_tokenizer=True,
- token_len=4,
- pool_mode='max',
- pool_size=2,
- enc_with_pos=True,
- enc_depth=1,
- enc_head_dim=64,
- dec_depth=8,
- dec_head_dim=8,
- **backbone_kwargs):
- super(BIT, self).__init__()
- # TODO: reduce hard-coded parameters
- DIM = 32
- MLP_DIM = 2 * DIM
- EBD_DIM = DIM
- self.backbone = Backbone(
- in_channels,
- EBD_DIM,
- arch=backbone,
- n_stages=n_stages,
- **backbone_kwargs)
- self.use_tokenizer = use_tokenizer
- if not use_tokenizer:
- # If a tokenzier is not to be used,then downsample the feature maps.
- self.pool_size = pool_size
- self.pool_mode = pool_mode
- self.token_len = pool_size * pool_size
- else:
- self.conv_att = Conv1x1(32, token_len, bias=False)
- self.token_len = token_len
- self.enc_with_pos = enc_with_pos
- if enc_with_pos:
- self.enc_pos_embedding = self.create_parameter(
- shape=(1, self.token_len * 2, EBD_DIM),
- default_initializer=Normal())
- self.enc_depth = enc_depth
- self.dec_depth = dec_depth
- self.enc_head_dim = enc_head_dim
- self.dec_head_dim = dec_head_dim
- self.encoder = TransformerEncoder(
- dim=DIM,
- depth=enc_depth,
- n_heads=8,
- head_dim=enc_head_dim,
- mlp_dim=MLP_DIM,
- dropout_rate=0.)
- self.decoder = TransformerDecoder(
- dim=DIM,
- depth=dec_depth,
- n_heads=8,
- head_dim=dec_head_dim,
- mlp_dim=MLP_DIM,
- dropout_rate=0.,
- apply_softmax=True)
- self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
- self.conv_out = nn.Sequential(
- Conv3x3(
- EBD_DIM, EBD_DIM, norm=True, act=True),
- Conv3x3(EBD_DIM, num_classes))
- def _get_semantic_tokens(self, x):
- b, c = x.shape[:2]
- att_map = self.conv_att(x)
- att_map = att_map.reshape(
- (b, self.token_len, 1, calc_product(*att_map.shape[2:])))
- att_map = F.softmax(att_map, axis=-1)
- x = x.reshape((b, 1, c, att_map.shape[-1]))
- tokens = (x * att_map).sum(-1)
- return tokens
- def _get_reshaped_tokens(self, x):
- if self.pool_mode == 'max':
- x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size))
- elif self.pool_mode == 'avg':
- x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size))
- else:
- x = x
- tokens = x.transpose((0, 2, 3, 1)).flatten(1, 2)
- return tokens
- def encode(self, x):
- if self.enc_with_pos:
- x += self.enc_pos_embedding
- x = self.encoder(x)
- return x
- def decode(self, x, m):
- b, c, h, w = x.shape
- x = x.transpose((0, 2, 3, 1)).flatten(1, 2)
- x = self.decoder(x, m)
- x = x.transpose((0, 2, 1)).reshape((b, c, h, w))
- return x
- def forward(self, t1, t2):
- # Extract features via shared backbone.
- x1 = self.backbone(t1)
- x2 = self.backbone(t2)
- # Tokenization
- if self.use_tokenizer:
- token1 = self._get_semantic_tokens(x1)
- token2 = self._get_semantic_tokens(x2)
- else:
- token1 = self._get_reshaped_tokens(x1)
- token2 = self._get_reshaped_tokens(x2)
- # Transformer encoder forward
- token = paddle.concat([token1, token2], axis=1)
- token = self.encode(token)
- token1, token2 = paddle.chunk(token, 2, axis=1)
- # Transformer decoder forward
- y1 = self.decode(x1, token1)
- y2 = self.decode(x2, token2)
- # Feature differencing
- y = paddle.abs(y1 - y2)
- y = self.upsample(y)
- # Classifier forward
- pred = self.conv_out(y)
- return [pred]
- def init_weight(self):
- # Use the default initialization method.
- pass
- class Residual(nn.Layer):
- def __init__(self, fn):
- super(Residual, self).__init__()
- self.fn = fn
- def forward(self, x, **kwargs):
- return self.fn(x, **kwargs) + x
- class Residual2(nn.Layer):
- def __init__(self, fn):
- super(Residual2, self).__init__()
- self.fn = fn
- def forward(self, x1, x2, **kwargs):
- return self.fn(x1, x2, **kwargs) + x1
- class PreNorm(nn.Layer):
- def __init__(self, dim, fn):
- super(PreNorm, self).__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
- def forward(self, x, **kwargs):
- return self.fn(self.norm(x), **kwargs)
- class PreNorm2(nn.Layer):
- def __init__(self, dim, fn):
- super(PreNorm2, self).__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
- def forward(self, x1, x2, **kwargs):
- return self.fn(self.norm(x1), self.norm(x2), **kwargs)
- class FeedForward(nn.Sequential):
- def __init__(self, dim, hidden_dim, dropout_rate=0.):
- super(FeedForward, self).__init__(
- nn.Linear(dim, hidden_dim),
- nn.GELU(),
- nn.Dropout(dropout_rate),
- nn.Linear(hidden_dim, dim), nn.Dropout(dropout_rate))
- class CrossAttention(nn.Layer):
- def __init__(self,
- dim,
- n_heads=8,
- head_dim=64,
- dropout_rate=0.,
- apply_softmax=True):
- super(CrossAttention, self).__init__()
- inner_dim = head_dim * n_heads
- self.n_heads = n_heads
- self.head_dim = head_dim
- self.scale = dim**-0.5
- self.apply_softmax = apply_softmax
- self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False)
- self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False)
- self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False)
- self.fc_out = nn.Sequential(
- nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate))
- def forward(self, x, ref):
- b, n = x.shape[:2]
- h = self.n_heads
- q = self.fc_q(x)
- k = self.fc_k(ref)
- v = self.fc_v(ref)
- q = q.reshape((b, n, h, self.head_dim)).transpose((0, 2, 1, 3))
- rn = ref.shape[1]
- k = k.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
- v = v.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
- mult = paddle.matmul(q, k, transpose_y=True) * self.scale
- if self.apply_softmax:
- mult = F.softmax(mult, axis=-1)
- out = paddle.matmul(mult, v)
- out = out.transpose((0, 2, 1, 3)).flatten(2)
- return self.fc_out(out)
- class SelfAttention(CrossAttention):
- def forward(self, x):
- return super(SelfAttention, self).forward(x, x)
- class TransformerEncoder(nn.Layer):
- def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate):
- super(TransformerEncoder, self).__init__()
- self.layers = nn.LayerList([])
- for _ in range(depth):
- self.layers.append(
- nn.LayerList([
- Residual(
- PreNorm(dim,
- SelfAttention(dim, n_heads, head_dim,
- dropout_rate))),
- Residual(
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
- ]))
- def forward(self, x):
- for att, ff in self.layers:
- x = att(x)
- x = ff(x)
- return x
- class TransformerDecoder(nn.Layer):
- def __init__(self,
- dim,
- depth,
- n_heads,
- head_dim,
- mlp_dim,
- dropout_rate,
- apply_softmax=True):
- super(TransformerDecoder, self).__init__()
- self.layers = nn.LayerList([])
- for _ in range(depth):
- self.layers.append(
- nn.LayerList([
- Residual2(
- PreNorm2(dim,
- CrossAttention(dim, n_heads, head_dim,
- dropout_rate, apply_softmax))),
- Residual(
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
- ]))
- def forward(self, x, m):
- for att, ff in self.layers:
- x = att(x, m)
- x = ff(x)
- return x
- class Backbone(nn.Layer, KaimingInitMixin):
- def __init__(self,
- in_ch,
- out_ch=32,
- arch='resnet18',
- pretrained=True,
- n_stages=5):
- super(Backbone, self).__init__()
- expand = 1
- strides = (2, 1, 2, 1, 1)
- if arch == 'resnet18':
- self.resnet = resnet.resnet18(
- pretrained=pretrained,
- strides=strides,
- norm_layer=get_norm_layer())
- elif arch == 'resnet34':
- self.resnet = resnet.resnet34(
- pretrained=pretrained,
- strides=strides,
- norm_layer=get_norm_layer())
- else:
- raise ValueError
- self.n_stages = n_stages
- if self.n_stages == 5:
- itm_ch = 512 * expand
- elif self.n_stages == 4:
- itm_ch = 256 * expand
- elif self.n_stages == 3:
- itm_ch = 128 * expand
- else:
- raise ValueError
- self.upsample = nn.Upsample(scale_factor=2)
- self.conv_out = Conv3x3(itm_ch, out_ch)
- self._trim_resnet()
- if in_ch != 3:
- self.resnet.conv1 = nn.Conv2D(
- in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)
- if not pretrained:
- self.init_weight()
- def forward(self, x):
- y = self.resnet.conv1(x)
- y = self.resnet.bn1(y)
- y = self.resnet.relu(y)
- y = self.resnet.maxpool(y)
- y = self.resnet.layer1(y)
- y = self.resnet.layer2(y)
- y = self.resnet.layer3(y)
- y = self.resnet.layer4(y)
- y = self.upsample(y)
- return self.conv_out(y)
- def _trim_resnet(self):
- if self.n_stages > 5:
- raise ValueError
- if self.n_stages < 5:
- self.resnet.layer4 = Identity()
- if self.n_stages <= 3:
- self.resnet.layer3 = Identity()
- self.resnet.avgpool = Identity()
- self.resnet.fc = Identity()
|