Rrestest copy.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import argparse
  2. from skimage.metrics import structural_similarity as SSIM
  3. from skimage.metrics import peak_signal_noise_ratio as PSNR
  4. from skimage import io
  5. import os
  6. from osgeo import gdal
  7. import numpy as np
  8. import math
  9. import itertools
  10. import sys
  11. import torchvision.transforms as transforms
  12. from torchvision.utils import save_image, make_grid
  13. from torch.utils.data import DataLoader
  14. from torch.autograd import Variable
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from torchvision.models import vgg19
  18. import glob
  19. import random
  20. import torch
  21. from torch.utils.data import Dataset
  22. from PIL import Image
  23. import time
  24. import datetime
  25. from torch.utils.tensorboard import SummaryWriter
  26. import cv2
  27. from models9 import Generator, Discriminator, FeatureExtractor
  28. from DEM_parameters import accumation, aspect, slope, hillshade, curvature, RMSE, getdis
  29. import evaluation_util
  30. def writeTiff(im_data, im_geotrans, im_proj, path):
  31. if 'int8' in im_data.dtype.name:
  32. datatype = gdal.GDT_Byte
  33. elif 'int16' in im_data.dtype.name:
  34. datatype = gdal.GDT_UInt16
  35. else:
  36. datatype = gdal.GDT_Float32
  37. if len(im_data.shape) == 3:
  38. im_bands, im_height, im_width = im_data.shape
  39. elif len(im_data.shape) == 2:
  40. im_data = np.array([im_data])
  41. im_bands, im_height, im_width = im_data.shape
  42. # 创建文件
  43. driver = gdal.GetDriverByName("GTiff")
  44. dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
  45. if(dataset!= None):
  46. dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
  47. dataset.SetProjection(im_proj) # 写入投影
  48. for i in range(im_bands):
  49. dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
  50. del dataset
  51. # 读取tif数据集
  52. def readTif(fileName):
  53. dataset = gdal.Open(fileName)
  54. if dataset == None:
  55. print(fileName + "文件无法打开")
  56. return dataset
  57. # 像素坐标和地理坐标仿射变换
  58. def CoordTransf(Xpixel, Ypixel, GeoTransform):
  59. XGeo = GeoTransform[0]+GeoTransform[1]*Xpixel+Ypixel*GeoTransform[2];
  60. YGeo = GeoTransform[3]+GeoTransform[4]*Xpixel+Ypixel*GeoTransform[5];
  61. return XGeo, YGeo
  62. Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
  63. dataloader = DataLoader(
  64. ImageDataset(DATBASE_PATH+"/"+"%s" % dataset_name, hr_shape=hr_shape),
  65. batch_size=batch_size,
  66. shuffle=True,
  67. num_workers=n_cpu,
  68. )
  69. writer = SummaryWriter(DATBASE_PATH+'/testlogs')
  70. for epoch in range(epoch, n_epochs):
  71. for i, imgs in enumerate(dataloader):
  72. imgs_lr = Variable(imgs["lr"].type(Tensor))
  73. imgs_hr = Variable(imgs["hr"].type(Tensor))
  74. gen_hr = generator(imgs_hr)
  75. gen_hr = (gen_hr + 1)* (1823 - 567) / 2 +567
  76. gen_hr1 = gen_hr.detach().cpu().numpy().astype(np.float32)
  77. imgs_hr = (imgs_hr + 1)* (1823 - 567) / 2 +567
  78. imgs_hr1 = imgs_hr.detach().cpu().numpy().astype(np.float32)
  79. imgs_lr = (imgs_lr + 1)* (1823 - 567) / 2 +567
  80. imgs_lr1 = imgs_lr.detach().cpu().numpy().astype(np.float32)
  81. # print(gen_hr.shape)
  82. # io.imshow(gen_hr1[0][0])
  83. # io.show()
  84. # io.imshow(imgs_hr1[0][0])
  85. # io.show()
  86. # io.imshow(imgs_lr1[0][0])
  87. # io.show()
  88. fake_dem1= 1.0*(gen_hr[0].detach().cpu().numpy().astype(np.float32))
  89. path3=r'D:\\daima\\1channel\\Resresults\\fake12.5'
  90. new_name1 = len(os.listdir(path3)) + 1
  91. dataset_img = gdal.Open("D:\\daima\\1channel\\ceshi\\3.tif")
  92. width = dataset_img.RasterXSize
  93. height = dataset_img.RasterYSize
  94. proj = dataset_img.GetProjection()
  95. geotrans = dataset_img.GetGeoTransform()
  96. XGeo, YGeo = CoordTransf(0, 0, geotrans)
  97. crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
  98. writeTiff(fake_dem1, crop_geotrans, proj, path3 + "/fake%d.tif"%new_name1)
  99. new_name1 = new_name1 +1
  100. real_dem1= 1.0*(imgs_hr[0].detach().cpu().numpy().astype(np.float32))
  101. path4=r'D:\\daima\\1channel\\Resresults\\real12.5'
  102. new_name = len(os.listdir(path4)) + 1
  103. writeTiff(real_dem1, crop_geotrans, proj, path4 + "/real%d.tif"%new_name)
  104. new_name = new_name +1
  105. # ######### Train discriminator #########
  106. # a = discriminator(imgs_hr)
  107. # b = discriminator(Variable(gen_hr.data))
  108. # discriminator_loss = adversarial_criterion(discriminator(imgs_hr), torch.ones_like(a)) + \
  109. # adversarial_criterion(discriminator(Variable(gen_hr.data)), torch.zeros_like(b))
  110. # fake_features = feature_extractor(gen_hr)
  111. # real_features = feature_extractor(imgs_hr)
  112. # loss_mse = content_criterion(gen_hr, imgs_lr)
  113. # loss_mse1 = content_criterion(imgs_lr, imgs_hr)
  114. # generator_content_loss = content_criterion(gen_hr, imgs_hr) + 0.006 * content_criterion(
  115. # fake_features, real_features.detach())
  116. # c = discriminator(gen_hr)
  117. # generator_adversarial_loss = adversarial_criterion(discriminator(gen_hr), torch.ones_like(c))
  118. # generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
  119. # sys.stdout.write(
  120. # "[Epoch %d/%d] [Batch %d/%d] [lose_mse: %f] ]"
  121. # % (epoch, n_epochs, i, len(dataloader),loss_mse.item())
  122. # )
  123. # batches_done = epoch * len(dataloader) + i
  124. # if batches_done % sample_interval == 0:
  125. # gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
  126. # imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
  127. # imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
  128. # writer.add_scalar('loss_mse', loss_mse.item(), batches_done)
  129. # writer.add_scalar('loss_total', generator_total_loss.item(), batches_done)
  130. # img_grid = torch.cat((imgs_lr, gen_hr,imgs_hr), -1)
  131. # save_image(img_grid, DATBASE_PATH+"/results/%d.png" % batches_done, normalize=False)
  132. # gen_hr[0] = (gen_hr[0] + 1) * (1459 - 767) / 2 + 767 #max:1458.5 min:767.6
  133. # gen_hr[1] = (gen_hr[:,1,:,:] + 1) * 83 / 2 #max:82.6 min:0
  134. # gen_hr[:,2,:,:] = ((gen_hr[:,2,:,:] + 1 ) * 361 / 2) - 1 # max: 360 min:-1
  135. # imgs_hr[:,0,:,:] = (imgs_hr[:,0,:,:] + 1) * (1459 - 767) / 2 + 767
  136. # imgs_hr[:,1,:,:] = (imgs_hr[:,1,:,:] + 1) * 83 / 2
  137. # imgs_hr[:,2,:,:] = ((gen_hr[:,2,:,:] + 1 ) * 361 / 2) - 1
  138. # for j in range(batch_size):
  139. # lose_rmse = RMSE(gen_hr[j],imgs_hr[j])
  140. # print(lose_rmse)
  141. # fake_dem1= 1.0*(gen_hr.detach().cpu().numpy().astype(np.float32))
  142. # path3=r'D:\\daima\\gisr1\\GANresults\\3fake'
  143. # new_name1 = len(os.listdir(path3))
  144. # io.imsave(path3 + "/%d.tif"%new_name1, fake_dem1[j][0])
  145. # new_name1 = new_name1 +1
  146. # real_dem1= 1.0*(imgs_hr.detach().cpu().numpy().astype(np.float32))
  147. # path4=r'D:\\daima\\gisr1\\GANresults\\3real'
  148. # new_name = len(os.listdir(path4))
  149. # io.imsave(path4 + "/%d.tif"%new_name, real_dem1[j][0])
  150. # new_name = new_name +1