I wrote a project recently: PaddleSeg: Segmentation of aero remote sensing images using the Transfomer model, The PaddleSeg module was used to train Transfomer semantic segmentation models, and the transfomer mIOU reached 74.50% in the UDD6 data set, compared with 73.18% in the original paper higher 1.32% . The training results are as follows: car: red; road: light blue; vegetation: dark blue; building facade: bright green; building roof: purple; other: burnt green.
%cd /home/aistudio/
import matplotlib.pyplot as plt
from PIL import Image
output = Image.open(r"work/example/Seg/UDD6_result/added_prediction/000161.JPG")
plt.figure(figsize=(18, 12)) # Set window size
plt.imshow(output), plt.axis('off')
In this project, the UAV remote sensing image super-resolution module provided by PaddleRS was used to carry out the real low-quality UAV image data super-resolution, and then the segformer model trained by UDD6 was used to predict, and the low-resolution model was compared with that directly used. Index cannot be calculated because low quality data is not marked. However, human eyes judged that the prediction results after the super-resolution were better. The left side was the artificially labeled label, the middle was the prediction result of low resolution, and the right side was the result after the super resolution reconstruction
img = Image.open(r"work/example/Seg/gt_result/data_05_2_14.png")
lq = Image.open(r"work/example/Seg/lq_result/added_prediction/data_05_2_14.png")
sr = Image.open(r"work/example/Seg/sr_result/added_prediction/data_05_2_14.png")
plt.figure(figsize=(18, 12))
plt.subplot(1,3,1), plt.title('GT')
plt.imshow(img), plt.axis('off')
plt.subplot(1,3,2), plt.title('predict_LR')
plt.imshow(lq), plt.axis('off')
plt.subplot(1,3,3), plt.title('predict_SR')
plt.imshow(sr), plt.axis('off')
plt.show()
Part of the annotated data is shown below
add_lb = Image.open(r"work/example/Seg/gt_result/data_05_2_19.png")
lb = Image.open(r"work/example/Seg/gt_label/data_05_2_19.png")
img = Image.open(r"work/ValData/DJI300/data_05_2_19.png")
plt.figure(figsize=(18, 12))
plt.subplot(1,3,1), plt.title('image')
plt.imshow(img), plt.axis('off')
plt.subplot(1,3,2), plt.title('label')
plt.imshow(lb), plt.axis('off')
plt.subplot(1,3,3), plt.title('add_label')
plt.imshow(add_lb), plt.axis('off')
plt.show()
Since PaddleRS provides a pre-trained super-resolution model, this step is mainly divided into the following two steps:
The super-resolution prediction interface in PaddleRS was called to carry out the super-resolution reconstruction for the low resolution UAV image
# Clone the repository from github
!git clone https://github.com/PaddlePaddle/PaddleRS.git
# Install dependency, about a minute or so
%cd PaddleRS/
!pip install -r requirements.txt
# For image super-resolution processing, the model used is DRN
import os
import paddle
import numpy as np
from PIL import Image
from paddlers.models.ppgan.apps.drn_predictor import DRNPredictor
# The folder where the prediction results are output
output = r'../work/example'
# Low resolution image location to be input
input_dir = r"../work/ValData/DJI300"
paddle.device.set_device("gpu:0") # if cpu, use paddle.device.set_device("cpu")
predictor = DRNPredictor(output) # instantiation
filenames = [f for f in os.listdir(input_dir) if f.endswith('.png')]
for filename in filenames:
imgPath = os.path.join(input_dir, filename)
predictor.run(imgPath) # prediction
The results of super-resolution reconstruction before and after comparison
# visualization
import os
import matplotlib.pyplot as plt
%matplotlib inline
lq_dir = r"../work/ValData/DJI300" # Low resolution image folder
sr_dir = r"../work/example/DRN" # super-resolution image folder
img_list = [f for f in os.listdir(lq_dir) if f.endswith('.png')]
show_num = 3 # How many pairs of images are shown
for i in range(show_num):
lq_box = (100, 100, 175, 175)
sr_box = (400, 400, 700, 700)
filename = img_list[i]
image = Image.open(os.path.join(lq_dir, filename)).crop(lq_box) # Read low resolution images
sr_img = Image.open(os.path.join(sr_dir, filename)).crop(sr_box) # Read super-resolution images
plt.figure(figsize=(12, 8))
plt.subplot(1,2,1), plt.title('Input')
plt.imshow(image), plt.axis('off')
plt.subplot(1,2,2), plt.title('Output')
plt.imshow(sr_img), plt.axis('off')
plt.show()
Firstly, the model is used to predict the low-quality UAV data, and then the image reconstructed by the super-resolution is used to predict. Finally, the prediction effect is compared
%cd ..
# clone PaddleSeg
!git clone https://gitee.com/paddlepaddle/PaddleSeg
# install packages
%cd /home/aistudio/PaddleSeg
!pip install -r requirements.txt
# Low resolution drone images are predicted
!python predict.py \
--config ../work/segformer_b3_UDD.yml \
--model_path ../work/best_model/model.pdparams \
--image_path ../work/ValData/DJI300 \
--save_dir ../work/example/Seg/lq_result
# The image reconstructed by DRN was predicted
!python predict.py \
--config ../work/segformer_b3_UDD.yml \
--model_path ../work/best_model/model.pdparams \
--image_path ../work/example/DRN \
--save_dir ../work/example/Seg/sr_result
Prediction Result
Kind | Color |
---|---|
Others | Burnt green |
Building facade | Bright green |
Road | Light blue |
Vegetation | Dark blue |
Car | Red |
Roof | Purple |
Since only five images are marked, only five images' results are shown, and the remaining prediction results are all in the folder work/example/Seg/
, where the left side is the true value, the middle is the prediction result of low-resolution image, and the right is the prediction result after super-resplution reconstruction
# Show part of prediction result
%cd /home/aistudio/
import matplotlib.pyplot as plt
from PIL import Image
import os
img_dir = r"work/example/Seg/gt_result" # Low resolution image folder
lq_dir = r"work/example/Seg/lq_result/added_prediction"
sr_dir = r"work/example/Seg/sr_result/added_prediction" # Super resolution prediction results image folder
img_list = [f for f in os.listdir(img_dir) if f.endswith('.png') ]
for filename in img_list:
img = Image.open(os.path.join(img_dir, filename))
lq_pred = Image.open(os.path.join(lq_dir, filename))
sr_pred = Image.open(os.path.join(sr_dir, filename))
plt.figure(figsize=(12, 8))
plt.subplot(1,3,1), plt.title('GT')
plt.imshow(img), plt.axis('off')
plt.subplot(1,3,2), plt.title('LR_pred')
plt.imshow(lq_pred), plt.axis('off')
plt.subplot(1,3,3), plt.title('SR_pred')
plt.imshow(sr_pred), plt.axis('off')
plt.show()