Browse Source

[Example] Training C2FNet on iSAID Dataset

* new add files

* reproduce C2FNet

* add C2FNet ReadMe

* revision

* revision

* PR revision

* move to examples

* revised the link

* revision

* the second revision
sherwinchen 2 years ago
parent
commit
389e586d61

+ 122 - 0
examples/c2fnet/README.md

@@ -0,0 +1,122 @@
+# 基于PaddleRS的遥感图像小目标语义分割优化方法
+本项目为C2FNet基于PaddleRS的官方实现代码。本方法实现了一个从粗到细的模型,对现有的任意语义分割方法进行优化,实现对小目标的准确分割。
+
+## 安装说明
+### 环境依赖
+```
+Python: 3.8  
+PaddlePaddle: 2.3.2
+PaddleRS: 1.0
+```
+
+### 安装过程
+a. 创建并激活一个conda虚拟环境。
+```bash
+conda create -n paddlers python=3.8
+conda activate paddlers
+```
+b. 安装PaddlePaddle [详见官方网址](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/install/pip/linux-pip_en.html) (PaddlePaddle版本需要 >= 2.3)。
+
+c. 克隆PaddleRS代码库。
+```bash
+git clone https://github.com/PaddlePaddle/PaddleRS
+```
+
+d. 安装PaddleRS环境依赖。
+```bash
+cd PaddleRS
+git checkout develop
+pip install -r requirements.txt
+```
+
+e. 安装PaddleRS包。
+```bash
+cd PaddleRS
+python setup.py install
+```
+
+f. 进入c2fnet目录。
+```bash
+cd examples/c2fnet
+```
+
+*注意:后续的操作默认在`c2fnet`目录*
+
+## 数据集
+
++ iSAID: https://captain-whu.github.io/iSAID
++ ISPRS Potsdam/Vaihingen 将在后面的版本提供支持。
+
+### iSAID数据集处理
+
+a. 从官方网站下载[iSAID](https://captain-whu.github.io/iSAID)数据集。
+
+b. 运行针对c2fnet的iSAID处理脚本。
+
+```python
+python data/prepare_isaid_c2fnet.py {下载的原始iSAID数据集存放路径}
+```
+
+c. 处理完的数据集目录结构如下所示:
+
+```
+{c2fnet}/data/iSAID
+├── img_dir
+│   ├── train
+│   │   ├── *.png
+│   │   └── *.png
+│   ├── val
+│   │   ├── *.png
+│   │   └── *.png
+│   └── test
+└── ann_dir
+│   ├── train
+│   │   ├── *.png
+│   │   └── *.png
+│   ├── val
+│   │   ├── *.png
+│   │   └── *.png
+│   └── test
+├── label.txt
+├── train.txt
+└── val.txt
+```
+
+其中`train.txt`、`val.txt`、`label.txt`可以参考[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/docs/data/marker/marker_cn.md)的方式生成。
+
+## 训练过程
+
+a. 通过[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)或者[PaddleRS](https://github.com/PaddlePaddle/PaddleRS/tree/release/1.0/tutorials/train)训练一个粗分割模型,或者下载我们训练好的基线模型[FCN_HRNetW18](https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/fcn_hrnet_isaid.pdparams),并放置在如下位置:
+
+```
+{c2fnet}/coarse_model/{YOUR COARSE_MODEL NAME}.pdparams
+```
+
+c. 单GPU训练精细化模型。
+```bash
+# 指定显卡编号
+export CUDA_VISIBLE_DEVICES=0
+python train.py
+```
+
+c. 多GPU训练精细化模型。
+```bash
+# 指定显卡编号
+export CUDA_VISIBLE_DEVICES={要使用的GPU编号}
+python -m paddle.distributed.launch train.py
+```
+
+d. 其他训练的细节可以参考 [PaddleRS的训练说明](/tutorials/train/README.md)。
+
+## 实验结果
+
+| 模型 | 主干网络 | 分辨率 | Ship | Large_Vehicle | Small_Vehicle | Helicopter | Swimming_Pool |Plane| Harbor | Links |
+|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
+|FCN       |HRNet_W18|512x512|69.04|62.61|48.75|23.14|44.99|83.35|58.61|[model](https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/fcn_hrnet_isaid.pdparams)|
+|FCN_C2FNet|HRNet_W18|512x512|69.31|63.03|50.90|23.53|45.93|83.82|59.62|[model](https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/c2fnet_fcn_hrnet_isaid.pdparams)|
+
+## 联系人
+
+wangqingzhong@baidu.com
+
+silin.chen@cumt.edu.cn

+ 309 - 0
examples/c2fnet/data/prepare_isaid_c2fnet.py

@@ -0,0 +1,309 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Adapted from https://github.com/open-mmlab/mmsegmentation/blob/master/tools/convert_datasets/isaid.py
+#
+# Original copyright info:
+# Copyright (c) OpenMMLab. All rights reserved.
+#
+# See original LICENSE at https://github.com/open-mmlab/mmsegmentation/blob/master/LICENSE
+
+import argparse
+import glob
+import os
+import os.path as osp
+import shutil
+import tempfile
+import zipfile
+
+import cv2
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+
+iSAID_palette = \
+    {
+        0: (0, 0, 0),
+        1: (0, 0, 63),
+        2: (0, 63, 63),
+        3: (0, 63, 0),
+        4: (0, 63, 127),
+        5: (0, 63, 191),
+        6: (0, 63, 255),
+        7: (0, 127, 63),
+        8: (0, 127, 127),
+        9: (0, 0, 127),
+        10: (0, 0, 191),
+        11: (0, 0, 255),
+        12: (0, 191, 127),
+        13: (0, 127, 191),
+        14: (0, 127, 255),
+        15: (0, 100, 155)
+    }
+
+iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+    if dir_name == '':
+        return
+    dir_name = osp.expanduser(dir_name)
+    os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
+    """RGB-color encoding to grayscale labels."""
+    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
+
+    for c, i in palette.items():
+        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
+        arr_2d[m] = i
+
+    return arr_2d
+
+
+def pad(img, shape=None, padding=None, pad_val=0, padding_mode='constant'):
+    assert (shape is not None) ^ (padding is not None)
+    if shape is not None:
+        width = max(shape[1] - img.shape[1], 0)
+        height = max(shape[0] - img.shape[0], 0)
+        padding = (0, 0, width, height)
+
+    # Check pad_val
+    if isinstance(pad_val, tuple):
+        assert len(pad_val) == img.shape[-1]
+    elif not isinstance(pad_val, numbers.Number):
+        raise TypeError('pad_val must be a int or a tuple. '
+                        f'But received {type(pad_val)}')
+
+    # Check padding
+    if isinstance(padding, tuple) and len(padding) in [2, 4]:
+        if len(padding) == 2:
+            padding = (padding[0], padding[1], padding[0], padding[1])
+    elif isinstance(padding, numbers.Number):
+        padding = (padding, padding, padding, padding)
+    else:
+        raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+                         f'But received {padding}')
+
+    # Check padding mode
+    assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+    border_type = {
+        'constant': cv2.BORDER_CONSTANT,
+        'edge': cv2.BORDER_REPLICATE,
+        'reflect': cv2.BORDER_REFLECT_101,
+        'symmetric': cv2.BORDER_REFLECT
+    }
+    img = cv2.copyMakeBorder(
+        img,
+        padding[1],
+        padding[3],
+        padding[0],
+        padding[2],
+        border_type[padding_mode],
+        value=pad_val)
+
+    return img
+
+
+def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
+    img = np.asarray(Image.open(src_path).convert('RGB'))
+
+    img_H, img_W, _ = img.shape
+
+    if img_H < patch_H and img_W > patch_W:
+
+        img = pad(img, shape=(patch_H, img_W), pad_val=0)
+
+        img_H, img_W, _ = img.shape
+
+    elif img_H > patch_H and img_W < patch_W:
+
+        img = pad(img, shape=(img_H, patch_W), pad_val=0)
+
+        img_H, img_W, _ = img.shape
+
+    elif img_H < patch_H and img_W < patch_W:
+
+        img = pad(img, shape=(patch_H, patch_W), pad_val=0)
+
+        img_H, img_W, _ = img.shape
+
+    for x in range(0, img_W, patch_W - overlap):
+        for y in range(0, img_H, patch_H - overlap):
+            x_str = x
+            x_end = x + patch_W
+            if x_end > img_W:
+                diff_x = x_end - img_W
+                x_str -= diff_x
+                x_end = img_W
+            y_str = y
+            y_end = y + patch_H
+            if y_end > img_H:
+                diff_y = y_end - img_H
+                y_str -= diff_y
+                y_end = img_H
+
+            img_patch = img[y_str:y_end, x_str:x_end, :]
+            img_patch = Image.fromarray(img_patch.astype(np.uint8))
+            image = osp.basename(src_path).split('.')[0] + '_' + str(
+                y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
+                    x_end) + '.png'
+            # print(image)
+            save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
+            img_patch.save(save_path_image)
+
+
+def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
+    label = Image.open(src_path).convert('RGB')
+    label = np.asarray(label)
+    label = iSAID_convert_from_color(label)
+    img_H, img_W = label.shape
+
+    if img_H < patch_H and img_W > patch_W:
+
+        label = pad(label, shape=(patch_H, img_W), pad_val=255)
+
+        img_H = patch_H
+
+    elif img_H > patch_H and img_W < patch_W:
+
+        label = pad(label, shape=(img_H, patch_W), pad_val=255)
+
+        img_W = patch_W
+
+    elif img_H < patch_H and img_W < patch_W:
+
+        label = pad(label, shape=(patch_H, patch_W), pad_val=255)
+
+        img_H = patch_H
+        img_W = patch_W
+
+    for x in range(0, img_W, patch_W - overlap):
+        for y in range(0, img_H, patch_H - overlap):
+            x_str = x
+            x_end = x + patch_W
+            if x_end > img_W:
+                diff_x = x_end - img_W
+                x_str -= diff_x
+                x_end = img_W
+            y_str = y
+            y_end = y + patch_H
+            if y_end > img_H:
+                diff_y = y_end - img_H
+                y_str -= diff_y
+                y_end = img_H
+
+            lab_patch = label[y_str:y_end, x_str:x_end]
+            lab_patch = Image.fromarray(lab_patch.astype(np.uint8))
+
+            image = osp.basename(src_path).split('.')[0].split('_')[
+                0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
+                    x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
+            lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('dataset_path', help='Path of raw iSAID dataset.')
+    parser.add_argument('--tmp_dir', help='Path of the temporary directory.')
+    parser.add_argument('-o', '--out_dir', help='Output path.')
+
+    parser.add_argument(
+        '--patch_width',
+        default=896,
+        type=int,
+        help='Width of the cropped image patch.')
+    parser.add_argument(
+        '--patch_height',
+        default=896,
+        type=int,
+        help='Height of the cropped image patch.')
+    parser.add_argument(
+        '--overlap_area', default=384, type=int, help='Overlap area.')
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    dataset_path = args.dataset_path
+    # image patch width and height
+    patch_H, patch_W = args.patch_width, args.patch_height
+
+    overlap = args.overlap_area  # overlap area
+
+    if args.out_dir is None:
+        out_dir = osp.join('data', 'iSAID')
+    else:
+        out_dir = args.out_dir
+
+    print('Creating directories...')
+    mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
+    mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
+    mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
+
+    mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
+    mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
+    mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
+
+    assert os.path.exists(os.path.join(dataset_path, 'train')), \
+        'train is not in {}'.format(dataset_path)
+    assert os.path.exists(os.path.join(dataset_path, 'val')), \
+        'val is not in {}'.format(dataset_path)
+    assert os.path.exists(os.path.join(dataset_path, 'test')), \
+        'test is not in {}'.format(dataset_path)
+
+    with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
+        for dataset_mode in ['train', 'val', 'test']:
+
+            # for dataset_mode in [ 'test']:
+            print('Extracting  {}ing.zip...'.format(dataset_mode))
+            img_zipp_list = glob.glob(
+                os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
+            print('Find the data', img_zipp_list)
+            for img_zipp in img_zipp_list:
+                zip_file = zipfile.ZipFile(img_zipp)
+                zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
+            src_path_list = glob.glob(
+                os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
+
+            for i, img_path in enumerate(tqdm(src_path_list)):
+                if dataset_mode != 'test':
+                    slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
+                                     patch_W, overlap)
+
+                else:
+                    shutil.move(img_path,
+                                os.path.join(out_dir, 'img_dir', dataset_mode))
+
+            if dataset_mode != 'test':
+                label_zipp_list = glob.glob(
+                    os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
+                                 '*.zip'))
+                for label_zipp in label_zipp_list:
+                    zip_file = zipfile.ZipFile(label_zipp)
+                    zip_file.extractall(
+                        os.path.join(tmp_dir, dataset_mode, 'lab'))
+
+                lab_path_list = glob.glob(
+                    os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
+                                 '*.png'))
+                for i, lab_path in enumerate(tqdm(lab_path_list)):
+                    slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
+                                     patch_W, overlap)
+
+        print('Removing the temporary files...')
+
+    print('Done!')

+ 90 - 0
examples/c2fnet/train.py

@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+
+# 图像分割模型C2FNet训练脚本,粗分割器模型为FCN_HRNet.
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/iSAID'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/iSAID/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/iSAID/val.txt'
+# 数据集类别信息文件路径
+LABEL_LIST_PATH = './data/iSAID/label.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/c2fnet/'
+
+# 影像波段数量
+NUM_BANDS = 3
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 将影像缩放到512x512大小
+    T.Resize(target_size=512),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ArrangeSegmenter('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ReloadMask(),
+    T.ArrangeSegmenter('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.SegDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    label_list=LABEL_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.SegDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    label_list=LABEL_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+model = pdrs.tasks.seg.C2FNet(
+    in_channels=NUM_BANDS,
+    num_classes=len(train_dataset.labels),
+    coarse_model='FCN',
+    coarse_model_backbone='HRNet_W18',
+    coarse_model_path='./coarse_model/fcn_hrnet_baseline_on_iSAID.pdparams')
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=1,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=4,
+    save_dir=EXP_DIR,
+    # 初始学习率大小
+    learning_rate=0.01,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None,
+    pretrain_weights='IMAGENET')

+ 1 - 0
paddlers/rs_models/seg/__init__.py

@@ -14,3 +14,4 @@
 
 from .farseg import FarSeg
 from .factseg import FactSeg
+from .c2fnet import C2FNet

+ 340 - 0
paddlers/rs_models/seg/c2fnet.py

@@ -0,0 +1,340 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import paddlers.models.ppseg as ppseg
+import paddlers.utils.logging as logging
+
+from paddlers.models.ppseg.cvlibs import param_init
+from paddlers.rs_models.seg.layers import layers_lib as layers
+from paddlers.models.ppseg.utils import utils
+
+
+class C2FNet(nn.Layer):
+    """
+     A Coarse-to-Fine Segmentation Network for Small Objects in Remote Sensing Images.
+
+     Args:
+         num_classes (int): The unique number of target classes.
+         backbone (str): The backbone network.
+         backbone_indices (tuple, optional): The values in the tuple indicate the indices of output of backbone.
+            Default: (-1, ).
+         kernel_sizes(tuple, optional): The sliding windows' size. Default: (128,128).
+         training_stride(int, optional): The stride of sliding windows. Default: 32.
+         samples_per_gpu(int, optional): The fined process's batch size. Default: 32.
+         channels (int, optional): The channels between conv layer and the last layer of FCNHead.
+            If None, it will be the number of channels of input features. Default: None.
+         align_corners (bool, optional): An argument of `F.interpolate`. It should be set to False when the output size of feature
+            is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.  Default: False.
+     """
+
+    def __init__(self,
+                 num_classes,
+                 backbone,
+                 backbone_indices=(-1, ),
+                 kernel_sizes=(128, 128),
+                 training_stride=32,
+                 samples_per_gpu=32,
+                 channels=None,
+                 align_corners=False):
+        super(C2FNet, self).__init__()
+        self.backbone = backbone
+        backbone_channels = [
+            backbone.feat_channels[i] for i in backbone_indices
+        ]
+        self.head_fgbg = FCNHead(2, backbone_indices, backbone_channels,
+                                 channels)
+        self.num_cls = num_classes
+        self.kernel_sizes = [kernel_sizes[0], kernel_sizes[1]]
+        self.training_stride = training_stride
+        self.samples = samples_per_gpu
+        self.align_corners = align_corners
+
+    def forward(self, x, heatmaps, label=None):
+        ori_heatmap = heatmaps
+        heatmap = paddle.argmax(heatmaps, axis=1, keepdim=True, dtype='int32')
+        if paddle.max(heatmap) > 15:
+            logging.warning(
+                "Please note that currently C2FNet can only be trained and evaluated on the iSAID dataset."
+            )
+        heatmap = paddle.where(
+            (heatmap == 10) | (heatmap == 11) | (heatmap == 8) |
+            (heatmap == 15) | (heatmap == 9) | (heatmap == 1) | (heatmap == 14),
+            paddle.ones_like(heatmap),
+            paddle.zeros_like(heatmap)).astype('float32')
+
+        if self.training:
+            label = paddle.unsqueeze(label, axis=1).astype('float32')
+            label = paddle.where((label == 10) | (label == 11) | (label == 8) |
+                                 (label == 15) | (label == 9) | (label == 1) |
+                                 (label == 14),
+                                 paddle.ones_like(label),
+                                 paddle.zeros_like(label))
+            mask_regions = F.unfold(
+                heatmap,
+                kernel_sizes=self.kernel_sizes,
+                strides=self.training_stride,
+                paddings=0,
+                dilations=1,
+                name=None)
+            mask_regions = paddle.transpose(mask_regions, perm=[0, 2, 1])
+            mask_regions = paddle.reshape(
+                mask_regions,
+                shape=[-1, self.kernel_sizes[0] * self.kernel_sizes[1]])
+
+            img_regions = F.unfold(
+                x,
+                kernel_sizes=self.kernel_sizes,
+                strides=self.training_stride,
+                paddings=0,
+                dilations=1,
+                name=None)
+            img_regions = paddle.transpose(img_regions, perm=[0, 2, 1])
+            img_regions = paddle.reshape(
+                img_regions,
+                shape=[-1, 3 * self.kernel_sizes[0] * self.kernel_sizes[1]])
+
+            label_regions = F.unfold(
+                label,
+                kernel_sizes=self.kernel_sizes,
+                strides=self.training_stride,
+                paddings=0,
+                dilations=1,
+                name=None)
+            label_regions = paddle.transpose(label_regions, perm=[0, 2, 1])
+            label_regions = paddle.reshape(
+                label_regions,
+                shape=[-1, self.kernel_sizes[0] * self.kernel_sizes[1]])
+
+            mask_regions_sum = paddle.sum(mask_regions, axis=1)
+            mask_regions_selected = paddle.where(
+                mask_regions_sum > 0,
+                paddle.ones_like(mask_regions_sum),
+                paddle.zeros_like(mask_regions_sum))
+            final_mask_regions_selected = paddle.zeros_like(
+                mask_regions_selected).astype('bool')
+            final_mask_regions_selected.stop_gradient = True
+
+            theld = self.samples * paddle.shape(x)[0]
+
+            if paddle.sum(mask_regions_selected) >= theld:
+                _, top_k_idx = paddle.topk(mask_regions_sum, k=theld)
+                final_mask_regions_selected[top_k_idx] = True
+                selected_img_regions = img_regions[final_mask_regions_selected]
+                selected_img_regions = paddle.reshape(
+                    selected_img_regions,
+                    shape=[
+                        theld, 3, self.kernel_sizes[0], self.kernel_sizes[1]
+                    ])
+
+                selected_label_regions = label_regions[
+                    final_mask_regions_selected]
+                selected_label_regions = paddle.reshape(
+                    selected_label_regions,
+                    shape=[theld, self.kernel_sizes[0],
+                           self.kernel_sizes[1]]).astype('int32')
+
+                feat_list = self.backbone(selected_img_regions)
+                bgfg = self.head_fgbg(feat_list)
+
+                binary_fea = F.interpolate(
+                    bgfg[0],
+                    self.kernel_sizes,
+                    mode='bilinear',
+                    align_corners=self.align_corners)
+
+                return [binary_fea, selected_label_regions]
+            else:
+                theld = theld // 8
+                _, top_k_idx = paddle.topk(mask_regions_sum, k=theld)
+                final_mask_regions_selected[top_k_idx] = True
+
+                selected_img_regions = img_regions[final_mask_regions_selected]
+                selected_img_regions = paddle.reshape(
+                    selected_img_regions,
+                    shape=[
+                        theld, 3, self.kernel_sizes[0], self.kernel_sizes[1]
+                    ])
+
+                selected_label_regions = label_regions[
+                    final_mask_regions_selected]
+                selected_label_regions = paddle.reshape(
+                    selected_label_regions,
+                    shape=[theld, self.kernel_sizes[0],
+                           self.kernel_sizes[1]]).astype('int32')
+
+                feat_list = self.backbone(selected_img_regions)
+                bgfg = self.head_fgbg(feat_list)
+
+                binary_fea = F.interpolate(
+                    bgfg[0],
+                    self.kernel_sizes,
+                    mode='bilinear',
+                    align_corners=self.align_corners)
+
+                return [binary_fea, selected_label_regions]
+
+        else:
+            mask_regions = F.unfold(
+                heatmap,
+                kernel_sizes=self.kernel_sizes,
+                strides=self.kernel_sizes[0],
+                paddings=0,
+                dilations=1,
+                name=None)
+            mask_regions = paddle.transpose(mask_regions, perm=[0, 2, 1])
+            mask_regions = paddle.reshape(
+                mask_regions,
+                shape=[-1, self.kernel_sizes[0] * self.kernel_sizes[1]])
+
+            img_regions = F.unfold(
+                x,
+                kernel_sizes=self.kernel_sizes,
+                strides=self.kernel_sizes[0],
+                paddings=0,
+                dilations=1,
+                name=None)
+            img_regions = paddle.transpose(img_regions, perm=[0, 2, 1])
+            img_regions = paddle.reshape(
+                img_regions,
+                shape=[-1, 3 * self.kernel_sizes[0] * self.kernel_sizes[1]])
+
+            mask_regions_sum = paddle.sum(mask_regions, axis=1)
+            mask_regions_selected = paddle.where(
+                mask_regions_sum > 0,
+                paddle.ones_like(mask_regions_sum),
+                paddle.zeros_like(mask_regions_sum)).astype('bool')
+
+            if paddle.sum(mask_regions_selected.astype('int')) == 0:
+                return [ori_heatmap]
+            else:
+                ori_fea_regions = F.unfold(
+                    ori_heatmap,
+                    kernel_sizes=self.kernel_sizes,
+                    strides=self.kernel_sizes[0],
+                    paddings=0,
+                    dilations=1,
+                    name=None)
+                ori_fea_regions = paddle.transpose(
+                    ori_fea_regions, perm=[0, 2, 1])
+                ori_fea_regions = paddle.reshape(
+                    ori_fea_regions,
+                    shape=[
+                        -1, self.num_cls * self.kernel_sizes[0] *
+                        self.kernel_sizes[1]
+                    ])
+                selected_img_regions = img_regions[mask_regions_selected]
+                selected_img_regions = paddle.reshape(
+                    selected_img_regions,
+                    shape=[
+                        paddle.shape(selected_img_regions)[0], 3,
+                        self.kernel_sizes[0], self.kernel_sizes[1]
+                    ])
+                selected_fea_regions = ori_fea_regions[mask_regions_selected]
+                selected_fea_regions = paddle.reshape(
+                    selected_fea_regions,
+                    shape=[
+                        paddle.shape(selected_fea_regions)[0], self.num_cls,
+                        self.kernel_sizes[0], self.kernel_sizes[1]
+                    ])
+                feat_list = self.backbone(selected_img_regions)
+                bgfg = self.head_fgbg(feat_list)
+                binary_fea = F.interpolate(
+                    bgfg[0],
+                    self.kernel_sizes,
+                    mode='bilinear',
+                    align_corners=self.align_corners)
+                binary_fea = F.softmax(binary_fea, axis=1)
+                bg_binary, fg_binary = paddle.chunk(
+                    binary_fea, chunks=2, axis=1)
+                front, ship, mid, lv, sv, hl, swp, mid2, pl, hb = paddle.split(
+                    selected_fea_regions,
+                    num_or_sections=[1, 1, 6, 1, 1, 1, 1, 2, 1, 1],
+                    axis=1)
+                ship = paddle.add(ship, fg_binary)
+                lv = paddle.add(lv, fg_binary)
+                sv = paddle.add(sv, fg_binary)
+                hl = paddle.add(hl, fg_binary)
+                swp = paddle.add(swp, fg_binary)
+                pl = paddle.add(pl, fg_binary)
+                hb = paddle.add(hb, fg_binary)
+                selected_fea_regions = paddle.concat(
+                    x=[front, ship, mid, lv, sv, hl, swp, mid2, pl, hb], axis=1)
+                selected_fea_regions = paddle.reshape(
+                    selected_fea_regions,
+                    shape=[paddle.shape(selected_fea_regions)[0], -1])
+                ori_fea_regions[mask_regions_selected] = selected_fea_regions
+                ori_fea_regions = paddle.reshape(
+                    ori_fea_regions,
+                    shape=[
+                        paddle.shape(x)[0], -1, self.num_cls *
+                        self.kernel_sizes[0] * self.kernel_sizes[1]
+                    ])
+                ori_fea_regions = paddle.transpose(
+                    ori_fea_regions, perm=[0, 2, 1])
+                fea_out = F.fold(
+                    ori_fea_regions, [paddle.shape(x)[2], paddle.shape(x)[3]],
+                    self.kernel_sizes,
+                    strides=self.kernel_sizes[0],
+                    paddings=0,
+                    dilations=1,
+                    name=None)
+
+                return [fea_out]
+
+
+class FCNHead(nn.Layer):
+    def __init__(self,
+                 num_classes,
+                 backbone_indices=(-1, ),
+                 backbone_channels=(270, ),
+                 channels=None):
+        super(FCNHead, self).__init__()
+
+        self.num_classes = num_classes
+        self.backbone_indices = backbone_indices
+        if channels is None:
+            channels = backbone_channels[0]
+
+        self.conv_1 = layers.ConvBNReLU(
+            in_channels=backbone_channels[0],
+            out_channels=channels,
+            kernel_size=1,
+            stride=1,
+            bias_attr=True)
+        self.cls = nn.Conv2D(
+            in_channels=channels,
+            out_channels=self.num_classes,
+            kernel_size=1,
+            stride=1,
+            bias_attr=True)
+        self.init_weight()
+
+    def forward(self, feat_list):
+        logit_list = []
+        x = feat_list[self.backbone_indices[0]]
+        x = self.conv_1(x)
+        logit = self.cls(x)
+        logit_list.append(logit)
+        return logit_list
+
+    def init_weight(self):
+        for layer in self.sublayers():
+            if isinstance(layer, nn.Conv2D):
+                param_init.normal_init(layer.weight, std=0.001)
+            elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
+                param_init.constant_init(layer.weight, value=1.0)
+                param_init.constant_init(layer.bias, value=0.0)

+ 125 - 2
paddlers/tasks/segmenter.py

@@ -36,7 +36,8 @@ from .utils.infer_nets import InferSegNet
 from .utils.slider_predict import slider_predict
 
 __all__ = [
-    "UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg", "FactSeg"
+    "UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg", "FactSeg",
+    "C2FNet"
 ]
 
 
@@ -410,7 +411,6 @@ class BaseSegmenter(BaseModel):
         """
 
         self._check_transforms(eval_dataset.transforms, 'eval')
-
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
         local_rank = paddle.distributed.get_rank()
@@ -911,3 +911,126 @@ class FactSeg(BaseSegmenter):
             losses=losses,
             in_channels=in_channels,
             **params)
+
+
+class C2FNet(BaseSegmenter):
+    def __init__(self,
+                 in_channels=3,
+                 num_classes=2,
+                 backbone='HRNet_W18',
+                 use_mixed_loss=False,
+                 losses=None,
+                 output_stride=8,
+                 backbone_indices=(-1, ),
+                 kernel_sizes=(128, 128),
+                 training_stride=32,
+                 samples_per_gpu=32,
+                 channels=None,
+                 align_corners=False,
+                 coarse_model=None,
+                 coarse_model_backbone=None,
+                 coarse_model_path=None,
+                 **params):
+        self.backbone_name = backbone
+        if params.get('with_net', True):
+            with DisablePrint():
+                backbone = getattr(ppseg.models, self.backbone_name)(
+                    in_channels=in_channels, align_corners=align_corners)
+        else:
+            backbone = None
+        if coarse_model_backbone in ['ResNet50_vd', 'ResNet101_vd']:
+            self.coarse_model_backbone = getattr(
+                ppseg.models, coarse_model_backbone)(output_stride=8)
+        elif coarse_model_backbone in ['HRNet_W18', 'HRNet_W48']:
+            self.coarse_model_backbone = getattr(
+                ppseg.models, coarse_model_backbone)(align_corners=False)
+        else:
+            raise ValueError(
+                "coarse_model_backbone: {} is not supported. Please choose one of "
+                "{'ResNet50_vd', 'ResNet101_vd', 'HRNet_W18', 'HRNet_W48'}.".
+                format(coarse_model_backbone))
+        self.coarse_model = dict(ppseg.models.__dict__)[coarse_model](
+            num_classes=num_classes, backbone=self.coarse_model_backbone)
+        self.coarse_params = paddle.load(coarse_model_path)
+        self.coarse_model.set_state_dict(self.coarse_params)
+        self.coarse_model.eval()
+        params.update({
+            'backbone': backbone,
+            'backbone_indices': backbone_indices,
+            'kernel_sizes': kernel_sizes,
+            'training_stride': training_stride,
+            'samples_per_gpu': samples_per_gpu,
+            'align_corners': align_corners
+        })
+        super(C2FNet, self).__init__(
+            model_name='C2FNet',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            **params)
+
+    def run(self, net, inputs, mode):
+        with paddle.no_grad():
+            pre_coarse = self.coarse_model(inputs[0])
+            pre_coarse = pre_coarse[0]
+            heatmaps = pre_coarse
+        if mode == 'test':
+            net_out = net(inputs[0], heatmaps)
+            logit = net_out[0]
+            outputs = OrderedDict()
+            origin_shape = inputs[1]
+            if self.status == 'Infer':
+                label_map_list, score_map_list = self.postprocess(
+                    net_out, origin_shape, transforms=inputs[2])
+            else:
+                logit_list = self.postprocess(
+                    logit, origin_shape, transforms=inputs[2])
+                label_map_list = []
+                score_map_list = []
+                for logit in logit_list:
+                    logit = paddle.transpose(logit, perm=[0, 2, 3, 1])  # NHWC
+                    label_map_list.append(
+                        paddle.argmax(
+                            logit, axis=-1, keepdim=False, dtype='int32')
+                        .squeeze().numpy())
+                    score_map_list.append(
+                        F.softmax(
+                            logit, axis=-1).squeeze().numpy().astype('float32'))
+            outputs['label_map'] = label_map_list
+            outputs['score_map'] = score_map_list
+
+        if mode == 'eval':
+            net_out = net(inputs[0], heatmaps)
+            logit = net_out[0]
+            outputs = OrderedDict()
+            if self.status == 'Infer':
+                pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
+            else:
+                pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
+            label = inputs[1]
+            if label.ndim == 3:
+                paddle.unsqueeze_(label, axis=1)
+            if label.ndim != 4:
+                raise ValueError("Expected label.ndim == 4 but got {}".format(
+                    label.ndim))
+            origin_shape = [label.shape[-2:]]
+            pred = self.postprocess(
+                pred, origin_shape, transforms=inputs[2])[0]  # NCHW
+            intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
+                pred, label, self.num_classes)
+            outputs['intersect_area'] = intersect_area
+            outputs['pred_area'] = pred_area
+            outputs['label_area'] = label_area
+            outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
+                                                           self.num_classes)
+        if mode == 'train':
+            net_out = net(inputs[0], heatmaps, inputs[1])
+            logit = [net_out[0], ]
+            labels = net_out[1]
+            outputs = OrderedDict()
+            loss_list = metrics.loss_computation(
+                logits_list=logit, labels=labels, losses=self.losses)
+            loss = sum(loss_list)
+            outputs['loss'] = loss
+
+        return outputs

+ 14 - 3
paddlers/utils/checkpoint.py

@@ -87,7 +87,8 @@ seg_pretrain_weights_dict = {
     'FastSCNN': ['CITYSCAPES'],
     'HRNet': ['CITYSCAPES', 'PascalVOC'],
     'BiSeNetV2': ['CITYSCAPES'],
-    'FactSeg': ['iSAID']
+    'FactSeg': ['iSAID'],
+    'C2FNet': ['IMAGENET', 'iSAID']
 }
 
 cityscapes_weights = {
@@ -323,7 +324,15 @@ imagenet_weights = {
     'DeepLabV3P_ResNet50_vd_IMAGENET':
     'https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz',
     'DeepLabV3P_ResNet101_vd_IMAGENET':
-    'https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz'
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz',
+    'C2FNet_ResNet50_vd_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz',
+    'C2FNet_ResNet101_vd_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz',
+    'C2FNet_HRNet_W18_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz',
+    'C2FNet_HRNet_W48_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz',
 }
 
 pascalvoc_weights = {
@@ -441,7 +450,9 @@ levircd_weights = {
 
 isaid_weights = {
     'FactSeg_iSAID':
-    'https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/factseg_isaid.pdparams'
+    'https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/factseg_isaid.pdparams',
+    'C2FNet_HRNet_W18_iSAID':
+    'https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/c2fnet_fcn_hrnet_isaid.pdparams'
 }