farseg.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # This code is based on https://github.com/Z-Zheng/FarSeg
  15. # The copyright of Z-Zheng/FarSeg is as follows:
  16. # Apache License (see https://github.com/Z-Zheng/FarSeg/blob/master/LICENSE for details).
  17. import math
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle.vision.models import resnet
  21. from paddlers.models.ppdet.modeling import initializer as init
  22. class FPNConvBlock(nn.Conv2D):
  23. def __init__(self,
  24. in_channels,
  25. out_channels,
  26. kernel_size,
  27. stride=1,
  28. dilation=1):
  29. super(FPNConvBlock, self).__init__(
  30. in_channels,
  31. out_channels,
  32. kernel_size=kernel_size,
  33. stride=stride,
  34. padding=dilation * (kernel_size - 1) // 2,
  35. dilation=dilation)
  36. init.kaiming_uniform_(self.weight, a=1)
  37. init.constant_(self.bias, value=0)
  38. class DefaultConvBlock(nn.Conv2D):
  39. def __init__(self,
  40. in_channels,
  41. out_channels,
  42. kernel_size,
  43. stride=1,
  44. padding=0,
  45. bias_attr=None):
  46. super(DefaultConvBlock, self).__init__(
  47. in_channels,
  48. out_channels,
  49. kernel_size,
  50. stride=stride,
  51. padding=padding,
  52. bias_attr=bias_attr)
  53. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  54. if self.bias is not None:
  55. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  56. bound = 1 / math.sqrt(fan_in)
  57. init.uniform_(self.bias, -bound, bound)
  58. class ResNetEncoder(nn.Layer):
  59. def __init__(self, backbone='resnet50', in_channels=3, pretrained=True):
  60. super(ResNetEncoder, self).__init__()
  61. self.resnet = getattr(resnet, backbone)(pretrained=pretrained)
  62. if in_channels != 3:
  63. self.resnet.conv1 = nn.Conv2D(
  64. in_channels, 64, 7, stride=2, padding=3, bias_attr=False)
  65. for layer in self.resnet.sublayers():
  66. if isinstance(layer, (nn.BatchNorm2D, nn.SyncBatchNorm)):
  67. layer._momentum = 0.1
  68. def forward(self, x):
  69. x = self.resnet.conv1(x)
  70. x = self.resnet.bn1(x)
  71. x = self.resnet.relu(x)
  72. x = self.resnet.maxpool(x)
  73. c2 = self.resnet.layer1(x)
  74. c3 = self.resnet.layer2(c2)
  75. c4 = self.resnet.layer3(c3)
  76. c5 = self.resnet.layer4(c4)
  77. return [c2, c3, c4, c5]
  78. class FPN(nn.Layer):
  79. def __init__(self, in_channels_list, out_channels, conv_block=FPNConvBlock):
  80. super(FPN, self).__init__()
  81. inner_blocks = []
  82. layer_blocks = []
  83. for idx, in_channels in enumerate(in_channels_list, 1):
  84. if in_channels == 0:
  85. continue
  86. inner_blocks.append(conv_block(in_channels, out_channels, 1))
  87. layer_blocks.append(conv_block(out_channels, out_channels, 3, 1))
  88. self.inner_blocks = nn.LayerList(inner_blocks)
  89. self.layer_blocks = nn.LayerList(layer_blocks)
  90. def forward(self, x):
  91. last_inner = self.inner_blocks[-1](x[-1])
  92. results = [self.layer_blocks[-1](last_inner)]
  93. for i, feature in enumerate(x[-2::-1]):
  94. inner_block = self.inner_blocks[len(self.inner_blocks) - 2 - i]
  95. layer_block = self.layer_blocks[len(self.layer_blocks) - 2 - i]
  96. inner_top_down = F.interpolate(
  97. last_inner, scale_factor=2, mode="nearest")
  98. inner_lateral = inner_block(feature)
  99. last_inner = inner_lateral + inner_top_down
  100. results.insert(0, layer_block(last_inner))
  101. return tuple(results)
  102. class FSRelation(nn.Layer):
  103. def __init__(self,
  104. in_channels,
  105. channels_list,
  106. out_channels,
  107. scale_aware_proj=True,
  108. conv_block=DefaultConvBlock):
  109. super(FSRelation, self).__init__()
  110. self.scale_aware_proj = scale_aware_proj
  111. if self.scale_aware_proj:
  112. self.scene_encoder = nn.LayerList([
  113. nn.Sequential(
  114. conv_block(in_channels, out_channels, 1),
  115. nn.ReLU(), conv_block(out_channels, out_channels, 1))
  116. for _ in range(len(channels_list))
  117. ])
  118. else:
  119. self.scene_encoder = nn.Sequential(
  120. conv_block(in_channels, out_channels, 1),
  121. nn.ReLU(), conv_block(out_channels, out_channels, 1))
  122. self.content_encoders = nn.LayerList()
  123. self.feature_reencoders = nn.LayerList()
  124. for channel in channels_list:
  125. self.content_encoders.append(
  126. nn.Sequential(
  127. conv_block(
  128. channel, out_channels, 1, bias_attr=True),
  129. nn.BatchNorm2D(
  130. out_channels, momentum=0.1),
  131. nn.ReLU()))
  132. self.feature_reencoders.append(
  133. nn.Sequential(
  134. conv_block(
  135. channel, out_channels, 1, bias_attr=True),
  136. nn.BatchNorm2D(
  137. out_channels, momentum=0.1),
  138. nn.ReLU()))
  139. self.normalizer = nn.Sigmoid()
  140. def forward(self, scene_feature, feature_list):
  141. content_feats = [
  142. c_en(p_feat)
  143. for c_en, p_feat in zip(self.content_encoders, feature_list)
  144. ]
  145. if self.scale_aware_proj:
  146. scene_feats = [op(scene_feature) for op in self.scene_encoder]
  147. relations = [
  148. self.normalizer((sf * cf).sum(axis=1, keepdim=True))
  149. for sf, cf in zip(scene_feats, content_feats)
  150. ]
  151. else:
  152. scene_feat = self.scene_encoder(scene_feature)
  153. relations = [
  154. self.normalizer((scene_feat * cf).sum(axis=1, keepdim=True))
  155. for cf in content_feats
  156. ]
  157. p_feats = [
  158. op(p_feat)
  159. for op, p_feat in zip(self.feature_reencoders, feature_list)
  160. ]
  161. refined_feats = [r * p for r, p in zip(relations, p_feats)]
  162. return refined_feats
  163. class AsymmetricDecoder(nn.Layer):
  164. def __init__(self,
  165. in_channels,
  166. out_channels,
  167. in_feature_output_strides=(4, 8, 16, 32),
  168. out_feature_output_stride=4,
  169. conv_block=DefaultConvBlock):
  170. super(AsymmetricDecoder, self).__init__()
  171. self.blocks = nn.LayerList()
  172. for in_feature_output_stride in in_feature_output_strides:
  173. num_upsample = int(math.log2(int(in_feature_output_stride))) - int(
  174. math.log2(int(out_feature_output_stride)))
  175. num_layers = num_upsample if num_upsample != 0 else 1
  176. self.blocks.append(
  177. nn.Sequential(*[
  178. nn.Sequential(
  179. conv_block(
  180. in_channels if idx == 0 else out_channels,
  181. out_channels,
  182. 3,
  183. 1,
  184. 1,
  185. bias_attr=False),
  186. nn.BatchNorm2D(
  187. out_channels, momentum=0.1),
  188. nn.ReLU(),
  189. nn.UpsamplingBilinear2D(scale_factor=2) if num_upsample
  190. != 0 else nn.Identity(), ) for idx in range(num_layers)
  191. ]))
  192. def forward(self, feature_list):
  193. inner_feature_list = []
  194. for idx, block in enumerate(self.blocks):
  195. decoder_feature = block(feature_list[idx])
  196. inner_feature_list.append(decoder_feature)
  197. out_feature = sum(inner_feature_list) / len(inner_feature_list)
  198. return out_feature
  199. class FarSeg(nn.Layer):
  200. """
  201. The FarSeg implementation based on PaddlePaddle.
  202. The original article refers to
  203. Zheng Z, Zhong Y, Wang J, et al. Foreground-aware relation network for geospatial object segmentation in
  204. high spatial resolution remote sensing imagery[C]//Proceedings of the IEEE/CVF conference on computer vision
  205. and pattern recognition. 2020: 4096-4105.
  206. Args:
  207. in_channels (int): Number of input channels.
  208. num_classes (int): Unique number of target classes.
  209. backbone (str, optional): Backbone network, one of models available in `paddle.vision.models.resnet`. Default: resnet50.
  210. backbone_pretrained (bool, optional): Whether the backbone network uses IMAGENET pretrained weights. Default: True.
  211. fpn_out_channels (int, optional): Number of channels output by the feature pyramid network. Default: 256.
  212. fsr_out_channels (int, optional): Number of channels output by the F-S relation module. Default: 256.
  213. scale_aware_proj (bool, optional): Whether to use scale awareness in F-S relation module. Default: True.
  214. decoder_out_channels (int, optional): Number of channels output by the decoder. Default: 128.
  215. """
  216. def __init__(self,
  217. in_channels,
  218. num_classes,
  219. backbone='resnet50',
  220. backbone_pretrained=True,
  221. fpn_out_channels=256,
  222. fsr_out_channels=256,
  223. scale_aware_proj=True,
  224. decoder_out_channels=128):
  225. super(FarSeg, self).__init__()
  226. backbone = backbone.lower()
  227. self.encoder = ResNetEncoder(
  228. backbone=backbone,
  229. in_channels=in_channels,
  230. pretrained=backbone_pretrained)
  231. fpn_max_in_channels = 2048
  232. if backbone in ['resnet18', 'resnet34']:
  233. fpn_max_in_channels = 512
  234. self.fpn = FPN(in_channels_list=[
  235. fpn_max_in_channels // (2**(3 - i)) for i in range(4)
  236. ],
  237. out_channels=fpn_out_channels)
  238. self.gap = nn.AdaptiveAvgPool2D(1)
  239. self.fsr = FSRelation(
  240. in_channels=fpn_max_in_channels,
  241. channels_list=[fpn_out_channels] * 4,
  242. out_channels=fsr_out_channels,
  243. scale_aware_proj=scale_aware_proj)
  244. self.decoder = AsymmetricDecoder(
  245. in_channels=fsr_out_channels, out_channels=decoder_out_channels)
  246. self.cls_head = nn.Sequential(
  247. DefaultConvBlock(decoder_out_channels, num_classes, 1),
  248. nn.UpsamplingBilinear2D(scale_factor=4))
  249. def forward(self, x):
  250. feature_list = self.encoder(x)
  251. fpn_feature_list = self.fpn(feature_list)
  252. scene_feature = self.gap(feature_list[-1])
  253. refined_feature_list = self.fsr(scene_feature, fpn_feature_list)
  254. feature = self.decoder(refined_feature_list)
  255. logit = self.cls_head(feature)
  256. return [logit]