|
@@ -17,9 +17,6 @@ LABEL_LIST_PATH = './data/rsseg/labels.txt'
|
|
|
# 实验目录,保存输出的模型权重和结果
|
|
|
EXP_DIR = './output/farseg/'
|
|
|
|
|
|
-# 影像波段数量
|
|
|
-NUM_BANDS = 10
|
|
|
-
|
|
|
# 下载和解压多光谱地块分类数据集
|
|
|
pdrs.utils.download_and_decompress(
|
|
|
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
|
|
@@ -30,22 +27,26 @@ pdrs.utils.download_and_decompress(
|
|
|
train_transforms = T.Compose([
|
|
|
# 读取影像
|
|
|
T.DecodeImg(),
|
|
|
+ # 选择前三个波段
|
|
|
+ T.SelectBand([1, 2, 3]),
|
|
|
# 将影像缩放到512x512大小
|
|
|
T.Resize(target_size=512),
|
|
|
# 以50%的概率实施随机水平翻转
|
|
|
T.RandomHorizontalFlip(prob=0.5),
|
|
|
# 将数据归一化到[-1,1]
|
|
|
T.Normalize(
|
|
|
- mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
|
|
|
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
|
T.ArrangeSegmenter('train')
|
|
|
])
|
|
|
|
|
|
eval_transforms = T.Compose([
|
|
|
T.DecodeImg(),
|
|
|
+ # 验证阶段与训练阶段应当选择相同的波段
|
|
|
+ T.SelectBand([1, 2, 3]),
|
|
|
T.Resize(target_size=512),
|
|
|
# 验证阶段与训练阶段的数据归一化方式必须相同
|
|
|
T.Normalize(
|
|
|
- mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
|
|
|
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
|
T.ReloadMask(),
|
|
|
T.ArrangeSegmenter('eval')
|
|
|
])
|
|
@@ -70,8 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset(
|
|
|
# 构建FarSeg模型
|
|
|
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
|
|
|
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
|
|
|
-model = pdrs.tasks.seg.FarSeg(
|
|
|
- in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
|
|
|
+model = pdrs.tasks.seg.FarSeg(num_classes=len(train_dataset.labels))
|
|
|
|
|
|
# 执行模型训练
|
|
|
model.train(
|