bit.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from .backbones import resnet
  19. from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
  20. from .param_init import KaimingInitMixin
  21. def calc_product(*args):
  22. if len(args) < 1:
  23. raise ValueError
  24. ret = args[0]
  25. for arg in args[1:]:
  26. ret *= arg
  27. return ret
  28. class BIT(nn.Layer):
  29. """
  30. The BIT implementation based on PaddlePaddle.
  31. The original article refers to
  32. H. Chen, et al., "Remote Sensing Image Change Detection With Transformers"
  33. (https://arxiv.org/abs/2103.00208).
  34. This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized.
  35. Args:
  36. in_channels (int): The number of bands of the input images.
  37. num_classes (int): The number of target classes.
  38. backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and
  39. 'resnet34' are supported. Default: 'resnet18'.
  40. n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}.
  41. Default: 4.
  42. use_tokenizer (bool, optional): Use a tokenizer or not. Default: True.
  43. token_len (int, optional): The length of input tokens. Default: 4.
  44. pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max'
  45. for global max pooling and 'avg' for global average pooling. Default: 'max'.
  46. pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False.
  47. Default: 2.
  48. enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the
  49. encoder. Default: True.
  50. enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1
  51. enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64.
  52. dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8.
  53. dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8.
  54. Raises:
  55. ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5.
  56. """
  57. def __init__(self,
  58. in_channels,
  59. num_classes,
  60. backbone='resnet18',
  61. n_stages=4,
  62. use_tokenizer=True,
  63. token_len=4,
  64. pool_mode='max',
  65. pool_size=2,
  66. enc_with_pos=True,
  67. enc_depth=1,
  68. enc_head_dim=64,
  69. dec_depth=8,
  70. dec_head_dim=8,
  71. **backbone_kwargs):
  72. super(BIT, self).__init__()
  73. # TODO: reduce hard-coded parameters
  74. DIM = 32
  75. MLP_DIM = 2 * DIM
  76. EBD_DIM = DIM
  77. self.backbone = Backbone(
  78. in_channels,
  79. EBD_DIM,
  80. arch=backbone,
  81. n_stages=n_stages,
  82. **backbone_kwargs)
  83. self.use_tokenizer = use_tokenizer
  84. if not use_tokenizer:
  85. # If a tokenzier is not to be used,then downsample the feature maps.
  86. self.pool_size = pool_size
  87. self.pool_mode = pool_mode
  88. self.token_len = pool_size * pool_size
  89. else:
  90. self.conv_att = Conv1x1(32, token_len, bias=False)
  91. self.token_len = token_len
  92. self.enc_with_pos = enc_with_pos
  93. if enc_with_pos:
  94. self.enc_pos_embedding = self.create_parameter(
  95. shape=(1, self.token_len * 2, EBD_DIM),
  96. default_initializer=Normal())
  97. self.enc_depth = enc_depth
  98. self.dec_depth = dec_depth
  99. self.enc_head_dim = enc_head_dim
  100. self.dec_head_dim = dec_head_dim
  101. self.encoder = TransformerEncoder(
  102. dim=DIM,
  103. depth=enc_depth,
  104. n_heads=8,
  105. head_dim=enc_head_dim,
  106. mlp_dim=MLP_DIM,
  107. dropout_rate=0.)
  108. self.decoder = TransformerDecoder(
  109. dim=DIM,
  110. depth=dec_depth,
  111. n_heads=8,
  112. head_dim=dec_head_dim,
  113. mlp_dim=MLP_DIM,
  114. dropout_rate=0.,
  115. apply_softmax=True)
  116. self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
  117. self.conv_out = nn.Sequential(
  118. Conv3x3(
  119. EBD_DIM, EBD_DIM, norm=True, act=True),
  120. Conv3x3(EBD_DIM, num_classes))
  121. def _get_semantic_tokens(self, x):
  122. b, c = x.shape[:2]
  123. att_map = self.conv_att(x)
  124. att_map = att_map.reshape(
  125. (b, self.token_len, 1, calc_product(*att_map.shape[2:])))
  126. att_map = F.softmax(att_map, axis=-1)
  127. x = x.reshape((b, 1, c, att_map.shape[-1]))
  128. tokens = (x * att_map).sum(-1)
  129. return tokens
  130. def _get_reshaped_tokens(self, x):
  131. if self.pool_mode == 'max':
  132. x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size))
  133. elif self.pool_mode == 'avg':
  134. x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size))
  135. else:
  136. x = x
  137. tokens = x.transpose((0, 2, 3, 1)).flatten(1, 2)
  138. return tokens
  139. def encode(self, x):
  140. if self.enc_with_pos:
  141. x += self.enc_pos_embedding
  142. x = self.encoder(x)
  143. return x
  144. def decode(self, x, m):
  145. b, c, h, w = x.shape
  146. x = x.transpose((0, 2, 3, 1)).flatten(1, 2)
  147. x = self.decoder(x, m)
  148. x = x.transpose((0, 2, 1)).reshape((b, c, h, w))
  149. return x
  150. def forward(self, t1, t2):
  151. # Extract features via shared backbone.
  152. x1 = self.backbone(t1)
  153. x2 = self.backbone(t2)
  154. # Tokenization
  155. if self.use_tokenizer:
  156. token1 = self._get_semantic_tokens(x1)
  157. token2 = self._get_semantic_tokens(x2)
  158. else:
  159. token1 = self._get_reshaped_tokens(x1)
  160. token2 = self._get_reshaped_tokens(x2)
  161. # Transformer encoder forward
  162. token = paddle.concat([token1, token2], axis=1)
  163. token = self.encode(token)
  164. token1, token2 = paddle.chunk(token, 2, axis=1)
  165. # Transformer decoder forward
  166. y1 = self.decode(x1, token1)
  167. y2 = self.decode(x2, token2)
  168. # Feature differencing
  169. y = paddle.abs(y1 - y2)
  170. y = self.upsample(y)
  171. # Classifier forward
  172. pred = self.conv_out(y)
  173. return [pred]
  174. def init_weight(self):
  175. # Use the default initialization method.
  176. pass
  177. class Residual(nn.Layer):
  178. def __init__(self, fn):
  179. super(Residual, self).__init__()
  180. self.fn = fn
  181. def forward(self, x, **kwargs):
  182. return self.fn(x, **kwargs) + x
  183. class Residual2(nn.Layer):
  184. def __init__(self, fn):
  185. super(Residual2, self).__init__()
  186. self.fn = fn
  187. def forward(self, x1, x2, **kwargs):
  188. return self.fn(x1, x2, **kwargs) + x1
  189. class PreNorm(nn.Layer):
  190. def __init__(self, dim, fn):
  191. super(PreNorm, self).__init__()
  192. self.norm = nn.LayerNorm(dim)
  193. self.fn = fn
  194. def forward(self, x, **kwargs):
  195. return self.fn(self.norm(x), **kwargs)
  196. class PreNorm2(nn.Layer):
  197. def __init__(self, dim, fn):
  198. super(PreNorm2, self).__init__()
  199. self.norm = nn.LayerNorm(dim)
  200. self.fn = fn
  201. def forward(self, x1, x2, **kwargs):
  202. return self.fn(self.norm(x1), self.norm(x2), **kwargs)
  203. class FeedForward(nn.Sequential):
  204. def __init__(self, dim, hidden_dim, dropout_rate=0.):
  205. super(FeedForward, self).__init__(
  206. nn.Linear(dim, hidden_dim),
  207. nn.GELU(),
  208. nn.Dropout(dropout_rate),
  209. nn.Linear(hidden_dim, dim), nn.Dropout(dropout_rate))
  210. class CrossAttention(nn.Layer):
  211. def __init__(self,
  212. dim,
  213. n_heads=8,
  214. head_dim=64,
  215. dropout_rate=0.,
  216. apply_softmax=True):
  217. super(CrossAttention, self).__init__()
  218. inner_dim = head_dim * n_heads
  219. self.n_heads = n_heads
  220. self.head_dim = head_dim
  221. self.scale = dim**-0.5
  222. self.apply_softmax = apply_softmax
  223. self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False)
  224. self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False)
  225. self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False)
  226. self.fc_out = nn.Sequential(
  227. nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate))
  228. def forward(self, x, ref):
  229. b, n = x.shape[:2]
  230. h = self.n_heads
  231. q = self.fc_q(x)
  232. k = self.fc_k(ref)
  233. v = self.fc_v(ref)
  234. q = q.reshape((b, n, h, self.head_dim)).transpose((0, 2, 1, 3))
  235. rn = ref.shape[1]
  236. k = k.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
  237. v = v.reshape((b, rn, h, self.head_dim)).transpose((0, 2, 1, 3))
  238. mult = paddle.matmul(q, k, transpose_y=True) * self.scale
  239. if self.apply_softmax:
  240. mult = F.softmax(mult, axis=-1)
  241. out = paddle.matmul(mult, v)
  242. out = out.transpose((0, 2, 1, 3)).flatten(2)
  243. return self.fc_out(out)
  244. class SelfAttention(CrossAttention):
  245. def forward(self, x):
  246. return super(SelfAttention, self).forward(x, x)
  247. class TransformerEncoder(nn.Layer):
  248. def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate):
  249. super(TransformerEncoder, self).__init__()
  250. self.layers = nn.LayerList([])
  251. for _ in range(depth):
  252. self.layers.append(
  253. nn.LayerList([
  254. Residual(
  255. PreNorm(dim,
  256. SelfAttention(dim, n_heads, head_dim,
  257. dropout_rate))),
  258. Residual(
  259. PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
  260. ]))
  261. def forward(self, x):
  262. for att, ff in self.layers:
  263. x = att(x)
  264. x = ff(x)
  265. return x
  266. class TransformerDecoder(nn.Layer):
  267. def __init__(self,
  268. dim,
  269. depth,
  270. n_heads,
  271. head_dim,
  272. mlp_dim,
  273. dropout_rate,
  274. apply_softmax=True):
  275. super(TransformerDecoder, self).__init__()
  276. self.layers = nn.LayerList([])
  277. for _ in range(depth):
  278. self.layers.append(
  279. nn.LayerList([
  280. Residual2(
  281. PreNorm2(dim,
  282. CrossAttention(dim, n_heads, head_dim,
  283. dropout_rate, apply_softmax))),
  284. Residual(
  285. PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
  286. ]))
  287. def forward(self, x, m):
  288. for att, ff in self.layers:
  289. x = att(x, m)
  290. x = ff(x)
  291. return x
  292. class Backbone(nn.Layer, KaimingInitMixin):
  293. def __init__(self,
  294. in_ch,
  295. out_ch=32,
  296. arch='resnet18',
  297. pretrained=True,
  298. n_stages=5):
  299. super(Backbone, self).__init__()
  300. expand = 1
  301. strides = (2, 1, 2, 1, 1)
  302. if arch == 'resnet18':
  303. self.resnet = resnet.resnet18(
  304. pretrained=pretrained,
  305. strides=strides,
  306. norm_layer=get_norm_layer())
  307. elif arch == 'resnet34':
  308. self.resnet = resnet.resnet34(
  309. pretrained=pretrained,
  310. strides=strides,
  311. norm_layer=get_norm_layer())
  312. else:
  313. raise ValueError
  314. self.n_stages = n_stages
  315. if self.n_stages == 5:
  316. itm_ch = 512 * expand
  317. elif self.n_stages == 4:
  318. itm_ch = 256 * expand
  319. elif self.n_stages == 3:
  320. itm_ch = 128 * expand
  321. else:
  322. raise ValueError
  323. self.upsample = nn.Upsample(scale_factor=2)
  324. self.conv_out = Conv3x3(itm_ch, out_ch)
  325. self._trim_resnet()
  326. if in_ch != 3:
  327. self.resnet.conv1 = nn.Conv2D(
  328. in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)
  329. if not pretrained:
  330. self.init_weight()
  331. def forward(self, x):
  332. y = self.resnet.conv1(x)
  333. y = self.resnet.bn1(y)
  334. y = self.resnet.relu(y)
  335. y = self.resnet.maxpool(y)
  336. y = self.resnet.layer1(y)
  337. y = self.resnet.layer2(y)
  338. y = self.resnet.layer3(y)
  339. y = self.resnet.layer4(y)
  340. y = self.upsample(y)
  341. return self.conv_out(y)
  342. def _trim_resnet(self):
  343. if self.n_stages > 5:
  344. raise ValueError
  345. if self.n_stages < 5:
  346. self.resnet.layer4 = Identity()
  347. if self.n_stages <= 3:
  348. self.resnet.layer3 = Identity()
  349. self.resnet.avgpool = Identity()
  350. self.resnet.fc = Identity()