|
@@ -7,12 +7,14 @@ from .builder import GENERATORS
|
|
|
|
|
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
|
|
- return nn.Conv2D(
|
|
|
- in_channels,
|
|
|
- out_channels,
|
|
|
- kernel_size,
|
|
|
- padding=(kernel_size // 2),
|
|
|
- bias_attr=bias)
|
|
|
+ weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.XavierUniform(),
|
|
|
+ need_clip =True)
|
|
|
+ return nn.Conv2D(in_channels,
|
|
|
+ out_channels,
|
|
|
+ kernel_size,
|
|
|
+ padding=(kernel_size // 2),
|
|
|
+ weight_attr=weight_attr,
|
|
|
+ bias_attr=bias)
|
|
|
|
|
|
|
|
|
class MeanShift(nn.Conv2D):
|