Browse Source

Update rcan.py

Alleviate gradient explosion, but convergence is still difficult
kongdebug 3 years ago
parent
commit
393e570a68
1 changed files with 8 additions and 6 deletions
  1. 8 6
      paddlers/custom_models/gan/generators/rcan.py

+ 8 - 6
paddlers/custom_models/gan/generators/rcan.py

@@ -7,12 +7,14 @@ from .builder import GENERATORS
 
 
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
-    return nn.Conv2D(
-        in_channels,
-        out_channels,
-        kernel_size,
-        padding=(kernel_size // 2),
-        bias_attr=bias)
+    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.XavierUniform(),
+                                   need_clip =True)
+    return nn.Conv2D(in_channels,
+                     out_channels,
+                     kernel_size,
+                     padding=(kernel_size // 2),
+                     weight_attr=weight_attr,
+                     bias_attr=bias)
 
 
 class MeanShift(nn.Conv2D):