param_init.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle.nn as nn
  15. import paddle.nn.functional as F
  16. def normal_init(param, *args, **kwargs):
  17. """
  18. Initialize parameters with a normal distribution.
  19. Args:
  20. param (Tensor): The tensor that needs to be initialized.
  21. Returns:
  22. The initialized parameters.
  23. """
  24. return nn.initializer.Normal(*args, **kwargs)(param)
  25. def kaiming_normal_init(param, *args, **kwargs):
  26. """
  27. Initialize parameters with the Kaiming normal distribution.
  28. For more information about the Kaiming initialization method, please refer to
  29. https://arxiv.org/abs/1502.01852
  30. Args:
  31. param (Tensor): The tensor that needs to be initialized.
  32. Returns:
  33. The initialized parameters.
  34. """
  35. return nn.initializer.KaimingNormal(*args, **kwargs)(param)
  36. def constant_init(param, *args, **kwargs):
  37. """
  38. Initialize parameters with constants.
  39. Args:
  40. param (Tensor): The tensor that needs to be initialized.
  41. Returns:
  42. The initialized parameters.
  43. """
  44. return nn.initializer.Constant(*args, **kwargs)(param)
  45. class KaimingInitMixin:
  46. """
  47. A mix-in that provides the Kaiming initialization functionality.
  48. Examples:
  49. from paddlers.custom_models.cd.models.param_init import KaimingInitMixin
  50. class CustomNet(nn.Layer, KaimingInitMixin):
  51. def __init__(self, num_channels, num_classes):
  52. super().__init__()
  53. self.conv = nn.Conv2D(num_channels, num_classes, 3, 1, 0, bias_attr=False)
  54. self.bn = nn.BatchNorm2D(num_classes)
  55. self.init_weight()
  56. """
  57. def init_weight(self):
  58. for layer in self.sublayers():
  59. if isinstance(layer, nn.Conv2D):
  60. kaiming_normal_init(layer.weight)
  61. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  62. constant_init(layer.weight, value=1)
  63. constant_init(layer.bias, value=0)