Lin Manhui пре 2 година
родитељ
комит
129520aaae

+ 1 - 0
docs/intro/model_zoo.md

@@ -18,6 +18,7 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型
 | 变化检测 | \*FC-Siam-conc | 是 |
 | 变化检测 | \*FC-Siam-diff | 是 |
 | 变化检测 | \*FCCDN | 是 |
+| 变化检测 | \*P2V-CD | 是 |
 | 变化检测 | \*SNUNet | 是 |
 | 变化检测 | \*STANet | 是 |
 | 场景分类 | CondenseNet V2 | 是 |

+ 1 - 0
paddlers/rs_models/cd/__init__.py

@@ -24,4 +24,5 @@ from .fc_siam_conc import FCSiamConc
 from .fc_siam_diff import FCSiamDiff
 from .changeformer import ChangeFormer
 from .fccdn import FCCDN
+from .p2v import P2V
 from .losses import fccdn_ssl_loss

+ 320 - 0
paddlers/rs_models/cd/p2v.py

@@ -0,0 +1,320 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .layers import Conv1x1, Conv3x3, MaxPool2x2
+
+
+class SimpleResBlock(nn.Layer):
+    def __init__(self, in_ch, out_ch):
+        super().__init__()
+        self.conv1 = Conv3x3(in_ch, out_ch, norm=True, act=True)
+        self.conv2 = Conv3x3(out_ch, out_ch, norm=True)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        return F.relu(x + self.conv2(x))
+
+
+class ResBlock(nn.Layer):
+    def __init__(self, in_ch, out_ch):
+        super().__init__()
+        self.conv1 = Conv3x3(in_ch, out_ch, norm=True, act=True)
+        self.conv2 = Conv3x3(out_ch, out_ch, norm=True, act=True)
+        self.conv3 = Conv3x3(out_ch, out_ch, norm=True)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        return F.relu(x + self.conv3(self.conv2(x)))
+
+
+class DecBlock(nn.Layer):
+    def __init__(self, in_ch1, in_ch2, out_ch):
+        super().__init__()
+        self.conv_fuse = SimpleResBlock(in_ch1 + in_ch2, out_ch)
+
+    def forward(self, x1, x2):
+        x2 = F.interpolate(x2, size=x1.shape[2:])
+        x = paddle.concat([x1, x2], axis=1)
+        return self.conv_fuse(x)
+
+
+class BasicConv3D(nn.Layer):
+    def __init__(self,
+                 in_ch,
+                 out_ch,
+                 kernel_size,
+                 bias='auto',
+                 bn=False,
+                 act=False,
+                 **kwargs):
+        super().__init__()
+        seq = []
+        if kernel_size >= 2:
+            seq.append(nn.Pad3D(kernel_size // 2, mode='constant'))
+        seq.append(
+            nn.Conv3D(
+                in_ch,
+                out_ch,
+                kernel_size,
+                padding=0,
+                bias_attr=(False if bn else None) if bias == 'auto' else bias,
+                **kwargs))
+        if bn:
+            seq.append(nn.BatchNorm3D(out_ch))
+        if act:
+            seq.append(nn.ReLU())
+        self.seq = nn.Sequential(*seq)
+
+    def forward(self, x):
+        return self.seq(x)
+
+
+class Conv3x3x3(BasicConv3D):
+    def __init__(self,
+                 in_ch,
+                 out_ch,
+                 bias='auto',
+                 bn=False,
+                 act=False,
+                 **kwargs):
+        super().__init__(in_ch, out_ch, 3, bias=bias, bn=bn, act=act, **kwargs)
+
+
+class ResBlock3D(nn.Layer):
+    def __init__(self, in_ch, out_ch, itm_ch, stride=1, ds=None):
+        super().__init__()
+        self.conv1 = BasicConv3D(
+            in_ch, itm_ch, 1, bn=True, act=True, stride=stride)
+        self.conv2 = Conv3x3x3(itm_ch, itm_ch, bn=True, act=True)
+        self.conv3 = BasicConv3D(itm_ch, out_ch, 1, bn=True, act=False)
+        self.ds = ds
+
+    def forward(self, x):
+        res = x
+        y = self.conv1(x)
+        y = self.conv2(y)
+        y = self.conv3(y)
+        if self.ds is not None:
+            res = self.ds(res)
+        y = F.relu(y + res)
+        return y
+
+
+class PairEncoder(nn.Layer):
+    def __init__(self, in_ch, enc_chs=(16, 32, 64), add_chs=(0, 0)):
+        super().__init__()
+
+        self.n_layers = 3
+
+        self.conv1 = SimpleResBlock(2 * in_ch, enc_chs[0])
+        self.pool1 = MaxPool2x2()
+
+        self.conv2 = SimpleResBlock(enc_chs[0] + add_chs[0], enc_chs[1])
+        self.pool2 = MaxPool2x2()
+
+        self.conv3 = ResBlock(enc_chs[1] + add_chs[1], enc_chs[2])
+        self.pool3 = MaxPool2x2()
+
+    def forward(self, x1, x2, add_feats=None):
+        x = paddle.concat([x1, x2], axis=1)
+        feats = [x]
+
+        for i in range(self.n_layers):
+            conv = getattr(self, f'conv{i+1}')
+            if i > 0 and add_feats is not None:
+                add_feat = F.interpolate(add_feats[i - 1], size=x.shape[2:])
+                x = paddle.concat([x, add_feat], axis=1)
+            x = conv(x)
+            pool = getattr(self, f'pool{i+1}')
+            x = pool(x)
+            feats.append(x)
+
+        return feats
+
+
+class VideoEncoder(nn.Layer):
+    def __init__(self, in_ch, enc_chs=(64, 128)):
+        super().__init__()
+        if in_ch != 3:
+            raise NotImplementedError
+
+        self.n_layers = 2
+        self.expansion = 4
+        self.tem_scales = (1.0, 0.5)
+
+        self.stem = nn.Sequential(
+            nn.Conv3D(
+                3,
+                enc_chs[0],
+                kernel_size=(3, 9, 9),
+                stride=(1, 4, 4),
+                padding=(1, 4, 4),
+                bias_attr=False),
+            nn.BatchNorm3D(enc_chs[0]),
+            nn.ReLU())
+        exps = self.expansion
+        self.layer1 = nn.Sequential(
+            ResBlock3D(
+                enc_chs[0],
+                enc_chs[0] * exps,
+                enc_chs[0],
+                ds=BasicConv3D(
+                    enc_chs[0], enc_chs[0] * exps, 1, bn=True)),
+            ResBlock3D(enc_chs[0] * exps, enc_chs[0] * exps, enc_chs[0]))
+        self.layer2 = nn.Sequential(
+            ResBlock3D(
+                enc_chs[0] * exps,
+                enc_chs[1] * exps,
+                enc_chs[1],
+                stride=(2, 2, 2),
+                ds=BasicConv3D(
+                    enc_chs[0] * exps,
+                    enc_chs[1] * exps,
+                    1,
+                    stride=(2, 2, 2),
+                    bn=True)),
+            ResBlock3D(enc_chs[1] * exps, enc_chs[1] * exps, enc_chs[1]))
+
+    def forward(self, x):
+        feats = [x]
+
+        x = self.stem(x)
+        for i in range(self.n_layers):
+            layer = getattr(self, f'layer{i+1}')
+            x = layer(x)
+            feats.append(x)
+
+        return feats
+
+
+class SimpleDecoder(nn.Layer):
+    def __init__(self, itm_ch, enc_chs, dec_chs, num_classes=1):
+        super().__init__()
+
+        enc_chs = enc_chs[::-1]
+        self.conv_bottom = Conv3x3(itm_ch, itm_ch, norm=True, act=True)
+        self.blocks = nn.LayerList([
+            DecBlock(in_ch1, in_ch2, out_ch)
+            for in_ch1, in_ch2, out_ch in zip(enc_chs, (itm_ch, ) +
+                                              dec_chs[:-1], dec_chs)
+        ])
+        self.conv_out = Conv1x1(dec_chs[-1], num_classes)
+
+    def forward(self, x, feats):
+        feats = feats[::-1]
+
+        x = self.conv_bottom(x)
+
+        for feat, blk in zip(feats, self.blocks):
+            x = blk(feat, x)
+
+        y = self.conv_out(x)
+
+        return y
+
+
+class P2V(nn.Layer):
+    """
+    The P2V-CD implementation based on PaddlePaddle.
+
+    The original article refers to
+        M. Lin, et al. "Transition Is a Process: Pair-to-Video Change Detection Networks 
+        for Very High Resolution Remote Sensing Images"
+        (https://ieeexplore.ieee.org/document/9975266).
+
+    Args:
+        in_channels (int): Number of bands of the input images.
+        num_classes (int): Number of target classes.
+        video_len (int, optional): Number of frames of the constructed pseudo video. 
+            Default: 8.
+        pair_encoder_channels (tuple[int], optional): Output channels of each block in the 
+            spatial (pair) encoder. Default: (32, 64, 128).
+        video_encoder_channels (tuple[int], optional): Output channels of each block in the
+            temporal (video) encoder. Default: (64, 128).
+        decoder_channels (tuple[int], optional): Output channels of each block in the 
+            decoder. Default: (256, 128, 64, 32).
+    """
+
+    def __init__(self,
+                 in_channels,
+                 num_classes,
+                 video_len=8,
+                 pair_encoder_channels=(32, 64, 128),
+                 video_encoder_channels=(64, 128),
+                 decoder_channels=(256, 128, 64, 32)):
+        super().__init__()
+        if video_len < 2:
+            raise ValueError
+        self.video_len = video_len
+        self.encoder_v = VideoEncoder(
+            in_channels, enc_chs=video_encoder_channels)
+        video_encoder_channels = tuple(ch * self.encoder_v.expansion
+                                       for ch in video_encoder_channels)
+        self.encoder_p = PairEncoder(
+            in_channels,
+            enc_chs=pair_encoder_channels,
+            add_chs=video_encoder_channels)
+        self.conv_out_v = Conv1x1(video_encoder_channels[-1], num_classes)
+        self.convs_video = nn.LayerList([
+            Conv1x1(
+                2 * ch, ch, norm=True, act=True)
+            for ch in video_encoder_channels
+        ])
+        self.decoder = SimpleDecoder(
+            pair_encoder_channels[-1],
+            (2 * in_channels, ) + pair_encoder_channels, decoder_channels,
+            num_classes)
+
+    def forward(self, t1, t2):
+        frames = self.pair_to_video(t1, t2)
+        feats_v = self.encoder_v(frames.transpose((0, 2, 1, 3, 4)))
+        feats_v.pop(0)
+
+        for i, feat in enumerate(feats_v):
+            feats_v[i] = self.convs_video[i](self.tem_aggr(feat))
+
+        feats_p = self.encoder_p(t1, t2, feats_v)
+
+        pred = self.decoder(feats_p[-1], feats_p)
+
+        if self.training:
+            pred_v = self.conv_out_v(feats_v[-1])
+            pred_v = F.interpolate(pred_v, size=pred.shape[2:])
+            return [pred, pred_v]
+        else:
+            return [pred]
+
+    def pair_to_video(self, im1, im2, rate_map=None):
+        def _interpolate(im1, im2, rate_map, len):
+            delta = 1.0 / (len - 1)
+            delta_map = rate_map * delta
+            steps = paddle.arange(
+                end=len, dtype='float32').reshape((1, -1, 1, 1, 1))
+            interped = im1.unsqueeze(1) + (
+                (im2 - im1) * delta_map).unsqueeze(1) * steps
+            return interped
+
+        if rate_map is None:
+            rate_map = paddle.ones_like(im1[:, 0:1])
+        frames = _interpolate(im1, im2, rate_map, self.video_len)
+        return frames
+
+    def tem_aggr(self, f):
+        return paddle.concat(
+            [paddle.mean(
+                f, axis=2), paddle.max(f, axis=2)], axis=1)

+ 34 - 5
paddlers/tasks/change_detector.py

@@ -39,7 +39,7 @@ from .utils.slider_predict import slider_predict
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
-    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer", "FCCDN"
+    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer", "FCCDN", "P2V"
 ]
 
 
@@ -950,7 +950,7 @@ class DSIFN(BaseChangeDetector):
             }
         else:
             raise ValueError(
-                f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
+                f"Currently `use_mixed_loss` must be set to False for {self.__class__}."
             )
 
 
@@ -986,7 +986,7 @@ class DSAMNet(BaseChangeDetector):
             }
         else:
             raise ValueError(
-                f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
+                f"Currently `use_mixed_loss` must be set to False for {self.__class__}."
             )
 
 
@@ -1022,7 +1022,7 @@ class ChangeStar(BaseChangeDetector):
             }
         else:
             raise ValueError(
-                f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
+                f"Currently `use_mixed_loss` must be set to False for {self.__class__}."
             )
 
 
@@ -1072,5 +1072,34 @@ class FCCDN(BaseChangeDetector):
             }
         else:
             raise ValueError(
-                f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
+                f"Currently `use_mixed_loss` must be set to False for {self.__class__}."
+            )
+
+
+class P2V(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 in_channels=3,
+                 video_len=8,
+                 **params):
+        params.update({'in_channels': in_channels, 'video_len': video_len})
+        super(P2V, self).__init__(
+            model_name='P2V',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)
+
+    def default_loss(self):
+        if self.use_mixed_loss is False:
+            return {
+                'types':
+                [seg_losses.CrossEntropyLoss(), seg_losses.CrossEntropyLoss()],
+                'coef': [1.0, 0.4]
+            }
+        else:
+            raise ValueError(
+                f"Currently `use_mixed_loss` must be set to False for {self.__class__}."
             )

+ 8 - 0
test_tipc/configs/cd/p2v/p2v_airchange.yaml

@@ -0,0 +1,8 @@
+# Configurations of P2V-CD with AirChange dataset
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/p2v/
+
+model: !Node
+    type: P2V

+ 8 - 0
test_tipc/configs/cd/p2v/p2v_levircd.yaml

@@ -0,0 +1,8 @@
+# Configurations of P2V-CD with LEVIR-CD dataset
+
+_base_: ../_base_/levircd.yaml
+
+save_dir: ./test_tipc/output/cd/p2v/
+
+model: !Node
+    type: P2V

+ 53 - 0
test_tipc/configs/cd/p2v/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:cd:p2v
+python:python
+gpu_list:0|0,1
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8
+--model_path:null
+--config:lite_train_lite_infer=./test_tipc/configs/cd/p2v/p2v_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/p2v/p2v_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/p2v/p2v_levircd.yaml
+train_model_name:best_model
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train cd
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[-1,3,256,256]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--config:null
+--save_log_path:null
+--benchmark:True
+--model_name:p2v
+null:null

+ 22 - 1
tests/rs_models/test_cd_models.py

@@ -21,7 +21,7 @@ __all__ = [
     'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
     'TestDSIFNModel', 'TestFCEarlyFusionModel', 'TestFCSiamConcModel',
     'TestFCSiamDiffModel', 'TestSNUNetModel', 'TestSTANetModel',
-    'TestChangeFormerModel', 'TestFCCDNModel'
+    'TestChangeFormerModel', 'TestFCCDNModel', 'TestP2VModel'
 ]
 
 
@@ -252,3 +252,24 @@ class TestFCCDNModel(TestCDModel):
             tar_c2, tar_c2, [self.get_zeros_array(8), tar_c2[1]],
             [self.get_zeros_array(2)]
         ]
+
+
+class TestP2VModel(TestCDModel):
+    MODEL_CLASS = paddlers.rs_models.cd.P2V
+
+    def set_specs(self):
+        base_spec = dict(in_channels=3, num_classes=2)
+        self.specs = [
+            base_spec,
+            dict(in_channels=3, num_classes=8),
+            dict(**base_spec, video_len=4),
+            dict(**base_spec, _phase='eval', _stop_grad=True)
+        ]   # yapf: disable
+
+    def set_targets(self):
+        # Avoid allocation of large memories
+        tar_c2 = [self.get_zeros_array(2)] * 2
+        self.targets = [
+            tar_c2, [self.get_zeros_array(8)] * 2, tar_c2,
+            [self.get_zeros_array(2)]
+        ]

+ 1 - 0
tutorials/train/README.md

@@ -13,6 +13,7 @@
 |change_detection/fc_siam_conc.py | 变化检测 | FC-Siam-conc |
 |change_detection/fc_siam_diff.py | 变化检测 | FC-Siam-diff |
 |change_detection/fccdn.py | 变化检测 | FCCDN |
+|change_detection/p2v.py | 变化检测 | P2V-CD |
 |change_detection/snunet.py | 变化检测 | SNUNet |
 |change_detection/stanet.py | 变化检测 | STANet |
 |classification/condensenetv2.py | 场景分类 | CondenseNet V2 |

+ 94 - 0
tutorials/train/change_detection/p2v.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+
+# 变化检测模型P2V-CD训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/airchange/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/airchange/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/airchange/eval.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/p2v/'
+
+# 下载和解压AirChange数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 随机裁剪
+    T.RandomCrop(
+        # 裁剪区域将被缩放到256x256
+        crop_size=256,
+        # 裁剪区域的横纵比在0.5-2之间变动
+        aspect_ratio=[0.5, 2.0],
+        # 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5
+        scaling=[0.2, 1.0]),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeChangeDetector('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    label_list=None,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+eval_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    label_list=None,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+# 使用默认参数构建P2V模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
+model = pdrs.tasks.cd.P2V()
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=3,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)