123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import paddle
- import paddle.nn as nn
- from paddlers.datasets.cd_dataset import MaskType
- from paddlers.rs_models.seg import FarSeg
- from .layers import Conv3x3, Identity
- class _ChangeStarBase(nn.Layer):
- USE_MULTITASK_DECODER = True
- OUT_TYPES = (MaskType.CD, MaskType.CD, MaskType.SEG_T1, MaskType.SEG_T2)
- def __init__(self, seg_model, num_classes, mid_channels, inner_channels,
- num_convs, scale_factor):
- super(_ChangeStarBase, self).__init__()
- self.extract = seg_model
- self.detect = ChangeMixin(
- in_ch=mid_channels * 2,
- out_ch=num_classes,
- mid_ch=inner_channels,
- num_convs=num_convs,
- scale_factor=scale_factor)
- self.segment = nn.Sequential(
- Conv3x3(mid_channels, 2),
- nn.UpsamplingBilinear2D(scale_factor=scale_factor))
- self.init_weight()
- def forward(self, t1, t2):
- x1 = self.extract(t1)[0]
- x2 = self.extract(t2)[0]
- logit12, logit21 = self.detect(x1, x2)
- if not self.training:
- logit_list = [logit12]
- else:
- logit1 = self.segment(x1)
- logit2 = self.segment(x2)
- logit_list = [logit12, logit21, logit1, logit2]
- return logit_list
- def init_weight(self):
- pass
- class ChangeMixin(nn.Layer):
- def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor):
- super(ChangeMixin, self).__init__()
- convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)]
- convs += [
- Conv3x3(
- mid_ch, mid_ch, norm=True, act=True)
- for _ in range(num_convs - 1)
- ]
- self.detect = nn.Sequential(
- *convs,
- Conv3x3(mid_ch, out_ch),
- nn.UpsamplingBilinear2D(scale_factor=scale_factor))
- def forward(self, x1, x2):
- pred12 = self.detect(paddle.concat([x1, x2], axis=1))
- pred21 = self.detect(paddle.concat([x2, x1], axis=1))
- return pred12, pred21
- class ChangeStar_FarSeg(_ChangeStarBase):
- """
- The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle.
- The original article refers to
- Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object
- Change Detection in Remote Sensing Imagery"
- (https://arxiv.org/abs/2108.07002).
-
- Note that this implementation differs from the original code in two aspects:
- 1. The encoder of the FarSeg model is ResNet50.
- 2. We use conv-bn-relu instead of conv-relu-bn.
- Args:
- num_classes (int): Number of target classes.
- mid_channels (int, optional): Number of channels required by the
- ChangeMixin module. Default: 256.
- inner_channels (int, optional): Number of filters used in the
- convolutional layers in the ChangeMixin module. Default: 16.
- num_convs (int, optional): Number of convolutional layers used in the
- ChangeMixin module. Default: 4.
- scale_factor (float, optional): Scaling factor of the output upsampling
- layer. Default: 4.0.
- """
- def __init__(
- self,
- num_classes,
- mid_channels=256,
- inner_channels=16,
- num_convs=4,
- scale_factor=4.0, ):
-
- class _FarSegWrapper(nn.Layer):
- def __init__(self, seg_model):
- super(_FarSegWrapper, self).__init__()
- self._seg_model = seg_model
- self._seg_model.cls_head = Identity()
- def forward(self, x):
- return self._seg_model(x)
- seg_model = FarSeg(
- in_channels=3,
- num_classes=num_classes,
- decoder_out_channels=mid_channels)
- super(ChangeStar_FarSeg, self).__init__(
- seg_model=_FarSegWrapper(seg_model),
- num_classes=num_classes,
- mid_channels=mid_channels,
- inner_channels=inner_channels,
- num_convs=num_convs,
- scale_factor=scale_factor)
- ChangeStar = ChangeStar_FarSeg
|