analyze_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #!/usr/bin/env python
  2. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py
  16. import argparse
  17. import os
  18. import os.path as osp
  19. import sys
  20. import paddle
  21. import numpy as np
  22. import paddlers
  23. from paddle.hapi.dynamic_flops import (count_parameters, register_hooks,
  24. count_io_info)
  25. from paddle.hapi.static_flops import Table
  26. _dir = osp.dirname(osp.abspath(__file__))
  27. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  28. import custom_model
  29. import custom_trainer
  30. def parse_args():
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument(
  33. "--model_dir", default=None, type=str, help="Path of saved model.")
  34. parser.add_argument(
  35. "--input_shape",
  36. nargs='+',
  37. type=int,
  38. default=[1, 3, 256, 256],
  39. help="Shape of each input tensor.")
  40. return parser.parse_args()
  41. def analyze(model, inputs, custom_ops=None, print_detail=False):
  42. handler_collection = []
  43. types_collection = set()
  44. if custom_ops is None:
  45. custom_ops = {}
  46. def add_hooks(m):
  47. if len(list(m.children())) > 0:
  48. return
  49. m.register_buffer('total_ops', paddle.zeros([1], dtype='int64'))
  50. m.register_buffer('total_params', paddle.zeros([1], dtype='int64'))
  51. m_type = type(m)
  52. flops_fn = None
  53. if m_type in custom_ops:
  54. flops_fn = custom_ops[m_type]
  55. if m_type not in types_collection:
  56. print("Customized function has been applied to {}".format(
  57. m_type))
  58. elif m_type in register_hooks:
  59. flops_fn = register_hooks[m_type]
  60. if m_type not in types_collection:
  61. print("{}'s FLOPs metric has been counted".format(m_type))
  62. else:
  63. if m_type not in types_collection:
  64. print(
  65. "Cannot find suitable counting function for {}. Treat it as zero FLOPs."
  66. .format(m_type))
  67. if flops_fn is not None:
  68. flops_handler = m.register_forward_post_hook(flops_fn)
  69. handler_collection.append(flops_handler)
  70. params_handler = m.register_forward_post_hook(count_parameters)
  71. io_handler = m.register_forward_post_hook(count_io_info)
  72. handler_collection.append(params_handler)
  73. handler_collection.append(io_handler)
  74. types_collection.add(m_type)
  75. training = model.training
  76. model.eval()
  77. model.apply(add_hooks)
  78. with paddle.framework.no_grad():
  79. model(*inputs)
  80. total_ops = 0
  81. total_params = 0
  82. for m in model.sublayers():
  83. if len(list(m.children())) > 0:
  84. continue
  85. if set(['total_ops', 'total_params', 'input_shape',
  86. 'output_shape']).issubset(set(list(m._buffers.keys()))):
  87. total_ops += m.total_ops
  88. total_params += m.total_params
  89. if training:
  90. model.train()
  91. for handler in handler_collection:
  92. handler.remove()
  93. table = Table(
  94. ["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"])
  95. for n, m in model.named_sublayers():
  96. if len(list(m.children())) > 0:
  97. continue
  98. if set(['total_ops', 'total_params', 'input_shape',
  99. 'output_shape']).issubset(set(list(m._buffers.keys()))):
  100. table.add_row([
  101. m.full_name(), list(m.input_shape.numpy()),
  102. list(m.output_shape.numpy()),
  103. round(float(m.total_params / 1e6), 3),
  104. round(float(m.total_ops / 1e9), 3)
  105. ])
  106. m._buffers.pop("total_ops")
  107. m._buffers.pop("total_params")
  108. m._buffers.pop('input_shape')
  109. m._buffers.pop('output_shape')
  110. if print_detail:
  111. table.print_table()
  112. print('Total FLOPs: {}G Total Params: {}M'.format(
  113. round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3)))
  114. return int(total_ops)
  115. if __name__ == '__main__':
  116. args = parse_args()
  117. # Enforce the use of CPU
  118. paddle.set_device('cpu')
  119. model = paddlers.tasks.load_model(args.model_dir)
  120. net = model.net
  121. # Construct bi-temporal inputs
  122. inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)]
  123. analyze(model.net, inputs)