Browse Source

Factseg paddle (#54)

* fact-seg paddle

* fact-paddle

* Coding modify

* Add files via upload

* Passed testing code

* Use offical hook for pre-commit

* Use the official .style.yarp file and passed the yarp test locally

* Restore the file which was changed by error

* Revert "Use the official .style.yarp file and passed the yarp test locally"

This reverts commit 6a294fa21b74166bc6be94b858fcf486ff13d4ea.

* Fix code style

* Code Modify

* Code Modify according to reviewer

* solve conflict

* Solve conflict

* solve yarp

* yarp

* Fix code style

* [Feat] Make FactSeg public

Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
LHE-IT 3 years ago
parent
commit
afc25c93c9

+ 1 - 0
docs/intro/model_zoo.md

@@ -34,6 +34,7 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型
 | 目标检测 | YOLOv3 | 否 |
 | 目标检测 | YOLOv3 | 否 |
 | 图像分割 | BiSeNet V2 | 是 |
 | 图像分割 | BiSeNet V2 | 是 |
 | 图像分割 | DeepLab V3+ | 是 |
 | 图像分割 | DeepLab V3+ | 是 |
+| 图像分割 | \*FactSeg | 是 |
 | 图像分割 | \*FarSeg | 是 |
 | 图像分割 | \*FarSeg | 是 |
 | 图像分割 | Fast-SCNN | 是 |
 | 图像分割 | Fast-SCNN | 是 |
 | 图像分割 | HRNet | 是 |
 | 图像分割 | HRNet | 是 |

+ 2 - 2
paddlers/models/ppgan/models/generators/generator_firstorder.py

@@ -131,8 +131,8 @@ class FirstOrderGenerator(nn.Layer):
                     transformed_kp['jacobian']))
                     transformed_kp['jacobian']))
                 normed_driving = paddle.inverse(kp_driving['jacobian'])
                 normed_driving = paddle.inverse(kp_driving['jacobian'])
                 normed_transformed = jacobian_transformed
                 normed_transformed = jacobian_transformed
-                value = paddle.matmul(
-                    *broadcast(normed_driving, normed_transformed))
+                value = paddle.matmul(*broadcast(normed_driving,
+                                                 normed_transformed))
                 eye = paddle.tensor.eye(2, dtype='float32').reshape(
                 eye = paddle.tensor.eye(2, dtype='float32').reshape(
                     (1, 1, 2, 2))
                     (1, 1, 2, 2))
                 eye = paddle.tile(eye, [1, value.shape[1], 1, 1])
                 eye = paddle.tile(eye, [1, value.shape[1], 1, 1])

+ 2 - 2
paddlers/models/ppseg/models/losses/lovasz_loss.py

@@ -77,8 +77,8 @@ class LovaszHingeLoss(nn.Layer):
         """
         """
         if logits.shape[1] == 2:
         if logits.shape[1] == 2:
             logits = binary_channel_to_unary(logits)
             logits = binary_channel_to_unary(logits)
-        loss = lovasz_hinge_flat(
-            *flatten_binary_scores(logits, labels, self.ignore_index))
+        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels,
+                                                        self.ignore_index))
         return loss
         return loss
 
 
 
 

+ 1 - 0
paddlers/rs_models/seg/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from .farseg import FarSeg
 from .farseg import FarSeg
+from .factseg import FactSeg

+ 141 - 0
paddlers/rs_models/seg/factseg.py

@@ -0,0 +1,141 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddlers.models.ppdet.modeling import \
+                         initializer as init
+from paddlers.rs_models.seg.farseg import FPN, \
+                         ResNetEncoder,AsymmetricDecoder
+
+
+def conv_with_kaiming_uniform(use_gn=False, use_relu=False):
+    def make_conv(in_channels, out_channels, kernel_size, stride=1, dilation=1):
+        conv = nn.Conv2D(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=dilation * (kernel_size - 1) // 2,
+            dilation=dilation,
+            bias_attr=False if use_gn else True)
+
+        init.kaiming_uniform_(conv.weight, a=1)
+        if not use_gn:
+            init.constant_(conv.bias, 0)
+        module = [conv, ]
+        if use_gn:
+            raise NotImplementedError
+        if use_relu:
+            module.append(nn.ReLU())
+        if len(module) > 1:
+            return nn.Sequential(*module)
+        return conv
+
+    return make_conv
+
+
+default_conv_block = conv_with_kaiming_uniform(use_gn=False, use_relu=False)
+
+
+class FactSeg(nn.Layer):
+    """
+     The FactSeg implementation based on PaddlePaddle.
+
+     The original article refers to
+     A. Ma, J. Wang, Y. Zhong and Z. Zheng, "FactSeg: Foreground Activation
+     -Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing
+      Imagery,"in IEEE Transactions on Geoscience and Remote Sensing, vol. 60,
+       pp. 1-16, 2022, Art no. 5606216.
+
+
+     Args:
+         in_channels (int): The number of image channels for the input model.
+         num_classes (int): The unique number of target classes.
+         backbone (str, optional): A backbone network, models available in
+         `paddle.vision.models.resnet`. Default: resnet50.
+         backbone_pretrained (bool, optional): Whether the backbone network uses
+         IMAGENET pretrained weights. Default: True.
+     """
+
+    def __init__(self,
+                 in_channels,
+                 num_classes,
+                 backbone='resnet50',
+                 backbone_pretrained=True):
+        super(FactSeg, self).__init__()
+        backbone = backbone.lower()
+        self.resencoder = ResNetEncoder(
+            backbone=backbone,
+            in_channels=in_channels,
+            pretrained=backbone_pretrained)
+        self.resencoder.resnet._sub_layers.pop('fc')
+        self.fgfpn = FPN(in_channels_list=[256, 512, 1024, 2048],
+                         out_channels=256,
+                         conv_block=default_conv_block)
+        self.bifpn = FPN(in_channels_list=[256, 512, 1024, 2048],
+                         out_channels=256,
+                         conv_block=default_conv_block)
+        self.fg_decoder = AsymmetricDecoder(
+            in_channels=256,
+            out_channels=128,
+            in_feature_output_strides=(4, 8, 16, 32),
+            out_feature_output_stride=4,
+            conv_block=nn.Conv2D)
+        self.bi_decoder = AsymmetricDecoder(
+            in_channels=256,
+            out_channels=128,
+            in_feature_output_strides=(4, 8, 16, 32),
+            out_feature_output_stride=4,
+            conv_block=nn.Conv2D)
+        self.fg_cls = nn.Conv2D(128, num_classes, kernel_size=1)
+        self.bi_cls = nn.Conv2D(128, 1, kernel_size=1)
+        self.config_loss = ['joint_loss']
+        self.config_foreground = []
+        self.fbattention_atttention = False
+
+    def forward(self, x):
+        feat_list = self.resencoder(x)
+        if 'skip_decoder' in []:
+            fg_out = self.fgskip_deocder(feat_list)
+            bi_out = self.bgskip_deocder(feat_list)
+        else:
+            forefeat_list = list(self.fgfpn(feat_list))
+            binaryfeat_list = self.bifpn(feat_list)
+            if self.fbattention_atttention:
+                for i in range(len(binaryfeat_list)):
+                    forefeat_list[i] = self.fbatt_block_list[i](
+                        binaryfeat_list[i], forefeat_list[i])
+            fg_out = self.fg_decoder(forefeat_list)
+            bi_out = self.bi_decoder(binaryfeat_list)
+        fg_pred = self.fg_cls(fg_out)
+        bi_pred = self.bi_cls(bi_out)
+        fg_pred = F.interpolate(
+            fg_pred, scale_factor=4.0, mode='bilinear', align_corners=True)
+        bi_pred = F.interpolate(
+            bi_pred, scale_factor=4.0, mode='bilinear', align_corners=True)
+        if self.training:
+            return [fg_pred]
+        else:
+            binary_prob = F.sigmoid(bi_pred)
+            cls_prob = F.softmax(fg_pred, axis=1)
+            cls_prob[:, 0, :, :] = cls_prob[:, 0, :, :] * (
+                1 - binary_prob).squeeze(axis=1)
+            cls_prob[:, 1:, :, :] = cls_prob[:, 1:, :, :] * binary_prob
+            z = paddle.sum(cls_prob, axis=1)
+            z = z.unsqueeze(axis=1)
+            cls_prob = paddle.divide(cls_prob, z)
+            return [cls_prob]

+ 19 - 2
paddlers/tasks/segmenter.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import math
 import math
-import os
 import os.path as osp
 import os.path as osp
 from collections import OrderedDict
 from collections import OrderedDict
 
 
@@ -36,7 +35,9 @@ from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferSegNet
 from .utils.infer_nets import InferSegNet
 from .utils.slider_predict import slider_predict
 from .utils.slider_predict import slider_predict
 
 
-__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
+__all__ = [
+    "UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg", "FactSeg"
+]
 
 
 
 
 class BaseSegmenter(BaseModel):
 class BaseSegmenter(BaseModel):
@@ -894,3 +895,19 @@ class FarSeg(BaseSegmenter):
             losses=losses,
             losses=losses,
             in_channels=in_channels,
             in_channels=in_channels,
             **params)
             **params)
+
+
+class FactSeg(BaseSegmenter):
+    def __init__(self,
+                 in_channels=3,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 losses=None,
+                 **params):
+        super(FactSeg, self).__init__(
+            model_name='FactSeg',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            losses=losses,
+            in_channels=in_channels,
+            **params)

+ 1 - 0
test_tipc/README.md

@@ -46,6 +46,7 @@
 | 目标检测 | YOLOv3 | 支持 | - | - | - |
 | 目标检测 | YOLOv3 | 支持 | - | - | - |
 | 图像分割 | BiSeNet V2 | 支持 | - | - | - |
 | 图像分割 | BiSeNet V2 | 支持 | - | - | - |
 | 图像分割 | DeepLab V3+ | 支持 | - | - | - |
 | 图像分割 | DeepLab V3+ | 支持 | - | - | - |
+| 图像分割 | FactSeg | 支持 | - | - | - |
 | 图像分割 | FarSeg | 支持 | - | - | - |
 | 图像分割 | FarSeg | 支持 | - | - | - |
 | 图像分割 | Fast-SCNN | 支持 | - | - | - |
 | 图像分割 | Fast-SCNN | 支持 | - | - | - |
 | 图像分割 | HRNet | 支持 | - | - | - |
 | 图像分割 | HRNet | 支持 | - | - | - |

+ 11 - 0
test_tipc/configs/seg/factseg/factseg_rsseg.yaml

@@ -0,0 +1,11 @@
+# Configurations of FactSeg with RSSeg dataset
+
+_base_: ../_base_/rsseg.yaml
+
+save_dir: ./test_tipc/output/seg/factseg/
+
+model: !Node
+    type: FactSeg
+    args:
+        in_channels: 3
+        num_classes: 5

+ 53 - 0
test_tipc/configs/seg/factseg/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:seg:factseg
+python:python
+gpu_list:0
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+--config:lite_train_lite_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml
+train_model_name:best_model
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train seg
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[-1,3,512,512]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--config:null
+--save_log_path:null
+--benchmark:True
+--model_name:factseg
+null:null

+ 2 - 0
test_tipc/docs/test_train_inference_python.md

@@ -33,6 +33,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
 |  图像复原  | LESRCNN | 正常训练 | 正常训练 | PSNR=23.67 |
 |  图像复原  | LESRCNN | 正常训练 | 正常训练 | PSNR=23.67 |
 |  图像分割  | BiSeNet V2 | 正常训练 | 正常训练 | mIoU=70.52% |
 |  图像分割  | BiSeNet V2 | 正常训练 | 正常训练 | mIoU=70.52% |
 |  图像分割  | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=64.41% |
 |  图像分割  | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=64.41% |
+|  图像分割  | FactSeg | 正常训练 | 正常训练 |  |
 |  图像分割  | FarSeg | 正常训练 | 正常训练 | mIoU=50.60% |
 |  图像分割  | FarSeg | 正常训练 | 正常训练 | mIoU=50.60% |
 |  图像分割  | Fast-SCNN | 正常训练 | 正常训练 | mIoU=49.27% |
 |  图像分割  | Fast-SCNN | 正常训练 | 正常训练 | mIoU=49.27% |
 |  图像分割  | HRNet | 正常训练 | 正常训练 | mIoU=33.03% |
 |  图像分割  | HRNet | 正常训练 | 正常训练 | mIoU=33.03% |
@@ -68,6 +69,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
 |  目标检测  | YOLOv3 | 支持 | 支持 | 1 |
 |  目标检测  | YOLOv3 | 支持 | 支持 | 1 |
 |  图像分割  | BiSeNet V2 | 支持 | 支持 | 1 |
 |  图像分割  | BiSeNet V2 | 支持 | 支持 | 1 |
 |  图像分割  | DeepLab V3+ | 支持 | 支持 | 1 |
 |  图像分割  | DeepLab V3+ | 支持 | 支持 | 1 |
+|  图像分割  | FactSeg | 支持 | 支持 | 1 |
 |  图像分割  | FarSeg | 支持 | 支持 | 1 |
 |  图像分割  | FarSeg | 支持 | 支持 | 1 |
 |  图像分割  | Fast-SCNN | 支持 | 支持 | 1 |
 |  图像分割  | Fast-SCNN | 支持 | 支持 | 1 |
 |  图像分割  | HRNet | 支持 | 支持 | 1 |
 |  图像分割  | HRNet | 支持 | 支持 | 1 |

+ 19 - 1
tests/rs_models/test_seg_models.py

@@ -15,7 +15,7 @@
 import paddlers
 import paddlers
 from rs_models.test_model import TestModel
 from rs_models.test_model import TestModel
 
 
-__all__ = ['TestFarSegModel']
+__all__ = ['TestFarSegModel', 'TestFactSegModel']
 
 
 
 
 class TestSegModel(TestModel):
 class TestSegModel(TestModel):
@@ -70,3 +70,21 @@ class TestFarSegModel(TestSegModel):
         self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
         self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
                         [self.get_zeros_array(2)], [self.get_zeros_array(2)],
                         [self.get_zeros_array(2)], [self.get_zeros_array(2)],
                         [self.get_zeros_array(2)]]
                         [self.get_zeros_array(2)]]
+
+
+class TestFactSegModel(TestSegModel):
+    MODEL_CLASS = paddlers.rs_models.seg.FactSeg
+
+    def set_specs(self):
+        base_spec = dict(in_channels=3, num_classes=2)
+        self.specs = [
+            base_spec,
+            dict(in_channels=6, num_classes=10),
+            dict(**base_spec,
+                 backbone='resnet50',
+                 backbone_pretrained=False)
+        ]  # yapf: disable
+
+    def set_targets(self):
+        self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
+                        [self.get_zeros_array(2)]]

+ 1 - 0
tutorials/train/README.md

@@ -29,6 +29,7 @@
 |object_detection/yolov3.py | 目标检测 | YOLOv3 |
 |object_detection/yolov3.py | 目标检测 | YOLOv3 |
 |semantic_segmentation/bisenetv2.py | 图像分割 | BiSeNet V2 |
 |semantic_segmentation/bisenetv2.py | 图像分割 | BiSeNet V2 |
 |semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ |
 |semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ |
+|semantic_segmentation/factseg.py | 图像分割 | FactSeg |
 |semantic_segmentation/farseg.py | 图像分割 | FarSeg |
 |semantic_segmentation/farseg.py | 图像分割 | FarSeg |
 |semantic_segmentation/fast_scnn.py | 图像分割 | Fast-SCNN |
 |semantic_segmentation/fast_scnn.py | 图像分割 | Fast-SCNN |
 |semantic_segmentation/hrnet.py | 图像分割 | HRNet |
 |semantic_segmentation/hrnet.py | 图像分割 | HRNet |

+ 94 - 0
tutorials/train/semantic_segmentation/factseg.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+
+# 图像分割模型FactSeg训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rsseg/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
+# 数据集类别信息文件路径
+LABEL_LIST_PATH = './data/rsseg/labels.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/factseg/'
+
+# 下载和解压多光谱地块分类数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 选择前三个波段
+    T.SelectBand([1, 2, 3]),
+    # 将影像缩放到512x512大小
+    T.Resize(target_size=512),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeSegmenter('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    # 验证阶段与训练阶段应当选择相同的波段
+    T.SelectBand([1, 2, 3]),
+    T.Resize(target_size=512),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ReloadMask(),
+    T.ArrangeSegmenter('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.SegDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    label_list=LABEL_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.SegDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    label_list=LABEL_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 构建FactSeg模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
+model = pdrs.tasks.seg.FactSeg(num_classes=len(train_dataset.labels))
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=4,
+    save_dir=EXP_DIR,
+    pretrain_weights=None,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)