| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import argparse
- from skimage.metrics import structural_similarity as SSIM
- from skimage.metrics import peak_signal_noise_ratio as PSNR
- from skimage import io
- import os
- from osgeo import gdal
- import numpy as np
- import math
- import itertools
- import sys
- import torchvision.transforms as transforms
- from torchvision.utils import save_image, make_grid
- from torch.utils.data import DataLoader
- from torch.autograd import Variable
- import torch.nn as nn
- import torch.nn.functional as F
- from torchvision.models import vgg19
- import glob
- import random
- import torch
- from torch.utils.data import Dataset
- from PIL import Image
- import time
- import datetime
- from torch.utils.tensorboard import SummaryWriter
- import cv2
- from models9 import Generator, Discriminator, FeatureExtractor
- from DEM_parameters import accumation, aspect, slope, hillshade, curvature, RMSE, getdis
- import evaluation_util
- def writeTiff(im_data, im_geotrans, im_proj, path):
- if 'int8' in im_data.dtype.name:
- datatype = gdal.GDT_Byte
- elif 'int16' in im_data.dtype.name:
- datatype = gdal.GDT_UInt16
- else:
- datatype = gdal.GDT_Float32
- if len(im_data.shape) == 3:
- im_bands, im_height, im_width = im_data.shape
- elif len(im_data.shape) == 2:
- im_data = np.array([im_data])
- im_bands, im_height, im_width = im_data.shape
- # 创建文件
- driver = gdal.GetDriverByName("GTiff")
- dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
- if(dataset!= None):
- dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
- dataset.SetProjection(im_proj) # 写入投影
- for i in range(im_bands):
- dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
- del dataset
- # 读取tif数据集
- def readTif(fileName):
- dataset = gdal.Open(fileName)
- if dataset == None:
- print(fileName + "文件无法打开")
- return dataset
- # 像素坐标和地理坐标仿射变换
- def CoordTransf(Xpixel, Ypixel, GeoTransform):
- XGeo = GeoTransform[0]+GeoTransform[1]*Xpixel+Ypixel*GeoTransform[2];
- YGeo = GeoTransform[3]+GeoTransform[4]*Xpixel+Ypixel*GeoTransform[5];
- return XGeo, YGeo
- Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
- dataloader = DataLoader(
- ImageDataset(DATBASE_PATH+"/"+"%s" % dataset_name, hr_shape=hr_shape),
- batch_size=batch_size,
- shuffle=True,
- num_workers=n_cpu,
- )
- writer = SummaryWriter(DATBASE_PATH+'/testlogs')
- for epoch in range(epoch, n_epochs):
- for i, imgs in enumerate(dataloader):
- imgs_lr = Variable(imgs["lr"].type(Tensor))
- imgs_hr = Variable(imgs["hr"].type(Tensor))
- gen_hr = generator(imgs_hr)
- gen_hr = (gen_hr + 1)* (1823 - 567) / 2 +567
- gen_hr1 = gen_hr.detach().cpu().numpy().astype(np.float32)
- imgs_hr = (imgs_hr + 1)* (1823 - 567) / 2 +567
- imgs_hr1 = imgs_hr.detach().cpu().numpy().astype(np.float32)
- imgs_lr = (imgs_lr + 1)* (1823 - 567) / 2 +567
- imgs_lr1 = imgs_lr.detach().cpu().numpy().astype(np.float32)
- # print(gen_hr.shape)
- # io.imshow(gen_hr1[0][0])
- # io.show()
- # io.imshow(imgs_hr1[0][0])
- # io.show()
- # io.imshow(imgs_lr1[0][0])
- # io.show()
- fake_dem1= 1.0*(gen_hr[0].detach().cpu().numpy().astype(np.float32))
- path3=r'D:\\daima\\1channel\\Resresults\\fake12.5'
- new_name1 = len(os.listdir(path3)) + 1
- dataset_img = gdal.Open("D:\\daima\\1channel\\ceshi\\3.tif")
- width = dataset_img.RasterXSize
- height = dataset_img.RasterYSize
- proj = dataset_img.GetProjection()
- geotrans = dataset_img.GetGeoTransform()
- XGeo, YGeo = CoordTransf(0, 0, geotrans)
- crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
- writeTiff(fake_dem1, crop_geotrans, proj, path3 + "/fake%d.tif"%new_name1)
- new_name1 = new_name1 +1
- real_dem1= 1.0*(imgs_hr[0].detach().cpu().numpy().astype(np.float32))
- path4=r'D:\\daima\\1channel\\Resresults\\real12.5'
- new_name = len(os.listdir(path4)) + 1
- writeTiff(real_dem1, crop_geotrans, proj, path4 + "/real%d.tif"%new_name)
- new_name = new_name +1
- # ######### Train discriminator #########
- # a = discriminator(imgs_hr)
- # b = discriminator(Variable(gen_hr.data))
- # discriminator_loss = adversarial_criterion(discriminator(imgs_hr), torch.ones_like(a)) + \
- # adversarial_criterion(discriminator(Variable(gen_hr.data)), torch.zeros_like(b))
- # fake_features = feature_extractor(gen_hr)
- # real_features = feature_extractor(imgs_hr)
- # loss_mse = content_criterion(gen_hr, imgs_lr)
- # loss_mse1 = content_criterion(imgs_lr, imgs_hr)
- # generator_content_loss = content_criterion(gen_hr, imgs_hr) + 0.006 * content_criterion(
- # fake_features, real_features.detach())
- # c = discriminator(gen_hr)
- # generator_adversarial_loss = adversarial_criterion(discriminator(gen_hr), torch.ones_like(c))
- # generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
- # sys.stdout.write(
- # "[Epoch %d/%d] [Batch %d/%d] [lose_mse: %f] ]"
- # % (epoch, n_epochs, i, len(dataloader),loss_mse.item())
- # )
- # batches_done = epoch * len(dataloader) + i
- # if batches_done % sample_interval == 0:
- # gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
- # imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
- # imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
- # writer.add_scalar('loss_mse', loss_mse.item(), batches_done)
- # writer.add_scalar('loss_total', generator_total_loss.item(), batches_done)
- # img_grid = torch.cat((imgs_lr, gen_hr,imgs_hr), -1)
- # save_image(img_grid, DATBASE_PATH+"/results/%d.png" % batches_done, normalize=False)
- # gen_hr[0] = (gen_hr[0] + 1) * (1459 - 767) / 2 + 767 #max:1458.5 min:767.6
- # gen_hr[1] = (gen_hr[:,1,:,:] + 1) * 83 / 2 #max:82.6 min:0
- # gen_hr[:,2,:,:] = ((gen_hr[:,2,:,:] + 1 ) * 361 / 2) - 1 # max: 360 min:-1
- # imgs_hr[:,0,:,:] = (imgs_hr[:,0,:,:] + 1) * (1459 - 767) / 2 + 767
- # imgs_hr[:,1,:,:] = (imgs_hr[:,1,:,:] + 1) * 83 / 2
- # imgs_hr[:,2,:,:] = ((gen_hr[:,2,:,:] + 1 ) * 361 / 2) - 1
- # for j in range(batch_size):
- # lose_rmse = RMSE(gen_hr[j],imgs_hr[j])
- # print(lose_rmse)
- # fake_dem1= 1.0*(gen_hr.detach().cpu().numpy().astype(np.float32))
- # path3=r'D:\\daima\\gisr1\\GANresults\\3fake'
- # new_name1 = len(os.listdir(path3))
- # io.imsave(path3 + "/%d.tif"%new_name1, fake_dem1[j][0])
- # new_name1 = new_name1 +1
- # real_dem1= 1.0*(imgs_hr.detach().cpu().numpy().astype(np.float32))
- # path4=r'D:\\daima\\gisr1\\GANresults\\3real'
- # new_name = len(os.listdir(path4))
- # io.imsave(path4 + "/%d.tif"%new_name, real_dem1[j][0])
- # new_name = new_name +1
|