前段时间写了个项目:PaddleSeg:使用Transfomer模型对航空遥感图像分割,项目利用PaddleSeg模块训练Transfomer类的语义分割模型,在UDD6数据集中mIOU达到74.50% ,原论文使用DeepLabV3+的mIOU为73.18%, 高1.32% ,训练效果图如下,其中:车辆:红色;道路:浅蓝色;植被:深蓝色;建筑立面:亮绿色;建筑屋顶:紫色;其他:焦绿色
%cd /home/aistudio/
import matplotlib.pyplot as plt
from PIL import Image
output = Image.open(r"work/example/Seg/UDD6_result/added_prediction/000161.JPG")
plt.figure(figsize=(18, 12)) # 设置窗口大小
plt.imshow(output), plt.axis('off')
本项目使用PaddleRS提供的无人机遥感图像超分模块,对真实的低质量无人机影像数据进行超分,然后再使用前段时间用UDD6训练的Segformer模型预测,与直接使用低分辨率模型对比。由于没有对低质量数据进行标注无法计算指标。但人眼判别,超分之后的预测结果更好,左边是人工标注的label,中间是低分辨率的预测结果,右边是超分辨率重建后的结果
img = Image.open(r"work/example/Seg/gt_result/data_05_2_14.png")
lq = Image.open(r"work/example/Seg/lq_result/added_prediction/data_05_2_14.png")
sr = Image.open(r"work/example/Seg/sr_result/added_prediction/data_05_2_14.png")
plt.figure(figsize=(18, 12))
plt.subplot(1,3,1), plt.title('GT')
plt.imshow(img), plt.axis('off')
plt.subplot(1,3,2), plt.title('predict_LR')
plt.imshow(lq), plt.axis('off')
plt.subplot(1,3,3), plt.title('predict_SR')
plt.imshow(sr), plt.axis('off')
plt.show()
部分标注数据展示如下
add_lb = Image.open(r"work/example/Seg/gt_result/data_05_2_19.png")
lb = Image.open(r"work/example/Seg/gt_label/data_05_2_19.png")
img = Image.open(r"work/ValData/DJI300/data_05_2_19.png")
plt.figure(figsize=(18, 12))
plt.subplot(1,3,1), plt.title('image')
plt.imshow(img), plt.axis('off')
plt.subplot(1,3,2), plt.title('label')
plt.imshow(lb), plt.axis('off')
plt.subplot(1,3,3), plt.title('add_label')
plt.imshow(add_lb), plt.axis('off')
plt.show()
因为PaddleRS提供了预训练的超分模型,所以这步主要分为以下两个步骤:
调用PaddleRS中的超分预测接口,对低分辨率无人机影像进行超分重建
# 从github上克隆仓库
!git clone https://github.com/PaddlePaddle/PaddleRS.git
# 安装依赖,大概一分多钟
%cd PaddleRS/
!pip install -r requirements.txt
# 进行图像超分处理,使用的模型为DRN
import os
import paddle
import numpy as np
from PIL import Image
from paddlers.models.ppgan.apps.drn_predictor import DRNPredictor
# 输出预测结果的文件夹
output = r'../work/example'
# 待输入的低分辨率影像位置
input_dir = r"../work/ValData/DJI300"
paddle.device.set_device("gpu:0") # 若是cpu环境,则替换为 paddle.device.set_device("cpu")
predictor = DRNPredictor(output) # 实例化
filenames = [f for f in os.listdir(input_dir) if f.endswith('.png')]
for filename in filenames:
imgPath = os.path.join(input_dir, filename)
predictor.run(imgPath) # 预测
超分重建结果前后对比展示
# 可视化
import os
import matplotlib.pyplot as plt
%matplotlib inline
lq_dir = r"../work/ValData/DJI300" # 低分辨率影像文件夹
sr_dir = r"../work/example/DRN" # 超分辨率影像所在文件夹
img_list = [f for f in os.listdir(lq_dir) if f.endswith('.png')]
show_num = 3 # 展示多少对影像
for i in range(show_num):
lq_box = (100, 100, 175, 175)
sr_box = (400, 400, 700, 700)
filename = img_list[i]
image = Image.open(os.path.join(lq_dir, filename)).crop(lq_box) # 读取低分辨率影像
sr_img = Image.open(os.path.join(sr_dir, filename)).crop(sr_box) # 读取超分辨率影像
plt.figure(figsize=(12, 8))
plt.subplot(1,2,1), plt.title('Input')
plt.imshow(image), plt.axis('off')
plt.subplot(1,2,2), plt.title('Output')
plt.imshow(sr_img), plt.axis('off')
plt.show()
首先用该模型对低质量的无人机数据进行预测,然后再使用超分重建后的图像预测,最后对比一下预测的效果
%cd ..
# clone PaddleSeg的项目
!git clone https://gitee.com/paddlepaddle/PaddleSeg
# 安装依赖
%cd /home/aistudio/PaddleSeg
!pip install -r requirements.txt
# 对低分辨率的无人机影像进行预测
!python predict.py \
--config ../work/segformer_b3_UDD.yml \
--model_path ../work/best_model/model.pdparams \
--image_path ../work/ValData/DJI300 \
--save_dir ../work/example/Seg/lq_result
# 对使用DRN超分重建后的影像进行预测
!python predict.py \
--config ../work/segformer_b3_UDD.yml \
--model_path ../work/best_model/model.pdparams \
--image_path ../work/example/DRN \
--save_dir ../work/example/Seg/sr_result
展示预测结果
种类 | 颜色 |
---|---|
其他 | 焦绿色 |
建筑外立面 | 亮绿色 |
道路 | 淡蓝色 |
植被 | 深蓝色 |
车辆 | 红色 |
屋顶 | 紫色 |
由于只标注了五张图片,所以只展示五张图片的结果,剩下的预测结果均在 work/example/Seg/
文件夹下,其中左边是真值,中间是低分辨率影像预测结果,右边是超分重建后预测结果
# 展示部分预测的结果
%cd /home/aistudio/
import matplotlib.pyplot as plt
from PIL import Image
import os
img_dir = r"work/example/Seg/gt_result" # 低分辨率影像文件夹
lq_dir = r"work/example/Seg/lq_result/added_prediction"
sr_dir = r"work/example/Seg/sr_result/added_prediction" # 超分辨率预测的结果影像所在文件夹
img_list = [f for f in os.listdir(img_dir) if f.endswith('.png') ]
for filename in img_list:
img = Image.open(os.path.join(img_dir, filename))
lq_pred = Image.open(os.path.join(lq_dir, filename))
sr_pred = Image.open(os.path.join(sr_dir, filename))
plt.figure(figsize=(12, 8))
plt.subplot(1,3,1), plt.title('GT')
plt.imshow(img), plt.axis('off')
plt.subplot(1,3,2), plt.title('LR_pred')
plt.imshow(lq_pred), plt.axis('off')
plt.subplot(1,3,3), plt.title('SR_pred')
plt.imshow(sr_pred), plt.axis('off')
plt.show()