changeformer.py 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. import math
  16. from functools import partial
  17. import paddle as pd
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from .layers.pd_timm import DropPath, to_2tuple
  21. def calc_product(*args):
  22. if len(args) < 1:
  23. raise ValueError
  24. ret = args[0]
  25. for arg in args[1:]:
  26. ret *= arg
  27. return ret
  28. class ConvBlock(pd.nn.Layer):
  29. def __init__(self,
  30. input_size,
  31. output_size,
  32. kernel_size=3,
  33. stride=1,
  34. padding=1,
  35. bias=True,
  36. activation='prelu',
  37. norm=None):
  38. super(ConvBlock, self).__init__()
  39. self.conv = pd.nn.Conv2D(
  40. input_size,
  41. output_size,
  42. kernel_size,
  43. stride,
  44. padding,
  45. bias_attr=bias)
  46. self.norm = norm
  47. if self.norm == 'batch':
  48. self.bn = pd.nn.BatchNorm2D(output_size)
  49. elif self.norm == 'instance':
  50. self.bn = pd.nn.InstanceNorm2D(output_size)
  51. self.activation = activation
  52. if self.activation == 'relu':
  53. self.act = pd.nn.ReLU(True)
  54. elif self.activation == 'prelu':
  55. self.act = pd.nn.PReLU()
  56. elif self.activation == 'lrelu':
  57. self.act = pd.nn.LeakyReLU(0.2, True)
  58. elif self.activation == 'tanh':
  59. self.act = pd.nn.Tanh()
  60. elif self.activation == 'sigmoid':
  61. self.act = pd.nn.Sigmoid()
  62. def forward(self, x):
  63. if self.norm is not None:
  64. out = self.bn(self.conv(x))
  65. else:
  66. out = self.conv(x)
  67. if self.activation != 'no':
  68. return self.act(out)
  69. else:
  70. return out
  71. class DeconvBlock(pd.nn.Layer):
  72. def __init__(self,
  73. input_size,
  74. output_size,
  75. kernel_size=4,
  76. stride=2,
  77. padding=1,
  78. bias=True,
  79. activation='prelu',
  80. norm=None):
  81. super(DeconvBlock, self).__init__()
  82. self.deconv = pd.nn.Conv2DTranspose(
  83. input_size,
  84. output_size,
  85. kernel_size,
  86. stride,
  87. padding,
  88. bias_attr=bias)
  89. self.norm = norm
  90. if self.norm == 'batch':
  91. self.bn = pd.nn.BatchNorm2D(output_size)
  92. elif self.norm == 'instance':
  93. self.bn = pd.nn.InstanceNorm2D(output_size)
  94. self.activation = activation
  95. if self.activation == 'relu':
  96. self.act = pd.nn.ReLU(True)
  97. elif self.activation == 'prelu':
  98. self.act = pd.nn.PReLU()
  99. elif self.activation == 'lrelu':
  100. self.act = pd.nn.LeakyReLU(0.2, True)
  101. elif self.activation == 'tanh':
  102. self.act = pd.nn.Tanh()
  103. elif self.activation == 'sigmoid':
  104. self.act = pd.nn.Sigmoid()
  105. def forward(self, x):
  106. if self.norm is not None:
  107. out = self.bn(self.deconv(x))
  108. else:
  109. out = self.deconv(x)
  110. if self.activation is not None:
  111. return self.act(out)
  112. else:
  113. return out
  114. class ConvLayer(nn.Layer):
  115. def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
  116. super(ConvLayer, self).__init__()
  117. self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size, stride,
  118. padding)
  119. def forward(self, x):
  120. out = self.conv2d(x)
  121. return out
  122. class UpsampleConvLayer(pd.nn.Layer):
  123. def __init__(self, in_channels, out_channels, kernel_size, stride):
  124. super(UpsampleConvLayer, self).__init__()
  125. self.conv2d = nn.Conv2DTranspose(
  126. in_channels, out_channels, kernel_size, stride=stride, padding=1)
  127. def forward(self, x):
  128. out = self.conv2d(x)
  129. return out
  130. class ResidualBlock(pd.nn.Layer):
  131. def __init__(self, channels):
  132. super(ResidualBlock, self).__init__()
  133. self.conv1 = ConvLayer(
  134. channels, channels, kernel_size=3, stride=1, padding=1)
  135. self.conv2 = ConvLayer(
  136. channels, channels, kernel_size=3, stride=1, padding=1)
  137. self.relu = nn.ReLU()
  138. def forward(self, x):
  139. residual = x
  140. out = self.relu(self.conv1(x))
  141. out = self.conv2(out) * 0.1
  142. out = pd.add(out, residual)
  143. return out
  144. class ChangeFormer(nn.Layer):
  145. """
  146. The ChangeFormer implementation based on PaddlePaddle.
  147. The original article refers to
  148. Wele Gedara Chaminda Bandara, Vishal M. Patel., "A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION"
  149. (https://arxiv.org/pdf/2201.01293.pdf).
  150. Args:
  151. in_channels (int): Number of bands of the input images. Default: 3.
  152. num_classes (int): Number of target classes. Default: 2.
  153. decoder_softmax (bool, optional): Use softmax after decode or not. Default: False.
  154. embed_dim (int, optional): Embedding dimension of each decoder head. Default: 256.
  155. """
  156. def __init__(self,
  157. in_channels=3,
  158. num_classes=2,
  159. decoder_softmax=False,
  160. embed_dim=256):
  161. super(ChangeFormer, self).__init__()
  162. # Transformer Encoder
  163. self.embed_dims = [64, 128, 320, 512]
  164. self.depths = [3, 3, 4, 3]
  165. self.embedding_dim = embed_dim
  166. self.drop_rate = 0.1
  167. self.attn_drop = 0.1
  168. self.drop_path_rate = 0.1
  169. self.Tenc_x2 = EncoderTransformer_v3(
  170. img_size=256,
  171. patch_size=7,
  172. in_chans=in_channels,
  173. num_classes=num_classes,
  174. embed_dims=self.embed_dims,
  175. num_heads=[1, 2, 4, 8],
  176. mlp_ratios=[4, 4, 4, 4],
  177. qkv_bias=True,
  178. qk_scale=None,
  179. drop_rate=self.drop_rate,
  180. attn_drop_rate=self.attn_drop,
  181. drop_path_rate=self.drop_path_rate,
  182. norm_layer=partial(
  183. nn.LayerNorm, epsilon=1e-6),
  184. depths=self.depths,
  185. sr_ratios=[8, 4, 2, 1])
  186. # Transformer Decoder
  187. self.TDec_x2 = DecoderTransformer_v3(
  188. input_transform='multiple_select',
  189. in_index=[0, 1, 2, 3],
  190. align_corners=False,
  191. in_channels=self.embed_dims,
  192. embedding_dim=self.embedding_dim,
  193. output_nc=num_classes,
  194. decoder_softmax=decoder_softmax,
  195. feature_strides=[2, 4, 8, 16])
  196. def forward(self, x1, x2):
  197. [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
  198. cp = self.TDec_x2(fx1, fx2)
  199. return [cp]
  200. # Transormer Ecoder with x2, x4, x8, x16 scales
  201. class EncoderTransformer_v3(nn.Layer):
  202. def __init__(self,
  203. img_size=256,
  204. patch_size=3,
  205. in_chans=3,
  206. num_classes=2,
  207. embed_dims=[32, 64, 128, 256],
  208. num_heads=[2, 2, 4, 8],
  209. mlp_ratios=[4, 4, 4, 4],
  210. qkv_bias=True,
  211. qk_scale=None,
  212. drop_rate=0.,
  213. attn_drop_rate=0.,
  214. drop_path_rate=0.,
  215. norm_layer=nn.LayerNorm,
  216. depths=[3, 3, 6, 18],
  217. sr_ratios=[8, 4, 2, 1]):
  218. super().__init__()
  219. self.num_classes = num_classes
  220. self.depths = depths
  221. self.embed_dims = embed_dims
  222. # Patch embedding definitions
  223. self.patch_embed1 = OverlapPatchEmbed(
  224. img_size=img_size,
  225. patch_size=7,
  226. stride=4,
  227. in_chans=in_chans,
  228. embed_dim=embed_dims[0])
  229. self.patch_embed2 = OverlapPatchEmbed(
  230. img_size=img_size // 4,
  231. patch_size=patch_size,
  232. stride=2,
  233. in_chans=embed_dims[0],
  234. embed_dim=embed_dims[1])
  235. self.patch_embed3 = OverlapPatchEmbed(
  236. img_size=img_size // 8,
  237. patch_size=patch_size,
  238. stride=2,
  239. in_chans=embed_dims[1],
  240. embed_dim=embed_dims[2])
  241. self.patch_embed4 = OverlapPatchEmbed(
  242. img_size=img_size // 16,
  243. patch_size=patch_size,
  244. stride=2,
  245. in_chans=embed_dims[2],
  246. embed_dim=embed_dims[3])
  247. # Stage-1 (x1/4 scale)
  248. dpr = [x.item() for x in pd.linspace(0, drop_path_rate, sum(depths))]
  249. cur = 0
  250. self.block1 = nn.LayerList([
  251. Block(
  252. dim=embed_dims[0],
  253. num_heads=num_heads[0],
  254. mlp_ratio=mlp_ratios[0],
  255. qkv_bias=qkv_bias,
  256. qk_scale=qk_scale,
  257. drop=drop_rate,
  258. attn_drop=attn_drop_rate,
  259. drop_path=dpr[cur + i],
  260. norm_layer=norm_layer,
  261. sr_ratio=sr_ratios[0]) for i in range(depths[0])
  262. ])
  263. self.norm1 = norm_layer(embed_dims[0])
  264. # Stage-2 (x1/8 scale)
  265. cur += depths[0]
  266. self.block2 = nn.LayerList([
  267. Block(
  268. dim=embed_dims[1],
  269. num_heads=num_heads[1],
  270. mlp_ratio=mlp_ratios[1],
  271. qkv_bias=qkv_bias,
  272. qk_scale=qk_scale,
  273. drop=drop_rate,
  274. attn_drop=attn_drop_rate,
  275. drop_path=dpr[cur + i],
  276. norm_layer=norm_layer,
  277. sr_ratio=sr_ratios[1]) for i in range(depths[1])
  278. ])
  279. self.norm2 = norm_layer(embed_dims[1])
  280. # Stage-3 (x1/16 scale)
  281. cur += depths[1]
  282. self.block3 = nn.LayerList([
  283. Block(
  284. dim=embed_dims[2],
  285. num_heads=num_heads[2],
  286. mlp_ratio=mlp_ratios[2],
  287. qkv_bias=qkv_bias,
  288. qk_scale=qk_scale,
  289. drop=drop_rate,
  290. attn_drop=attn_drop_rate,
  291. drop_path=dpr[cur + i],
  292. norm_layer=norm_layer,
  293. sr_ratio=sr_ratios[2]) for i in range(depths[2])
  294. ])
  295. self.norm3 = norm_layer(embed_dims[2])
  296. # Stage-4 (x1/32 scale)
  297. cur += depths[2]
  298. self.block4 = nn.LayerList([
  299. Block(
  300. dim=embed_dims[3],
  301. num_heads=num_heads[3],
  302. mlp_ratio=mlp_ratios[3],
  303. qkv_bias=qkv_bias,
  304. qk_scale=qk_scale,
  305. drop=drop_rate,
  306. attn_drop=attn_drop_rate,
  307. drop_path=dpr[cur + i],
  308. norm_layer=norm_layer,
  309. sr_ratio=sr_ratios[3]) for i in range(depths[3])
  310. ])
  311. self.norm4 = norm_layer(embed_dims[3])
  312. self.apply(self._init_weights)
  313. def _init_weights(self, m):
  314. if isinstance(m, nn.Linear):
  315. trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
  316. trunc_normal_op(m.weight)
  317. if isinstance(m, nn.Linear) and m.bias is not None:
  318. init_bias = nn.initializer.Constant(0)
  319. init_bias(m.bias)
  320. elif isinstance(m, nn.LayerNorm):
  321. init_bias = nn.initializer.Constant(0)
  322. init_bias(m.bias)
  323. init_weight = nn.initializer.Constant(1.0)
  324. init_weight(m.weight)
  325. elif isinstance(m, nn.Conv2D):
  326. fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  327. fan_out //= m._groups
  328. init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
  329. init_weight(m.weight)
  330. if m.bias is not None:
  331. init_bias = nn.initializer.Constant(0)
  332. init_bias(m.bias)
  333. def reset_drop_path(self, drop_path_rate):
  334. dpr = [
  335. x.item() for x in pd.linspace(0, drop_path_rate, sum(self.depths))
  336. ]
  337. cur = 0
  338. for i in range(self.depths[0]):
  339. self.block1[i].drop_path.drop_prob = dpr[cur + i]
  340. cur += self.depths[0]
  341. for i in range(self.depths[1]):
  342. self.block2[i].drop_path.drop_prob = dpr[cur + i]
  343. cur += self.depths[1]
  344. for i in range(self.depths[2]):
  345. self.block3[i].drop_path.drop_prob = dpr[cur + i]
  346. cur += self.depths[2]
  347. for i in range(self.depths[3]):
  348. self.block4[i].drop_path.drop_prob = dpr[cur + i]
  349. def forward_features(self, x):
  350. B = x.shape[0]
  351. outs = []
  352. # Stage 1
  353. x1, H1, W1 = self.patch_embed1(x)
  354. for i, blk in enumerate(self.block1):
  355. x1 = blk(x1, H1, W1)
  356. x1 = self.norm1(x1)
  357. x1 = x1.reshape(
  358. [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
  359. [0, 3, 1, 2])
  360. outs.append(x1)
  361. # Stage 2
  362. x1, H1, W1 = self.patch_embed2(x1)
  363. for i, blk in enumerate(self.block2):
  364. x1 = blk(x1, H1, W1)
  365. x1 = self.norm2(x1)
  366. x1 = x1.reshape(
  367. [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
  368. [0, 3, 1, 2])
  369. outs.append(x1)
  370. # Stage 3
  371. x1, H1, W1 = self.patch_embed3(x1)
  372. for i, blk in enumerate(self.block3):
  373. x1 = blk(x1, H1, W1)
  374. x1 = self.norm3(x1)
  375. x1 = x1.reshape(
  376. [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
  377. [0, 3, 1, 2])
  378. outs.append(x1)
  379. # Stage 4
  380. x1, H1, W1 = self.patch_embed4(x1)
  381. for i, blk in enumerate(self.block4):
  382. x1 = blk(x1, H1, W1)
  383. x1 = self.norm4(x1)
  384. x1 = x1.reshape(
  385. [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose(
  386. [0, 3, 1, 2])
  387. outs.append(x1)
  388. return outs
  389. def forward(self, x):
  390. x = self.forward_features(x)
  391. return x
  392. class DecoderTransformer_v3(nn.Layer):
  393. """
  394. Transformer Decoder
  395. """
  396. def __init__(self,
  397. input_transform='multiple_select',
  398. in_index=[0, 1, 2, 3],
  399. align_corners=True,
  400. in_channels=[32, 64, 128, 256],
  401. embedding_dim=64,
  402. output_nc=2,
  403. decoder_softmax=False,
  404. feature_strides=[2, 4, 8, 16]):
  405. super(DecoderTransformer_v3, self).__init__()
  406. assert len(feature_strides) == len(in_channels)
  407. assert min(feature_strides) == feature_strides[0]
  408. # Settings
  409. self.feature_strides = feature_strides
  410. self.input_transform = input_transform
  411. self.in_index = in_index
  412. self.align_corners = align_corners
  413. self.in_channels = in_channels
  414. self.embedding_dim = embedding_dim
  415. self.output_nc = output_nc
  416. c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
  417. # MLP decoder heads
  418. self.linear_c4 = MLP(input_dim=c4_in_channels,
  419. embed_dim=self.embedding_dim)
  420. self.linear_c3 = MLP(input_dim=c3_in_channels,
  421. embed_dim=self.embedding_dim)
  422. self.linear_c2 = MLP(input_dim=c2_in_channels,
  423. embed_dim=self.embedding_dim)
  424. self.linear_c1 = MLP(input_dim=c1_in_channels,
  425. embed_dim=self.embedding_dim)
  426. # Convolutional Difference Layers
  427. self.diff_c4 = conv_diff(
  428. in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
  429. self.diff_c3 = conv_diff(
  430. in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
  431. self.diff_c2 = conv_diff(
  432. in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
  433. self.diff_c1 = conv_diff(
  434. in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
  435. # Take outputs from middle of the encoder
  436. self.make_pred_c4 = make_prediction(
  437. in_channels=self.embedding_dim, out_channels=self.output_nc)
  438. self.make_pred_c3 = make_prediction(
  439. in_channels=self.embedding_dim, out_channels=self.output_nc)
  440. self.make_pred_c2 = make_prediction(
  441. in_channels=self.embedding_dim, out_channels=self.output_nc)
  442. self.make_pred_c1 = make_prediction(
  443. in_channels=self.embedding_dim, out_channels=self.output_nc)
  444. # Final linear fusion layer
  445. self.linear_fuse = nn.Sequential(
  446. nn.Conv2D(
  447. in_channels=self.embedding_dim * len(in_channels),
  448. out_channels=self.embedding_dim,
  449. kernel_size=1),
  450. nn.BatchNorm2D(self.embedding_dim))
  451. # Final predction head
  452. self.convd2x = UpsampleConvLayer(
  453. self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
  454. self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
  455. self.convd1x = UpsampleConvLayer(
  456. self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
  457. self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
  458. self.change_probability = ConvLayer(
  459. self.embedding_dim,
  460. self.output_nc,
  461. kernel_size=3,
  462. stride=1,
  463. padding=1)
  464. # Final activation
  465. self.output_softmax = decoder_softmax
  466. self.active = nn.Sigmoid()
  467. def _transform_inputs(self, inputs):
  468. """
  469. Transform inputs for decoder.
  470. Args:
  471. inputs (list[Tensor]): List of multi-level img features.
  472. Returns:
  473. Tensor: The transformed inputs
  474. """
  475. if self.input_transform == 'resize_concat':
  476. inputs = [inputs[i] for i in self.in_index]
  477. upsampled_inputs = [
  478. resize(
  479. input=x,
  480. size=inputs[0].shape[2:],
  481. mode='bilinear',
  482. align_corners=self.align_corners) for x in inputs
  483. ]
  484. inputs = pd.concat(upsampled_inputs, dim=1)
  485. elif self.input_transform == 'multiple_select':
  486. inputs = [inputs[i] for i in self.in_index]
  487. else:
  488. inputs = inputs[self.in_index]
  489. return inputs
  490. def forward(self, inputs1, inputs2):
  491. # Transforming encoder features (select layers)
  492. x_1 = self._transform_inputs(inputs1) # len=4, 1/2, 1/4, 1/8, 1/16
  493. x_2 = self._transform_inputs(inputs2) # len=4, 1/2, 1/4, 1/8, 1/16
  494. # img1 and img2 features
  495. c1_1, c2_1, c3_1, c4_1 = x_1
  496. c1_2, c2_2, c3_2, c4_2 = x_2
  497. ############## MLP decoder on C1-C4 ###########
  498. n, _, h, w = c4_1.shape
  499. outputs = []
  500. # Stage 4: x1/32 scale
  501. _c4_1 = self.linear_c4(c4_1).transpose([0, 2, 1])
  502. _c4_1 = _c4_1.reshape([
  503. n, calc_product(*_c4_1.shape[1:]) //
  504. (c4_1.shape[2] * c4_1.shape[3]), c4_1.shape[2], c4_1.shape[3]
  505. ])
  506. _c4_2 = self.linear_c4(c4_2).transpose([0, 2, 1])
  507. _c4_2 = _c4_2.reshape([
  508. n, calc_product(*_c4_2.shape[1:]) //
  509. (c4_2.shape[2] * c4_2.shape[3]), c4_2.shape[2], c4_2.shape[3]
  510. ])
  511. _c4 = self.diff_c4(pd.concat((_c4_1, _c4_2), axis=1))
  512. p_c4 = self.make_pred_c4(_c4)
  513. outputs.append(p_c4)
  514. _c4_up = resize(
  515. _c4, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
  516. # Stage 3: x1/16 scale
  517. _c3_1 = self.linear_c3(c3_1).transpose([0, 2, 1])
  518. _c3_1 = _c3_1.reshape([
  519. n, calc_product(*_c3_1.shape[1:]) //
  520. (c3_1.shape[2] * c3_1.shape[3]), c3_1.shape[2], c3_1.shape[3]
  521. ])
  522. _c3_2 = self.linear_c3(c3_2).transpose([0, 2, 1])
  523. _c3_2 = _c3_2.reshape([
  524. n, calc_product(*_c3_2.shape[1:]) //
  525. (c3_2.shape[2] * c3_2.shape[3]), c3_2.shape[2], c3_2.shape[3]
  526. ])
  527. _c3 = self.diff_c3(pd.concat((_c3_1, _c3_2), axis=1)) + \
  528. F.interpolate(_c4, scale_factor=2, mode="bilinear")
  529. p_c3 = self.make_pred_c3(_c3)
  530. outputs.append(p_c3)
  531. _c3_up = resize(
  532. _c3, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
  533. # Stage 2: x1/8 scale
  534. _c2_1 = self.linear_c2(c2_1).transpose([0, 2, 1])
  535. _c2_1 = _c2_1.reshape([
  536. n, calc_product(*_c2_1.shape[1:]) //
  537. (c2_1.shape[2] * c2_1.shape[3]), c2_1.shape[2], c2_1.shape[3]
  538. ])
  539. _c2_2 = self.linear_c2(c2_2).transpose([0, 2, 1])
  540. _c2_2 = _c2_2.reshape([
  541. n, calc_product(*_c2_2.shape[1:]) //
  542. (c2_2.shape[2] * c2_2.shape[3]), c2_2.shape[2], c2_2.shape[3]
  543. ])
  544. _c2 = self.diff_c2(pd.concat((_c2_1, _c2_2), axis=1)) + \
  545. F.interpolate(_c3, scale_factor=2, mode="bilinear")
  546. p_c2 = self.make_pred_c2(_c2)
  547. outputs.append(p_c2)
  548. _c2_up = resize(
  549. _c2, size=c1_2.shape[2:], mode='bilinear', align_corners=False)
  550. # Stage 1: x1/4 scale
  551. _c1_1 = self.linear_c1(c1_1).transpose([0, 2, 1])
  552. _c1_1 = _c1_1.reshape([
  553. n, calc_product(*_c1_1.shape[1:]) //
  554. (c1_1.shape[2] * c1_1.shape[3]), c1_1.shape[2], c1_1.shape[3]
  555. ])
  556. _c1_2 = self.linear_c1(c1_2).transpose([0, 2, 1])
  557. _c1_2 = _c1_2.reshape([
  558. n, calc_product(*_c1_2.shape[1:]) //
  559. (c1_2.shape[2] * c1_2.shape[3]), c1_2.shape[2], c1_2.shape[3]
  560. ])
  561. _c1 = self.diff_c1(pd.concat((_c1_1, _c1_2), axis=1)) + \
  562. F.interpolate(_c2, scale_factor=2, mode="bilinear")
  563. p_c1 = self.make_pred_c1(_c1)
  564. outputs.append(p_c1)
  565. # Linear Fusion of difference image from all scales
  566. _c = self.linear_fuse(pd.concat((_c4_up, _c3_up, _c2_up, _c1), axis=1))
  567. # Upsampling x2 (x1/2 scale)
  568. x = self.convd2x(_c)
  569. # Residual block
  570. x = self.dense_2x(x)
  571. # Upsampling x2 (x1 scale)
  572. x = self.convd1x(x)
  573. # Residual block
  574. x = self.dense_1x(x)
  575. # Final prediction
  576. cp = self.change_probability(x)
  577. outputs.append(cp)
  578. if self.output_softmax:
  579. temp = outputs
  580. outputs = []
  581. for pred in temp:
  582. outputs.append(self.active(pred))
  583. return outputs[-1]
  584. class OverlapPatchEmbed(nn.Layer):
  585. """
  586. Image to Patch Embedding
  587. """
  588. def __init__(self,
  589. img_size=224,
  590. patch_size=7,
  591. stride=4,
  592. in_chans=3,
  593. embed_dim=768):
  594. super().__init__()
  595. img_size = to_2tuple(img_size)
  596. patch_size = to_2tuple(patch_size)
  597. self.img_size = img_size
  598. self.patch_size = patch_size
  599. self.H, self.W = img_size[0] // patch_size[0], img_size[
  600. 1] // patch_size[1]
  601. self.num_patches = self.H * self.W
  602. self.proj = nn.Conv2D(
  603. in_chans,
  604. embed_dim,
  605. kernel_size=patch_size,
  606. stride=stride,
  607. padding=(patch_size[0] // 2, patch_size[1] // 2))
  608. self.norm = nn.LayerNorm(embed_dim)
  609. self.apply(self._init_weights)
  610. def _init_weights(self, m):
  611. if isinstance(m, nn.Linear):
  612. trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
  613. trunc_normal_op(m.weight)
  614. if isinstance(m, nn.Linear) and m.bias is not None:
  615. init_bias = nn.initializer.Constant(0)
  616. init_bias(m.bias)
  617. elif isinstance(m, nn.LayerNorm):
  618. init_bias = nn.initializer.Constant(0)
  619. init_bias(m.bias)
  620. init_weight = nn.initializer.Constant(1.0)
  621. init_weight(m.weight)
  622. elif isinstance(m, nn.Conv2D):
  623. fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  624. fan_out //= m._groups
  625. init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
  626. init_weight(m.weight)
  627. if m.bias is not None:
  628. init_bias = nn.initializer.Constant(0)
  629. init_bias(m.bias)
  630. def forward(self, x):
  631. x = self.proj(x)
  632. _, _, H, W = x.shape
  633. x = x.flatten(2).transpose([0, 2, 1])
  634. x = self.norm(x)
  635. return x, H, W
  636. def resize(input,
  637. size=None,
  638. scale_factor=None,
  639. mode='nearest',
  640. align_corners=None,
  641. warning=True):
  642. if warning:
  643. if size is not None and align_corners:
  644. input_h, input_w = tuple(int(x) for x in input.shape[2:])
  645. output_h, output_w = tuple(int(x) for x in size)
  646. if output_h > input_h or output_w > output_h:
  647. if ((output_h > 1 and output_w > 1 and input_h > 1 and
  648. input_w > 1) and (output_h - 1) % (input_h - 1) and
  649. (output_w - 1) % (input_w - 1)):
  650. warnings.warn(
  651. f'When align_corners={align_corners}, '
  652. 'the output would more aligned if '
  653. f'input size {(input_h, input_w)} is `x+1` and '
  654. f'out size {(output_h, output_w)} is `nx+1`')
  655. return F.interpolate(input, size, scale_factor, mode, align_corners)
  656. class Mlp(nn.Layer):
  657. def __init__(self,
  658. in_features,
  659. hidden_features=None,
  660. out_features=None,
  661. act_layer=nn.GELU,
  662. drop=0.):
  663. super().__init__()
  664. out_features = out_features or in_features
  665. hidden_features = hidden_features or in_features
  666. self.fc1 = nn.Linear(in_features, hidden_features)
  667. self.dwconv = DWConv(hidden_features)
  668. self.act = act_layer()
  669. self.fc2 = nn.Linear(hidden_features, out_features)
  670. self.drop = nn.Dropout(drop)
  671. self.apply(self._init_weights)
  672. def _init_weights(self, m):
  673. if isinstance(m, nn.Linear):
  674. trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
  675. trunc_normal_op(m.weight)
  676. if isinstance(m, nn.Linear) and m.bias is not None:
  677. init_bias = nn.initializer.Constant(0)
  678. init_bias(m.bias)
  679. elif isinstance(m, nn.LayerNorm):
  680. init_bias = nn.initializer.Constant(0)
  681. init_bias(m.bias)
  682. init_weight = nn.initializer.Constant(1.0)
  683. init_weight(m.weight)
  684. elif isinstance(m, nn.Conv2D):
  685. fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  686. fan_out //= m._groups
  687. init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
  688. init_weight(m.weight)
  689. if m.bias is not None:
  690. init_bias = nn.initializer.Constant(0)
  691. init_bias(m.bias)
  692. def forward(self, x, H, W):
  693. x = self.fc1(x)
  694. x = self.dwconv(x, H, W)
  695. x = self.act(x)
  696. x = self.drop(x)
  697. x = self.fc2(x)
  698. x = self.drop(x)
  699. return x
  700. class Attention(nn.Layer):
  701. def __init__(self,
  702. dim,
  703. num_heads=8,
  704. qkv_bias=False,
  705. qk_scale=None,
  706. attn_drop=0.,
  707. proj_drop=0.,
  708. sr_ratio=1):
  709. super().__init__()
  710. assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
  711. self.dim = dim
  712. self.num_heads = num_heads
  713. head_dim = dim // num_heads
  714. self.scale = qk_scale or head_dim**-0.5
  715. self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
  716. self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
  717. self.attn_drop = nn.Dropout(attn_drop)
  718. self.proj = nn.Linear(dim, dim)
  719. self.proj_drop = nn.Dropout(proj_drop)
  720. self.sr_ratio = sr_ratio
  721. if sr_ratio > 1:
  722. self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
  723. self.norm = nn.LayerNorm(dim)
  724. self.apply(self._init_weights)
  725. def _init_weights(self, m):
  726. if isinstance(m, nn.Linear):
  727. trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
  728. trunc_normal_op(m.weight)
  729. if isinstance(m, nn.Linear) and m.bias is not None:
  730. init_bias = nn.initializer.Constant(0)
  731. init_bias(m.bias)
  732. elif isinstance(m, nn.LayerNorm):
  733. init_bias = nn.initializer.Constant(0)
  734. init_bias(m.bias)
  735. init_weight = nn.initializer.Constant(1.0)
  736. init_weight(m.weight)
  737. elif isinstance(m, nn.Conv2D):
  738. fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  739. fan_out //= m._groups
  740. init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
  741. init_weight(m.weight)
  742. if m.bias is not None:
  743. init_bias = nn.initializer.Constant(0)
  744. init_bias(m.bias)
  745. def forward(self, x, H, W):
  746. B, N, C = x.shape
  747. q = self.q(x).reshape([B, N, self.num_heads,
  748. C // self.num_heads]).transpose([0, 2, 1, 3])
  749. if self.sr_ratio > 1:
  750. x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
  751. x_ = self.sr(x_)
  752. x_ = x_.reshape([B, C, calc_product(*x_.shape[2:])]).transpose(
  753. [0, 2, 1])
  754. x_ = self.norm(x_)
  755. kv = self.kv(x_)
  756. kv = kv.reshape([
  757. B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
  758. C // self.num_heads
  759. ]).transpose([2, 0, 3, 1, 4])
  760. else:
  761. kv = self.kv(x)
  762. kv = kv.reshape([
  763. B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads,
  764. C // self.num_heads
  765. ]).transpose([2, 0, 3, 1, 4])
  766. k, v = kv[0], kv[1]
  767. attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
  768. attn = F.softmax(attn, axis=-1)
  769. attn = self.attn_drop(attn)
  770. x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
  771. x = self.proj(x)
  772. x = self.proj_drop(x)
  773. return x
  774. class Block(nn.Layer):
  775. def __init__(self,
  776. dim,
  777. num_heads,
  778. mlp_ratio=4.,
  779. qkv_bias=False,
  780. qk_scale=None,
  781. drop=0.,
  782. attn_drop=0.,
  783. drop_path=0.,
  784. act_layer=nn.GELU,
  785. norm_layer=nn.LayerNorm,
  786. sr_ratio=1):
  787. super().__init__()
  788. self.norm1 = norm_layer(dim)
  789. self.attn = Attention(
  790. dim,
  791. num_heads=num_heads,
  792. qkv_bias=qkv_bias,
  793. qk_scale=qk_scale,
  794. attn_drop=attn_drop,
  795. proj_drop=drop,
  796. sr_ratio=sr_ratio)
  797. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity(
  798. )
  799. self.norm2 = norm_layer(dim)
  800. mlp_hidden_dim = int(dim * mlp_ratio)
  801. self.mlp = Mlp(in_features=dim,
  802. hidden_features=mlp_hidden_dim,
  803. act_layer=act_layer,
  804. drop=drop)
  805. def _init_weights(self, m):
  806. if isinstance(m, nn.Linear):
  807. trunc_normal_op = nn.initializer.TruncatedNormal(std=.02)
  808. trunc_normal_op(m.weight)
  809. if isinstance(m, nn.Linear) and m.bias is not None:
  810. init_bias = nn.initializer.Constant(0)
  811. init_bias(m.bias)
  812. elif isinstance(m, nn.LayerNorm):
  813. init_bias = nn.initializer.Constant(0)
  814. init_bias(m.bias)
  815. init_weight = nn.initializer.Constant(1.0)
  816. init_weight(m.weight)
  817. elif isinstance(m, nn.Conv2D):
  818. fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  819. fan_out //= m._groups
  820. init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out))
  821. init_weight(m.weight)
  822. if m.bias is not None:
  823. init_bias = nn.initializer.Constant(0)
  824. init_bias(m.bias)
  825. def forward(self, x, H, W):
  826. x = x + self.drop_path(self.attn(self.norm1(x), H, W))
  827. x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
  828. return x
  829. class DWConv(nn.Layer):
  830. def __init__(self, dim=768):
  831. super(DWConv, self).__init__()
  832. self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)
  833. def forward(self, x, H, W):
  834. B, N, C = x.shape
  835. x = x.transpose([0, 2, 1]).reshape([B, C, H, W])
  836. x = self.dwconv(x)
  837. x = x.flatten(2).transpose([0, 2, 1])
  838. return x
  839. # Transformer Decoder
  840. class MLP(nn.Layer):
  841. """
  842. Linear Embedding
  843. """
  844. def __init__(self, input_dim=2048, embed_dim=768):
  845. super().__init__()
  846. self.proj = nn.Linear(input_dim, embed_dim)
  847. def forward(self, x):
  848. x = x.flatten(2).transpose([0, 2, 1])
  849. x = self.proj(x)
  850. return x
  851. # Difference Layer
  852. def conv_diff(in_channels, out_channels):
  853. return nn.Sequential(
  854. nn.Conv2D(
  855. in_channels, out_channels, kernel_size=3, padding=1),
  856. nn.ReLU(),
  857. nn.BatchNorm2D(out_channels),
  858. nn.Conv2D(
  859. out_channels, out_channels, kernel_size=3, padding=1),
  860. nn.ReLU())
  861. # Intermediate prediction Layer
  862. def make_prediction(in_channels, out_channels):
  863. return nn.Sequential(
  864. nn.Conv2D(
  865. in_channels, out_channels, kernel_size=3, padding=1),
  866. nn.ReLU(),
  867. nn.BatchNorm2D(out_channels),
  868. nn.Conv2D(
  869. out_channels, out_channels, kernel_size=3, padding=1))