snunet.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention
  18. from .param_init import KaimingInitMixin
  19. class SNUNet(nn.Layer, KaimingInitMixin):
  20. """
  21. The SNUNet implementation based on PaddlePaddle.
  22. The original article refers to
  23. S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images"
  24. (https://ieeexplore.ieee.org/document/9355573).
  25. Note that bilinear interpolation is adopted as the upsampling method, which is different from the paper.
  26. Args:
  27. in_channels (int): The number of bands of the input images.
  28. num_classes (int): The number of target classes.
  29. width (int, optional): The output channels of the first convolutional layer. Default: 32.
  30. """
  31. def __init__(self, in_channels, num_classes, width=32):
  32. super().__init__()
  33. filters = (width, width*2, width*4, width*8, width*16)
  34. self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0])
  35. self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1])
  36. self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2])
  37. self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3])
  38. self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4])
  39. self.down1 = MaxPool2x2()
  40. self.down2 = MaxPool2x2()
  41. self.down3 = MaxPool2x2()
  42. self.down4 = MaxPool2x2()
  43. self.up1_0 = Up(filters[1])
  44. self.up2_0 = Up(filters[2])
  45. self.up3_0 = Up(filters[3])
  46. self.up4_0 = Up(filters[4])
  47. self.conv0_1 = ConvBlockNested(filters[0]*2+filters[1], filters[0], filters[0])
  48. self.conv1_1 = ConvBlockNested(filters[1]*2+filters[2], filters[1], filters[1])
  49. self.conv2_1 = ConvBlockNested(filters[2]*2+filters[3], filters[2], filters[2])
  50. self.conv3_1 = ConvBlockNested(filters[3]*2+filters[4], filters[3], filters[3])
  51. self.up1_1 = Up(filters[1])
  52. self.up2_1 = Up(filters[2])
  53. self.up3_1 = Up(filters[3])
  54. self.conv0_2 = ConvBlockNested(filters[0]*3+filters[1], filters[0], filters[0])
  55. self.conv1_2 = ConvBlockNested(filters[1]*3+filters[2], filters[1], filters[1])
  56. self.conv2_2 = ConvBlockNested(filters[2]*3+filters[3], filters[2], filters[2])
  57. self.up1_2 = Up(filters[1])
  58. self.up2_2 = Up(filters[2])
  59. self.conv0_3 = ConvBlockNested(filters[0]*4+filters[1], filters[0], filters[0])
  60. self.conv1_3 = ConvBlockNested(filters[1]*4+filters[2], filters[1], filters[1])
  61. self.up1_3 = Up(filters[1])
  62. self.conv0_4 = ConvBlockNested(filters[0]*5+filters[1], filters[0], filters[0])
  63. self.ca_intra = ChannelAttention(filters[0], ratio=4)
  64. self.ca_inter = ChannelAttention(filters[0]*4, ratio=16)
  65. self.conv_out = Conv1x1(filters[0]*4, num_classes)
  66. self.init_weight()
  67. def forward(self, t1, t2):
  68. x0_0_t1 = self.conv0_0(t1)
  69. x1_0_t1 = self.conv1_0(self.down1(x0_0_t1))
  70. x2_0_t1 = self.conv2_0(self.down2(x1_0_t1))
  71. x3_0_t1 = self.conv3_0(self.down3(x2_0_t1))
  72. x0_0_t2 = self.conv0_0(t2)
  73. x1_0_t2 = self.conv1_0(self.down1(x0_0_t2))
  74. x2_0_t2 = self.conv2_0(self.down2(x1_0_t2))
  75. x3_0_t2 = self.conv3_0(self.down3(x2_0_t2))
  76. x4_0_t2 = self.conv4_0(self.down4(x3_0_t2))
  77. x0_1 = self.conv0_1(paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1))
  78. x1_1 = self.conv1_1(paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1))
  79. x0_2 = self.conv0_2(paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1))
  80. x2_1 = self.conv2_1(paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1))
  81. x1_2 = self.conv1_2(paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1))
  82. x0_3 = self.conv0_3(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1))
  83. x3_1 = self.conv3_1(paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1))
  84. x2_2 = self.conv2_2(paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1))
  85. x1_3 = self.conv1_3(paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1))
  86. x0_4 = self.conv0_4(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))
  87. out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1)
  88. intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0)
  89. m_intra = self.ca_intra(intra)
  90. out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1,4,1,1)))
  91. pred = self.conv_out(out)
  92. return pred,
  93. class ConvBlockNested(nn.Layer):
  94. def __init__(self, in_ch, out_ch, mid_ch):
  95. super().__init__()
  96. self.act = nn.ReLU()
  97. self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1)
  98. self.bn1 = make_norm(mid_ch)
  99. self.conv2 = nn.Conv2D(mid_ch, out_ch, kernel_size=3, padding=1)
  100. self.bn2 = make_norm(out_ch)
  101. def forward(self, x):
  102. x = self.conv1(x)
  103. identity = x
  104. x = self.bn1(x)
  105. x = self.act(x)
  106. x = self.conv2(x)
  107. x = self.bn2(x)
  108. output = self.act(x + identity)
  109. return output
  110. class Up(nn.Layer):
  111. def __init__(self, in_ch, use_conv=False):
  112. super().__init__()
  113. if use_conv:
  114. self.up = nn.Conv2DTranspose(in_ch, in_ch, 2, stride=2)
  115. else:
  116. self.up = nn.Upsample(scale_factor=2,
  117. mode='bilinear',
  118. align_corners=True)
  119. def forward(self, x):
  120. x = self.up(x)
  121. return x