stanet.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from .backbones import resnet
  18. from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
  19. from .param_init import KaimingInitMixin
  20. class STANet(nn.Layer):
  21. """
  22. The STANet implementation based on PaddlePaddle.
  23. The original article refers to
  24. H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
  25. (https://www.mdpi.com/2072-4292/12/10/1662).
  26. Note that this implementation differs from the original work in two aspects:
  27. 1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
  28. 2. A classification head is used in place of the original metric learning-based head to stablize the training process.
  29. Args:
  30. in_channels (int): The number of bands of the input images.
  31. num_classes (int): The number of target classes.
  32. att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
  33. ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values
  34. greater than 1, the input features will first be processed by an average pooling layer with the kernel size of
  35. `ds_factor`, before being used to calculate the attention scores. Default: 1.
  36. Raises:
  37. ValueError: When `att_type` has an illeagal value (unsupported attention type).
  38. """
  39. def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):
  40. super(STANet, self).__init__()
  41. WIDTH = 64
  42. self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
  43. self.attend = build_sta_module(
  44. in_ch=WIDTH, att_type=att_type, ds=ds_factor)
  45. self.conv_out = nn.Sequential(
  46. Conv3x3(
  47. WIDTH, WIDTH, norm=True, act=True),
  48. Conv3x3(WIDTH, num_classes))
  49. self.init_weight()
  50. def forward(self, t1, t2):
  51. f1 = self.extract(t1)
  52. f2 = self.extract(t2)
  53. f1, f2 = self.attend(f1, f2)
  54. y = paddle.abs(f1 - f2)
  55. y = F.interpolate(
  56. y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
  57. pred = self.conv_out(y)
  58. return [pred]
  59. def init_weight(self):
  60. # Do nothing here as the encoder and decoder weights have already been initialized.
  61. # Note however that currently self.attend and self.conv_out use the default initilization method.
  62. pass
  63. def build_feat_extractor(in_ch, width):
  64. return nn.Sequential(Backbone(in_ch, 'resnet18'), Decoder(width))
  65. def build_sta_module(in_ch, att_type, ds):
  66. if att_type == 'BAM':
  67. return Attention(BAM(in_ch, ds))
  68. elif att_type == 'PAM':
  69. return Attention(PAM(in_ch, ds))
  70. else:
  71. raise ValueError
  72. class Backbone(nn.Layer, KaimingInitMixin):
  73. def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)):
  74. super(Backbone, self).__init__()
  75. if arch == 'resnet18':
  76. self.resnet = resnet.resnet18(
  77. pretrained=pretrained,
  78. strides=strides,
  79. norm_layer=get_norm_layer())
  80. elif arch == 'resnet34':
  81. self.resnet = resnet.resnet34(
  82. pretrained=pretrained,
  83. strides=strides,
  84. norm_layer=get_norm_layer())
  85. elif arch == 'resnet50':
  86. self.resnet = resnet.resnet50(
  87. pretrained=pretrained,
  88. strides=strides,
  89. norm_layer=get_norm_layer())
  90. else:
  91. raise ValueError
  92. self._trim_resnet()
  93. if in_ch != 3:
  94. self.resnet.conv1 = nn.Conv2D(
  95. in_ch,
  96. 64,
  97. kernel_size=7,
  98. stride=strides[0],
  99. padding=3,
  100. bias_attr=False)
  101. if not pretrained:
  102. self.init_weight()
  103. def forward(self, x):
  104. x = self.resnet.conv1(x)
  105. x = self.resnet.bn1(x)
  106. x = self.resnet.relu(x)
  107. x = self.resnet.maxpool(x)
  108. x1 = self.resnet.layer1(x)
  109. x2 = self.resnet.layer2(x1)
  110. x3 = self.resnet.layer3(x2)
  111. x4 = self.resnet.layer4(x3)
  112. return x1, x2, x3, x4
  113. def _trim_resnet(self):
  114. self.resnet.avgpool = Identity()
  115. self.resnet.fc = Identity()
  116. class Decoder(nn.Layer, KaimingInitMixin):
  117. def __init__(self, f_ch):
  118. super(Decoder, self).__init__()
  119. self.dr1 = Conv1x1(64, 96, norm=True, act=True)
  120. self.dr2 = Conv1x1(128, 96, norm=True, act=True)
  121. self.dr3 = Conv1x1(256, 96, norm=True, act=True)
  122. self.dr4 = Conv1x1(512, 96, norm=True, act=True)
  123. self.conv_out = nn.Sequential(
  124. Conv3x3(
  125. 384, 256, norm=True, act=True),
  126. nn.Dropout(0.5),
  127. Conv1x1(
  128. 256, f_ch, norm=True, act=True))
  129. self.init_weight()
  130. def forward(self, feats):
  131. f1 = self.dr1(feats[0])
  132. f2 = self.dr2(feats[1])
  133. f3 = self.dr3(feats[2])
  134. f4 = self.dr4(feats[3])
  135. f2 = F.interpolate(
  136. f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
  137. f3 = F.interpolate(
  138. f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
  139. f4 = F.interpolate(
  140. f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
  141. x = paddle.concat([f1, f2, f3, f4], axis=1)
  142. y = self.conv_out(x)
  143. return y
  144. class BAM(nn.Layer):
  145. def __init__(self, in_ch, ds):
  146. super(BAM, self).__init__()
  147. self.ds = ds
  148. self.pool = nn.AvgPool2D(self.ds)
  149. self.val_ch = in_ch
  150. self.key_ch = in_ch // 8
  151. self.conv_q = Conv1x1(in_ch, self.key_ch)
  152. self.conv_k = Conv1x1(in_ch, self.key_ch)
  153. self.conv_v = Conv1x1(in_ch, self.val_ch)
  154. self.softmax = nn.Softmax(axis=-1)
  155. def forward(self, x):
  156. x = x.flatten(-2)
  157. x_rs = self.pool(x)
  158. b, c, h, w = paddle.shape(x_rs)
  159. query = self.conv_q(x_rs).reshape((b, -1, h * w)).transpose((0, 2, 1))
  160. key = self.conv_k(x_rs).reshape((b, -1, h * w))
  161. energy = paddle.bmm(query, key)
  162. energy = (self.key_ch**(-0.5)) * energy
  163. attention = self.softmax(energy)
  164. value = self.conv_v(x_rs).reshape((b, -1, w * h))
  165. out = paddle.bmm(value, attention.transpose((0, 2, 1)))
  166. out = out.reshape((b, c, h, w))
  167. out = F.interpolate(out, scale_factor=self.ds)
  168. out = out + x
  169. return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
  170. class PAMBlock(nn.Layer):
  171. def __init__(self, in_ch, scale=1, ds=1):
  172. super(PAMBlock, self).__init__()
  173. self.scale = scale
  174. self.ds = ds
  175. self.pool = nn.AvgPool2D(self.ds)
  176. self.val_ch = in_ch
  177. self.key_ch = in_ch // 8
  178. self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)
  179. self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)
  180. self.conv_v = Conv1x1(in_ch, self.val_ch)
  181. def forward(self, x):
  182. x_rs = self.pool(x)
  183. # Get query, key, and value.
  184. query = self.conv_q(x_rs)
  185. key = self.conv_k(x_rs)
  186. value = self.conv_v(x_rs)
  187. # Split the whole image into subregions.
  188. b, c, h, w = x_rs.shape
  189. query = self._split_subregions(query)
  190. key = self._split_subregions(key)
  191. value = self._split_subregions(value)
  192. # Perform subregion-wise attention.
  193. out = self._attend(query, key, value)
  194. # Stack subregions to reconstruct the whole image.
  195. out = self._recons_whole(out, b, c, h, w)
  196. out = F.interpolate(out, scale_factor=self.ds)
  197. return out
  198. def _attend(self, query, key, value):
  199. energy = paddle.bmm(query.transpose((0, 2, 1)),
  200. key) # batch matrix multiplication
  201. energy = (self.key_ch**(-0.5)) * energy
  202. attention = F.softmax(energy, axis=-1)
  203. out = paddle.bmm(value, attention.transpose((0, 2, 1)))
  204. return out
  205. def _split_subregions(self, x):
  206. b, c, h, w = x.shape
  207. assert h % self.scale == 0 and w % self.scale == 0
  208. x = x.reshape(
  209. (b, c, self.scale, h // self.scale, self.scale, w // self.scale))
  210. x = x.transpose((0, 2, 4, 1, 3, 5))
  211. x = x.reshape((b * self.scale * self.scale, c, -1))
  212. return x
  213. def _recons_whole(self, x, b, c, h, w):
  214. x = x.reshape(
  215. (b, self.scale, self.scale, c, h // self.scale, w // self.scale))
  216. x = x.transpose((0, 3, 1, 4, 2, 5)).reshape((b, c, h, w))
  217. return x
  218. class PAM(nn.Layer):
  219. def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)):
  220. super(PAM, self).__init__()
  221. self.stages = nn.LayerList(
  222. [PAMBlock(
  223. in_ch, scale=s, ds=ds) for s in scales])
  224. self.conv_out = Conv1x1(in_ch * len(scales), in_ch, bias=False)
  225. def forward(self, x):
  226. x = x.flatten(-2)
  227. res = [stage(x) for stage in self.stages]
  228. out = self.conv_out(paddle.concat(res, axis=1))
  229. return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
  230. class Attention(nn.Layer):
  231. def __init__(self, att):
  232. super(Attention, self).__init__()
  233. self.att = att
  234. def forward(self, x1, x2):
  235. x = paddle.stack([x1, x2], axis=-1)
  236. y = self.att(x)
  237. return y[..., 0], y[..., 1]