drn_train.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import sys
  3. sys.path.append(os.path.abspath('../PaddleRS'))
  4. import paddle
  5. import paddlers as pdrs
  6. # 定义训练和验证时的transforms
  7. train_transforms = pdrs.datasets.ComposeTrans(
  8. input_keys=['lq', 'gt'],
  9. output_keys=['lq', 'lqx2', 'gt'],
  10. pipelines=[{
  11. 'name': 'SRPairedRandomCrop',
  12. 'gt_patch_size': 192,
  13. 'scale': 4,
  14. 'scale_list': True
  15. }, {
  16. 'name': 'PairedRandomHorizontalFlip'
  17. }, {
  18. 'name': 'PairedRandomVerticalFlip'
  19. }, {
  20. 'name': 'PairedRandomTransposeHW'
  21. }, {
  22. 'name': 'Transpose'
  23. }, {
  24. 'name': 'Normalize',
  25. 'mean': [0.0, 0.0, 0.0],
  26. 'std': [1.0, 1.0, 1.0]
  27. }])
  28. test_transforms = pdrs.datasets.ComposeTrans(
  29. input_keys=['lq', 'gt'],
  30. output_keys=['lq', 'gt'],
  31. pipelines=[{
  32. 'name': 'Transpose'
  33. }, {
  34. 'name': 'Normalize',
  35. 'mean': [0.0, 0.0, 0.0],
  36. 'std': [1.0, 1.0, 1.0]
  37. }])
  38. # 定义训练集
  39. train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
  40. train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
  41. num_workers = 4
  42. batch_size = 8
  43. scale = 4
  44. train_dataset = pdrs.datasets.SRdataset(
  45. mode='train',
  46. gt_floder=train_gt_floder,
  47. lq_floder=train_lq_floder,
  48. transforms=train_transforms(),
  49. scale=scale,
  50. num_workers=num_workers,
  51. batch_size=batch_size)
  52. train_dict = train_dataset()
  53. # 定义测试集
  54. test_gt_floder = r"../work/RSdata_for_SR/test_HR"
  55. test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
  56. test_dataset = pdrs.datasets.SRdataset(
  57. mode='test',
  58. gt_floder=test_gt_floder,
  59. lq_floder=test_lq_floder,
  60. transforms=test_transforms(),
  61. scale=scale)
  62. # 初始化模型,可以对网络结构的参数进行调整
  63. model = pdrs.tasks.DRNet(
  64. n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2)
  65. model.train(
  66. total_iters=100000,
  67. train_dataset=train_dataset(),
  68. test_dataset=test_dataset(),
  69. output_dir='output_dir',
  70. validate=5000,
  71. snapshot=5000,
  72. lr_rate=0.0001,
  73. log=10)