analyze_model.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python
  2. # Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py
  3. import argparse
  4. import os
  5. import os.path as osp
  6. import sys
  7. import paddle
  8. import numpy as np
  9. import paddlers
  10. from paddle.hapi.dynamic_flops import (count_parameters, register_hooks,
  11. count_io_info)
  12. from paddle.hapi.static_flops import Table
  13. _dir = osp.dirname(osp.abspath(__file__))
  14. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  15. import custom_model
  16. import custom_trainer
  17. def parse_args():
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument(
  20. "--model_dir", default=None, type=str, help="Path of saved model.")
  21. parser.add_argument(
  22. "--input_shape",
  23. nargs='+',
  24. type=int,
  25. default=[1, 3, 256, 256],
  26. help="Shape of each input tensor.")
  27. return parser.parse_args()
  28. def analyze(model, inputs, custom_ops=None, print_detail=False):
  29. handler_collection = []
  30. types_collection = set()
  31. if custom_ops is None:
  32. custom_ops = {}
  33. def add_hooks(m):
  34. if len(list(m.children())) > 0:
  35. return
  36. m.register_buffer('total_ops', paddle.zeros([1], dtype='int64'))
  37. m.register_buffer('total_params', paddle.zeros([1], dtype='int64'))
  38. m_type = type(m)
  39. flops_fn = None
  40. if m_type in custom_ops:
  41. flops_fn = custom_ops[m_type]
  42. if m_type not in types_collection:
  43. print("Customized function has been applied to {}".format(
  44. m_type))
  45. elif m_type in register_hooks:
  46. flops_fn = register_hooks[m_type]
  47. if m_type not in types_collection:
  48. print("{}'s FLOPs metric has been counted".format(m_type))
  49. else:
  50. if m_type not in types_collection:
  51. print(
  52. "Cannot find suitable counting function for {}. Treat it as zero FLOPs."
  53. .format(m_type))
  54. if flops_fn is not None:
  55. flops_handler = m.register_forward_post_hook(flops_fn)
  56. handler_collection.append(flops_handler)
  57. params_handler = m.register_forward_post_hook(count_parameters)
  58. io_handler = m.register_forward_post_hook(count_io_info)
  59. handler_collection.append(params_handler)
  60. handler_collection.append(io_handler)
  61. types_collection.add(m_type)
  62. training = model.training
  63. model.eval()
  64. model.apply(add_hooks)
  65. with paddle.framework.no_grad():
  66. model(*inputs)
  67. total_ops = 0
  68. total_params = 0
  69. for m in model.sublayers():
  70. if len(list(m.children())) > 0:
  71. continue
  72. if set(['total_ops', 'total_params', 'input_shape',
  73. 'output_shape']).issubset(set(list(m._buffers.keys()))):
  74. total_ops += m.total_ops
  75. total_params += m.total_params
  76. if training:
  77. model.train()
  78. for handler in handler_collection:
  79. handler.remove()
  80. table = Table(
  81. ["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"])
  82. for n, m in model.named_sublayers():
  83. if len(list(m.children())) > 0:
  84. continue
  85. if set(['total_ops', 'total_params', 'input_shape',
  86. 'output_shape']).issubset(set(list(m._buffers.keys()))):
  87. table.add_row([
  88. m.full_name(), list(m.input_shape.numpy()),
  89. list(m.output_shape.numpy()),
  90. round(float(m.total_params / 1e6), 3),
  91. round(float(m.total_ops / 1e9), 3)
  92. ])
  93. m._buffers.pop("total_ops")
  94. m._buffers.pop("total_params")
  95. m._buffers.pop('input_shape')
  96. m._buffers.pop('output_shape')
  97. if print_detail:
  98. table.print_table()
  99. print('Total FLOPs: {}G Total Params: {}M'.format(
  100. round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3)))
  101. return int(total_ops)
  102. if __name__ == '__main__':
  103. args = parse_args()
  104. # Enforce the use of CPU
  105. paddle.set_device('cpu')
  106. model = paddlers.tasks.load_model(args.model_dir)
  107. net = model.net
  108. # Construct bi-temporal inputs
  109. inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)]
  110. analyze(model.net, inputs)