lesrcnn_train.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import sys
  3. sys.path.append(os.path.abspath('../PaddleRS'))
  4. import paddlers as pdrs
  5. # 定义训练和验证时的transforms
  6. train_transforms = pdrs.datasets.ComposeTrans(
  7. input_keys=['lq', 'gt'],
  8. output_keys=['lq', 'gt'],
  9. pipelines=[{
  10. 'name': 'SRPairedRandomCrop',
  11. 'gt_patch_size': 192,
  12. 'scale': 4
  13. }, {
  14. 'name': 'PairedRandomHorizontalFlip'
  15. }, {
  16. 'name': 'PairedRandomVerticalFlip'
  17. }, {
  18. 'name': 'PairedRandomTransposeHW'
  19. }, {
  20. 'name': 'Transpose'
  21. }, {
  22. 'name': 'Normalize',
  23. 'mean': [0.0, 0.0, 0.0],
  24. 'std': [255.0, 255.0, 255.0]
  25. }])
  26. test_transforms = pdrs.datasets.ComposeTrans(
  27. input_keys=['lq', 'gt'],
  28. output_keys=['lq', 'gt'],
  29. pipelines=[{
  30. 'name': 'Transpose'
  31. }, {
  32. 'name': 'Normalize',
  33. 'mean': [0.0, 0.0, 0.0],
  34. 'std': [255.0, 255.0, 255.0]
  35. }])
  36. # 定义训练集
  37. train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
  38. train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
  39. num_workers = 4
  40. batch_size = 16
  41. scale = 4
  42. train_dataset = pdrs.datasets.SRdataset(
  43. mode='train',
  44. gt_floder=train_gt_floder,
  45. lq_floder=train_lq_floder,
  46. transforms=train_transforms(),
  47. scale=scale,
  48. num_workers=num_workers,
  49. batch_size=batch_size)
  50. # 定义测试集
  51. test_gt_floder = r"../work/RSdata_for_SR/test_HR"
  52. test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
  53. test_dataset = pdrs.datasets.SRdataset(
  54. mode='test',
  55. gt_floder=test_gt_floder,
  56. lq_floder=test_lq_floder,
  57. transforms=test_transforms(),
  58. scale=scale)
  59. # 初始化模型,可以对网络结构的参数进行调整
  60. model = pdrs.tasks.LESRCNNet(scale=4, multi_scale=False, group=1)
  61. model.train(
  62. total_iters=1000000,
  63. train_dataset=train_dataset(),
  64. test_dataset=test_dataset(),
  65. output_dir='output_dir',
  66. validate=5000,
  67. snapshot=5000,
  68. log=100,
  69. lr_rate=0.0001,
  70. periods=[250000, 250000, 250000, 250000],
  71. restart_weights=[1, 1, 1, 1])