Selaa lähdekoodia

Add multi-channel support to FarSeg

Bobholamovic 2 vuotta sitten
vanhempi
commit
0b87782eec

+ 1 - 1
paddlers/rs_models/cd/bit.py

@@ -56,7 +56,7 @@ class BIT(nn.Layer):
             Default: 2.
         enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the 
             encoder. Default: True.
-        enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1
+        enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1.
         enc_head_dim (int, optional): Embedding dimension of each encoder head. Default: 64.
         dec_depth (int, optional): Number of attention blocks used in the decoder. Default: 8.
         dec_head_dim (int, optional): Embedding dimension of each decoder head. Default: 8.

+ 24 - 11
paddlers/rs_models/seg/farseg.py

@@ -164,7 +164,7 @@ class SceneRelation(nn.Layer):
         return refined_feats
 
 
-class AssymetricDecoder(nn.Layer):
+class AsymmetricDecoder(nn.Layer):
     def __init__(self,
                  in_channels,
                  out_channels,
@@ -172,7 +172,7 @@ class AssymetricDecoder(nn.Layer):
                  out_feat_output_stride=4,
                  norm_fn=nn.BatchNorm2D,
                  num_groups_gn=None):
-        super(AssymetricDecoder, self).__init__()
+        super(AsymmetricDecoder, self).__init__()
         if norm_fn == nn.BatchNorm2D:
             norm_fn_args = dict(num_features=out_channels)
         elif norm_fn == nn.GroupNorm:
@@ -215,9 +215,12 @@ class AssymetricDecoder(nn.Layer):
 
 
 class ResNet50Encoder(nn.Layer):
-    def __init__(self, pretrained=True):
+    def __init__(self, in_ch=3, pretrained=True):
         super(ResNet50Encoder, self).__init__()
         self.resnet = resnet50(pretrained=pretrained)
+        if in_ch != 3:
+            self.resnet.conv1 = nn.Conv2D(
+                in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)
 
     def forward(self, inputs):
         x = inputs
@@ -234,25 +237,35 @@ class ResNet50Encoder(nn.Layer):
 
 class FarSeg(nn.Layer):
     """
-        The FarSeg implementation based on PaddlePaddle.
+    The FarSeg implementation based on PaddlePaddle.
 
-        The original article refers to
-        Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object 
-            Segmentation in High Spatial Resolution Remote Sensing Imagery"
-        (https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf)
+    The original article refers to
+    Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution 
+        Remote Sensing Imagery"
+    (https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf)
+
+    Args:
+        in_channels (int, optional): Number of bands of the input images. Default: 3.
+        num_classes (int, optional): Number of target classes. Default: 16.
+        fpn_ch_list (list[int]|tuple[int], optional): Channel list of the FPN. Default: (256, 512, 1024, 2048).
+        mid_ch (int, optional): Output channels of the FPN. Default: 256.
+        out_ch (int, optional): Output channels of the decoder. Default: 128.
+        sr_ch_list (list[int]|tuple[int], optional): Channel list of the foreground-scene relation module. Default: (256, 256, 256, 256).
+        pretrained_encoder (bool, optional): Whether to use a pretrained encoder. Default: True.
     """
 
     def __init__(self,
+                 in_channels=3,
                  num_classes=16,
                  fpn_ch_list=(256, 512, 1024, 2048),
                  mid_ch=256,
                  out_ch=128,
                  sr_ch_list=(256, 256, 256, 256),
-                 encoder_pretrained=True):
+                 pretrained_encoder=True):
         super(FarSeg, self).__init__()
-        self.en = ResNet50Encoder(encoder_pretrained)
+        self.en = ResNet50Encoder(in_channels, pretrained_encoder)
         self.fpn = FPN(in_channels_list=fpn_ch_list, out_channels=mid_ch)
-        self.decoder = AssymetricDecoder(
+        self.decoder = AsymmetricDecoder(
             in_channels=mid_ch, out_channels=out_ch)
         self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1)
         self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4)

+ 26 - 19
paddlers/tasks/change_detector.py

@@ -31,7 +31,7 @@ import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image
 from paddlers.utils import get_single_card_bs
-from paddlers.utils.checkpoint import seg_pretrain_weights_dict
+from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferCDNet
@@ -275,7 +275,7 @@ class BaseChangeDetector(BaseModel):
                 exit=True)
         if pretrain_weights is not None and resume_checkpoint is not None:
             logging.error(
-                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
                 exit=True)
         self.labels = train_dataset.labels
         if self.losses is None:
@@ -289,23 +289,30 @@ class BaseChangeDetector(BaseModel):
         else:
             self.optimizer = optimizer
 
-        if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            if pretrain_weights not in seg_pretrain_weights_dict[
-                    self.model_name]:
-                logging.warning(
-                    "Path of pretrain_weights('{}') does not exist!".format(
-                        pretrain_weights))
-                logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                                "If don't want to use pretrain weights, "
-                                "set pretrain_weights to be None.".format(
-                                    seg_pretrain_weights_dict[self.model_name][
-                                        0]))
-                pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
-        elif pretrain_weights is not None and osp.exists(pretrain_weights):
-            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
-                logging.error(
-                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
-                    exit=True)
+        if pretrain_weights is not None:
+            if not osp.exists(pretrain_weights):
+                if self.model_name not in cd_pretrain_weights_dict:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = None
+                elif pretrain_weights not in cd_pretrain_weights_dict[
+                        self.model_name]:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = cd_pretrain_weights_dict[
+                        self.model_name][0]
+                    logging.warning(
+                        "`pretrain_weights` is forcibly set to '{}'. "
+                        "If you don't want to use pretrained weights, "
+                        "please set `pretrain_weights` to None.".format(
+                            pretrain_weights))
+            else:
+                if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                    logging.error(
+                        "Invalid pretrained weights. Please specify a .pdparams file.",
+                        exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
         self.net_initialize(

+ 26 - 19
paddlers/tasks/classifier.py

@@ -246,7 +246,7 @@ class BaseClassifier(BaseModel):
                 exit=True)
         if pretrain_weights is not None and resume_checkpoint is not None:
             logging.error(
-                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
                 exit=True)
         self.labels = train_dataset.labels
         if self.losses is None:
@@ -262,25 +262,32 @@ class BaseClassifier(BaseModel):
         else:
             self.optimizer = optimizer
 
-        if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            if pretrain_weights not in cls_pretrain_weights_dict[
-                    self.model_name]:
-                logging.warning(
-                    "Path of pretrain_weights('{}') does not exist!".format(
-                        pretrain_weights))
-                logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                                "If don't want to use pretrain weights, "
-                                "set pretrain_weights to be None.".format(
-                                    cls_pretrain_weights_dict[self.model_name][
-                                        0]))
-                pretrain_weights = cls_pretrain_weights_dict[self.model_name][0]
-        elif pretrain_weights is not None and osp.exists(pretrain_weights):
-            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
-                logging.error(
-                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
-                    exit=True)
+        if pretrain_weights is not None:
+            if not osp.exists(pretrain_weights):
+                if self.model_name not in cls_pretrain_weights_dict:
+                    logging.warning(
+                        "Path of `pretrain_weights` ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = None
+                elif pretrain_weights not in cls_pretrain_weights_dict[
+                        self.model_name]:
+                    logging.warning(
+                        "Path of `pretrain_weights` ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = cls_pretrain_weights_dict[
+                        self.model_name][0]
+                    logging.warning(
+                        "`pretrain_weights` is forcibly set to '{}'. "
+                        "If you don't want to use pretrained weights, "
+                        "set `pretrain_weights` to None.".format(
+                            pretrain_weights))
+            else:
+                if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                    logging.error(
+                        "Invalid pretrained weights. Please specify a .pdparams file.",
+                        exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
-        is_backbone_weights = False  # pretrain_weights == 'IMAGENET'  # TODO: this is backbone
+        is_backbone_weights = False
         self.net_initialize(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,

+ 24 - 18
paddlers/tasks/object_detector.py

@@ -288,7 +288,7 @@ class BaseDetector(BaseModel):
                 exit=True)
         if pretrain_weights is not None and resume_checkpoint is not None:
             logging.error(
-                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
                 exit=True)
         if train_dataset.__class__.__name__ == 'VOCDetDataset':
             train_dataset.data_fields = {
@@ -337,23 +337,29 @@ class BaseDetector(BaseModel):
             self.optimizer = optimizer
 
         # Initiate weights
-        if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            if pretrain_weights not in det_pretrain_weights_dict['_'.join(
-                [self.model_name, self.backbone_name])]:
-                logging.warning(
-                    "Path of pretrain_weights('{}') does not exist!".format(
-                        pretrain_weights))
-                pretrain_weights = det_pretrain_weights_dict['_'.join(
-                    [self.model_name, self.backbone_name])][0]
-                logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                                "If you don't want to use pretrain weights, "
-                                "set pretrain_weights to be None.".format(
-                                    pretrain_weights))
-        elif pretrain_weights is not None and osp.exists(pretrain_weights):
-            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
-                logging.error(
-                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
-                    exit=True)
+        if pretrain_weights is not None:
+            if not osp.exists(pretrain_weights):
+                key = '_'.join([self.model_name, self.backbone_name])
+                if key not in det_pretrain_weights_dict:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = None
+                elif pretrain_weights not in det_pretrain_weights_dict[key]:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = det_pretrain_weights_dict[key][0]
+                    logging.warning(
+                        "`pretrain_weights` is forcibly set to '{}'. "
+                        "If you don't want to use pretrained weights, "
+                        "please set `pretrain_weights` to None.".format(
+                            pretrain_weights))
+            else:
+                if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                    logging.error(
+                        "Invalid pretrained weights. Please specify a .pdparams file.",
+                        exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
             pretrain_weights=pretrain_weights,

+ 26 - 9
paddlers/tasks/restorer.py

@@ -31,6 +31,7 @@ from paddlers.models import res_losses
 from paddlers.transforms import Resize, decode_image
 from paddlers.transforms.functions import calc_hr_shape
 from paddlers.utils import get_single_card_bs
+from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from .base import BaseModel
 from .utils.res_adapters import GANAdapter, OptimizerAdapter
 from .utils.infer_nets import InferResNet
@@ -234,7 +235,7 @@ class BaseRestorer(BaseModel):
                 exit=True)
         if pretrain_weights is not None and resume_checkpoint is not None:
             logging.error(
-                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
                 exit=True)
 
         if self.losses is None:
@@ -256,14 +257,30 @@ class BaseRestorer(BaseModel):
         else:
             self.optimizer = optimizer
 
-        if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            logging.warning("Path of pretrain_weights('{}') does not exist!".
-                            format(pretrain_weights))
-        elif pretrain_weights is not None and osp.exists(pretrain_weights):
-            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
-                logging.error(
-                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
-                    exit=True)
+        if pretrain_weights is not None:
+            if not osp.exists(pretrain_weights):
+                if self.model_name not in res_pretrain_weights_dict:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = None
+                elif pretrain_weights not in res_pretrain_weights_dict[
+                        self.model_name]:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = res_pretrain_weights_dict[
+                        self.model_name][0]
+                    logging.warning(
+                        "`pretrain_weights` is forcibly set to '{}'. "
+                        "If you don't want to use pretrained weights, "
+                        "please set `pretrain_weights` to None.".format(
+                            pretrain_weights))
+            else:
+                if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                    logging.error(
+                        "Invalid pretrained weights. Please specify a .pdparams file.",
+                        exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
         self.net_initialize(

+ 27 - 18
paddlers/tasks/segmenter.py

@@ -267,7 +267,7 @@ class BaseSegmenter(BaseModel):
                 exit=True)
         if pretrain_weights is not None and resume_checkpoint is not None:
             logging.error(
-                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
                 exit=True)
         self.labels = train_dataset.labels
         if self.losses is None:
@@ -281,23 +281,30 @@ class BaseSegmenter(BaseModel):
         else:
             self.optimizer = optimizer
 
-        if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            if pretrain_weights not in seg_pretrain_weights_dict[
-                    self.model_name]:
-                logging.warning(
-                    "Path of pretrain_weights('{}') does not exist!".format(
-                        pretrain_weights))
-                logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                                "If don't want to use pretrain weights, "
-                                "set pretrain_weights to be None.".format(
-                                    seg_pretrain_weights_dict[self.model_name][
-                                        0]))
-                pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
-        elif pretrain_weights is not None and osp.exists(pretrain_weights):
-            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
-                logging.error(
-                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
-                    exit=True)
+        if pretrain_weights is not None:
+            if not osp.exists(pretrain_weights):
+                if self.model_name not in seg_pretrain_weights_dict:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = None
+                elif pretrain_weights not in seg_pretrain_weights_dict[
+                        self.model_name]:
+                    logging.warning(
+                        "Path of pretrained weights ('{}') does not exist!".
+                        format(pretrain_weights))
+                    pretrain_weights = seg_pretrain_weights_dict[
+                        self.model_name][0]
+                    logging.warning(
+                        "`pretrain_weights` is forcibly set to '{}'. "
+                        "If you don't want to use pretrained weights, "
+                        "please set `pretrain_weights` to None.".format(
+                            pretrain_weights))
+            else:
+                if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                    logging.error(
+                        "Invalid pretrained weights. Please specify a .pdparams file.",
+                        exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
         self.net_initialize(
@@ -909,6 +916,7 @@ class BiSeNetV2(BaseSegmenter):
 
 class FarSeg(BaseSegmenter):
     def __init__(self,
+                 in_channels=3,
                  num_classes=2,
                  use_mixed_loss=False,
                  losses=None,
@@ -918,4 +926,5 @@ class FarSeg(BaseSegmenter):
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             losses=losses,
+            in_channels=in_channels,
             **params)

+ 12 - 8
paddlers/utils/checkpoint.py

@@ -21,20 +21,14 @@ import paddle
 from . import logging
 from .download import download_and_decompress
 
+cd_pretrain_weights_dict = {}
+
 cls_pretrain_weights_dict = {
     'ResNet50_vd': ['IMAGENET'],
     'MobileNetV3_small_x1_0': ['IMAGENET'],
     'HRNet_W18_C': ['IMAGENET'],
 }
 
-seg_pretrain_weights_dict = {
-    'UNet': ['CITYSCAPES'],
-    'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
-    'FastSCNN': ['CITYSCAPES'],
-    'HRNet': ['CITYSCAPES', 'PascalVOC'],
-    'BiSeNetV2': ['CITYSCAPES']
-}
-
 det_pretrain_weights_dict = {
     'PicoDet_ESNet_s': ['COCO', 'IMAGENET'],
     'PicoDet_ESNet_m': ['COCO', 'IMAGENET'],
@@ -74,6 +68,16 @@ det_pretrain_weights_dict = {
     'MaskRCNN_ResNet101_vd_fpn': ['COCO', 'IMAGENET']
 }
 
+res_pretrain_weights_dict = {}
+
+seg_pretrain_weights_dict = {
+    'UNet': ['CITYSCAPES'],
+    'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
+    'FastSCNN': ['CITYSCAPES'],
+    'HRNet': ['CITYSCAPES', 'PascalVOC'],
+    'BiSeNetV2': ['CITYSCAPES']
+}
+
 cityscapes_weights = {
     'UNet_CITYSCAPES':
     'https://bj.bcebos.com/paddleseg/dygraph/cityscapes/unet_cityscapes_1024x512_160k/model.pdparams',

+ 2 - 1
tests/rs_models/test_seg_models.py

@@ -50,7 +50,8 @@ class TestFarSegModel(TestSegModel):
 
     def set_specs(self):
         self.specs = [
-            dict(), dict(num_classes=20), dict(encoder_pretrained=False)
+            dict(), dict(num_classes=20), dict(encoder_pretrained=False),
+            dict(in_channels=10)
         ]
 
     def set_targets(self):

+ 1 - 0
tutorials/train/README.md

@@ -26,6 +26,7 @@
 |object_detection/ppyolov2.py | 目标检测 | PP-YOLOv2 |
 |object_detection/yolov3.py | 目标检测 | YOLOv3 |
 |semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ |
+|semantic_segmentation/farseg.py | 图像分割 | FarSeg |
 |semantic_segmentation/unet.py | 图像分割 | UNet |
 
 ## 环境准备

+ 1 - 1
tutorials/train/semantic_segmentation/deeplabv3p.py

@@ -71,7 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset(
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
 model = pdrs.tasks.seg.DeepLabV3P(
-    input_channel=NUM_BANDS,
+    in_channels=NUM_BANDS,
     num_classes=len(train_dataset.labels),
     backbone='ResNet50_vd')
 

+ 94 - 0
tutorials/train/semantic_segmentation/farseg.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+
+# 图像分割模型FarSeg训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/rsseg/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
+# 数据集类别信息文件路径
+LABEL_LIST_PATH = './data/rsseg/labels.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/farseg/'
+
+# 影像波段数量
+NUM_BANDS = 10
+
+# 下载和解压多光谱地块分类数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = T.Compose([
+    # 读取影像
+    T.DecodeImg(),
+    # 将影像缩放到512x512大小
+    T.Resize(target_size=512),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 将数据归一化到[-1,1]
+    T.Normalize(
+        mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ArrangeSegmenter('train')
+])
+
+eval_transforms = T.Compose([
+    T.DecodeImg(),
+    T.Resize(target_size=512),
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
+    T.ReloadMask(),
+    T.ArrangeSegmenter('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.SegDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    label_list=LABEL_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.SegDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    label_list=LABEL_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 构建FarSeg模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
+model = pdrs.tasks.seg.FarSeg(
+    in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
+
+# 执行模型训练
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=4,
+    save_dir=EXP_DIR,
+    pretrain_weights=None,
+    # 初始学习率大小
+    learning_rate=0.001,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 1 - 1
tutorials/train/semantic_segmentation/unet.py

@@ -71,7 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset(
 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
 # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
 model = pdrs.tasks.seg.UNet(
-    input_channel=NUM_BANDS, num_classes=len(train_dataset.labels))
+    in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
 
 # 执行模型训练
 model.train(