瀏覽代碼

Update examples/rs_research

Bobholamovic 2 年之前
父節點
當前提交
5f353e6c51

+ 122 - 12
examples/rs_research/README.md

@@ -34,32 +34,136 @@ python ../../tools/prepare_dataset/prepare_svcd.py \
 
 ### 3.1 问题分析与思路拟定
 
-科学研究是为了解决实际问题的,本案例也不例外。本案例的研究动机如下:随着深度学习技术应用的不断深入,变化检测领域涌现了许多。与之相对应的是,模型的参数量也越来越大。
+随着深度学习技术应用的不断深入,近年来,变化检测领域涌现了许多基于全卷积神经网络(fully convolutional network, FCN)的遥感影像变化检测算法。与基于特征和基于影像块的方法相比,基于FCN的方法具有处理效率高、依赖超参数少等优势,但其缺点在于参数量往往较大,因而对训练样本的数量更为依赖。尽管中、大型变化检测数据集的数量与日俱增,训练样本日益丰富,但深度学习变化检测模型的参数量也越来越大。下图显示了从2018年到2021年一些已发表的文献中提出的基于FCN的变化检测模型的参数量与其在SVCD数据集上取得的F1分数(柱状图中bar的高度与模型参数量成正比):
 
-[近年来变化检测模型]()
+![params_versus_f1](params_versus_f1.png)
 
-诚然,。
+诚然,增大参数数量在大多数情况下等同于增加模型容量,而模型容量的增加意味着模型拟合能力的提升,从而有助于模型在实验数据集上取得更高的精度指标但是,“更大”一定意味着“更好”吗?答案显然是否定的。在实际应用中,“更大”的遥感影像变化检测模型常常遭遇如下问题:
 
-1. 存储开销。
-2. 过拟合。
+1. 巨大的参数量意味着巨大的存储开销。在许多实际场景中,硬件资源往往是有限的,过多的模型参数将给部署造成困难。
+2. 在数据有限的情况下,大模型更易遭受过拟合,其在实验数据集上看起来良好的结果也难以泛化到真实场景
 
-为了解决上述问题,本案例拟提出一种基于网络迭代优化思想的深度学习变化检测算法。本案例的基本思路是,构造一个轻量级的变化检测模型,并以其作为基础迭代单元。每次迭代开始时,由上一次迭代输出的概率图以及原始的输入影像对构造新的输入,实现coarse-to-fine优化。考虑到增加迭代单元的数量将使模型参数量成倍增加,在迭代过程中始终复用同一迭代单元的参数,充分挖掘变化检测网络的拟合能力,迫使其学习到更加有效的特征。这一做法类似[循环神经网络](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490)。根据此思路可以绘制框图如下:
+本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力?基于这个观点,本案例的基本思路是设计一种基于网络迭代优化思想的深度学习变化检测算法。首先,构造一个轻量级的变化检测模型,并以其作为基础迭代单元。每次迭代开始时,由上一次迭代输出的概率图以及原始的输入影像对构造新的输入,如此逐级实现coarse-to-fine优化。考虑到增加迭代单元的数量将使模型参数量成倍增加,在迭代过程中应始终复用同一迭代单元的参数以充分挖掘变化检测网络的拟合能力,迫使其学习到更加有效的特征。这一做法类似[循环神经网络](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490)。根据此思路可以绘制框图如下:
 
-[思路展示]()
+![draft](draft.png)
 
-### 3.2 确定baseline
+### 3.2 确定baseline模型
 
-科研工作往往需要“站在巨人的肩膀上”,在前人工作的基础上做“增量创新”。因此,对模型设计类工作而言,选用一个合适的baseline网络至关重要。考虑到本案例的出发点是解决,并且使用了。
+科研工作往往需要“站在巨人的肩膀上”,在前人工作的基础上做“增量创新”。因此,对模型设计类工作而言,选用一个合适的baseline模型至关重要。考虑到本案例的出发点是解决现有模型参数量过大、冗余特征过多的问题,并且在拟定的解决方案中使用到了循环结构,用作baseline的网络结构必须足够轻量和高效(因为最直接的思路是使用baseline作为基础迭代单元)。为此,本案例选用Bitemporal Image Transformer(BIT)作为baseline。BIT是一个轻量级的深度学习变化检测模型,其基本结构如图所示:
+
+![bit](bit.png)
+
+BIT的核心思想在于,
 
 ### 3.3 定义新模型
 
-[算法整体框图]()
+确定了基本思路和baseline模型之后,可以绘制如下的算法整体框图:
+
+![framework](framework.png)
+
+依据此框图,即可在。
+
+#### 3.3.1 自定义模型组网
+
+在`custom_model.py`中定义模型的宏观(macro)结构以及组成模型的各个微观(micro)模块。例如,当前`custom_model.py`中定义了迭代版本的BIT模型`IterativeBIT`:
+```python
+@attach
+class IterativeBIT(nn.Layer):
+    def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
+        super().__init__()
+
+        if num_iters <= 0:
+            raise ValueError(f"`num_iters` should have positive value, but got {num_iters}.")
+
+        self.num_iters = num_iters
+        self.gamma = gamma
+
+        if bit_kwargs is None:
+            bit_kwargs = dict()
+
+        if 'num_classes' in bit_kwargs:
+            raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
+        bit_kwargs['num_classes'] = num_classes
+
+        self.bit = BIT(**bit_kwargs)
+
+    def forward(self, t1, t2):
+        rate_map = self._init_rate_map(t1.shape)
+
+        for it in range(self.num_iters):
+            # Construct inputs
+            x1 = self._constr_iter_input(t1, rate_map)
+            x2 = self._constr_iter_input(t2, rate_map)
+            # Get logits
+            logits_list = self.bit(x1, x2)
+            # Construct rate map
+            prob_map = F.softmax(logits_list[0], axis=1)
+            rate_map = self._constr_rate_map(prob_map)
+
+        return logits_list
+    ...
+```
+
+在编写组网相关代码时请注意以下两点:
+
+1. 所有模型必须为`paddle.nn.Layer`的子类;
+2. 包含模型整体逻辑结构的最外层模块须用`@attach`装饰;
+3. 对于变化检测任务,`forward()`方法除`self`参数外还接受两个参数`t1`、`t2`,分别表示第一时相和第二时相影像。
+
+关于模型定义的更多细节请参考[API文档]()。
+
+#### 3.3.2 自定义训练器
+
+在`custom_trainer.py`中定义训练器。例如,当前`custom_trainer.py`中定义了与`IterativeBIT`模型对应的训练器:
+```python
+@attach
+class IterativeBIT(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 num_iters=1,
+                 gamma=0.1,
+                 bit_kwargs=None,
+                 **params):
+        params.update({
+            'num_iters': num_iters,
+            'gamma': gamma,
+            'bit_kwargs': bit_kwargs
+        })
+        super().__init__(
+            model_name='IterativeBIT',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)
+```
+
+在编写训练器定义相关代码时请注意以下两点:
+
+1. 对于变化检测任务,训练器必须为`paddlers.tasks.cd.BaseChangeDetector`的子类;
+2. 与模型一样,训练器也须用`@attach`装饰;
+3. 训练器和模型可以同名。
+
+关于训练器定义的更多细节请参考[API文档]()。
 
 ### 3.4 进行参数分析与消融实验
 
 #### 3.4.1 实验设置
 
-#### 3.4.2 实验结果
+#### 3.4.2 编写配置文件
+
+#### 3.4.3 实验结果
+
+### 3.5 \*Magic Behind
+
+本小节涉及技术细节,对于本案例来说属于进阶内容,您可以选择性了解。
+
+#### 3.5.1 延迟属性绑定
+
+PaddleRS提供了,只需要。`attach_tools.Attach`对象自动。
+
+#### 3.5.2 非侵入式轻量级配置系统
 
 ### 3.5 开展特征可视化实验
 
@@ -75,10 +179,16 @@ python ../../tools/prepare_dataset/prepare_svcd.py \
 
 #### 4.3.2 SVCD数据集上的对比结果
 
-精度、FLOPs、运行时间
+精度
 
 ## 5 总结与展望
 
+### 5.1 总结
+
+### 5.2 展望
+
+耗时,模型大小,FLOPs
+
 ## 参考文献
 
 > [1] Chen, Hao, and Zhenwei Shi. "A spatial-temporal attention-based method and a new dataset for remote sensing image change detection." *Remote Sensing* 12.10 (2020): 1662.  

+ 20 - 0
examples/rs_research/attach_tools.py

@@ -0,0 +1,20 @@
+class Attach(object):
+    def __init__(self, dst):
+        self.dst = dst
+
+    def __call__(self, obj, name=None):
+        if name is None:
+            # Automatically get names of functions and classes
+            name = obj.__name__
+        if hasattr(self.dst, name):
+            raise RuntimeError(
+                f"{self.dst} already has the attribute {name}, which is {getattr(self.dst, name)}."
+            )
+        setattr(self.dst, name, obj)
+        if hasattr(self.dst, '__all__'):
+            self.dst.__all__.append(name)
+        return obj
+
+    @staticmethod
+    def to(dst):
+        return Attach(dst)

+ 253 - 0
examples/rs_research/config_utils.py

@@ -0,0 +1,253 @@
+#!/usr/bin/env python
+
+import argparse
+import os.path as osp
+from collections.abc import Mapping
+
+import yaml
+
+
+def _chain_maps(*maps):
+    chained = dict()
+    keys = set().union(*maps)
+    for key in keys:
+        vals = [m[key] for m in maps if key in m]
+        if isinstance(vals[0], Mapping):
+            chained[key] = _chain_maps(*vals)
+        else:
+            chained[key] = vals[0]
+    return chained
+
+
+def read_config(config_path):
+    with open(config_path, 'r', encoding='utf-8') as f:
+        cfg = yaml.safe_load(f)
+    return cfg or {}
+
+
+def parse_configs(cfg_path, inherit=True):
+    if inherit:
+        cfgs = []
+        cfgs.append(read_config(cfg_path))
+        while cfgs[-1].get('_base_'):
+            base_path = cfgs[-1].pop('_base_')
+            curr_dir = osp.dirname(cfg_path)
+            cfgs.append(
+                read_config(osp.normpath(osp.join(curr_dir, base_path))))
+        return _chain_maps(*cfgs)
+    else:
+        return read_config(cfg_path)
+
+
+def _cfg2args(cfg, parser, prefix=''):
+    node_keys = set()
+    for k, v in cfg.items():
+        opt = prefix + k
+        if isinstance(v, list):
+            if len(v) == 0:
+                parser.add_argument(
+                    '--' + opt, type=object, nargs='*', default=v)
+            else:
+                # Only apply to homogeneous lists
+                if isinstance(v[0], CfgNode):
+                    node_keys.add(opt)
+                parser.add_argument(
+                    '--' + opt, type=type(v[0]), nargs='*', default=v)
+        elif isinstance(v, dict):
+            # Recursively parse a dict
+            _, new_node_keys = _cfg2args(v, parser, opt + '.')
+            node_keys.update(new_node_keys)
+        elif isinstance(v, CfgNode):
+            node_keys.add(opt)
+            _, new_node_keys = _cfg2args(v.to_dict(), parser, opt + '.')
+            node_keys.update(new_node_keys)
+        elif isinstance(v, bool):
+            parser.add_argument('--' + opt, action='store_true', default=v)
+        else:
+            parser.add_argument('--' + opt, type=type(v), default=v)
+    return parser, node_keys
+
+
+def _args2cfg(cfg, args, node_keys):
+    args = vars(args)
+    for k, v in args.items():
+        pos = k.find('.')
+        if pos != -1:
+            # Iteratively parse a dict
+            dict_ = cfg
+            while pos != -1:
+                dict_.setdefault(k[:pos], {})
+                dict_ = dict_[k[:pos]]
+                k = k[pos + 1:]
+                pos = k.find('.')
+            dict_[k] = v
+        else:
+            cfg[k] = v
+
+    for k in node_keys:
+        pos = k.find('.')
+        if pos != -1:
+            # Iteratively parse a dict
+            dict_ = cfg
+            while pos != -1:
+                dict_.setdefault(k[:pos], {})
+                dict_ = dict_[k[:pos]]
+                k = k[pos + 1:]
+                pos = k.find('.')
+            v = dict_[k]
+            dict_[k] = [CfgNode(v_) for v_ in v] if isinstance(
+                v, list) else CfgNode(v)
+        else:
+            v = cfg[k]
+            cfg[k] = [CfgNode(v_) for v_ in v] if isinstance(
+                v, list) else CfgNode(v)
+
+    return cfg
+
+
+def parse_args(*args, **kwargs):
+    cfg_parser = argparse.ArgumentParser(add_help=False)
+    cfg_parser.add_argument('--config', type=str, default='')
+    cfg_parser.add_argument('--inherit_off', action='store_true')
+    cfg_args = cfg_parser.parse_known_args()[0]
+    cfg_path = cfg_args.config
+    inherit_on = not cfg_args.inherit_off
+
+    # Main parser
+    parser = argparse.ArgumentParser(
+        conflict_handler='resolve', parents=[cfg_parser])
+    # Global settings
+    parser.add_argument('cmd', choices=['train', 'eval'])
+    parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg'])
+
+    # Data
+    parser.add_argument('--datasets', type=dict, default={})
+    parser.add_argument('--transforms', type=dict, default={})
+    parser.add_argument('--download_on', action='store_true')
+    parser.add_argument('--download_url', type=str, default='')
+    parser.add_argument('--download_path', type=str, default='./')
+
+    # Optimizer
+    parser.add_argument('--optimizer', type=dict, default={})
+
+    # Training related
+    parser.add_argument('--num_epochs', type=int, default=100)
+    parser.add_argument('--train_batch_size', type=int, default=8)
+    parser.add_argument('--save_interval_epochs', type=int, default=1)
+    parser.add_argument('--log_interval_steps', type=int, default=1)
+    parser.add_argument('--save_dir', default='../exp/')
+    parser.add_argument('--learning_rate', type=float, default=0.01)
+    parser.add_argument('--early_stop', action='store_true')
+    parser.add_argument('--early_stop_patience', type=int, default=5)
+    parser.add_argument('--use_vdl', action='store_true')
+    parser.add_argument('--resume_checkpoint', type=str)
+    parser.add_argument('--train', type=dict, default={})
+
+    # Loss
+    parser.add_argument('--losses', type=dict, nargs='+', default={})
+
+    # Model
+    parser.add_argument('--model', type=dict, default={})
+
+    if osp.exists(cfg_path):
+        cfg = parse_configs(cfg_path, inherit_on)
+        parser, node_keys = _cfg2args(cfg, parser, '')
+        node_keys = sorted(node_keys, reverse=True)
+        args = parser.parse_args(*args, **kwargs)
+        return _args2cfg(dict(), args, node_keys)
+    elif cfg_path != '':
+        raise FileNotFoundError
+    else:
+        args = parser.parse_args()
+        return _args2cfg(dict(), args, set())
+
+
+class _CfgNodeMeta(yaml.YAMLObjectMetaclass):
+    def __call__(cls, obj):
+        if isinstance(obj, CfgNode):
+            return obj
+        return super(_CfgNodeMeta, cls).__call__(obj)
+
+
+class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
+    yaml_tag = u'!Node'
+    yaml_loader = yaml.SafeLoader
+    # By default use a lexical scope
+    ctx = globals()
+
+    def __init__(self, dict_):
+        super().__init__()
+        self.type = dict_['type']
+        self.args = dict_.get('args', [])
+        self.module = dict_.get('module', '')
+
+    @classmethod
+    def set_context(cls, ctx):
+        # TODO: Implement dynamic scope with inspect.stack()
+        old_ctx = cls.ctx
+        cls.ctx = ctx
+        return old_ctx
+
+    def build_object(self, mod=None):
+        if mod is None:
+            mod = self._get_module(self.module)
+        cls = getattr(mod, self.type)
+        if isinstance(self.args, list):
+            args = build_objects(self.args)
+            obj = cls(*args)
+        elif isinstance(self.args, dict):
+            args = build_objects(self.args)
+            obj = cls(**args)
+        else:
+            raise NotImplementedError
+        return obj
+
+    def _get_module(self, s):
+        mod = None
+        while s:
+            idx = s.find('.')
+            if idx == -1:
+                next_ = s
+                s = ''
+            else:
+                next_ = s[:idx]
+                s = s[idx + 1:]
+            if mod is None:
+                mod = self.ctx[next_]
+            else:
+                mod = getattr(mod, next_)
+        return mod
+
+    @staticmethod
+    def build_objects(cfg, mod=None):
+        if isinstance(cfg, list):
+            return [CfgNode.build_objects(c, mod=mod) for c in cfg]
+        elif isinstance(cfg, CfgNode):
+            return cfg.build_object(mod=mod)
+        elif isinstance(cfg, dict):
+            return {
+                k: CfgNode.build_objects(
+                    v, mod=mod)
+                for k, v in cfg.items()
+            }
+        else:
+            return cfg
+
+    def __repr__(self):
+        return f"(type={self.type}, args={self.args}, module={self.module or ' '})"
+
+    @classmethod
+    def from_yaml(cls, loader, node):
+        map_ = loader.construct_mapping(node)
+        return cls(map_)
+
+    def items(self):
+        yield from [('type', self.type), ('args', self.args), ('module',
+                                                               self.module)]
+
+    def to_dict(self):
+        return dict(self.items())
+
+
+def build_objects(cfg, mod=None):
+    return CfgNode.build_objects(cfg, mod=mod)

+ 6 - 0
examples/rs_research/configs/levircd/bit.yaml

@@ -0,0 +1,6 @@
+_base_: ./levircd.yaml
+
+save_dir: ./exp/bit/
+
+model: !Node
+    type: BIT

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2_gamma01.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter2_gamma01/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 2
+        gamma: 0.1
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2_gamma02.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter2_gamma02/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 2
+        gamma: 0.2
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2_gamma05.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter2_gamma05/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 2
+        gamma: 0.5
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma01.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter3_gamma01/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 3
+        gamma: 0.1
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma02.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter3_gamma02/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 3
+        gamma: 0.2
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma05.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter3_gamma05/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 3
+        gamma: 0.5
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 12 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma10.yaml

@@ -0,0 +1,12 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/custom_model/iter3_gamma10/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 3
+        gamma: 1.0
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 4

+ 74 - 0
examples/rs_research/configs/levircd/levircd.yaml

@@ -0,0 +1,74 @@
+# Basic configurations of LEVIR-CD dataset
+
+datasets:
+    train: !Node
+        type: CDDataset
+        args: 
+            data_dir: ./data/levircd/
+            file_list: ./data/levircd/train.txt
+            label_list: null
+            num_workers: 2
+            shuffle: True
+            with_seg_labels: False
+            binarize_labels: True
+    eval: !Node
+        type: CDDataset
+        args:
+            data_dir: ./data/levircd/
+            file_list: ./data/levircd/val.txt
+            label_list: null
+            num_workers: 0
+            shuffle: False
+            with_seg_labels: False
+            binarize_labels: True
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: RandomFlipOrRotate
+          args:
+            probs: [0.35, 0.35]
+            probsf: [0.5, 0.5, 0, 0, 0]
+            probsr: [0.33, 0.34, 0.33]
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['eval']
+download_on: False
+
+num_epochs: 40
+train_batch_size: 8
+optimizer: !Node
+    type: Adam
+    args:
+        learning_rate: !Node
+            type: StepDecay
+            module: paddle.optimizer.lr
+            args:
+                learning_rate: 0.002
+                step_size: 30
+                gamma: 0.2
+save_interval_epochs: 10
+log_interval_steps: 500
+save_dir: ./exp/
+learning_rate: 0.002
+early_stop: False
+early_stop_patience: 5
+use_vdl: True
+resume_checkpoint: ''

+ 58 - 0
examples/rs_research/custom_model.py

@@ -0,0 +1,58 @@
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import paddlers
+from paddlers.rs_models.cd import BIT
+from attach_tools import Attach
+
+attach = Attach.to(paddlers.rs_models.cd)
+
+
+@attach
+class IterativeBIT(nn.Layer):
+    def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
+        super().__init__()
+
+        if num_iters <= 0:
+            raise ValueError(
+                f"`num_iters` should have positive value, but got {num_iters}.")
+
+        self.num_iters = num_iters
+        self.gamma = gamma
+
+        if bit_kwargs is None:
+            bit_kwargs = dict()
+
+        if 'num_classes' in bit_kwargs:
+            raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
+        bit_kwargs['num_classes'] = num_classes
+
+        self.bit = BIT(**bit_kwargs)
+
+    def forward(self, t1, t2):
+        rate_map = self._init_rate_map(t1.shape)
+
+        for it in range(self.num_iters):
+            # Construct inputs
+            x1 = self._constr_iter_input(t1, rate_map)
+            x2 = self._constr_iter_input(t2, rate_map)
+            # Get logits
+            logits_list = self.bit(x1, x2)
+            # Construct rate map
+            prob_map = F.softmax(logits_list[0], axis=1)
+            rate_map = self._constr_rate_map(prob_map)
+
+        return logits_list
+
+    def _constr_iter_input(self, im, rate_map):
+        return paddle.concat([im.rate_map], axis=1)
+
+    def _init_rate_map(self, im_shape):
+        b, _, h, w = im_shape
+        return paddle.zeros((b, 1, h, w))
+
+    def _constr_rate_map(self, prob_map):
+        if prob_map.shape[1] != 2:
+            raise ValueError(
+                f"`prob_map.shape[1]` must be 2, but got {prob_map.shape[1]}.")
+        return (prob_map[:, 1:2] * self.gamma)

+ 29 - 0
examples/rs_research/custom_trainer.py

@@ -0,0 +1,29 @@
+import paddlers
+from paddlers.tasks.change_detector import BaseChangeDetector
+
+from attach_tools import Attach
+
+attach = Attach.to(paddlers.tasks.change_detector)
+
+
+@attach
+class IterativeBIT(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 num_iters=1,
+                 gamma=0.1,
+                 bit_kwargs=None,
+                 **params):
+        params.update({
+            'num_iters': num_iters,
+            'gamma': gamma,
+            'bit_kwargs': bit_kwargs
+        })
+        super().__init__(
+            model_name='IterativeBIT',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)

二進制
examples/rs_research/params_versus_f1.png


+ 115 - 0
examples/rs_research/run_task.py

@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+
+import os
+
+import paddle
+import paddlers
+from paddlers import transforms as T
+
+import custom_model
+import custom_trainer
+from config_utils import parse_args, build_objects, CfgNode
+
+
+def format_cfg(cfg, indent=0):
+    s = ''
+    if isinstance(cfg, dict):
+        for i, (k, v) in enumerate(sorted(cfg.items())):
+            s += ' ' * indent + str(k) + ': '
+            if isinstance(v, (dict, list, CfgNode)):
+                s += '\n' + format_cfg(v, indent=indent + 1)
+            else:
+                s += str(v)
+            if i != len(cfg) - 1:
+                s += '\n'
+    elif isinstance(cfg, list):
+        for i, v in enumerate(cfg):
+            s += ' ' * indent + '- '
+            if isinstance(v, (dict, list, CfgNode)):
+                s += '\n' + format_cfg(v, indent=indent + 1)
+            else:
+                s += str(v)
+            if i != len(cfg) - 1:
+                s += '\n'
+    elif isinstance(cfg, CfgNode):
+        s += ' ' * indent + f"type: {cfg.type}" + '\n'
+        s += ' ' * indent + f"module: {cfg.module}" + '\n'
+        s += ' ' * indent + 'args: \n' + format_cfg(cfg.args, indent + 1)
+    return s
+
+
+if __name__ == '__main__':
+    CfgNode.set_context(globals())
+
+    cfg = parse_args()
+    print(format_cfg(cfg))
+
+    # Automatically download data
+    if cfg['download_on']:
+        paddlers.utils.download_and_decompress(
+            cfg['download_url'], path=cfg['download_path'])
+
+    if cfg['cmd'] == 'train':
+        if not isinstance(cfg['datasets']['train'].args, dict):
+            raise ValueError("args of train dataset must be a dict!")
+        if cfg['datasets']['train'].args.get('transforms', None) is not None:
+            raise ValueError(
+                "Found key 'transforms' in args of train dataset and the value is not None."
+            )
+        train_transforms = T.Compose(
+            build_objects(
+                cfg['transforms']['train'], mod=T))
+        # Inplace modification
+        cfg['datasets']['train'].args['transforms'] = train_transforms
+        train_dataset = build_objects(
+            cfg['datasets']['train'], mod=paddlers.datasets)
+    if not isinstance(cfg['datasets']['eval'].args, dict):
+        raise ValueError("args of eval dataset must be a dict!")
+    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
+        raise ValueError(
+            "Found key 'transforms' in args of eval dataset and the value is not None."
+        )
+    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
+    # Inplace modification
+    cfg['datasets']['eval'].args['transforms'] = eval_transforms
+    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
+
+    model = build_objects(
+        cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
+
+    if cfg['cmd'] == 'train':
+        if cfg['optimizer']:
+            if len(cfg['optimizer'].args) == 0:
+                cfg['optimizer'].args = {}
+            if not isinstance(cfg['optimizer'].args, dict):
+                raise TypeError("args of optimizer must be a dict!")
+            if cfg['optimizer'].args.get('parameters', None) is not None:
+                raise ValueError(
+                    "Found key 'parameters' in args of optimizer and the value is not None."
+                )
+            cfg['optimizer'].args['parameters'] = model.net.parameters()
+            optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer)
+        else:
+            optimizer = None
+
+        model.train(
+            num_epochs=cfg['num_epochs'],
+            train_dataset=train_dataset,
+            train_batch_size=cfg['train_batch_size'],
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=cfg['save_interval_epochs'],
+            log_interval_steps=cfg['log_interval_steps'],
+            save_dir=cfg['save_dir'],
+            learning_rate=cfg['learning_rate'],
+            early_stop=cfg['early_stop'],
+            early_stop_patience=cfg['early_stop_patience'],
+            use_vdl=cfg['use_vdl'],
+            resume_checkpoint=cfg['resume_checkpoint'] or None,
+            **cfg['train'])
+    elif cfg['cmd'] == 'eval':
+        state_dict = paddle.load(
+            os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
+        model.net.set_state_dict(state_dict)
+        res = model.evaluate(eval_dataset)
+        print(res)

+ 0 - 0
examples/rs_research/configs/levircd/custom_model.yaml → examples/rs_research/scripts/run_benchmark.sh


+ 0 - 0
examples/rs_research/test.py → examples/rs_research/scripts/run_parameter_analysis.sh


+ 0 - 0
examples/rs_research/train.py


+ 4 - 4
test_tipc/configs/seg/unet/unet.yaml

@@ -5,7 +5,7 @@ _base_: ../_base_/rsseg.yaml
 save_dir: ./test_tipc/output/seg/unet/
 
 model: !Node
-       type: UNet
-       args:
-           input_channel: 10
-           num_classes: 5
+    type: UNet
+        args:
+            input_channel: 10
+            num_classes: 5