Procházet zdrojové kódy

Update rcan.py

Alleviate gradient explosion, but convergence is still difficult
kongdebug před 3 roky
rodič
revize
393e570a68
1 změnil soubory, kde provedl 8 přidání a 6 odebrání
  1. 8 6
      paddlers/custom_models/gan/generators/rcan.py

+ 8 - 6
paddlers/custom_models/gan/generators/rcan.py

@@ -7,12 +7,14 @@ from .builder import GENERATORS
 
 
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
-    return nn.Conv2D(
-        in_channels,
-        out_channels,
-        kernel_size,
-        padding=(kernel_size // 2),
-        bias_attr=bias)
+    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.XavierUniform(),
+                                   need_clip =True)
+    return nn.Conv2D(in_channels,
+                     out_channels,
+                     kernel_size,
+                     padding=(kernel_size // 2),
+                     weight_attr=weight_attr,
+                     bias_attr=bias)
 
 
 class MeanShift(nn.Conv2D):