esrgan_train.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import sys
  3. sys.path.append(os.path.abspath('../PaddleRS'))
  4. import paddlers as pdrs
  5. # 定义训练和验证时的transforms
  6. train_transforms = pdrs.datasets.ComposeTrans(
  7. input_keys=['lq', 'gt'],
  8. output_keys=['lq', 'gt'],
  9. pipelines=[{
  10. 'name': 'SRPairedRandomCrop',
  11. 'gt_patch_size': 128,
  12. 'scale': 4
  13. }, {
  14. 'name': 'PairedRandomHorizontalFlip'
  15. }, {
  16. 'name': 'PairedRandomVerticalFlip'
  17. }, {
  18. 'name': 'PairedRandomTransposeHW'
  19. }, {
  20. 'name': 'Transpose'
  21. }, {
  22. 'name': 'Normalize',
  23. 'mean': [0.0, 0.0, 0.0],
  24. 'std': [255.0, 255.0, 255.0]
  25. }])
  26. test_transforms = pdrs.datasets.ComposeTrans(
  27. input_keys=['lq', 'gt'],
  28. output_keys=['lq', 'gt'],
  29. pipelines=[{
  30. 'name': 'Transpose'
  31. }, {
  32. 'name': 'Normalize',
  33. 'mean': [0.0, 0.0, 0.0],
  34. 'std': [255.0, 255.0, 255.0]
  35. }])
  36. # 定义训练集
  37. train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
  38. train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
  39. num_workers = 6
  40. batch_size = 32
  41. scale = 4
  42. train_dataset = pdrs.datasets.SRdataset(
  43. mode='train',
  44. gt_floder=train_gt_floder,
  45. lq_floder=train_lq_floder,
  46. transforms=train_transforms(),
  47. scale=scale,
  48. num_workers=num_workers,
  49. batch_size=batch_size)
  50. # 定义测试集
  51. test_gt_floder = r"../work/RSdata_for_SR/test_HR"
  52. test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
  53. test_dataset = pdrs.datasets.SRdataset(
  54. mode='test',
  55. gt_floder=test_gt_floder,
  56. lq_floder=test_lq_floder,
  57. transforms=test_transforms(),
  58. scale=scale)
  59. # 初始化模型,可以对网络结构的参数进行调整
  60. # 若loss_type='gan' 使用感知损失、对抗损失和像素损失
  61. # 若loss_type = 'pixel' 只使用像素损失
  62. model = pdrs.tasks.ESRGANet(loss_type='pixel')
  63. model.train(
  64. total_iters=1000000,
  65. train_dataset=train_dataset(),
  66. test_dataset=test_dataset(),
  67. output_dir='output_dir',
  68. validate=5000,
  69. snapshot=5000,
  70. log=100,
  71. lr_rate=0.0001,
  72. periods=[250000, 250000, 250000, 250000],
  73. restart_weights=[1, 1, 1, 1])