rcan.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Based on https://github.com/kongdebug/RCAN-Paddle
  2. import math
  3. import paddle
  4. import paddle.nn as nn
  5. from .builder import GENERATORS
  6. def default_conv(in_channels, out_channels, kernel_size, bias=True):
  7. weight_attr = paddle.ParamAttr(
  8. initializer=paddle.nn.initializer.XavierUniform(), need_clip=True)
  9. return nn.Conv2D(
  10. in_channels,
  11. out_channels,
  12. kernel_size,
  13. padding=(kernel_size // 2),
  14. weight_attr=weight_attr,
  15. bias_attr=bias)
  16. class MeanShift(nn.Conv2D):
  17. def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
  18. super(MeanShift, self).__init__(3, 3, kernel_size=1)
  19. std = paddle.to_tensor(rgb_std)
  20. self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1]))
  21. self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1])))
  22. mean = paddle.to_tensor(rgb_mean)
  23. self.bias.set_value(sign * rgb_range * mean / std)
  24. self.weight.trainable = False
  25. self.bias.trainable = False
  26. ## Channel Attention (CA) Layer
  27. class CALayer(nn.Layer):
  28. def __init__(self, channel, reduction=16):
  29. super(CALayer, self).__init__()
  30. # Global average pooling: feature --> point
  31. self.avg_pool = nn.AdaptiveAvgPool2D(1)
  32. # Feature channel downscale and upscale --> channel weight
  33. self.conv_du = nn.Sequential(
  34. nn.Conv2D(
  35. channel, channel // reduction, 1, padding=0, bias_attr=True),
  36. nn.ReLU(),
  37. nn.Conv2D(
  38. channel // reduction, channel, 1, padding=0, bias_attr=True),
  39. nn.Sigmoid())
  40. def forward(self, x):
  41. y = self.avg_pool(x)
  42. y = self.conv_du(y)
  43. return x * y
  44. class RCAB(nn.Layer):
  45. def __init__(self,
  46. conv,
  47. n_feat,
  48. kernel_size,
  49. reduction=16,
  50. bias=True,
  51. bn=False,
  52. act=nn.ReLU(),
  53. res_scale=1):
  54. super(RCAB, self).__init__()
  55. modules_body = []
  56. for i in range(2):
  57. modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
  58. if bn: modules_body.append(nn.BatchNorm2D(n_feat))
  59. if i == 0: modules_body.append(act)
  60. modules_body.append(CALayer(n_feat, reduction))
  61. self.body = nn.Sequential(*modules_body)
  62. self.res_scale = res_scale
  63. def forward(self, x):
  64. res = self.body(x)
  65. res += x
  66. return res
  67. ## Residual Group (RG)
  68. class ResidualGroup(nn.Layer):
  69. def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale,
  70. n_resblocks):
  71. super(ResidualGroup, self).__init__()
  72. modules_body = []
  73. modules_body = [
  74. RCAB(
  75. conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
  76. for _ in range(n_resblocks)]
  77. modules_body.append(conv(n_feat, n_feat, kernel_size))
  78. self.body = nn.Sequential(*modules_body)
  79. def forward(self, x):
  80. res = self.body(x)
  81. res += x
  82. return res
  83. class Upsampler(nn.Sequential):
  84. def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
  85. m = []
  86. if (scale & (scale - 1)) == 0: # Is scale = 2^n?
  87. for _ in range(int(math.log(scale, 2))):
  88. m.append(conv(n_feats, 4 * n_feats, 3, bias))
  89. m.append(nn.PixelShuffle(2))
  90. if bn: m.append(nn.BatchNorm2D(n_feats))
  91. if act == 'relu':
  92. m.append(nn.ReLU())
  93. elif act == 'prelu':
  94. m.append(nn.PReLU(n_feats))
  95. elif scale == 3:
  96. m.append(conv(n_feats, 9 * n_feats, 3, bias))
  97. m.append(nn.PixelShuffle(3))
  98. if bn: m.append(nn.BatchNorm2D(n_feats))
  99. if act == 'relu':
  100. m.append(nn.ReLU())
  101. elif act == 'prelu':
  102. m.append(nn.PReLU(n_feats))
  103. else:
  104. raise NotImplementedError
  105. super(Upsampler, self).__init__(*m)
  106. @GENERATORS.register()
  107. class RCAN(nn.Layer):
  108. def __init__(
  109. self,
  110. scale,
  111. n_resgroups,
  112. n_resblocks,
  113. n_feats=64,
  114. n_colors=3,
  115. rgb_range=255,
  116. kernel_size=3,
  117. reduction=16,
  118. conv=default_conv, ):
  119. super(RCAN, self).__init__()
  120. self.scale = scale
  121. act = nn.ReLU()
  122. n_resgroups = n_resgroups
  123. n_resblocks = n_resblocks
  124. n_feats = n_feats
  125. kernel_size = kernel_size
  126. reduction = reduction
  127. scale = scale
  128. act = nn.ReLU()
  129. rgb_mean = (0.4488, 0.4371, 0.4040)
  130. rgb_std = (1.0, 1.0, 1.0)
  131. self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
  132. # Define head module
  133. modules_head = [conv(n_colors, n_feats, kernel_size)]
  134. # Define body module
  135. modules_body = [
  136. ResidualGroup(
  137. conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \
  138. for _ in range(n_resgroups)]
  139. modules_body.append(conv(n_feats, n_feats, kernel_size))
  140. # Define tail module
  141. modules_tail = [
  142. Upsampler(
  143. conv, scale, n_feats, act=False),
  144. conv(n_feats, n_colors, kernel_size)
  145. ]
  146. self.head = nn.Sequential(*modules_head)
  147. self.body = nn.Sequential(*modules_body)
  148. self.tail = nn.Sequential(*modules_tail)
  149. self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
  150. def forward(self, x):
  151. x = self.sub_mean(x)
  152. x = self.head(x)
  153. res = self.body(x)
  154. res += x
  155. x = self.tail(res)
  156. x = self.add_mean(x)
  157. return x