changestar.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from paddlers.datasets.cd_dataset import MaskType
  17. from paddlers.rs_models.seg import FarSeg
  18. from .layers import Conv3x3, Identity
  19. class _ChangeStarBase(nn.Layer):
  20. USE_MULTITASK_DECODER = True
  21. OUT_TYPES = (MaskType.CD, MaskType.CD, MaskType.SEG_T1, MaskType.SEG_T2)
  22. def __init__(self, seg_model, num_classes, mid_channels, inner_channels,
  23. num_convs, scale_factor):
  24. super(_ChangeStarBase, self).__init__()
  25. self.extract = seg_model
  26. self.detect = ChangeMixin(
  27. in_ch=mid_channels * 2,
  28. out_ch=num_classes,
  29. mid_ch=inner_channels,
  30. num_convs=num_convs,
  31. scale_factor=scale_factor)
  32. self.segment = nn.Sequential(
  33. Conv3x3(mid_channels, 2),
  34. nn.UpsamplingBilinear2D(scale_factor=scale_factor))
  35. self.init_weight()
  36. def forward(self, t1, t2):
  37. x1 = self.extract(t1)[0]
  38. x2 = self.extract(t2)[0]
  39. logit12, logit21 = self.detect(x1, x2)
  40. if not self.training:
  41. logit_list = [logit12]
  42. else:
  43. logit1 = self.segment(x1)
  44. logit2 = self.segment(x2)
  45. logit_list = [logit12, logit21, logit1, logit2]
  46. return logit_list
  47. def init_weight(self):
  48. pass
  49. class ChangeMixin(nn.Layer):
  50. def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor):
  51. super(ChangeMixin, self).__init__()
  52. convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)]
  53. convs += [
  54. Conv3x3(
  55. mid_ch, mid_ch, norm=True, act=True)
  56. for _ in range(num_convs - 1)
  57. ]
  58. self.detect = nn.Sequential(
  59. *convs,
  60. Conv3x3(mid_ch, out_ch),
  61. nn.UpsamplingBilinear2D(scale_factor=scale_factor))
  62. def forward(self, x1, x2):
  63. pred12 = self.detect(paddle.concat([x1, x2], axis=1))
  64. pred21 = self.detect(paddle.concat([x2, x1], axis=1))
  65. return pred12, pred21
  66. class ChangeStar_FarSeg(_ChangeStarBase):
  67. """
  68. The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle.
  69. The original article refers to
  70. Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object
  71. Change Detection in Remote Sensing Imagery"
  72. (https://arxiv.org/abs/2108.07002).
  73. Note that this implementation differs from the original code in two aspects:
  74. 1. The encoder of the FarSeg model is ResNet50.
  75. 2. We use conv-bn-relu instead of conv-relu-bn.
  76. Args:
  77. num_classes (int): Number of target classes.
  78. mid_channels (int, optional): Number of channels required by the
  79. ChangeMixin module. Default: 256.
  80. inner_channels (int, optional): Number of filters used in the
  81. convolutional layers in the ChangeMixin module. Default: 16.
  82. num_convs (int, optional): Number of convolutional layers used in the
  83. ChangeMixin module. Default: 4.
  84. scale_factor (float, optional): Scaling factor of the output upsampling
  85. layer. Default: 4.0.
  86. """
  87. def __init__(
  88. self,
  89. num_classes,
  90. mid_channels=256,
  91. inner_channels=16,
  92. num_convs=4,
  93. scale_factor=4.0, ):
  94. # TODO: Configurable FarSeg model
  95. class _FarSegWrapper(nn.Layer):
  96. def __init__(self, seg_model):
  97. super(_FarSegWrapper, self).__init__()
  98. self._seg_model = seg_model
  99. self._seg_model.cls_head = Identity()
  100. def forward(self, x):
  101. return self._seg_model(x)
  102. seg_model = FarSeg(
  103. in_channels=3,
  104. num_classes=num_classes,
  105. decoder_out_channels=mid_channels)
  106. super(ChangeStar_FarSeg, self).__init__(
  107. seg_model=_FarSegWrapper(seg_model),
  108. num_classes=num_classes,
  109. mid_channels=mid_channels,
  110. inner_channels=inner_channels,
  111. num_convs=num_convs,
  112. scale_factor=scale_factor)
  113. # NOTE: Currently, ChangeStar = FarSeg + ChangeMixin + SegHead
  114. ChangeStar = ChangeStar_FarSeg