analyze_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #!/usr/bin/env python
  2. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py
  16. import argparse
  17. import os
  18. import os.path as osp
  19. import sys
  20. import paddle
  21. import numpy as np
  22. import paddlers
  23. from paddle.hapi.dynamic_flops import (count_parameters, register_hooks,
  24. count_io_info)
  25. from paddle.hapi.static_flops import Table
  26. _dir = osp.dirname(osp.abspath(__file__))
  27. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  28. import bootstrap
  29. def parse_args():
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument(
  32. "--model_dir", default=None, type=str, help="Path of saved model.")
  33. parser.add_argument(
  34. "--input_shape",
  35. nargs='+',
  36. type=int,
  37. default=[1, 3, 256, 256],
  38. help="Shape of each input tensor.")
  39. return parser.parse_args()
  40. def analyze(model, inputs, custom_ops=None, print_detail=False):
  41. handler_collection = []
  42. types_collection = set()
  43. if custom_ops is None:
  44. custom_ops = {}
  45. def add_hooks(m):
  46. if len(list(m.children())) > 0:
  47. return
  48. m.register_buffer('total_ops', paddle.zeros([1], dtype='int64'))
  49. m.register_buffer('total_params', paddle.zeros([1], dtype='int64'))
  50. m_type = type(m)
  51. flops_fn = None
  52. if m_type in custom_ops:
  53. flops_fn = custom_ops[m_type]
  54. if m_type not in types_collection:
  55. print("Customized function has been applied to {}".format(
  56. m_type))
  57. elif m_type in register_hooks:
  58. flops_fn = register_hooks[m_type]
  59. if m_type not in types_collection:
  60. print("{}'s FLOPs metric has been counted".format(m_type))
  61. else:
  62. if m_type not in types_collection:
  63. print(
  64. "Cannot find suitable counting function for {}. Treat it as zero FLOPs."
  65. .format(m_type))
  66. if flops_fn is not None:
  67. flops_handler = m.register_forward_post_hook(flops_fn)
  68. handler_collection.append(flops_handler)
  69. params_handler = m.register_forward_post_hook(count_parameters)
  70. io_handler = m.register_forward_post_hook(count_io_info)
  71. handler_collection.append(params_handler)
  72. handler_collection.append(io_handler)
  73. types_collection.add(m_type)
  74. training = model.training
  75. model.eval()
  76. model.apply(add_hooks)
  77. with paddle.framework.no_grad():
  78. model(*inputs)
  79. total_ops = 0
  80. total_params = 0
  81. for m in model.sublayers():
  82. if len(list(m.children())) > 0:
  83. continue
  84. if set(['total_ops', 'total_params', 'input_shape',
  85. 'output_shape']).issubset(set(list(m._buffers.keys()))):
  86. total_ops += m.total_ops
  87. total_params += m.total_params
  88. if training:
  89. model.train()
  90. for handler in handler_collection:
  91. handler.remove()
  92. table = Table(
  93. ["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"])
  94. for n, m in model.named_sublayers():
  95. if len(list(m.children())) > 0:
  96. continue
  97. if set(['total_ops', 'total_params', 'input_shape',
  98. 'output_shape']).issubset(set(list(m._buffers.keys()))):
  99. table.add_row([
  100. m.full_name(), list(m.input_shape.numpy()),
  101. list(m.output_shape.numpy()),
  102. round(float(m.total_params / 1e6), 3),
  103. round(float(m.total_ops / 1e9), 3)
  104. ])
  105. m._buffers.pop("total_ops")
  106. m._buffers.pop("total_params")
  107. m._buffers.pop('input_shape')
  108. m._buffers.pop('output_shape')
  109. if print_detail:
  110. table.print_table()
  111. print('Total FLOPs: {}G Total Params: {}M'.format(
  112. round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3)))
  113. return int(total_ops)
  114. if __name__ == '__main__':
  115. args = parse_args()
  116. # Enforce the use of CPU
  117. paddle.set_device('cpu')
  118. model = paddlers.tasks.load_model(args.model_dir)
  119. net = model.net
  120. # Construct bi-temporal inputs
  121. inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)]
  122. analyze(model.net, inputs)