瀏覽代碼

Merge pull request #29 from Bobholamovic/add_examples

[Example] Update Example Docs and Add Research Example
cc 2 年之前
父節點
當前提交
4eaea16c0d
共有 45 個文件被更改,包括 2538 次插入118 次删除
  1. 1 2
      README.md
  2. 26 26
      docs/data/coco_tools.md
  3. 1 1
      docs/intro/transforms.md
  4. 33 0
      examples/README.md
  5. 2 0
      examples/rs_research/.gitignore
  6. 468 0
      examples/rs_research/README.md
  7. 35 0
      examples/rs_research/attach_tools.py
  8. 267 0
      examples/rs_research/config_utils.py
  9. 8 0
      examples/rs_research/configs/levircd/ablation/custom_model_c.yaml
  10. 8 0
      examples/rs_research/configs/levircd/ablation/custom_model_t.yaml
  11. 6 0
      examples/rs_research/configs/levircd/custom_model.yaml
  12. 6 0
      examples/rs_research/configs/levircd/fc_ef.yaml
  13. 6 0
      examples/rs_research/configs/levircd/fc_siam_conc.yaml
  14. 6 0
      examples/rs_research/configs/levircd/fc_siam_diff.yaml
  15. 74 0
      examples/rs_research/configs/levircd/levircd.yaml
  16. 241 0
      examples/rs_research/custom_model.py
  17. 79 0
      examples/rs_research/custom_trainer.py
  18. 82 0
      examples/rs_research/predict_cd.py
  19. 129 0
      examples/rs_research/run_task.py
  20. 17 0
      examples/rs_research/scripts/run_ablation.sh
  21. 22 0
      examples/rs_research/scripts/run_benchmark.sh
  22. 148 0
      examples/rs_research/tools/analyze_model.py
  23. 75 0
      examples/rs_research/tools/collect_imgs.py
  24. 228 0
      examples/rs_research/tools/visualize_feats.py
  25. 111 0
      examples/rs_research/train_cd.py
  26. 5 5
      paddlers/tasks/change_detector.py
  27. 3 3
      paddlers/transforms/operators.py
  28. 4 1
      test_tipc/common_func.sh
  29. 3 2
      test_tipc/config_utils.py
  30. 62 0
      test_tipc/configs/cd/_base_/levircd.yaml
  31. 1 1
      test_tipc/configs/cd/bit/bit.yaml
  32. 8 0
      test_tipc/configs/cd/bit/bit_airchange.yaml
  33. 8 0
      test_tipc/configs/cd/bit/bit_levircd.yaml
  34. 3 3
      test_tipc/configs/cd/bit/train_infer_python.txt
  35. 1 1
      test_tipc/configs/cd/changeformer/changeformer.yaml
  36. 1 1
      test_tipc/configs/clas/_base_/ucmerced.yaml
  37. 3 3
      test_tipc/configs/clas/hrnet/hrnet.yaml
  38. 3 3
      test_tipc/configs/det/ppyolo/ppyolo.yaml
  39. 1 1
      test_tipc/configs/seg/unet/unet.yaml
  40. 11 1
      test_tipc/prepare.sh
  41. 14 18
      test_tipc/run_task.py
  42. 40 46
      test_tipc/test_train_inference_python.sh
  43. 215 0
      tools/prepare_dataset/common.py
  44. 42 0
      tools/prepare_dataset/prepare_levircd.py
  45. 31 0
      tools/prepare_dataset/prepare_svcd.py

+ 1 - 2
README.md

@@ -200,8 +200,7 @@ PaddleRS目录树中关键部分如下:
   * [Python部署](./deploy/README.md)
   * [Python部署](./deploy/README.md)
   * [模型推理API说明](./docs/apis/infer.md)
   * [模型推理API说明](./docs/apis/infer.md)
 * 实践案例
 * 实践案例
-  * [遥感影像变化检测案例](./docs/cases/csc_cd_cn.md)
-  * [遥感影像超分辨率重建案例](./docs/cases/sr_seg_cn.md)
+  * [PaddleRS实践案例库](./examples/README.md)
 * 代码贡献
 * 代码贡献
   * [贡献指南](./docs/CONTRIBUTING.md)
   * [贡献指南](./docs/CONTRIBUTING.md)
   * [开发指南](./docs/dev/dev_guide.md)
   * [开发指南](./docs/dev/dev_guide.md)

+ 26 - 26
docs/data/coco_tools.md

@@ -17,7 +17,7 @@ coco_tools是PaddleRS提供的用于处理COCO格式标注文件的工具集,
 
 
 ## 3 使用示例
 ## 3 使用示例
 
 
-## 3.1 示例数据集
+### 3.1 示例数据集
 
 
 本文档以COCO 2017数据集作为示例数据进行演示。您可以在以下链接下载该数据集:
 本文档以COCO 2017数据集作为示例数据进行演示。您可以在以下链接下载该数据集:
 
 
@@ -47,11 +47,11 @@ coco_tools是PaddleRS提供的用于处理COCO格式标注文件的工具集,
 |  |--...
 |  |--...
 ```
 ```
 
 
-## 3.2 打印json信息
+### 3.2 打印json信息
 
 
 使用`json_InfoShow.py`可以打印json文件中的各个键值对的key, 并输出value中排列靠前的元素,从而帮助您快速了解标注信息。对于COCO格式标注数据而言,您应该特别留意`'image'`和`'annotation'`字段的内容。
 使用`json_InfoShow.py`可以打印json文件中的各个键值对的key, 并输出value中排列靠前的元素,从而帮助您快速了解标注信息。对于COCO格式标注数据而言,您应该特别留意`'image'`和`'annotation'`字段的内容。
 
 
-### 3.2.1 命令演示
+#### 3.2.1 命令演示
 
 
 执行如下命令,打印`instances_val2017.json`中的信息:
 执行如下命令,打印`instances_val2017.json`中的信息:
 
 
@@ -61,7 +61,7 @@ python ./coco_tools/json_InfoShow.py \
        --show_num 5
        --show_num 5
 ```
 ```
 
 
-### 3.2.2 参数说明
+#### 3.2.2 参数说明
 
 
 
 
 | 参数名         | 含义                                 | 默认值    |
 | 参数名         | 含义                                 | 默认值    |
@@ -70,7 +70,7 @@ python ./coco_tools/json_InfoShow.py \
 | `--show_num`  | (可选)输出value中排列靠前的元素的个数   | `5`      |
 | `--show_num`  | (可选)输出value中排列靠前的元素的个数   | `5`      |
 | `--Args_show` | (可选)是否打印输入参数信息             | `True`   |
 | `--Args_show` | (可选)是否打印输入参数信息             | `True`   |
 
 
-### 3.2.3 结果展示
+#### 3.2.3 结果展示
 
 
 执行上述命令后,输出结果如下:
 执行上述命令后,输出结果如下:
 
 
@@ -151,7 +151,7 @@ contributor : COCO Consortium
 
 
 ```
 ```
 
 
-### 3.2.4 结果说明
+#### 3.2.4 结果说明
 
 
 `instances_val2017.json`的key有5个,分别为:
 `instances_val2017.json`的key有5个,分别为:
 
 
@@ -166,11 +166,11 @@ contributor : COCO Consortium
 - `annotations`键对应的值为列表,共有36781个元素,输出展示了前5个;
 - `annotations`键对应的值为列表,共有36781个元素,输出展示了前5个;
 - `categories`键对应的值为列表,共有80个元素,输出展示了前5个。
 - `categories`键对应的值为列表,共有80个元素,输出展示了前5个。
 
 
-## 3.3 统计图像信息
+### 3.3 统计图像信息
 
 
 使用`json_ImgSta.py`可以从`instances_val2017.json`中快速提取图像信息,生成csv表格,并生成统计图。
 使用`json_ImgSta.py`可以从`instances_val2017.json`中快速提取图像信息,生成csv表格,并生成统计图。
 
 
-### 3.3.1 命令演示
+#### 3.3.1 命令演示
 
 
 执行如下命令,打印`instances_val2017.json`信息:
 执行如下命令,打印`instances_val2017.json`信息:
 
 
@@ -182,7 +182,7 @@ python ./coco_tools/json_ImgSta.py \
     --png_shapeRate_path=./img_sta/images_shapeRate.png
     --png_shapeRate_path=./img_sta/images_shapeRate.png
 ```
 ```
 
 
-### 3.3.2 参数说明
+#### 3.3.2 参数说明
 
 
 | 参数名                  | 含义                                                                   | 默认值    |
 | 参数名                  | 含义                                                                   | 默认值    |
 | ---------------------- | --------------------------------------------------------------------- | -------- |
 | ---------------------- | --------------------------------------------------------------------- | -------- |
@@ -193,7 +193,7 @@ python ./coco_tools/json_ImgSta.py \
 | `--image_keyname`      | (可选)json文件中,图像所对应的key                                        |`'images'`|
 | `--image_keyname`      | (可选)json文件中,图像所对应的key                                        |`'images'`|
 | `--Args_show`          | (可选)是否打印输入参数信息                                               |`True`    |
 | `--Args_show`          | (可选)是否打印输入参数信息                                               |`True`    |
 
 
-### 3.3.3 结果展示
+#### 3.3.3 结果展示
 
 
 执行上述命令后,输出结果如下:
 执行上述命令后,输出结果如下:
 
 
@@ -232,11 +232,11 @@ csv save to ./img_sta/images.csv
 所有图像shape比例(宽/高)的一维分布:
 所有图像shape比例(宽/高)的一维分布:
 ![image.png](./assets/1650011634205-image.png)
 ![image.png](./assets/1650011634205-image.png)
 
 
-## 3.4 统计目标检测标注框信息
+### 3.4 统计目标检测标注框信息
 
 
 使用`json_AnnoSta.py`,可以从`instances_val2017.json`中快速提取标注信息,生成csv表格,并生成统计图。
 使用`json_AnnoSta.py`,可以从`instances_val2017.json`中快速提取标注信息,生成csv表格,并生成统计图。
 
 
-### 3.4.1 命令演示
+#### 3.4.1 命令演示
 
 
 执行如下命令,打印`instances_val2017.json`信息:
 执行如下命令,打印`instances_val2017.json`信息:
 
 
@@ -253,7 +253,7 @@ python ./coco_tools/json_AnnoSta.py \
     --get_relative=True
     --get_relative=True
 ```
 ```
 
 
-### 3.4.2 参数说明
+#### 3.4.2 参数说明
 
 
 | 参数名                  | 含义                                                                                                                       | 默认值         |
 | 参数名                  | 含义                                                                                                                       | 默认值         |
 | ---------------------- | ------------------------------------------------------------------------------------------------------------------------- | ------------- |
 | ---------------------- | ------------------------------------------------------------------------------------------------------------------------- | ------------- |
@@ -270,7 +270,7 @@ python ./coco_tools/json_AnnoSta.py \
 | `--anno_keyname`       | (可选)json文件中,标注所对应的key                                                                                             | `'annotations'`|
 | `--anno_keyname`       | (可选)json文件中,标注所对应的key                                                                                             | `'annotations'`|
 | `--Args_show`          | (可选)是否打印输入参数信息                                                                                                    | `True`        |
 | `--Args_show`          | (可选)是否打印输入参数信息                                                                                                    | `True`        |
 
 
-### 3.4.3 结果展示
+#### 3.4.3 结果展示
 
 
 执行上述命令后,输出结果如下:
 执行上述命令后,输出结果如下:
 
 
@@ -344,11 +344,11 @@ csv save to ./anno_sta/annos.csv
 
 
 ![image.png](./assets/1650026559309-image.png)
 ![image.png](./assets/1650026559309-image.png)
 
 
-## 3.5 统计图像信息生成json
+### 3.5 统计图像信息生成json
 
 
 使用`json_Test2Json.py`,可以根据`test2017`中的文件信息与训练集json文件快速提取图像信息,生成测试集json文件。
 使用`json_Test2Json.py`,可以根据`test2017`中的文件信息与训练集json文件快速提取图像信息,生成测试集json文件。
 
 
-### 3.5.1 命令演示
+#### 3.5.1 命令演示
 
 
 执行如下命令,统计并生成`test2017`信息:
 执行如下命令,统计并生成`test2017`信息:
 
 
@@ -359,7 +359,7 @@ python ./coco_tools/json_Img2Json.py \
     --json_test_path=./test.json
     --json_test_path=./test.json
 ```
 ```
 
 
-### 3.5.2 参数说明
+#### 3.5.2 参数说明
 
 
 
 
 | 参数名               | 含义                                      | 默认值        |
 | 参数名               | 含义                                      | 默认值        |
@@ -371,7 +371,7 @@ python ./coco_tools/json_Img2Json.py \
 | `--cat_keyname`     | (可选)json文件中,类别对应的key            | `'categories'`|
 | `--cat_keyname`     | (可选)json文件中,类别对应的key            | `'categories'`|
 | `--Args_show`       | (可选)是否打印输入参数信息                 | `True`        |
 | `--Args_show`       | (可选)是否打印输入参数信息                 | `True`        |
 
 
-### 3.5.3 结果展示
+#### 3.5.3 结果展示
 
 
 执行上述命令后,输出结果如下:
 执行上述命令后,输出结果如下:
 
 
@@ -431,11 +431,11 @@ json keys: dict_keys(['images', 'categories'])
 ...
 ...
 ```
 ```
 
 
-## 3.6 json文件拆分
+### 3.6 json文件拆分
 
 
 使用`json_Split.py`,可以将`instances_val2017.json`文件拆分为2个子集。
 使用`json_Split.py`,可以将`instances_val2017.json`文件拆分为2个子集。
 
 
-### 3.6.1 命令演示
+#### 3.6.1 命令演示
 
 
 执行如下命令,拆分`instances_val2017.json`文件:
 执行如下命令,拆分`instances_val2017.json`文件:
 
 
@@ -446,7 +446,7 @@ python ./coco_tools/json_Split.py \
     --json_val_path=./instances_val2017_val.json
     --json_val_path=./instances_val2017_val.json
 ```
 ```
 
 
-### 3.6.2 参数说明
+#### 3.6.2 参数说明
 
 
 
 
 | 参数名                | 含义                                                                                   | 默认值        |
 | 参数名                | 含义                                                                                   | 默认值        |
@@ -461,7 +461,7 @@ python ./coco_tools/json_Split.py \
 | `--cat_keyname`      | (可选)json文件中,类别对应的key                                                         | `'categories'`|
 | `--cat_keyname`      | (可选)json文件中,类别对应的key                                                         | `'categories'`|
 | `--Args_show`        | (可选)是否打印输入参数信息                                                              | `'True'`      |
 | `--Args_show`        | (可选)是否打印输入参数信息                                                              | `'True'`      |
 
 
-### 3.6.3 结果展示
+#### 3.6.3 结果展示
 
 
 执行上述命令后,输出结果如下:
 执行上述命令后,输出结果如下:
 
 
@@ -485,11 +485,11 @@ image total 5000, train 4500, val 500
 anno total 36781, train 33119, val 3662
 anno total 36781, train 33119, val 3662
 ```
 ```
 
 
-## 3.7 json文件合并
+### 3.7 json文件合并
 
 
 使用`json_Merge.py`,可以合并2个json文件。
 使用`json_Merge.py`,可以合并2个json文件。
 
 
-### 3.7.1 命令演示
+#### 3.7.1 命令演示
 
 
 执行如下命令,合并`instances_train2017.json`与`instances_val2017.json`:
 执行如下命令,合并`instances_train2017.json`与`instances_val2017.json`:
 
 
@@ -500,7 +500,7 @@ python ./coco_tools/json_Merge.py \
     --save_path=./instances_trainval2017.json
     --save_path=./instances_trainval2017.json
 ```
 ```
 
 
-### 3.7.2 参数说明
+#### 3.7.2 参数说明
 
 
 
 
 | 参数名          | 含义                             | 默认值                       |
 | 参数名          | 含义                             | 默认值                       |
@@ -511,7 +511,7 @@ python ./coco_tools/json_Merge.py \
 | `--merge_keys` | (可选)合并过程中需要合并的key      | `['images', 'annotations']` |
 | `--merge_keys` | (可选)合并过程中需要合并的key      | `['images', 'annotations']` |
 | `--Args_show`  | (可选)是否打印输入参数信息         | `True`                      |
 | `--Args_show`  | (可选)是否打印输入参数信息         | `True`                      |
 
 
-### 3.7.3 结果展示
+#### 3.7.3 结果展示
 
 
 执行上述命令后,输出结果如下:
 执行上述命令后,输出结果如下:
 
 

+ 1 - 1
docs/intro/transforms.md

@@ -12,7 +12,7 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为
 | RandomResizeByShort  | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... |
 | RandomResizeByShort  | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... |
 | ResizeByLong         | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数)。 | 所有任务 | ... |
 | ResizeByLong         | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数)。 | 所有任务 | ... |
 | RandomHorizontalFlip | 随机水平翻转输入影像。 | 所有任务 | ... |
 | RandomHorizontalFlip | 随机水平翻转输入影像。 | 所有任务 | ... |
-| RandomVerticalFlip   | 随机直翻转输入影像。 | 所有任务 | ... |
+| RandomVerticalFlip   | 随机直翻转输入影像。 | 所有任务 | ... |
 | Normalize            | 对输入影像应用标准化。 | 所有任务 | ... |
 | Normalize            | 对输入影像应用标准化。 | 所有任务 | ... |
 | CenterCrop           | 对输入影像进行中心裁剪。 | 所有任务 | ... |
 | CenterCrop           | 对输入影像进行中心裁剪。 | 所有任务 | ... |
 | RandomCrop           | 对输入影像进行随机中心裁剪。 | 所有任务 | ... |
 | RandomCrop           | 对输入影像进行随机中心裁剪。 | 所有任务 | ... |

+ 33 - 0
examples/README.md

@@ -0,0 +1,33 @@
+# PaddleRS实践案例
+
+PaddleRS提供从科学研究到产业应用的丰富示例,希望帮助遥感领域科研从业者快速完成算法的研发、验证和调优,以及帮助投身于产业实践的开发者便捷地实现从数据预处理到模型部署的全流程遥感深度学习应用。
+
+## 1 官方案例
+
+- [PaddleRS科研实战:设计深度学习变化检测模型](./rs_research/)
+
+## 2 社区贡献案例
+
+[AI Studio](https://aistudio.baidu.com/aistudio/index)是基于百度深度学习平台飞桨的人工智能学习与实训社区,提供在线编程环境、免费GPU算力、海量开源算法和开放数据,帮助开发者快速创建和部署模型。您可以在AI Studio上探索PaddleRS的更多玩法:
+
+[AI Studio上的PaddleRS相关项目](https://aistudio.baidu.com/aistudio/projectoverview/public?kw=PaddleRS)
+
+本文档收集了部分由开源爱好者贡献的精品项目:
+
+|项目链接|项目作者|项目类型|关键词|
+|-|-|-|-|
+|[手把手教你PaddleRS实现变化检测](https://aistudio.baidu.com/aistudio/projectdetail/3737991)|奔向未来的样子|入门教程|变化检测|
+|[【PPSIG】PaddleRS变化检测模型部署:以BIT为例](https://aistudio.baidu.com/aistudio/projectdetail/4184759)|古代飞|入门教程|变化检测,模型部署|
+|[【PPSIG】PaddleRS实现遥感影像场景分类](https://aistudio.baidu.com/aistudio/projectdetail/4198965)|古代飞|入门教程|场景分类|
+|[PaddleRS:使用超分模型提高真实的低分辨率无人机影像的分割精度](https://aistudio.baidu.com/aistudio/projectdetail/3696814)|KeyK-小胡之父|应用案例|超分辨率重建,无人机影像|
+|[PaddleRS:无人机汽车识别](https://aistudio.baidu.com/aistudio/projectdetail/3713122)|geoyee|应用案例|目标检测,无人机影像|
+|[PaddleRS:高光谱卫星影像场景分类](https://aistudio.baidu.com/aistudio/projectdetail/3711240)|geoyee|应用案例|场景分类,高光谱影像|
+|[PaddleRS:利用卫星影像与数字高程模型进行滑坡识别](https://aistudio.baidu.com/aistudio/projectdetail/4066570)|KeyK-小胡之父|应用案例|图像分割,DEM|
+|[为PaddleRS添加一个袖珍配置系统](https://aistudio.baidu.com/aistudio/projectdetail/4203534)|古代飞|创意开发||
+|[万丈高楼平地起 基于PaddleGAN与PaddleRS的建筑物生成](https://aistudio.baidu.com/aistudio/projectdetail/3716885)|奔向未来的样子|创意开发|超分辨率重建|
+|[【官方】第十一届 “中国软件杯”百度遥感赛项:变化检测功能](https://aistudio.baidu.com/aistudio/projectdetail/3684588)|古代飞|竞赛打榜|变化检测,比赛基线|
+|[【官方】第十一届 “中国软件杯”百度遥感赛项:目标提取功能](https://aistudio.baidu.com/aistudio/projectdetail/3792610)|古代飞|竞赛打榜|图像分割,比赛基线|
+|[【官方】第十一届 “中国软件杯”百度遥感赛项:地物分类功能](https://aistudio.baidu.com/aistudio/projectdetail/3792606)|古代飞|竞赛打榜|图像分割,比赛基线|
+|[【官方】第十一届 “中国软件杯”百度遥感赛项:目标检测功能](https://aistudio.baidu.com/aistudio/projectdetail/3792609)|古代飞|竞赛打榜|目标检测,比赛基线|
+|[【十一届软件杯】遥感解译赛道:变化检测任务——预赛第四名方案分享](https://aistudio.baidu.com/aistudio/projectdetail/4116895)|lzzzzzm|竞赛打榜|变化检测,高分方案|
+|[【方案分享】第十一届 “中国软件杯”大学生软件设计大赛遥感解译赛道 比赛方案分享](https://aistudio.baidu.com/aistudio/projectdetail/4146154)|trainer|竞赛打榜|变化检测,高分方案|

+ 2 - 0
examples/rs_research/.gitignore

@@ -0,0 +1,2 @@
+/data/
+/exp/

+ 468 - 0
examples/rs_research/README.md

@@ -0,0 +1,468 @@
+# PaddleRS科研实战:设计深度学习变化检测模型
+
+本案例演示如何使用PaddleRS设计变化检测模型,并开展对比实验、消融实验和特征可视化实验。
+
+## 1 环境配置
+
+根据[教程](https://github.com/PaddlePaddle/PaddleRS/tree/develop/tutorials/train#环境准备)安装PaddleRS及相关依赖。在本案例中,GDAL库并不是必需的。
+
+配置好环境后,在PaddleRS仓库根目录中执行如下指令切换到本案例所在目录:
+
+```shell
+cd examples/rs_research
+```
+
+请注意,本文档仅所提供的所有指令遵循bash语法。
+
+## 2 数据准备
+
+本案例在[LEVIR-CD数据集](https://www.mdpi.com/2072-4292/12/10/1662)[1]上开展实验。请在[LEVIR-CD数据集下载链接](https://justchenhao.github.io/LEVIR/)下载数据集,解压至本地目录,并执行如下指令:
+
+```bash
+mkdir data/
+python ../../tools/prepare_dataset/prepare_levircd.py \
+    --in_dataset_dir "{LEVIR-CD数据集存放目录路径}" \
+    --out_dataset_dir "data/levircd" \
+    --crop_size 256 \
+    --crop_stride 256
+```
+
+以上指令利用PaddleRS提供的数据集准备工具完成数据集切分、file list创建等操作。具体而言,使用LEVIR-CD数据集官方的训练/验证/测试集划分,并将原始的`1024x1024`大小的影像切分为无重叠的`256x256`的小块(参考[2]中的做法).
+
+## 3 模型设计
+
+### 3.1 问题分析与思路拟定
+
+随着深度学习技术应用的不断深入,近年来,变化检测领域涌现了许多基于全卷积神经网络(fully convolutional network, FCN)的遥感影像变化检测算法。与基于特征和基于影像块的方法相比,基于FCN的方法具有处理效率高、依赖超参数少等优势,但其缺点在于参数量往往较大,因而对训练样本的数量更为依赖。尽管中、大型变化检测数据集的数量与日俱增,训练样本日益丰富,但深度学习变化检测模型的参数量也越来越大。下图显示了从2018年到2021年一些已发表的文献中提出的基于FCN的变化检测模型的参数量与其在SVCD数据集[3]上取得的F1分数(柱状图中bar的高度与模型参数量成正比):
+
+<p align="center">
+<img src="https://user-images.githubusercontent.com/21275753/186670936-5f79983c-914c-4e81-8f01-11df2beadf09.png" width="850">
+</p>
+
+诚然,增大参数数量在大多数情况下等同于增加模型容量,而模型容量的增加意味着模型拟合能力的提升,从而有助于模型在实验数据集上取得更高的精度指标。但是,“更大”一定意味着“更好”吗?答案显然是否定的。在实际应用中,“更大”的遥感影像变化检测模型常常遭遇如下问题:
+
+1. 巨大的参数量意味着巨大的存储开销。在许多实际场景中,硬件资源往往是有限的,过多的模型参数将给部署造成困难。
+2. 在数据有限的情况下,大模型更易遭受过拟合,其在实验数据集上看起来良好的检测效果也难以泛化到真实场景。
+
+本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,也即存在一部分“无用”的特征,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力,获取更多更加有效的特征?基于这个观点,本案例的基本思路是为现有的变化检测模型添加一个“插件式”的特征优化模块,在仅引入较少额外的参数数量的情况下,实现变化特征增强。本案例计划以变化检测领域经典的FC-Siam-conc[4]为基线(baseline)网络,利用通道和时间注意力模块对网络的中间层特征进行优化,从而减小特征冗余,提升检测效果。在具体的模块设计方面,选用论文[5]中提出的通道注意力模块实现通道和时间维度的特征增强。
+
+FC-Siam-conc的网络结构如图所示:
+
+<p align="center">
+<img src="https://user-images.githubusercontent.com/21275753/186671480-d869a500-6409-4f97-b48b-50ce95ea3a71.jpg" width="500">
+</p>
+
+本案例计划在解码器中首个Concat模块之前添加通道与时间注意力模块组合而成的混合注意力模块以优化从编码器传来的特征,并将新模型称为CustomModel。
+
+### 3.2 模型定义
+
+本小节基于PaddlePaddle框架与PaddleRS库实现[3.1节](#31-问题分析与思路拟定)中提出的想法。
+
+在`custom_model.py`中定义模型的整体结构以及组成模型的各个模块。本案例在`custom_model.py`中定义了改进后的FC-Siam-conc结构,其核心部分实现如下:
+
+```python
+...
+# PaddleRS提供了许多开箱即用的模块,其中有对底层基础模块的封装(如conv-bn-relu结构等),也有注意力模块等较高层级的结构
+from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
+from paddlers.rs_models.cd.layers import ChannelAttention
+
+from attach_tools import Attach
+
+attach = Attach.to(paddlers.rs_models.cd)
+
+@attach
+class CustomModel(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 num_classes,
+                 att_types='ct',
+                 use_dropout=False):
+        super().__init__()
+        ...
+        # 构建一个混合注意力模块att4,用于处理两个编码器最终输出的特征
+        self.att4 = MixedAttention(C4, att_types)
+
+        self.init_weight()
+
+    def forward(self, t1, t2):
+        ...
+        x4d = self.upconv4(x4p)
+        pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
+                x43_1.shape[2] - x4d.shape[2])
+        x4d = F.pad(x4d, pad=pad4, mode='replicate')
+        # 将注意力模块接入第一个解码单元
+        x43_1, x43_2 = self.att4(x43_1, x43_2)
+        x4d = paddle.concat([x4d, x43_1, x43_2], 1)
+        x43d = self.do43d(self.conv43d(x4d))
+        x42d = self.do42d(self.conv42d(x43d))
+        x41d = self.do41d(self.conv41d(x42d))
+        ...
+
+
+class MixedAttention(nn.Layer):
+    def __init__(self, in_channels, att_types='ct'):
+        super(MixedAttention, self).__init__()
+
+        self.att_types = att_types
+
+        # 每个注意力模块都是可选的
+        if self.has_att_c:
+            self.att_c = ChannelAttention(in_channels, ratio=1)
+        else:
+            self.att_c = Identity()
+
+        if has_att_t:
+            # 时间注意力模块部分复用通道注意力的逻辑,在`forward()`中将具体解释
+            self.att_t = ChannelAttention(2, ratio=1)
+        else:
+            self.att_t = Identity()
+
+    def forward(x1, x2):
+        # x1和x2分别是FC-Siam-conc的两路编码器提取的特征
+
+        if self.has_att_c:
+            # 首先使用通道注意力模块对特征进行优化
+            # 两个时相的编码特征共享通道注意力模块
+            # 添加残差连接以加速收敛
+            x1 = (1 + self.att_c(x1)) * x1
+            x2 = (1 + self.att_c(x2)) * x2
+
+        if self.has_att_t:
+            b, c = x1.shape[:2]
+            # 为了复用通道注意力模块执行时间维度的注意力操作,首先将两个时相的特征堆叠
+            y = paddle.stack([x1, x2], axis=2)
+            # 堆叠后的y形状为[b, c, t, h, w],其中b表示batch size,c为特征通道数,t为2(时相数目),h和w分别为特征图高宽
+            # 将b和c两个维度合并,输出tensor形状为[b*c, t, h, w]
+            y = paddle.flatten(y, stop_axis=1)
+            # 此时,时间维度已经替代了原先的通道维度,将四维tensor输入ChannelAttention模块进行处理
+            # 同样添加残差连接
+            y = (1 + self.att_t(y)) * y
+            # 从处理结果中分离两个时相的信息
+            y = y.reshape((b, c, 2, *y.shape[2:]))
+            y1, y2 = y[:, :, 0], y[:, :, 1]
+        else:
+            y1, y2 = x1, x2
+
+        return y1, y2
+
+    @property
+    def has_att_c(self):
+        return 'c' in self.att_types
+
+    @property
+    def has_att_t(self):
+        return 't' in self.att_types
+```
+
+在编写组网相关代码时请注意以下两点:
+
+1. 所有模型必须为`paddle.nn.Layer`的子类;
+2. 包含模型整体逻辑结构的最外层模块(如本例中的`CustomModel`类)须用`@attach`装饰;
+3. 对于变化检测任务,最外层模块的`forward()`方法除`self`参数外还接受两个参数`t1`、`t2`,分别表示第一时相和第二时相影像。
+
+关于模型定义的更多细节请参考[《开发指南》](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/dev/dev_guide.md)。
+
+## 4 模型训练
+
+本案例提供两种模型训练方式:基于脚本编写的方式与基于配置文件的方式。
+
+- 对于初学者,建议使用脚本编写的方式:该方式更易理解,代码逻辑简单,且无需编写自定义训练器。
+- 对于较为熟练的科研者,或者是有开展大量对比实验、消融实验需求的科研者,建议使用基于配置文件的方式:该方式能够更方便地管理模型的不同配置,且易于并行执行多组实验。
+
+需要说明的是,本文档中的实验结果均来自以基于配置文件方式训练的模型。本案例提供了本文档中涉及的全部实验的配置文件,存储在`configs`目录中。
+
+### 4.1 基于脚本编写的方式
+
+本案例提供`train_cd.py`脚本对模型进行训练和验证,并汇报验证集上最优模型在测试集上的精度。通过如下指令执行脚本:
+
+```bash
+python train_cd.py
+```
+
+阅读脚本中的注释有助于使用者理解每个步骤的含义。脚本默认实现LEVIR-CD数据集上对自定义模型CustomModel的训练和验证。在实验过程中,可以根据需要修改脚本中的部分代码,以实现超参数调优或是对不同模型进行训练的功能。
+
+训练程序默认开启VisualDL日志记录功能。训练过程中或训练完成后,可使用VisualDL观察损失函数和精度指标的变化情况。在PaddleRS中使用VisualDL的方式请参考[使用教程](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tutorials/train/README.md#visualdl%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E6%8C%87%E6%A0%87)。
+
+### 4.2 基于配置文件的方式
+
+#### 4.2.1 配置文件编写
+
+本案例提供一个基于[YAML][https://yaml.org/]的轻量级配置系统,使用者可以通过修改yaml文件达到调整超参数、更换模型、更换数据集等目的,或通过编写yaml文件增加新的配置。
+
+关于本案例中配置文件的编写规则,请参考[此项目](https://aistudio.baidu.com/aistudio/projectdetail/4203534)。
+
+#### 4.2.2 自定义训练器
+
+在使用基于配置文件方式进行模型训练时,需要在`custom_trainer.py`中定义训练器。例如,本案例在`custom_trainer.py`中定义了与`CustomModel`模型对应的训练器:
+
+```python
+@attach
+class CustomTrainer(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 in_channels=3,
+                 att_types='ct',
+                 use_dropout=False,
+                 **params):
+        params.update({
+            'in_channels': in_channels,
+            'att_types': att_types,
+            'use_dropout': use_dropout
+        })
+        super().__init__(
+            model_name='CustomModel',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)
+```
+
+在编写训练器定义相关代码时请注意以下两点:
+
+1. 对于变化检测任务,训练器必须为`paddlers.tasks.cd.BaseChangeDetector`的子类;
+2. 与模型一样,训练器也须用`@attach`装饰;
+3. 训练器和模型可以同名。
+
+在本案例中,仅仅重写了训练器的`__init__()`方法。在实际科研过程中,可以通过重写`train()`、`evaluate()`、`default_loss()`等方法定制更加复杂的训练、评估策略或更换默认损失函数。
+
+关于训练器的更多细节请参考[《API文档》](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/train.md)。
+
+配置文件中的`model`项可以指定训练器名称与构造参数。例如:
+
+```yaml
+model: !Node
+    type: CustomTrainer
+    args:
+        att_types: c
+```
+
+上述配置指定构造这样的一个训练器对象:`CustomTrainer(att_types=c)`。
+
+#### 4.2.3 训练指令
+
+按照以下格式执行对某个模型的训练:
+
+```bash
+python run_task.py train cd \
+    --config "configs/levircd/{配置文件名称}" \
+    2>&1 | tee "{日志路径}"
+```
+
+训练完成后,使用如下指令对验证集上最优的模型在测试集上计算指标:
+
+```bash
+python run_task.py eval cd \
+    --config "configs/levircd/{配置文件名称}" \
+    --datasets.eval.args.file_list "data/levircd/test.txt" \
+    --resume_checkpoint "exp/levircd/{模型名称}/best_model"
+```
+
+## 5 对比实验
+
+为了验证模型设计的有效性,通常需要开展对比实验,在一个或多个数据集上比较所提出模型与其它模型的精度和性能。在本案例中,将自定义模型CustomModel与FC-EF、FC-Siam-diff、FC-Siam-conc三种结构进行比较,这三个模型均来自论文[4]。
+
+### 5.1 实验过程
+
+**当使用基于配置文件的方式进行模型训练和验证时**,可以通过如下指令在LEVIR-CD数据集上执行对所有参与对比的模型的训练:
+
+```bash
+bash scripts/run_benchmark.sh
+```
+
+**当使用`train_cd.py`脚本进行模型训练和验证时**,需要为每个实验手动更改模型的类型和构造参数。此外,可通过修改`EXP_DIR`变量为不同值,将每个模型对应的结果保存到不同的目录中,方便比较。本小节中的指令示例均假设实验过程中将`EXP_DIR`设置为`exp/levircd/{模型名称}`。
+
+在训练和精度指标验证完成后,可以通过如下指令保存模型输出的二值变化图:
+
+```bash
+python predict_cd.py \
+    --model_dir "exp/levircd/{模型名称}/best_model" \
+    --data_dir "data/levircd" \
+    --file_list "data/levircd/test.txt" \
+    --save_dir "exp/predict/levircd/{模型名称}"
+```
+
+之后,可在`exp/predict/levircd/{模型名称}`目录查看保存的输出结果。
+
+可以通过`tools/collect_imgs.py`脚本将输入图像、变化标签以及多个模型的预测结果放置在一个目录下以便于观察比较。该脚本接受三个命令行选项:
+- `--globs`指定一系列通配符(可用于Python的[`glob.glob()`函数](https://docs.python.org/zh-cn/3/library/glob.html#glob.glob)),用于匹配需要收集的图像;
+- `--tags`为`--globs`中的每一项指定一个别名,在存储目录中,相应的图像名将被替换为指定的别名;
+- `--save_dir`指定输出目录路径,若目录不存在将被自动创建。
+
+例如,对于LEVIR-CD数据集,执行如下指令:
+
+```bash
+python tools/collect_imgs.py \
+    --globs "data/levircd/LEVIR-CD/test/A/*/*.png" "data/levircd/LEVIR-CD/test/B/*/*.png" "data/levircd/LEVIR-CD/test/label/*/*.png" \
+        "exp/predict/levircd/fc_ef/*.png" "exp/predict/levircd/fc_siam_conc/*.png" "exp/predict/levircd/fc_siam_diff/*.png" \
+        "exp/predict/levircd/custom_model/*.png" \
+    --tags 'A' 'B' 'GT' \
+        'fc_ef' 'fc_siam_conc' 'fc_siam_diff' \
+        'custom_model' \
+    --save_dir "exp/collect/levircd"
+```
+
+执行完毕后,可在`exp/collect/levircd`目录中找到两个时相的输入影像、变化标签以及各个模型的预测结果。当新增模型后,可以再次调用`tools/collect_imgs.py`脚本补充结果到`exp/collect/levircd`目录中:
+
+```bash
+python tools/collect_imgs.py --globs "exp/predict/levircd/{新增模型名称}/*.png" --tags '{新增模型名称}' --save_dir "exp/collect/levircd"
+```
+
+此外,为了从精度和性能两个方面综合评估变化检测算法,可以通过如下指令计算变化检测模型的[浮点计算数(floating point operations, FLOPs)](https://blog.csdn.net/IT_flying625/article/details/104898152)和模型参数量:
+
+```bash
+python tools/analyze_model.py --model_dir "exp/levircd/{模型名称}/best_model"
+```
+
+### 5.2 实验结果
+
+本案例使用变化类的[交并比(intersection over union, IoU)](https://paddlepedia.readthedocs.io/en/latest/tutorials/computer_vision/semantic_segmentation/Overview/Overview.html#id6)和[F1分数](https://baike.baidu.com/item/F1%E5%88%86%E6%95%B0/13864979)作为定量评价指标,这两个指标越高,表示算法的检测效果越好。在每个数据集上,从目视效果和定量指标两个方面对算法效果进行评判。
+
+#### 5.2.1 目视效果对比
+
+下图展示了两个时相的输入影像、各算法输出的二值变化图(binary change map)以及变化标签。所选取的样本均来自LEVIR-CD数据集的测试集。
+
+|时相1影像|时相2影像|FC-EF|FC-Siam-diff|FC-Siam-conc|CustomModel|变化标签|
+|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
+|<img src="https://user-images.githubusercontent.com/21275753/186671764-2dc990a8-b297-43a2-ae81-e31f2d5582e5.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672204-e8e46e9a-7f29-4506-9ed4-31314284a6fb.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672237-ee5f67d8-8966-457d-8a80-0452bdb7af89.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671987-7da0023a-0c96-413f-9088-0f6730ab54dd.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671895-c6c40196-b86a-49d1-a4b0-48a7f40cba06.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672068-89a60f8c-c80e-4f73-bb3e-b9ad146e795d.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672106-37e8dcd0-b0f0-46e1-90a1-bd5f566ef97b.png" width="100">|
+|<img src="https://user-images.githubusercontent.com/21275753/186672287-efa1209d-2786-4543-b136-5f50b7b0dd8c.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671791-beb82760-8c3f-480f-8ada-9c1081860691.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671861-7b7989e4-15d8-4342-9abe-2d6efa82811a.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672362-94993c68-7c31-4501-b009-755c193a00a8.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672348-3134129c-e2cd-4011-8894-901ef332a43d.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672415-da3984b2-0354-49ad-8dba-9c796a18d282.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672449-fd225e4f-ac58-4506-8b66-3a255567998a.png" width="100">|
+
+从图中可以看出,虽然结果中仍存在一定程度的漏检与误检,但相比其它算法,CustomModel对变化区域的刻画相对更为准确。
+
+#### 5.2.2 定量指标对比
+
+|模型名称|FLOPs(G)|参数量(M)|IoU%|F1%|
+|:-:|:-:|:-:|:-:|:-:|
+|FC-EF|3.57|1.35|79.05|88.30|
+|FC-Siam-diff|4.71|1.35|81.33|89.70|
+|FC-Siam-conc|5.31|1.55|81.31|89.69|
+|CustomModel|5.31|1.58|**82.14**|**90.19**|
+
+最高的精度指标用粗体表示。从表中可以看出,CustomModel取得了所有算法中最高的IoU和F1分数指标(与FC-EF对比IoU增加3.09%,F1增加1.89%),而其相比baseline模型FC-Siam-conc仅仅引入0.03 M的额外参数量。
+
+## 6 消融实验
+
+在科研过程中,为了验证在baseline上所做修改的有效性,常常需要开展消融实验。在本案例中,CustomModel在FC-Siam-conc模型的基础上添加了通道和时间两种注意力模块,因此需要通过消融实验探讨各个注意力模块对最终精度的贡献。具体而言,包括以下4种实验情形(消融模型相关的配置文件存储在`configs/levircd/ablation`目录):
+
+1. 基础情况:不使用任何注意力模块,即baseline模型FC-Siam-conc;
+2. 仅添加通道注意力模块,对应的配置文件名称为`custom_model_c.yaml`;
+3. 仅添加时间注意力模块,对应的配置文件名称为`custom_model_t.yaml`;
+4. 标准情况:同时添加通道和时间注意力模块的完整模型。
+
+其中第1和第4个模型,即baseline和完整模型,在[第4节](#4-模型训练)和[第5节](#5-对比实验)中已经得到了训练、验证和测试。因此,本节只需要关注情形2、3。
+
+### 6.1 实验过程
+
+**当使用基于配置文件的方式进行模型训练时**,可通过如下指令训练全部消融模型:
+
+```bash
+bash scripts/run_ablation.sh
+```
+
+或者,可以按照以下格式执行对某一个模型的训练:
+
+```bash
+python run_task.py train cd \
+    --config "configs/levircd/ablation/{配置文件名称}" \
+    2>&1 | tee {日志路径}
+```
+
+训练完成后,使用如下指令对验证集上最优的模型在测试集上计算指标:
+
+```bash
+python run_task.py eval cd \
+    --config "configs/levircd/ablation/{配置文件名称}" \
+    --datasets.eval.args.file_list "data/levircd/test.txt" \
+    --resume_checkpoint "exp/levircd/ablation/{消融模型名称}/best_model"
+```
+
+注意,形如`custom_model_c.yaml`的配置文件默认对应的消融模型名称为`att_c`。
+
+**当使用`train_cd.py`进行模型训练时**,需要修改模型构造时的`att_types`参数,以得到不同消融模型的结果。例如,对于仅添加通道注意力模块的消融模型,应设置`att_types='c'`。此外,可通过修改`EXP_DIR`变量为不同值,将每个实验的结果保存到不同的目录中,方便比较。
+
+### 6.2 实验结果
+
+实验得到的定量指标如下表所示:
+
+|通道注意力模块|时间注意力模块|IoU%|F1%|
+|:-:|:-:|:-:|:-:|
+|||81.31|89.69|
+|✓||81.97|90.09|
+||✓|81.59|89.86|
+|✓|✓|**82.14**|**90.19**|
+
+从表中数据可知,无论是通道注意力模块还是时间注意力模块都能对算法的IoU和F1分数指标带来正面贡献,而同时添加两种注意力模块带来的增益是最大的(相比baseline模型IoU增加0.83%,F1分数增加0.50%)。
+
+## 7 特征可视化实验
+
+本节主要对模型的中间特征进行可视化,以进一步验证对baseline模型所做的修改是否实现了增强特征的效果。
+
+### 7.1 实验过程
+
+通过`tools/visualize_feats.py`脚本实现对模型中间特征的可视化。该脚本接受如下命令行选项:
+- `--model_dir`指定需要加载的模型的存储路径。
+- `--im_path`指定输入影像的路径,对于变化检测任务,需要依次指定两幅输入影像的路径。
+- `--save_dir`指定输出目录路径,若目录不存在将被自动创建。
+- `--hook_type`指定抓取的特征类型,有三种取值:当为`forward_in`时,表示抓取指定模块的前向输入特征;当为`forward_out`时,表示抓取指定模块的前向输出特征;当为`backward`时,表示抓取指定参数的梯度。
+- `--layer_names`指定一系列接受或产生需要抓取特征的模块的名称(父模块与子模块间使用`.`分隔)或是模型中权重参数的名称(即[state_dict](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/load_cn.html)中的key)。
+- `--to_pseudo_color`指定是否将特征图存储为伪彩色图。
+- `--output_size`指定将特征图缩放到的尺寸。
+
+`tools/visualize_feats.py`生成的文件遵照`{layer_name}_{j}_vis.png`或`{layer_name}_{i}_{j}_vis.png`格式命名。其中,`{layer_name}`对应`--layer_names`选项中指定的值;`{i}`的数值表示一次抓取到多个输入、输出特征时当前特征所对应的编号;`{j}`的数值在`--hook_type`指定为`forward_in`或`forward_out`时分别表示当前特征图是第几次调用该模块时输入或输出的(模型中的一些模块可能被重复调用,如FC-Siam-conc模型中的`conv4`)。例如,如下指令获取并存储CustomModel模型中`att4`模块的输入与输出特征的可视化结果:
+
+```bash
+IM1_PATH="data/levircd/LEVIR-CD/test/A/test_13/test_13_3.png"
+IM2_PATH="data/levircd/LEVIR-CD/test/B/test_13/test_13_3.png"
+
+python tools/visualize_feats.py \
+    --model_dir "exp/levircd/custom_model/best_model" \
+    --im_path "${IM1_PATH}" "${IM2_PATH}" \
+    --save_dir "exp/vis/test_13_3/in" \
+    --hook_type 'forward_in' \
+    --layer_names 'att4' \
+    --to_pseudo_color \
+    --output_size 256 256
+
+python tools/visualize_feats.py \
+    --model_dir "exp/levircd/custom_model/best_model" \
+    --im_path "${IM1_PATH}" "${IM2_PATH}" \
+    --save_dir "exp/vis/test_13_3/out" \
+    --hook_type 'forward_out' \
+    --layer_names 'att4' \
+    --to_pseudo_color \
+    --output_size 256 256
+```
+
+执行上述指令将在`exp/vis/test_13_3/{模型名称}`目录中产生2个子目录,每个子目录中有2个文件,其中`in/att4_0_0_vis.png`和`in/att4_1_0_vis.png`分别表示输入`att4`模块的两个时相特征的可视化结果,`out/att4_0_0_vis.png`和`out/att4_1_0_vis.png`分别表示`att4`模块输出的两个时相特征的可视化结果。
+
+### 7.2 实验结果
+
+下图从左往右分别为两个时相的输入影像、变化标签、输入混合注意力模块`att4`的两个时相特征图的可视化结果(分别用x1和x2代指)以及`att4`输出的两个时相特征图的可视化结果(分别用y1和y2代指):
+
+|时相1影像|时相2影像|变化标签|x1|x2|y1|y2|
+|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
+|<img src="https://user-images.githubusercontent.com/21275753/186672741-45c819f0-2591-4b97-ad32-05d787be1a0a.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672761-eb6958be-688d-4bc2-839b-6a60cb6cc3b5.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672791-ceb78cf7-5029-4991-88c2-6c4550fb27d8.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672835-7fda3499-33e0-4af1-b990-8d82f6c5c410.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672870-dba57441-509f-4cd0-bcc9-af343ddf07df.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672893-7bc692a7-c963-4686-b93c-895b5c51fecb.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672914-b99ffee3-9eb4-4f95-96f4-93cb00e0b109.png" width="100">|
+
+对比x2和y2可以看出,经过通道和时间注意力模块处理后,变化特征得到了增强,发生变化的区域在特征图中更加凸显。
+
+## 8 总结与展望
+
+### 8.1 总结
+
+- 本案例以为经典的FC-Siam-conc模型添加注意力模块为例,演示了使用PaddleRS开展科研工作的典型流程。
+- 本案例中对模型的改进带来了一定的目视效果的改善和检测精度的提升。
+- 本案例通过消融实验和特征可视化实验证实了所提出改进的有效性。
+
+### 8.2 展望
+
+- 本案例对所有参与比较的算法使用了相同的训练超参数,但由于模型之间存在差异,使用统一的超参训练往往难以保证所有模型都能取得较好的效果。在后续工作中,可以对每个对比算法进行调参,使其获得最优精度。
+- 本案例作为使用PaddleRS开展科研工作的简单例子,并未在算法设计上做出较大改进,因此所提出算法相比baseline的精度提升也较为有限。未来可以考虑更复杂的算法设计,以及使用更加先进的模型结构。
+
+## 参考文献
+
+> [1] Chen, Hao, and Zhenwei Shi. "A spatial-temporal attention-based method and a new dataset for remote sensing image change detection." *Remote Sensing* 12.10 (2020): 1662.  
+[2] Chen, Hao, Zipeng Qi, and Zhenwei Shi. "Remote sensing image change detection with transformers." *IEEE Transactions on Geoscience and Remote Sensing* 60 (2021): 1-14.  
+[3] Lebedev, M. A., et al. "CHANGE DETECTION IN REMOTE SENSING IMAGES USING CONDITIONAL ADVERSARIAL NETWORKS." *International Archives of the Photogrammetry, Remote Sensing & Spatial Information Sciences* 42.2 (2018).  
+[4] Daudt, Rodrigo Caye, Bertr Le Saux, and Alexandre Boulch. "Fully convolutional siamese networks for change detection." *2018 25th IEEE International Conference on Image Processing (ICIP)*. IEEE, 2018.  
+[5] Woo, Sanghyun, et al. "Cbam: Convolutional block attention module." *Proceedings of the European conference on computer vision (ECCV)*. 2018.

+ 35 - 0
examples/rs_research/attach_tools.py

@@ -0,0 +1,35 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class Attach(object):
+    def __init__(self, dst):
+        self.dst = dst
+
+    def __call__(self, obj, name=None):
+        if name is None:
+            # Automatically get names of functions and classes
+            name = obj.__name__
+        if hasattr(self.dst, name):
+            raise RuntimeError(
+                f"{self.dst} already has the attribute {name}, which is {getattr(self.dst, name)}."
+            )
+        setattr(self.dst, name, obj)
+        if hasattr(self.dst, '__all__'):
+            self.dst.__all__.append(name)
+        return obj
+
+    @staticmethod
+    def to(dst):
+        return Attach(dst)

+ 267 - 0
examples/rs_research/config_utils.py

@@ -0,0 +1,267 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os.path as osp
+from collections.abc import Mapping
+
+import yaml
+
+
+def _chain_maps(*maps):
+    chained = dict()
+    keys = set().union(*maps)
+    for key in keys:
+        vals = [m[key] for m in maps if key in m]
+        if isinstance(vals[0], Mapping):
+            chained[key] = _chain_maps(*vals)
+        else:
+            chained[key] = vals[0]
+    return chained
+
+
+def read_config(config_path):
+    with open(config_path, 'r', encoding='utf-8') as f:
+        cfg = yaml.safe_load(f)
+    return cfg or {}
+
+
+def parse_configs(cfg_path, inherit=True):
+    if inherit:
+        cfgs = []
+        cfgs.append(read_config(cfg_path))
+        while cfgs[-1].get('_base_'):
+            base_path = cfgs[-1].pop('_base_')
+            curr_dir = osp.dirname(cfg_path)
+            cfgs.append(
+                read_config(osp.normpath(osp.join(curr_dir, base_path))))
+        return _chain_maps(*cfgs)
+    else:
+        return read_config(cfg_path)
+
+
+def _cfg2args(cfg, parser, prefix=''):
+    node_keys = set()
+    for k, v in cfg.items():
+        opt = prefix + k
+        if isinstance(v, list):
+            if len(v) == 0:
+                parser.add_argument(
+                    '--' + opt, type=object, nargs='*', default=v)
+            else:
+                # Only apply to homogeneous lists
+                if isinstance(v[0], CfgNode):
+                    node_keys.add(opt)
+                parser.add_argument(
+                    '--' + opt, type=type(v[0]), nargs='*', default=v)
+        elif isinstance(v, dict):
+            # Recursively parse a dict
+            _, new_node_keys = _cfg2args(v, parser, opt + '.')
+            node_keys.update(new_node_keys)
+        elif isinstance(v, CfgNode):
+            node_keys.add(opt)
+            _, new_node_keys = _cfg2args(v.to_dict(), parser, opt + '.')
+            node_keys.update(new_node_keys)
+        elif isinstance(v, bool):
+            parser.add_argument('--' + opt, action='store_true', default=v)
+        else:
+            parser.add_argument('--' + opt, type=type(v), default=v)
+    return parser, node_keys
+
+
+def _args2cfg(cfg, args, node_keys):
+    args = vars(args)
+    for k, v in args.items():
+        pos = k.find('.')
+        if pos != -1:
+            # Iteratively parse a dict
+            dict_ = cfg
+            while pos != -1:
+                dict_.setdefault(k[:pos], {})
+                dict_ = dict_[k[:pos]]
+                k = k[pos + 1:]
+                pos = k.find('.')
+            dict_[k] = v
+        else:
+            cfg[k] = v
+
+    for k in node_keys:
+        pos = k.find('.')
+        if pos != -1:
+            # Iteratively parse a dict
+            dict_ = cfg
+            while pos != -1:
+                dict_.setdefault(k[:pos], {})
+                dict_ = dict_[k[:pos]]
+                k = k[pos + 1:]
+                pos = k.find('.')
+            v = dict_[k]
+            dict_[k] = [CfgNode(v_) for v_ in v] if isinstance(
+                v, list) else CfgNode(v)
+        else:
+            v = cfg[k]
+            cfg[k] = [CfgNode(v_) for v_ in v] if isinstance(
+                v, list) else CfgNode(v)
+
+    return cfg
+
+
+def parse_args(*args, **kwargs):
+    cfg_parser = argparse.ArgumentParser(add_help=False)
+    cfg_parser.add_argument('--config', type=str, default='')
+    cfg_parser.add_argument('--inherit_off', action='store_true')
+    cfg_args = cfg_parser.parse_known_args()[0]
+    cfg_path = cfg_args.config
+    inherit_on = not cfg_args.inherit_off
+
+    # Main parser
+    parser = argparse.ArgumentParser(
+        conflict_handler='resolve', parents=[cfg_parser])
+    # Global settings
+    parser.add_argument('cmd', choices=['train', 'eval'])
+    parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg'])
+
+    # Data
+    parser.add_argument('--datasets', type=dict, default={})
+    parser.add_argument('--transforms', type=dict, default={})
+    parser.add_argument('--download_on', action='store_true')
+    parser.add_argument('--download_url', type=str, default='')
+    parser.add_argument('--download_path', type=str, default='./')
+
+    # Optimizer
+    parser.add_argument('--optimizer', type=dict, default={})
+
+    # Training related
+    parser.add_argument('--num_epochs', type=int, default=100)
+    parser.add_argument('--train_batch_size', type=int, default=8)
+    parser.add_argument('--save_interval_epochs', type=int, default=1)
+    parser.add_argument('--log_interval_steps', type=int, default=1)
+    parser.add_argument('--save_dir', default='../exp/')
+    parser.add_argument('--learning_rate', type=float, default=0.01)
+    parser.add_argument('--early_stop', action='store_true')
+    parser.add_argument('--early_stop_patience', type=int, default=5)
+    parser.add_argument('--use_vdl', action='store_true')
+    parser.add_argument('--resume_checkpoint', type=str)
+    parser.add_argument('--train', type=dict, default={})
+
+    # Loss
+    parser.add_argument('--losses', type=dict, nargs='+', default={})
+
+    # Model
+    parser.add_argument('--model', type=dict, default={})
+
+    if osp.exists(cfg_path):
+        cfg = parse_configs(cfg_path, inherit_on)
+        parser, node_keys = _cfg2args(cfg, parser, '')
+        node_keys = sorted(node_keys, reverse=True)
+        args = parser.parse_args(*args, **kwargs)
+        return _args2cfg(dict(), args, node_keys)
+    elif cfg_path != '':
+        raise FileNotFoundError
+    else:
+        args = parser.parse_args()
+        return _args2cfg(dict(), args, set())
+
+
+class _CfgNodeMeta(yaml.YAMLObjectMetaclass):
+    def __call__(cls, obj):
+        if isinstance(obj, CfgNode):
+            return obj
+        return super(_CfgNodeMeta, cls).__call__(obj)
+
+
+class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
+    yaml_tag = u'!Node'
+    yaml_loader = yaml.SafeLoader
+    # By default use a lexical scope
+    ctx = globals()
+
+    def __init__(self, dict_):
+        super().__init__()
+        self.type = dict_['type']
+        self.args = dict_.get('args', [])
+        self.module = dict_.get('module', '')
+
+    @classmethod
+    def set_context(cls, ctx):
+        # TODO: Implement dynamic scope with inspect.stack()
+        old_ctx = cls.ctx
+        cls.ctx = ctx
+        return old_ctx
+
+    def build_object(self, mod=None):
+        if mod is None:
+            mod = self._get_module(self.module)
+        cls = getattr(mod, self.type)
+        if isinstance(self.args, list):
+            args = build_objects(self.args)
+            obj = cls(*args)
+        elif isinstance(self.args, dict):
+            args = build_objects(self.args)
+            obj = cls(**args)
+        else:
+            raise NotImplementedError
+        return obj
+
+    def _get_module(self, s):
+        mod = None
+        while s:
+            idx = s.find('.')
+            if idx == -1:
+                next_ = s
+                s = ''
+            else:
+                next_ = s[:idx]
+                s = s[idx + 1:]
+            if mod is None:
+                mod = self.ctx[next_]
+            else:
+                mod = getattr(mod, next_)
+        return mod
+
+    @staticmethod
+    def build_objects(cfg, mod=None):
+        if isinstance(cfg, list):
+            return [CfgNode.build_objects(c, mod=mod) for c in cfg]
+        elif isinstance(cfg, CfgNode):
+            return cfg.build_object(mod=mod)
+        elif isinstance(cfg, dict):
+            return {
+                k: CfgNode.build_objects(
+                    v, mod=mod)
+                for k, v in cfg.items()
+            }
+        else:
+            return cfg
+
+    def __repr__(self):
+        return f"(type={self.type}, args={self.args}, module={self.module or ' '})"
+
+    @classmethod
+    def from_yaml(cls, loader, node):
+        map_ = loader.construct_mapping(node)
+        return cls(map_)
+
+    def items(self):
+        yield from [('type', self.type), ('args', self.args), ('module',
+                                                               self.module)]
+
+    def to_dict(self):
+        return dict(self.items())
+
+
+def build_objects(cfg, mod=None):
+    return CfgNode.build_objects(cfg, mod=mod)

+ 8 - 0
examples/rs_research/configs/levircd/ablation/custom_model_c.yaml

@@ -0,0 +1,8 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/levircd/ablation/att_c/
+
+model: !Node
+    type: CustomTrainer
+    args:
+        att_types: c

+ 8 - 0
examples/rs_research/configs/levircd/ablation/custom_model_t.yaml

@@ -0,0 +1,8 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/levircd/ablation/att_t/
+
+model: !Node
+    type: CustomTrainer
+    args:
+        att_types: t

+ 6 - 0
examples/rs_research/configs/levircd/custom_model.yaml

@@ -0,0 +1,6 @@
+_base_: ./levircd.yaml
+
+save_dir: ./exp/levircd/custom_model/
+
+model: !Node
+    type: CustomTrainer

+ 6 - 0
examples/rs_research/configs/levircd/fc_ef.yaml

@@ -0,0 +1,6 @@
+_base_: ./levircd.yaml
+
+save_dir: ./exp/levircd/fc_ef/
+
+model: !Node
+    type: FCEarlyFusion

+ 6 - 0
examples/rs_research/configs/levircd/fc_siam_conc.yaml

@@ -0,0 +1,6 @@
+_base_: ./levircd.yaml
+
+save_dir: ./exp/levircd/fc_siam_conc/
+
+model: !Node
+    type: FCSiamConc

+ 6 - 0
examples/rs_research/configs/levircd/fc_siam_diff.yaml

@@ -0,0 +1,6 @@
+_base_: ./levircd.yaml
+
+save_dir: ./exp/levircd/fc_siam_diff/
+
+model: !Node
+    type: FCSiamDiff

+ 74 - 0
examples/rs_research/configs/levircd/levircd.yaml

@@ -0,0 +1,74 @@
+# Basic configurations of LEVIR-CD dataset
+
+datasets:
+    train: !Node
+        type: CDDataset
+        args: 
+            data_dir: ./data/levircd/
+            file_list: ./data/levircd/train.txt
+            label_list: null
+            num_workers: 2
+            shuffle: True
+            with_seg_labels: False
+            binarize_labels: True
+    eval: !Node
+        type: CDDataset
+        args:
+            data_dir: ./data/levircd/
+            file_list: ./data/levircd/val.txt
+            label_list: null
+            num_workers: 0
+            shuffle: False
+            with_seg_labels: False
+            binarize_labels: True
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: RandomFlipOrRotate
+          args:
+            probs: [0.35, 0.35]
+            probsf: [0.5, 0.5, 0, 0, 0]
+            probsr: [0.33, 0.34, 0.33]
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['eval']
+download_on: False
+
+num_epochs: 50
+train_batch_size: 8
+optimizer: !Node
+    type: Adam
+    args:
+        learning_rate: !Node
+            type: StepDecay
+            module: paddle.optimizer.lr
+            args:
+                learning_rate: 0.002
+                step_size: 35000
+                gamma: 0.2
+save_interval_epochs: 5
+log_interval_steps: 50
+save_dir: ./exp/levircd/
+learning_rate: 0.002
+early_stop: False
+early_stop_patience: 5
+use_vdl: True
+resume_checkpoint: ''

+ 241 - 0
examples/rs_research/custom_model.py

@@ -0,0 +1,241 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import paddlers
+from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
+from paddlers.rs_models.cd.layers import ChannelAttention
+
+from attach_tools import Attach
+
+attach = Attach.to(paddlers.rs_models.cd)
+
+
+@attach
+class CustomModel(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 num_classes,
+                 att_types='ct',
+                 use_dropout=False):
+        super(CustomModel, self).__init__()
+
+        C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
+
+        self.use_dropout = use_dropout
+
+        self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
+        self.do11 = self._make_dropout()
+        self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
+        self.do12 = self._make_dropout()
+        self.pool1 = MaxPool2x2()
+
+        self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
+        self.do21 = self._make_dropout()
+        self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
+        self.do22 = self._make_dropout()
+        self.pool2 = MaxPool2x2()
+
+        self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
+        self.do31 = self._make_dropout()
+        self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
+        self.do32 = self._make_dropout()
+        self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
+        self.do33 = self._make_dropout()
+        self.pool3 = MaxPool2x2()
+
+        self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
+        self.do41 = self._make_dropout()
+        self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
+        self.do42 = self._make_dropout()
+        self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
+        self.do43 = self._make_dropout()
+        self.pool4 = MaxPool2x2()
+
+        self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
+
+        self.conv43d = Conv3x3(C5 + C4, C4, norm=True, act=True)
+        self.do43d = self._make_dropout()
+        self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
+        self.do42d = self._make_dropout()
+        self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
+        self.do41d = self._make_dropout()
+
+        self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
+
+        self.conv33d = Conv3x3(C4 + C3, C3, norm=True, act=True)
+        self.do33d = self._make_dropout()
+        self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
+        self.do32d = self._make_dropout()
+        self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
+        self.do31d = self._make_dropout()
+
+        self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
+
+        self.conv22d = Conv3x3(C3 + C2, C2, norm=True, act=True)
+        self.do22d = self._make_dropout()
+        self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
+        self.do21d = self._make_dropout()
+
+        self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
+
+        self.conv12d = Conv3x3(C2 + C1, C1, norm=True, act=True)
+        self.do12d = self._make_dropout()
+        self.conv11d = Conv3x3(C1, num_classes)
+
+        self.init_weight()
+
+        self.att4 = MixedAttention(C4, att_types)
+
+    def forward(self, t1, t2):
+        # Encode t1
+        # Stage 1
+        x11 = self.do11(self.conv11(t1))
+        x12_1 = self.do12(self.conv12(x11))
+        x1p = self.pool1(x12_1)
+
+        # Stage 2
+        x21 = self.do21(self.conv21(x1p))
+        x22_1 = self.do22(self.conv22(x21))
+        x2p = self.pool2(x22_1)
+
+        # Stage 3
+        x31 = self.do31(self.conv31(x2p))
+        x32 = self.do32(self.conv32(x31))
+        x33_1 = self.do33(self.conv33(x32))
+        x3p = self.pool3(x33_1)
+
+        # Stage 4
+        x41 = self.do41(self.conv41(x3p))
+        x42 = self.do42(self.conv42(x41))
+        x43_1 = self.do43(self.conv43(x42))
+        x4p = self.pool4(x43_1)
+
+        # Encode t2
+        # Stage 1
+        x11 = self.do11(self.conv11(t2))
+        x12_2 = self.do12(self.conv12(x11))
+        x1p = self.pool1(x12_2)
+
+        # Stage 2
+        x21 = self.do21(self.conv21(x1p))
+        x22_2 = self.do22(self.conv22(x21))
+        x2p = self.pool2(x22_2)
+
+        # Stage 3
+        x31 = self.do31(self.conv31(x2p))
+        x32 = self.do32(self.conv32(x31))
+        x33_2 = self.do33(self.conv33(x32))
+        x3p = self.pool3(x33_2)
+
+        # Stage 4
+        x41 = self.do41(self.conv41(x3p))
+        x42 = self.do42(self.conv42(x41))
+        x43_2 = self.do43(self.conv43(x42))
+        x4p = self.pool4(x43_2)
+
+        # Decode
+        # Stage 4d
+        x4d = self.upconv4(x4p)
+        pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0,
+                x43_1.shape[2] - x4d.shape[2])
+        x4d = F.pad(x4d, pad=pad4, mode='replicate')
+        x43_1, x43_2 = self.att4(x43_1, x43_2)
+        x4d = paddle.concat([x4d, x43_1, x43_2], 1)
+        x43d = self.do43d(self.conv43d(x4d))
+        x42d = self.do42d(self.conv42d(x43d))
+        x41d = self.do41d(self.conv41d(x42d))
+
+        # Stage 3d
+        x3d = self.upconv3(x41d)
+        pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0,
+                x33_1.shape[2] - x3d.shape[2])
+        x3d = F.pad(x3d, pad=pad3, mode='replicate')
+        x3d = paddle.concat([x3d, x33_1, x33_2], 1)
+        x33d = self.do33d(self.conv33d(x3d))
+        x32d = self.do32d(self.conv32d(x33d))
+        x31d = self.do31d(self.conv31d(x32d))
+
+        # Stage 2d
+        x2d = self.upconv2(x31d)
+        pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0,
+                x22_1.shape[2] - x2d.shape[2])
+        x2d = F.pad(x2d, pad=pad2, mode='replicate')
+        x2d = paddle.concat([x2d, x22_1, x22_2], 1)
+        x22d = self.do22d(self.conv22d(x2d))
+        x21d = self.do21d(self.conv21d(x22d))
+
+        # Stage 1d
+        x1d = self.upconv1(x21d)
+        pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0,
+                x12_1.shape[2] - x1d.shape[2])
+        x1d = F.pad(x1d, pad=pad1, mode='replicate')
+        x1d = paddle.concat([x1d, x12_1, x12_2], 1)
+        x12d = self.do12d(self.conv12d(x1d))
+        x11d = self.conv11d(x12d)
+
+        return [x11d]
+
+    def init_weight(self):
+        pass
+
+    def _make_dropout(self):
+        if self.use_dropout:
+            return nn.Dropout2D(p=0.2)
+        else:
+            return Identity()
+
+
+class MixedAttention(nn.Layer):
+    def __init__(self, in_channels, att_types='ct'):
+        super(MixedAttention, self).__init__()
+
+        self.att_types = att_types
+
+        if self.has_att_c:
+            self.att_c = ChannelAttention(in_channels, ratio=1)
+        else:
+            self.att_c = Identity()
+
+        if self.has_att_t:
+            self.att_t = ChannelAttention(2, ratio=1)
+        else:
+            self.att_t = Identity()
+
+    def forward(self, x1, x2):
+        if self.has_att_c:
+            x1 = (1 + self.att_c(x1)) * x1
+            x2 = (1 + self.att_c(x2)) * x2
+
+        if self.has_att_t:
+            b, c = x1.shape[:2]
+            y = paddle.stack([x1, x2], axis=2)
+            y = paddle.flatten(y, stop_axis=1)
+            y = (1 + self.att_t(y)) * y
+            y = y.reshape((b, c, 2, *y.shape[2:]))
+            y1, y2 = y[:, :, 0], y[:, :, 1]
+        else:
+            y1, y2 = x1, x2
+
+        return y1, y2
+
+    @property
+    def has_att_c(self):
+        return 'c' in self.att_types
+
+    @property
+    def has_att_t(self):
+        return 't' in self.att_types

+ 79 - 0
examples/rs_research/custom_trainer.py

@@ -0,0 +1,79 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+
+import paddle
+import paddlers
+from paddlers.tasks.change_detector import BaseChangeDetector
+
+from attach_tools import Attach
+
+attach = Attach.to(paddlers.tasks.change_detector)
+
+
+def make_trainer(net_type, *args, **kwargs):
+    def _init_func(self,
+                   num_classes=2,
+                   use_mixed_loss=False,
+                   losses=None,
+                   **params):
+        sig = inspect.signature(net_type.__init__)
+        net_params = {
+            k: p.default
+            for k, p in sig.parameters.items() if not p.default is p.empty
+        }
+        net_params.pop('self', None)
+        net_params.pop('num_classes', None)
+        net_params.update(params)
+
+        super(trainer_type, self).__init__(
+            model_name=net_type.__name__,
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **net_params)
+
+    if not issubclass(net_type, paddle.nn.Layer):
+        raise TypeError("net must be a subclass of paddle.nn.Layer")
+
+    trainer_name = net_type.__name__
+
+    trainer_type = type(trainer_name, (BaseChangeDetector, ),
+                        {'__init__': _init_func})
+
+    return trainer_type(*args, **kwargs)
+
+
+@attach
+class CustomTrainer(BaseChangeDetector):
+    def __init__(self,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 in_channels=3,
+                 att_types='ct',
+                 use_dropout=False,
+                 **params):
+        params.update({
+            'in_channels': in_channels,
+            'att_types': att_types,
+            'use_dropout': use_dropout
+        })
+        super().__init__(
+            model_name='CustomModel',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)

+ 82 - 0
examples/rs_research/predict_cd.py

@@ -0,0 +1,82 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import os.path as osp
+
+import cv2
+import paddle
+import paddlers
+from tqdm import tqdm
+
+import custom_model
+import custom_trainer
+
+
+def read_file_list(file_list, sep=' '):
+    with open(file_list, 'r') as f:
+        for line in f:
+            line = line.strip()
+            parts = line.split(sep)
+            yield parts
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--model_dir", default=None, type=str, help="Path of saved model.")
+    parser.add_argument("--data_dir", type=str, help="Path of input dataset.")
+    parser.add_argument("--file_list", type=str, help="Path of file list.")
+    parser.add_argument(
+        "--save_dir",
+        default='./exp/predict',
+        type=str,
+        help="Path of directory to save prediction results.")
+    parser.add_argument(
+        "--ext",
+        default='.png',
+        type=str,
+        help="Extension name of the saved image file.")
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    model = paddlers.tasks.load_model(args.model_dir)
+
+    if not osp.exists(args.save_dir):
+        os.makedirs(args.save_dir)
+
+    with paddle.no_grad():
+        for parts in tqdm(read_file_list(args.file_list)):
+            im1_path = osp.join(args.data_dir, parts[0])
+            im2_path = osp.join(args.data_dir, parts[1])
+
+            pred = model.predict((im1_path, im2_path))
+            cm = pred['label_map']
+            # {0,1} -> {0,255}
+            cm[cm > 0] = 255
+            cm = cm.astype('uint8')
+
+            if len(parts) > 2:
+                name = osp.basename(parts[2])
+            else:
+                name = osp.basename(im1_path)
+            name = osp.splitext(name)[0] + args.ext
+            out_path = osp.join(args.save_dir, name)
+            cv2.imwrite(out_path, cm)

+ 129 - 0
examples/rs_research/run_task.py

@@ -0,0 +1,129 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+# Import cv2 and sklearn before paddlers to solve the
+# "ImportError: dlopen: cannot load any more object with static TLS" issue.
+import cv2
+import sklearn
+import paddle
+import paddlers
+from paddlers import transforms as T
+
+import custom_model
+import custom_trainer
+from config_utils import parse_args, build_objects, CfgNode
+
+
+def format_cfg(cfg, indent=0):
+    s = ''
+    if isinstance(cfg, dict):
+        for i, (k, v) in enumerate(sorted(cfg.items())):
+            s += ' ' * indent + str(k) + ': '
+            if isinstance(v, (dict, list, CfgNode)):
+                s += '\n' + format_cfg(v, indent=indent + 1)
+            else:
+                s += str(v)
+            if i != len(cfg) - 1:
+                s += '\n'
+    elif isinstance(cfg, list):
+        for i, v in enumerate(cfg):
+            s += ' ' * indent + '- '
+            if isinstance(v, (dict, list, CfgNode)):
+                s += '\n' + format_cfg(v, indent=indent + 1)
+            else:
+                s += str(v)
+            if i != len(cfg) - 1:
+                s += '\n'
+    elif isinstance(cfg, CfgNode):
+        s += ' ' * indent + f"type: {cfg.type}" + '\n'
+        s += ' ' * indent + f"module: {cfg.module}" + '\n'
+        s += ' ' * indent + 'args: \n' + format_cfg(cfg.args, indent + 1)
+    return s
+
+
+if __name__ == '__main__':
+    CfgNode.set_context(globals())
+
+    cfg = parse_args()
+    print(format_cfg(cfg))
+
+    # Automatically download data
+    if cfg['download_on']:
+        paddlers.utils.download_and_decompress(
+            cfg['download_url'], path=cfg['download_path'])
+
+    if not isinstance(cfg['datasets']['eval'].args, dict):
+        raise ValueError("args of eval dataset must be a dict!")
+    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
+        raise ValueError(
+            "Found key 'transforms' in args of eval dataset and the value is not None."
+        )
+    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
+    # Inplace modification
+    cfg['datasets']['eval'].args['transforms'] = eval_transforms
+    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
+
+    if cfg['cmd'] == 'train':
+        if not isinstance(cfg['datasets']['train'].args, dict):
+            raise ValueError("args of train dataset must be a dict!")
+        if cfg['datasets']['train'].args.get('transforms', None) is not None:
+            raise ValueError(
+                "Found key 'transforms' in args of train dataset and the value is not None."
+            )
+        train_transforms = T.Compose(
+            build_objects(
+                cfg['transforms']['train'], mod=T))
+        # Inplace modification
+        cfg['datasets']['train'].args['transforms'] = train_transforms
+        train_dataset = build_objects(
+            cfg['datasets']['train'], mod=paddlers.datasets)
+        model = build_objects(
+            cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
+        if cfg['optimizer']:
+            if len(cfg['optimizer'].args) == 0:
+                cfg['optimizer'].args = {}
+            if not isinstance(cfg['optimizer'].args, dict):
+                raise TypeError("args of optimizer must be a dict!")
+            if cfg['optimizer'].args.get('parameters', None) is not None:
+                raise ValueError(
+                    "Found key 'parameters' in args of optimizer and the value is not None."
+                )
+            cfg['optimizer'].args['parameters'] = model.net.parameters()
+            optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer)
+        else:
+            optimizer = None
+
+        model.train(
+            num_epochs=cfg['num_epochs'],
+            train_dataset=train_dataset,
+            train_batch_size=cfg['train_batch_size'],
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=cfg['save_interval_epochs'],
+            log_interval_steps=cfg['log_interval_steps'],
+            save_dir=cfg['save_dir'],
+            learning_rate=cfg['learning_rate'],
+            early_stop=cfg['early_stop'],
+            early_stop_patience=cfg['early_stop_patience'],
+            use_vdl=cfg['use_vdl'],
+            resume_checkpoint=cfg['resume_checkpoint'] or None,
+            **cfg['train'])
+    elif cfg['cmd'] == 'eval':
+        model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
+        res = model.evaluate(eval_dataset)
+        print(res)

+ 17 - 0
examples/rs_research/scripts/run_ablation.sh

@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -e 
+
+CONFIG_DIR='configs/levircd/ablation'
+LOG_DIR='exp/logs/ablation'
+
+mkdir -p "${LOG_DIR}"
+
+for config_file in $(ls "${CONFIG_DIR}"/*.yaml); do
+    filename="$(basename ${config_file})"
+    printf '=%.0s' {1..100} && echo
+    echo -e "\033[33m ${config_file} \033[0m"
+    printf '=%.0s' {1..100} && echo
+    python run_task.py train cd --config "${config_file}" 2>&1 | tee "${LOG_DIR}/${filename%.*}.log"
+    echo
+done

+ 22 - 0
examples/rs_research/scripts/run_benchmark.sh

@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -e 
+
+DATASET='levircd'
+
+config_dir="configs/${DATASET}"
+log_dir="exp/logs/${DATASET}"
+
+mkdir -p "${log_dir}"
+
+for config_file in $(ls "${config_dir}"/*.yaml); do
+    filename="$(basename ${config_file})"
+    if [ "${filename}" = "${DATASET}.yaml" ]; then
+        continue
+    fi
+    printf '=%.0s' {1..100} && echo
+    echo -e "\033[33m ${config_file} \033[0m"
+    printf '=%.0s' {1..100} && echo
+    python run_task.py train cd --config "${config_file}" 2>&1 | tee "${log_dir}/${filename%.*}.log"
+    echo
+done

+ 148 - 0
examples/rs_research/tools/analyze_model.py

@@ -0,0 +1,148 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py
+
+import argparse
+import os
+import os.path as osp
+import sys
+
+import paddle
+import numpy as np
+import paddlers
+from paddle.hapi.dynamic_flops import (count_parameters, register_hooks,
+                                       count_io_info)
+from paddle.hapi.static_flops import Table
+
+_dir = osp.dirname(osp.abspath(__file__))
+sys.path.append(osp.abspath(osp.join(_dir, '../')))
+import custom_model
+import custom_trainer
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--model_dir", default=None, type=str, help="Path of saved model.")
+    parser.add_argument(
+        "--input_shape",
+        nargs='+',
+        type=int,
+        default=[1, 3, 256, 256],
+        help="Shape of each input tensor.")
+    return parser.parse_args()
+
+
+def analyze(model, inputs, custom_ops=None, print_detail=False):
+    handler_collection = []
+    types_collection = set()
+    if custom_ops is None:
+        custom_ops = {}
+
+    def add_hooks(m):
+        if len(list(m.children())) > 0:
+            return
+        m.register_buffer('total_ops', paddle.zeros([1], dtype='int64'))
+        m.register_buffer('total_params', paddle.zeros([1], dtype='int64'))
+        m_type = type(m)
+
+        flops_fn = None
+        if m_type in custom_ops:
+            flops_fn = custom_ops[m_type]
+            if m_type not in types_collection:
+                print("Customized function has been applied to {}".format(
+                    m_type))
+        elif m_type in register_hooks:
+            flops_fn = register_hooks[m_type]
+            if m_type not in types_collection:
+                print("{}'s FLOPs metric has been counted".format(m_type))
+        else:
+            if m_type not in types_collection:
+                print(
+                    "Cannot find suitable counting function for {}. Treat it as zero FLOPs."
+                    .format(m_type))
+
+        if flops_fn is not None:
+            flops_handler = m.register_forward_post_hook(flops_fn)
+            handler_collection.append(flops_handler)
+        params_handler = m.register_forward_post_hook(count_parameters)
+        io_handler = m.register_forward_post_hook(count_io_info)
+        handler_collection.append(params_handler)
+        handler_collection.append(io_handler)
+        types_collection.add(m_type)
+
+    training = model.training
+
+    model.eval()
+    model.apply(add_hooks)
+
+    with paddle.framework.no_grad():
+        model(*inputs)
+
+    total_ops = 0
+    total_params = 0
+    for m in model.sublayers():
+        if len(list(m.children())) > 0:
+            continue
+        if set(['total_ops', 'total_params', 'input_shape',
+                'output_shape']).issubset(set(list(m._buffers.keys()))):
+            total_ops += m.total_ops
+            total_params += m.total_params
+
+    if training:
+        model.train()
+    for handler in handler_collection:
+        handler.remove()
+
+    table = Table(
+        ["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"])
+
+    for n, m in model.named_sublayers():
+        if len(list(m.children())) > 0:
+            continue
+        if set(['total_ops', 'total_params', 'input_shape',
+                'output_shape']).issubset(set(list(m._buffers.keys()))):
+            table.add_row([
+                m.full_name(), list(m.input_shape.numpy()),
+                list(m.output_shape.numpy()),
+                round(float(m.total_params / 1e6), 3),
+                round(float(m.total_ops / 1e9), 3)
+            ])
+            m._buffers.pop("total_ops")
+            m._buffers.pop("total_params")
+            m._buffers.pop('input_shape')
+            m._buffers.pop('output_shape')
+    if print_detail:
+        table.print_table()
+    print('Total FLOPs: {}G     Total Params: {}M'.format(
+        round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3)))
+    return int(total_ops)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # Enforce the use of CPU
+    paddle.set_device('cpu')
+
+    model = paddlers.tasks.load_model(args.model_dir)
+    net = model.net
+
+    # Construct bi-temporal inputs
+    inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)]
+
+    analyze(model.net, inputs)

+ 75 - 0
examples/rs_research/tools/collect_imgs.py

@@ -0,0 +1,75 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import os.path as osp
+import shutil
+from glob import glob
+
+from tqdm import tqdm
+
+
+def get_subdir_name(src_path):
+    basename = osp.basename(src_path)
+    subdir_name, _ = osp.splitext(basename)
+    return subdir_name
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--mode",
+        default='copy',
+        type=str,
+        choices=['copy', 'link'],
+        help="Copy or link images.")
+    parser.add_argument(
+        "--globs",
+        nargs='+',
+        type=str,
+        help="Glob patterns used to find the images to be copied.")
+    parser.add_argument(
+        "--tags", nargs='+', type=str, help="Tags of each source directory.")
+    parser.add_argument(
+        "--save_dir",
+        default='./',
+        type=str,
+        help="Path of directory to save collected results.")
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    if len(args.globs) != len(args.tags):
+        raise ValueError(
+            "The number of globs does not match the number of tags!")
+
+    for pat, tag in zip(args.globs, args.tags):
+        im_paths = glob(pat)
+        print(f"Glob: {pat}\tTag: {tag}")
+        for p in tqdm(im_paths):
+            subdir_name = get_subdir_name(p)
+            ext = osp.splitext(p)[1]
+            subdir_path = osp.join(args.save_dir, subdir_name)
+            subdir_path = osp.abspath(osp.normpath(subdir_path))
+            if not osp.exists(subdir_path):
+                os.makedirs(subdir_path)
+            if args.mode == 'copy':
+                shutil.copyfile(p, osp.join(subdir_path, tag + ext))
+            elif args.mode == 'link':
+                os.symlink(p, osp.join(subdir_path, tag + ext))

+ 228 - 0
examples/rs_research/tools/visualize_feats.py

@@ -0,0 +1,228 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import sys
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import numpy as np
+import cv2
+import paddle
+import paddlers
+from sklearn.decomposition import PCA
+
+_dir = osp.dirname(osp.abspath(__file__))
+sys.path.append(osp.abspath(osp.join(_dir, '../')))
+import custom_model
+import custom_trainer
+
+FILENAME_PATTERN = "{key}_{idx}_vis.png"
+
+
+class FeatureContainer:
+    def __init__(self):
+        self._dict = OrderedDict()
+
+    def __setitem__(self, key, val):
+        if key not in self._dict:
+            self._dict[key] = list()
+        self._dict[key].append(val)
+
+    def __getitem__(self, key):
+        return self._dict[key]
+
+    def __repr__(self):
+        return self._dict.__repr__()
+
+    def items(self):
+        return self._dict.items()
+
+    def keys(self):
+        return self._dict.keys()
+
+    def values(self):
+        return self._dict.values()
+
+
+class HookHelper:
+    def __init__(self,
+                 model,
+                 fetch_dict,
+                 out_dict,
+                 hook_type='forward_out',
+                 auto_key=True):
+        # XXX: A HookHelper object should only be used as a context manager and should not 
+        # persist in memory since it may keep references to some very large objects.
+        self.model = model
+        self.fetch_dict = fetch_dict
+        self.out_dict = out_dict
+        self._handles = []
+        self.hook_type = hook_type
+        self.auto_key = auto_key
+
+    def __enter__(self):
+        def _hook_proto(x, entry):
+            # `x` should be a tensor or a tuple;
+            # entry is expected to be a string or a non-nested tuple.
+            if isinstance(entry, tuple):
+                for key, f in zip(entry, x):
+                    self.out_dict[key] = f.detach().clone()
+            else:
+                if isinstance(x, tuple) and self.auto_key:
+                    for i, f in enumerate(x):
+                        key = self._gen_key(entry, i)
+                        self.out_dict[key] = f.detach().clone()
+                else:
+                    self.out_dict[entry] = x.detach().clone()
+
+        if self.hook_type == 'forward_in':
+            # NOTE: Register forward hooks for LAYERs
+            for name, layer in self.model.named_sublayers():
+                if name in self.fetch_dict:
+                    entry = self.fetch_dict[name]
+                    self._handles.append(
+                        layer.register_forward_pre_hook(
+                            lambda l, x, entry=entry:
+                                # x is a tuple
+                                _hook_proto(x[0] if len(x)==1 else x, entry)
+                        )
+                    )
+        elif self.hook_type == 'forward_out':
+            # NOTE: Register forward hooks for LAYERs.
+            for name, module in self.model.named_sublayers():
+                if name in self.fetch_dict:
+                    entry = self.fetch_dict[name]
+                    self._handles.append(
+                        module.register_forward_post_hook(
+                            lambda l, x, y, entry=entry:
+                                # y is a tensor or a tuple
+                                _hook_proto(y, entry)
+                        )
+                    )
+        elif self.hook_type == 'backward':
+            # NOTE: Register backward hooks for TENSORs.
+            for name, param in self.model.named_parameters():
+                if name in self.fetch_dict:
+                    entry = self.fetch_dict[name]
+                    self._handles.append(
+                        param.register_hook(
+                            lambda grad, entry=entry: _hook_proto(grad, entry)))
+        else:
+            raise RuntimeError("Hook type is not implemented.")
+
+    def __exit__(self, exc_type, exc_val, ext_tb):
+        for handle in self._handles:
+            handle.remove()
+
+    def _gen_key(self, key, i):
+        return key + f'_{i}'
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--model_dir", default=None, type=str, help="Path of saved model.")
+    parser.add_argument(
+        "--hook_type", default='forward_out', type=str, help="Type of hook.")
+    parser.add_argument(
+        "--layer_names",
+        nargs='+',
+        default=[],
+        type=str,
+        help="Layers that accepts or produces the features to visualize.")
+    parser.add_argument(
+        "--im_paths", nargs='+', type=str, help="Paths of input images.")
+    parser.add_argument(
+        "--save_dir",
+        type=str,
+        help="Path of directory to save prediction results.")
+    parser.add_argument(
+        "--to_pseudo_color",
+        action='store_true',
+        help="Whether to save pseudo-color images.")
+    parser.add_argument(
+        "--output_size",
+        nargs='+',
+        type=int,
+        default=None,
+        help="Resize the visualized image to `output_size`.")
+    return parser.parse_args()
+
+
+def normalize_minmax(x):
+    EPS = 1e-32
+    return (x - x.min()) / (x.max() - x.min() + EPS)
+
+
+def quantize_8bit(x):
+    # [0.0,1.0] float => [0,255] uint8
+    # or [0,1] int => [0,255] uint8
+    return (x * 255).astype('uint8')
+
+
+def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET):
+    return cv2.applyColorMap(gray, color_map)
+
+
+def process_fetched_feat(feat, to_pcolor=True):
+    # Convert tensor to array
+    feat = feat.squeeze(0).numpy()
+    # Get principal component
+    shape = feat.shape
+    x = feat.reshape(shape[0], -1).transpose((1, 0))
+    pca = PCA(n_components=1)
+    y = pca.fit_transform(x)
+    feat = y.reshape(shape[1:])
+    feat = normalize_minmax(feat)
+    feat = quantize_8bit(feat)
+    if to_pcolor:
+        feat = to_pseudo_color(feat)
+    return feat
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # Load model
+    model = paddlers.tasks.load_model(args.model_dir)
+
+    fetch_dict = dict(zip(args.layer_names, args.layer_names))
+    out_dict = FeatureContainer()
+
+    with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type):
+        if len(args.im_paths) == 1:
+            model.predict(args.im_paths[0])
+        else:
+            if len(args.im_paths) != 2:
+                raise ValueError
+            model.predict(tuple(args.im_paths))
+
+    if not osp.exists(args.save_dir):
+        os.makedirs(args.save_dir)
+
+    for key, feats in out_dict.items():
+        for idx, feat in enumerate(feats):
+            im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color)
+            if args.output_size is not None:
+                im_vis = cv2.resize(im_vis, tuple(args.output_size))
+            out_path = osp.join(
+                args.save_dir,
+                FILENAME_PATTERN.format(
+                    key=key.replace('.', '_'), idx=idx))
+            cv2.imwrite(out_path, im_vis)
+            print(f"Write feature map to {out_path}")

+ 111 - 0
examples/rs_research/train_cd.py

@@ -0,0 +1,111 @@
+#!/usr/bin/env bash
+
+import os.path as osp
+
+import paddle
+import paddlers as pdrs
+from paddlers import transforms as T
+
+from custom_model import CustomModel
+from custom_trainer import make_trainer
+
+# 数据集路径
+DATA_DIR = 'data/levircd/'
+# 保存实验结果的路径
+EXP_DIR = 'exp/levircd/custom_model/'
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 随机翻转和旋转
+    T.RandomFlipOrRotate(
+        # 以0.35的概率执行随机翻转,0.35的概率执行随机旋转
+        probs=[0.35, 0.35],
+        # 以0.5的概率执行随机水平翻转,0.5的概率执行随机垂直翻转
+        probsf=[0.5, 0.5, 0, 0, 0],
+        # 分别以0.33、0.34和0.33的概率执行90°、180°和270°旋转
+        probsr=[0.33, 0.34, 0.33]),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('eval')
+])
+
+# 分别构建训练、验证和测试所用的数据集
+train_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=osp.join(DATA_DIR, 'train.txt'),
+    label_list=None,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+val_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=osp.join(DATA_DIR, 'val.txt'),
+    label_list=None,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+test_dataset = pdrs.datasets.CDDataset(
+    data_dir=DATA_DIR,
+    file_list=osp.join(DATA_DIR, 'test.txt'),
+    label_list=None,
+    # 与验证阶段使用相同的数据变换算子
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    with_seg_labels=False,
+    binarize_labels=True)
+
+# 构建自定义模型CustomModel并为其自动生成训练器
+# make_trainer()的首个参数为模型类型,剩余参数为模型构造所需参数
+model = make_trainer(CustomModel, in_channels=3)
+
+# 构建学习率调度器
+# 使用定步长学习率衰减策略
+lr_scheduler = paddle.optimizer.lr.StepDecay(
+    learning_rate=0.002, step_size=35000, gamma=0.2)
+
+# 构建优化器
+optimizer = paddle.optimizer.Adam(
+    parameters=model.net.parameters(), learning_rate=lr_scheduler)
+
+# 执行模型训练
+model.train(
+    num_epochs=50,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    # 每多少个epoch验证并保存一次模型
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=50,
+    save_dir=EXP_DIR,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)
+
+# 加载验证集上效果最好的模型
+model = pdrs.tasks.load_model(osp.join(EXP_DIR, 'best_model'))
+# 在测试集上计算精度指标
+model.evaluate(test_dataset)

+ 5 - 5
paddlers/tasks/change_detector.py

@@ -52,9 +52,7 @@ class BaseChangeDetector(BaseModel):
         if 'with_net' in self.init_params:
         if 'with_net' in self.init_params:
             del self.init_params['with_net']
             del self.init_params['with_net']
         super(BaseChangeDetector, self).__init__('change_detector')
         super(BaseChangeDetector, self).__init__('change_detector')
-        if model_name not in __all__:
-            raise ValueError("ERROR: There is no model named {}.".format(
-                model_name))
+
         self.model_name = model_name
         self.model_name = model_name
         self.num_classes = num_classes
         self.num_classes = num_classes
         self.use_mixed_loss = use_mixed_loss
         self.use_mixed_loss = use_mixed_loss
@@ -1066,11 +1064,12 @@ class ChangeStar(BaseChangeDetector):
 
 
 class ChangeFormer(BaseChangeDetector):
 class ChangeFormer(BaseChangeDetector):
     def __init__(self,
     def __init__(self,
-                 in_channels=3,
                  num_classes=2,
                  num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 in_channels=3,
                  decoder_softmax=False,
                  decoder_softmax=False,
                  embed_dim=256,
                  embed_dim=256,
-                 use_mixed_loss=False,
                  **params):
                  **params):
         params.update({
         params.update({
             'in_channels': in_channels,
             'in_channels': in_channels,
@@ -1081,6 +1080,7 @@ class ChangeFormer(BaseChangeDetector):
             model_name='ChangeFormer',
             model_name='ChangeFormer',
             num_classes=num_classes,
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             use_mixed_loss=use_mixed_loss,
+            losses=losses,
             **params)
             **params)
 
 
 
 

+ 3 - 3
paddlers/transforms/operators.py

@@ -616,10 +616,10 @@ class RandomFlipOrRotate(Transform):
         probs (list[float]): Probabilities of performing flipping and rotation. 
         probs (list[float]): Probabilities of performing flipping and rotation. 
             Default: [0.35,0.25].
             Default: [0.35,0.25].
         probsf (list[float]): Probabilities of 5 flipping modes (horizontal, 
         probsf (list[float]): Probabilities of 5 flipping modes (horizontal, 
-            vertical, both horizontal diction and vertical, diagonal, 
-            anti-diagonal). Default: [0.3, 0.3, 0.2, 0.1, 0.1].
+            vertical, both horizontal and vertical, diagonal, anti-diagonal). 
+            Default: [0.3, 0.3, 0.2, 0.1, 0.1].
         probsr (list[float]): Probabilities of 3 rotation modes (90°, 180°, 270° 
         probsr (list[float]): Probabilities of 3 rotation modes (90°, 180°, 270° 
-            clockwise). Default: [0.25,0.5,0.25].
+            clockwise). Default: [0.25, 0.5, 0.25].
 
 
     Examples:
     Examples:
 
 

+ 4 - 1
test_tipc/common_func.sh

@@ -87,7 +87,10 @@ function download_and_unzip_dataset() {
     fi
     fi
 
 
     wget -nc -P "${ds_dir}" "${url}" --no-check-certificate
     wget -nc -P "${ds_dir}" "${url}" --no-check-certificate
-    cd "${ds_dir}" && unzip "${zip_name}" && cd - \
+    
+    # The extracted file/directory must have the same name as the zip file.
+    cd "${ds_dir}" && unzip "${zip_name}" \
+        && mv "${zip_name%.*}" ${ds_name} && cd - \
         && echo "Successfully downloaded ${zip_name} from ${url}. File saved in ${ds_path}. "
         && echo "Successfully downloaded ${zip_name} from ${url}. File saved in ${ds_path}. "
 }
 }
 
 

+ 3 - 2
test_tipc/config_utils.py

@@ -152,6 +152,7 @@ def parse_args(*args, **kwargs):
     if osp.exists(cfg_path):
     if osp.exists(cfg_path):
         cfg = parse_configs(cfg_path, inherit_on)
         cfg = parse_configs(cfg_path, inherit_on)
         parser, node_keys = _cfg2args(cfg, parser, '')
         parser, node_keys = _cfg2args(cfg, parser, '')
+        node_keys = sorted(node_keys, reverse=True)
         args = parser.parse_args(*args, **kwargs)
         args = parser.parse_args(*args, **kwargs)
         return _args2cfg(dict(), args, node_keys)
         return _args2cfg(dict(), args, node_keys)
     elif cfg_path != '':
     elif cfg_path != '':
@@ -178,7 +179,7 @@ class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
         super().__init__()
         super().__init__()
         self.type = dict_['type']
         self.type = dict_['type']
         self.args = dict_.get('args', [])
         self.args = dict_.get('args', [])
-        self.module = self._get_module(dict_.get('module', ''))
+        self.module = dict_.get('module', '')
 
 
     @classmethod
     @classmethod
     def set_context(cls, ctx):
     def set_context(cls, ctx):
@@ -189,7 +190,7 @@ class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
 
 
     def build_object(self, mod=None):
     def build_object(self, mod=None):
         if mod is None:
         if mod is None:
-            mod = self.module
+            mod = self._get_module(self.module)
         cls = getattr(mod, self.type)
         cls = getattr(mod, self.type)
         if isinstance(self.args, list):
         if isinstance(self.args, list):
             args = build_objects(self.args)
             args = build_objects(self.args)

+ 62 - 0
test_tipc/configs/cd/_base_/levircd.yaml

@@ -0,0 +1,62 @@
+# Basic configurations of LEVIR-CD dataset
+
+datasets:
+    train: !Node
+        type: CDDataset
+        args: 
+            data_dir: ./test_tipc/data/levircd/
+            file_list: ./test_tipc/data/levircd/train.txt
+            label_list: null
+            num_workers: 0
+            shuffle: True
+            with_seg_labels: False
+            binarize_labels: True
+    eval: !Node
+        type: CDDataset
+        args:
+            data_dir: ./test_tipc/data/levircd/
+            file_list: ./test_tipc/data/levircd/val.txt
+            label_list: null
+            num_workers: 0
+            shuffle: False
+            with_seg_labels: False
+            binarize_labels: True
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: RandomHorizontalFlip
+          args:
+            prob: 0.5
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['eval']
+download_on: False
+
+num_epochs: 10
+train_batch_size: 8
+save_interval_epochs: 5
+log_interval_steps: 50
+save_dir: ./test_tipc/output/cd/
+learning_rate: 0.002
+early_stop: False
+early_stop_patience: 5
+use_vdl: False
+resume_checkpoint: ''

+ 1 - 1
test_tipc/configs/cd/bit/bit.yaml

@@ -5,4 +5,4 @@ _base_: ../_base_/airchange.yaml
 save_dir: ./test_tipc/output/cd/bit/
 save_dir: ./test_tipc/output/cd/bit/
 
 
 model: !Node
 model: !Node
-       type: BIT
+    type: BIT

+ 8 - 0
test_tipc/configs/cd/bit/bit_airchange.yaml

@@ -0,0 +1,8 @@
+# Basic configurations of BIT with AirChange dataset
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/bit/
+
+model: !Node
+    type: BIT

+ 8 - 0
test_tipc/configs/cd/bit/bit_levircd.yaml

@@ -0,0 +1,8 @@
+# Basic configurations of BIT with LEVIR-CD dataset
+
+_base_: ../_base_/levircd.yaml
+
+save_dir: ./test_tipc/output/cd/bit/
+
+model: !Node
+    type: BIT

+ 3 - 3
test_tipc/configs/cd/bit/train_infer_python.txt

@@ -8,12 +8,12 @@ use_gpu:null|null
 --save_dir:adaptive
 --save_dir:adaptive
 --train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
 --train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
 --model_path:null
 --model_path:null
+--config:lite_train_lite_infer=./test_tipc/configs/cd/bit/bit_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/bit/bit_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/bit/bit_levircd.yaml
 train_model_name:best_model
 train_model_name:best_model
-train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
 null:null
 null:null
 ##
 ##
 trainer:norm
 trainer:norm
-norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/bit/bit.yaml
+norm_train:test_tipc/run_task.py train cd
 pact_train:null
 pact_train:null
 fpgm_train:null
 fpgm_train:null
 distill_train:null
 distill_train:null
@@ -46,7 +46,7 @@ inference:test_tipc/infer.py
 --use_trt:False
 --use_trt:False
 --precision:fp32
 --precision:fp32
 --model_dir:null
 --model_dir:null
---file_list:null:null
+--config:null
 --save_log_path:null
 --save_log_path:null
 --benchmark:True
 --benchmark:True
 --model_name:bit
 --model_name:bit

+ 1 - 1
test_tipc/configs/cd/changeformer/changeformer.yaml

@@ -5,4 +5,4 @@ _base_: ../_base_/airchange.yaml
 save_dir: ./test_tipc/output/cd/changeformer/
 save_dir: ./test_tipc/output/cd/changeformer/
 
 
 model: !Node
 model: !Node
-       type: ChangeFormer
+    type: ChangeFormer

+ 1 - 1
test_tipc/configs/clas/_base_/ucmerced.yaml

@@ -64,7 +64,7 @@ num_epochs: 2
 train_batch_size: 16
 train_batch_size: 16
 save_interval_epochs: 5
 save_interval_epochs: 5
 log_interval_steps: 50
 log_interval_steps: 50
-save_dir: e./test_tipc/output/clas/
+save_dir: ./test_tipc/output/clas/
 learning_rate: 0.01
 learning_rate: 0.01
 early_stop: False
 early_stop: False
 early_stop_patience: 5
 early_stop_patience: 5

+ 3 - 3
test_tipc/configs/clas/hrnet/hrnet.yaml

@@ -5,6 +5,6 @@ _base_: ../_base_/ucmerced.yaml
 save_dir: ./test_tipc/output/clas/hrnet/
 save_dir: ./test_tipc/output/clas/hrnet/
 
 
 model: !Node
 model: !Node
-       type: HRNet_W18_C
-       args:
-           num_classes: 21
+    type: HRNet_W18_C
+        args:
+            num_classes: 21

+ 3 - 3
test_tipc/configs/det/ppyolo/ppyolo.yaml

@@ -5,6 +5,6 @@ _base_: ../_base_/sarship.yaml
 save_dir: ./test_tipc/output/det/ppyolo/
 save_dir: ./test_tipc/output/det/ppyolo/
 
 
 model: !Node
 model: !Node
-       type: PPYOLO
-       args:
-           num_classes: 1
+    type: PPYOLO
+        args:
+            num_classes: 1

+ 1 - 1
test_tipc/configs/seg/unet/unet.yaml

@@ -8,4 +8,4 @@ model: !Node
        type: UNet
        type: UNet
        args:
        args:
            in_channels: 10
            in_channels: 10
-           num_classes: 5
+           num_classes: 5

+ 11 - 1
test_tipc/prepare.sh

@@ -27,7 +27,6 @@ DATA_DIR='./test_tipc/data/'
 mkdir -p "${DATA_DIR}"
 mkdir -p "${DATA_DIR}"
 if [[ ${MODE} == 'lite_train_lite_infer' \
 if [[ ${MODE} == 'lite_train_lite_infer' \
     || ${MODE} == 'lite_train_whole_infer' \
     || ${MODE} == 'lite_train_whole_infer' \
-    || ${MODE} == 'whole_train_whole_infer' \
     || ${MODE} == 'whole_infer' ]]; then
     || ${MODE} == 'whole_infer' ]]; then
 
 
     if [[ ${task_name} == 'cd' ]]; then
     if [[ ${task_name} == 'cd' ]]; then
@@ -40,4 +39,15 @@ if [[ ${MODE} == 'lite_train_lite_infer' \
         download_and_unzip_dataset "${DATA_DIR}" rsseg https://paddlers.bj.bcebos.com/datasets/rsseg_mini.zip
         download_and_unzip_dataset "${DATA_DIR}" rsseg https://paddlers.bj.bcebos.com/datasets/rsseg_mini.zip
     fi
     fi
 
 
+elif [[ ${MODE} == 'whole_train_whole_infer' ]]; then
+
+    if [[ ${task_name} == 'cd' ]]; then
+        download_and_unzip_dataset "${DATA_DIR}" raw_levircd https://paddlers.bj.bcebos.com/datasets/raw/LEVIR-CD.zip \
+        && python tools/prepare_dataset/prepare_levircd.py \
+            --in_dataset_dir "${DATA_DIR}/raw_levircd" \
+            --out_dataset_dir "${DATA_DIR}/levircd" \
+            --crop_size 256 \
+            --crop_stride 256
+    fi
+
 fi
 fi

+ 14 - 18
test_tipc/run_task.py

@@ -51,6 +51,17 @@ if __name__ == '__main__':
         paddlers.utils.download_and_decompress(
         paddlers.utils.download_and_decompress(
             cfg['download_url'], path=cfg['download_path'])
             cfg['download_url'], path=cfg['download_path'])
 
 
+    if not isinstance(cfg['datasets']['eval'].args, dict):
+        raise ValueError("args of eval dataset must be a dict!")
+    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
+        raise ValueError(
+            "Found key 'transforms' in args of eval dataset and the value is not None."
+        )
+    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
+    # Inplace modification
+    cfg['datasets']['eval'].args['transforms'] = eval_transforms
+    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
+
     if cfg['cmd'] == 'train':
     if cfg['cmd'] == 'train':
         if not isinstance(cfg['datasets']['train'].args, dict):
         if not isinstance(cfg['datasets']['train'].args, dict):
             raise ValueError("args of train dataset must be a dict!")
             raise ValueError("args of train dataset must be a dict!")
@@ -65,21 +76,8 @@ if __name__ == '__main__':
         cfg['datasets']['train'].args['transforms'] = train_transforms
         cfg['datasets']['train'].args['transforms'] = train_transforms
         train_dataset = build_objects(
         train_dataset = build_objects(
             cfg['datasets']['train'], mod=paddlers.datasets)
             cfg['datasets']['train'], mod=paddlers.datasets)
-    if not isinstance(cfg['datasets']['eval'].args, dict):
-        raise ValueError("args of eval dataset must be a dict!")
-    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
-        raise ValueError(
-            "Found key 'transforms' in args of eval dataset and the value is not None."
-        )
-    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
-    # Inplace modification
-    cfg['datasets']['eval'].args['transforms'] = eval_transforms
-    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
-
-    model = build_objects(
-        cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
-
-    if cfg['cmd'] == 'train':
+        model = build_objects(
+            cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
         if cfg['optimizer']:
         if cfg['optimizer']:
             if len(cfg['optimizer'].args) == 0:
             if len(cfg['optimizer'].args) == 0:
                 cfg['optimizer'].args = {}
                 cfg['optimizer'].args = {}
@@ -110,8 +108,6 @@ if __name__ == '__main__':
             resume_checkpoint=cfg['resume_checkpoint'] or None,
             resume_checkpoint=cfg['resume_checkpoint'] or None,
             **cfg['train'])
             **cfg['train'])
     elif cfg['cmd'] == 'eval':
     elif cfg['cmd'] == 'eval':
-        state_dict = paddle.load(
-            os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
-        model.net.set_state_dict(state_dict)
+        model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
         res = model.evaluate(eval_dataset)
         res = model.evaluate(eval_dataset)
         print(res)
         print(res)

+ 40 - 46
test_tipc/test_train_inference_python.sh

@@ -22,15 +22,15 @@ train_use_gpu_value=$(func_parser_value "${lines[4]}")
 autocast_list=$(func_parser_value "${lines[5]}")
 autocast_list=$(func_parser_value "${lines[5]}")
 autocast_key=$(func_parser_key "${lines[5]}")
 autocast_key=$(func_parser_key "${lines[5]}")
 epoch_key=$(func_parser_key "${lines[6]}")
 epoch_key=$(func_parser_key "${lines[6]}")
-epoch_num=$(func_parser_params "${lines[6]}")
+epoch_value=$(func_parser_params "${lines[6]}")
 save_model_key=$(func_parser_key "${lines[7]}")
 save_model_key=$(func_parser_key "${lines[7]}")
 train_batch_key=$(func_parser_key "${lines[8]}")
 train_batch_key=$(func_parser_key "${lines[8]}")
 train_batch_value=$(func_parser_params "${lines[8]}")
 train_batch_value=$(func_parser_params "${lines[8]}")
 pretrain_model_key=$(func_parser_key "${lines[9]}")
 pretrain_model_key=$(func_parser_key "${lines[9]}")
 pretrain_model_value=$(func_parser_value "${lines[9]}")
 pretrain_model_value=$(func_parser_value "${lines[9]}")
-train_model_name=$(func_parser_value "${lines[10]}")
-train_infer_img_dir=$(parse_first_value "${lines[11]}")
-train_infer_img_file_list=$(parse_second_value "${lines[11]}")
+train_config_key=$(func_parser_key "${lines[10]}")
+train_config_value=$(func_parser_params "${lines[10]}")
+train_model_name=$(func_parser_value "${lines[11]}")
 train_param_key1=$(func_parser_key "${lines[12]}")
 train_param_key1=$(func_parser_key "${lines[12]}")
 train_param_value1=$(func_parser_value "${lines[12]}")
 train_param_value1=$(func_parser_value "${lines[12]}")
 
 
@@ -85,9 +85,8 @@ use_trt_list=$(func_parser_value "${lines[45]}")
 precision_key=$(func_parser_key "${lines[46]}")
 precision_key=$(func_parser_key "${lines[46]}")
 precision_list=$(func_parser_value "${lines[46]}")
 precision_list=$(func_parser_value "${lines[46]}")
 infer_model_key=$(func_parser_key "${lines[47]}")
 infer_model_key=$(func_parser_key "${lines[47]}")
-file_list_key=$(func_parser_key "${lines[48]}")
-infer_img_dir=$(parse_first_value "${lines[48]}")
-infer_img_file_list=$(parse_second_value "${lines[48]}")
+infer_config_key=$(func_parser_key "${lines[48]}")
+infer_config_value=$(func_parser_value "${lines[48]}")
 save_log_key=$(func_parser_key "${lines[49]}")
 save_log_key=$(func_parser_key "${lines[49]}")
 benchmark_key=$(func_parser_key "${lines[50]}")
 benchmark_key=$(func_parser_key "${lines[50]}")
 benchmark_value=$(func_parser_value "${lines[50]}")
 benchmark_value=$(func_parser_value "${lines[50]}")
@@ -117,37 +116,37 @@ function func_inference() {
     local _script="$2"
     local _script="$2"
     local _model_dir="$3"
     local _model_dir="$3"
     local _log_path="$4"
     local _log_path="$4"
-    local _img_dir="$5"
-    local _file_list="$6"
+    local _config="$5"
+
+    local set_infer_config=$(func_set_params "${infer_config_key}" "${_config}")
+    local set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+    local set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+    local set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+    local set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
 
 
     # Do inference
     # Do inference
     for use_gpu in ${use_gpu_list[*]}; do
     for use_gpu in ${use_gpu_list[*]}; do
+        local set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
         if [ ${use_gpu} = 'False' ] || [ ${use_gpu} = 'cpu' ]; then
         if [ ${use_gpu} = 'False' ] || [ ${use_gpu} = 'cpu' ]; then
             for use_mkldnn in ${use_mkldnn_list[*]}; do
             for use_mkldnn in ${use_mkldnn_list[*]}; do
                 if [ ${use_mkldnn} = 'False' ]; then
                 if [ ${use_mkldnn} = 'False' ]; then
                     continue
                     continue
                 fi
                 fi
-                for threads in ${cpu_threads_list[*]}; do
-                    for batch_size in ${batch_size_list[*]}; do
-                        for precision in ${precision_list[*]}; do
-                            if [ ${use_mkldnn} = 'False' ] && [ ${precision} = 'fp16' ]; then
-                                continue
-                            fi # Skip when enable fp16 but disable mkldnn
-
-                            set_precision=$(func_set_params "${precision_key}" "${precision}")
+                for precision in ${precision_list[*]}; do
+                    if [ ${use_mkldnn} = 'False' ] && [ ${precision} = 'fp16' ]; then
+                        continue
+                    fi # Skip when enable fp16 but disable mkldnn
 
 
-                            _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
-                            infer_value1="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}_results"
-                            set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
-                            set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
-                            set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
-                            set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
-                            set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
-                            set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
-                            set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
-                            set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
+                    for threads in ${cpu_threads_list[*]}; do
+                        for batch_size in ${batch_size_list[*]}; do
+                            local _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+                            local infer_value1="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}_results"
+                            local set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
+                            local set_precision=$(func_set_params "${precision_key}" "${precision}")
+                            local set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+                            local set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
                             
                             
-                            cmd="${_python} ${_script} ${file_list_key} ${_img_dir} ${_file_list} ${set_device} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_precision} ${set_infer_params1} ${set_infer_params2}"
+                            local cmd="${_python} ${_script} ${set_config} ${set_device} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_precision} ${set_infer_params1} ${set_infer_params2}"
                             echo ${cmd}
                             echo ${cmd}
                             run_command "${cmd}" "${_save_log_path}"
                             run_command "${cmd}" "${_save_log_path}"
                             
                             
@@ -165,24 +164,18 @@ function func_inference() {
                     fi # Skip when enable fp16 but disable trt
                     fi # Skip when enable fp16 but disable trt
 
 
                     for batch_size in ${batch_size_list[*]}; do
                     for batch_size in ${batch_size_list[*]}; do
-                        _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
-                        infer_value1="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_results"
-                        set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
-                        set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
-                        set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
-                        set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
-                        set_precision=$(func_set_params "${precision_key}" "${precision}")
-                        set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
-                        set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
-                        set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
+                        local _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+                        local infer_value1="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_results"
+                        local set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
+                        local set_precision=$(func_set_params "${precision_key}" "${precision}")
+                        local set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
                         
                         
-                        cmd="${_python} ${_script} ${file_list_key} ${_img_dir} ${_file_list} ${set_device} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_infer_params2}"
+                        local cmd="${_python} ${_script} ${set_config} ${set_device} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_infer_params2}"
                         echo ${cmd}
                         echo ${cmd}
                         run_command "${cmd}" "${_save_log_path}"
                         run_command "${cmd}" "${_save_log_path}"
 
 
                         last_status=${PIPESTATUS[0]}
                         last_status=${PIPESTATUS[0]}
                         status_check $last_status "${cmd}" "${status_log}" "${model_name}"
                         status_check $last_status "${cmd}" "${status_log}" "${model_name}"
-
                     done
                     done
                 done
                 done
             done
             done
@@ -226,7 +219,7 @@ if [ ${MODE} = 'whole_infer' ]; then
             save_infer_dir=${infer_model}
             save_infer_dir=${infer_model}
         fi
         fi
         # Run inference
         # Run inference
-        func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${OUT_PATH}" "${infer_img_dir}" "${infer_img_file_list}"
+        func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${OUT_PATH}" "${infer_config_value}"
         count=$((${count} + 1))
         count=$((${count} + 1))
     done
     done
 else
 else
@@ -285,8 +278,9 @@ else
                 if [ ${run_train} = 'null' ]; then
                 if [ ${run_train} = 'null' ]; then
                     continue
                     continue
                 fi
                 fi
+                set_config=$(func_set_params "${train_config_key}" "${train_config_value}")
                 set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
                 set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
-                set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
+                set_epoch=$(func_set_params "${epoch_key}" "${epoch_value}")
                 set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
                 set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
                 set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
                 set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
                 set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}")
                 set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}")
@@ -312,11 +306,11 @@ else
 
 
                 set_save_model=$(func_set_params "${save_model_key}" "${save_dir}")
                 set_save_model=$(func_set_params "${save_model_key}" "${save_dir}")
                 if [ ${#gpu} -le 2 ]; then  # Train with cpu or single gpu
                 if [ ${#gpu} -le 2 ]; then  # Train with cpu or single gpu
-                    cmd="${python} ${run_train} ${set_use_gpu}  ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                    cmd="${python} ${run_train} ${set_config} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
                 elif [ ${#ips} -le 15 ]; then  # Train with multi-gpu
                 elif [ ${#ips} -le 15 ]; then  # Train with multi-gpu
-                    cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                    cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_config} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
                 else     # Train with multi-machine
                 else     # Train with multi-machine
-                    cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                    cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_config} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
                 fi
                 fi
 
 
                 echo ${cmd}
                 echo ${cmd}
@@ -359,7 +353,7 @@ else
                     else
                     else
                         infer_model_dir=${save_infer_path}
                         infer_model_dir=${save_infer_path}
                     fi
                     fi
-                    func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${OUT_PATH}" "${train_infer_img_dir}" "${train_infer_img_file_list}"
+                    func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${OUT_PATH}" "${train_config_value}"
 
 
                     eval "unset CUDA_VISIBLE_DEVICES"
                     eval "unset CUDA_VISIBLE_DEVICES"
                 fi
                 fi

+ 215 - 0
tools/prepare_dataset/common.py

@@ -0,0 +1,215 @@
+import argparse
+import os
+import os.path as osp
+from glob import glob
+from itertools import count
+from functools import partial
+from concurrent.futures import ThreadPoolExecutor
+
+from skimage.io import imread, imsave
+from tqdm import tqdm
+
+
+def get_default_parser():
+    """
+    Get argument parser with commonly used options.
+    
+    Returns:
+        argparse.ArgumentParser: Argument parser with the following arguments:
+            --in_dataset_dir: Input dataset directory.
+            --out_dataset_dir: Output dataset directory.
+    """
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--in_dataset_dir',
+        type=str,
+        required=True,
+        help="Input dataset directory.")
+    parser.add_argument(
+        '--out_dataset_dir', type=str, help="Output dataset directory.")
+    return parser
+
+
+def add_crop_options(parser):
+    """
+    Add patch cropping related arguments to an argument parser. The parser will be
+        modified in place.
+    
+    Args:
+        parser (argparse.ArgumentParser): Argument parser.
+    
+    Returns:
+        argparse.ArgumentParser: Argument parser with the following arguments:
+            --crop_size: Size of cropped patches.
+            --crop_stride: Stride of sliding windows when cropping patches.
+    """
+
+    parser.add_argument(
+        '--crop_size', type=int, help="Size of cropped patches.")
+    parser.add_argument(
+        '--crop_stride',
+        type=int,
+        help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
+    )
+    return parser
+
+
+def crop_and_save(path, out_subdir, crop_size, stride):
+    name, ext = osp.splitext(osp.basename(path))
+    out_subsubdir = osp.join(out_subdir, name)
+    if not osp.exists(out_subsubdir):
+        os.makedirs(out_subsubdir)
+    img = imread(path)
+    w, h = img.shape[:2]
+    counter = count()
+    for i in range(0, h - crop_size + 1, stride):
+        for j in range(0, w - crop_size + 1, stride):
+            imsave(
+                osp.join(out_subsubdir, '{}_{}{}'.format(name,
+                                                         next(counter), ext)),
+                img[i:i + crop_size, j:j + crop_size],
+                check_contrast=False)
+
+
+def crop_patches(crop_size,
+                 stride,
+                 data_dir,
+                 out_dir,
+                 subsets=('train', 'val', 'test'),
+                 subdirs=('A', 'B', 'label'),
+                 glob_pattern='*',
+                 max_workers=0):
+    """
+    Crop patches from images in specific directories.
+    
+    Args:
+        crop_size (int): Height and width of the cropped patches will be `crop_size`.
+        stride (int): Stride of sliding windows when cropping patches.
+        data_dir (str): Root directory of the dataset that contains the input images.
+        out_dir (str): Directory to save the cropped patches.
+        subsets (tuple|list|None, optional): List or tuple of names of subdirectories 
+            or None. Images to be cropped should be stored in `data_dir/subset/subdir/` 
+            or `data_dir/subdir/` (when `subsets` is set to None), where `subset` is an 
+            element of `subsets`. Defaults to ('train', 'val', 'test').
+        subdirs (tuple|list, optional): List or tuple of names of subdirectories. Images 
+            to be cropped should be stored in `data_dir/subset/subdir/` or 
+            `data_dir/subdir/` (when `subsets` is set to None), where `subdir` is an 
+            element of `subdirs`. Defaults to ('A', 'B', 'label').
+        glob_pattern (str, optional): Glob pattern used to match image files. 
+            Defaults to '*', which matches arbitrary file. 
+        max_workers (int, optional): Number of worker threads to perform the cropping 
+            operation. Deafults to 0.
+    """
+
+    if max_workers < 0:
+        raise ValueError("`max_workers` must be a non-negative integer!")
+
+    if subset is None:
+        subsets = ('', )
+
+    if max_workers == 0:
+        for subset in subsets:
+            for subdir in subdirs:
+                paths = glob(
+                    osp.join(data_dir, subset, subdir, glob_pattern),
+                    recursive=True)
+                out_subdir = osp.join(out_dir, subset, subdir)
+                for p in tqdm(paths):
+                    crop_and_save(
+                        p,
+                        out_subdir=out_subdir,
+                        crop_size=crop_size,
+                        stride=stride)
+    else:
+        # Concurrently crop image patches
+        with ThreadPoolExecutor(max_workers=max_workers) as executor:
+            for subset in subsets:
+                for subdir in subdirs:
+                    paths = glob(
+                        osp.join(data_dir, subset, subdir, glob_pattern),
+                        recursive=True)
+                    out_subdir = osp.join(out_dir, subset, subdir)
+                    for _ in tqdm(
+                            executor.map(partial(
+                                crop_and_save,
+                                out_subdir=out_subdir,
+                                crop_size=crop_size,
+                                stride=stride),
+                                         paths),
+                            total=len(paths)):
+                        pass
+
+
+def get_path_tuples(*dirs, glob_pattern='*', data_dir=None):
+    """
+    Get tuples of image paths. Each tuple corresponds to a sample in the dataset.
+    
+    Args:
+        *dirs (str): Directories that contains the images.
+        glob_pattern (str, optional): Glob pattern used to match image files. 
+            Defaults to '*', which matches arbitrary file. 
+        data_dir (str|None, optional): Root directory of the dataset that 
+            contains the images. If not None, `data_dir` will be used to 
+            determine relative paths of images. Defaults to None.
+    
+    Returns:
+        list[tuple]: For directories with the following structure:
+            ├── img  
+            │   ├── im1.png
+            │   ├── im2.png
+            │   └── im3.png
+            │
+            ├── mask
+            │   ├── im1.png
+            │   ├── im2.png
+            │   └── im3.png
+            └── ...
+
+        `get_path_tuples('img', 'mask', '*.png')` will return list of tuples:
+            [('img/im1.png', 'mask/im1.png'), ('img/im2.png', 'mask/im2.png'), ('img/im3.png', 'mask/im3.png')]
+    """
+
+    all_paths = []
+    for dir_ in dirs:
+        paths = glob(osp.join(dir_, glob_pattern), recursive=True)
+        paths = sorted(paths)
+        if data_dir is not None:
+            paths = [osp.relpath(p, data_dir) for p in paths]
+        all_paths.append(paths)
+    all_paths = list(zip(*all_paths))
+    return all_paths
+
+
+def create_file_list(file_list, path_tuples, sep=' '):
+    """
+    Create file list.
+    
+    Args:
+        file_list (str): Path of file list to create.
+        path_tuples (list[tuple]): See get_path_tuples().
+        sep (str, optional): Delimiter to use when writing lines to file list. 
+            Defaults to ' '.
+    """
+
+    with open(file_list, 'w') as f:
+        for tup in path_tuples:
+            line = sep.join(tup)
+            f.write(line + '\n')
+
+
+def link_dataset(src, dst):
+    """
+    Make a symbolic link to a dataset.
+    
+    Args:
+        src (str): Path of the original dataset.
+        dst (str): Path of the symbolic link.
+    """
+
+    if osp.exists(dst) and not osp.isdir(dst):
+        raise ValueError(f"{dst} exists and is not a directory.")
+    elif not osp.exists(dst):
+        os.makedirs(dst)
+    name = osp.basename(osp.normpath(src))
+    os.symlink(src, osp.join(dst, name), target_is_directory=True)

+ 42 - 0
tools/prepare_dataset/prepare_levircd.py

@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+
+import os.path as osp
+
+from common import (get_default_parser, add_crop_options, crop_patches,
+                    get_path_tuples, create_file_list, link_dataset)
+
+SUBSETS = ('train', 'val', 'test')
+SUBDIRS = ('A', 'B', 'label')
+FILE_LIST_PATTERN = "{subset}.txt"
+URL = ""
+
+if __name__ == '__main__':
+    parser = get_default_parser()
+    parser = add_crop_options(parser)
+    args = parser.parse_args()
+
+    out_dir = osp.join(args.out_dataset_dir,
+                       osp.basename(osp.normpath(args.in_dataset_dir)))
+
+    if args.crop_size is not None:
+        crop_patches(
+            args.crop_size,
+            args.crop_stride,
+            data_dir=args.in_dataset_dir,
+            out_dir=out_dir,
+            subsets=SUBSETS,
+            subdirs=SUBDIRS,
+            glob_pattern='*.png',
+            max_workers=0)
+    else:
+        link_dataset(args.in_dataset_dir, args.out_dataset_dir)
+
+    for subset in SUBSETS:
+        path_tuples = get_path_tuples(
+            *(osp.join(out_dir, subset, subdir) for subdir in SUBDIRS),
+            glob_pattern='**/*.png',
+            data_dir=args.out_dataset_dir)
+        file_list = osp.join(
+            args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
+        create_file_list(file_list, path_tuples)
+        print(f"Write file list to {file_list}.")

+ 31 - 0
tools/prepare_dataset/prepare_svcd.py

@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+
+import os.path as osp
+
+from common import (get_default_parser, get_path_tuples, create_file_list,
+                    link_dataset)
+
+SUBSETS = ('train', 'val', 'test')
+SUBDIRS = ('A', 'B', 'OUT')
+FILE_LIST_PATTERN = "{subset}.txt"
+URL = ""
+
+if __name__ == '__main__':
+    parser = get_default_parser()
+    args = parser.parse_args()
+
+    out_dir = osp.join(args.out_dataset_dir,
+                       osp.basename(osp.normpath(args.in_dataset_dir)))
+
+    link_dataset(args.in_dataset_dir, args.out_dataset_dir)
+
+    for subset in SUBSETS:
+        # NOTE: Only use cropped real samples.
+        path_tuples = get_path_tuples(
+            *(osp.join(out_dir, 'Real', 'subset', subset, subdir)
+              for subdir in SUBDIRS),
+            data_dir=args.out_dataset_dir)
+        file_list = osp.join(
+            args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
+        create_file_list(file_list, path_tuples)
+        print(f"Write file list to {file_list}.")