#!/usr/bin/env python # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py import argparse import os import os.path as osp import sys import paddle import numpy as np import paddlers from paddle.hapi.dynamic_flops import (count_parameters, register_hooks, count_io_info) from paddle.hapi.static_flops import Table _dir = osp.dirname(osp.abspath(__file__)) sys.path.append(osp.abspath(osp.join(_dir, '../'))) import custom_model import custom_trainer def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_dir", default=None, type=str, help="Path of saved model.") parser.add_argument( "--input_shape", nargs='+', type=int, default=[1, 3, 256, 256], help="Shape of each input tensor.") return parser.parse_args() def analyze(model, inputs, custom_ops=None, print_detail=False): handler_collection = [] types_collection = set() if custom_ops is None: custom_ops = {} def add_hooks(m): if len(list(m.children())) > 0: return m.register_buffer('total_ops', paddle.zeros([1], dtype='int64')) m.register_buffer('total_params', paddle.zeros([1], dtype='int64')) m_type = type(m) flops_fn = None if m_type in custom_ops: flops_fn = custom_ops[m_type] if m_type not in types_collection: print("Customized function has been applied to {}".format( m_type)) elif m_type in register_hooks: flops_fn = register_hooks[m_type] if m_type not in types_collection: print("{}'s FLOPs metric has been counted".format(m_type)) else: if m_type not in types_collection: print( "Cannot find suitable counting function for {}. Treat it as zero FLOPs." .format(m_type)) if flops_fn is not None: flops_handler = m.register_forward_post_hook(flops_fn) handler_collection.append(flops_handler) params_handler = m.register_forward_post_hook(count_parameters) io_handler = m.register_forward_post_hook(count_io_info) handler_collection.append(params_handler) handler_collection.append(io_handler) types_collection.add(m_type) training = model.training model.eval() model.apply(add_hooks) with paddle.framework.no_grad(): model(*inputs) total_ops = 0 total_params = 0 for m in model.sublayers(): if len(list(m.children())) > 0: continue if set(['total_ops', 'total_params', 'input_shape', 'output_shape']).issubset(set(list(m._buffers.keys()))): total_ops += m.total_ops total_params += m.total_params if training: model.train() for handler in handler_collection: handler.remove() table = Table( ["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"]) for n, m in model.named_sublayers(): if len(list(m.children())) > 0: continue if set(['total_ops', 'total_params', 'input_shape', 'output_shape']).issubset(set(list(m._buffers.keys()))): table.add_row([ m.full_name(), list(m.input_shape.numpy()), list(m.output_shape.numpy()), round(float(m.total_params / 1e6), 3), round(float(m.total_ops / 1e9), 3) ]) m._buffers.pop("total_ops") m._buffers.pop("total_params") m._buffers.pop('input_shape') m._buffers.pop('output_shape') if print_detail: table.print_table() print('Total FLOPs: {}G Total Params: {}M'.format( round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3))) return int(total_ops) if __name__ == '__main__': args = parse_args() # Enforce the use of CPU paddle.set_device('cpu') model = paddlers.tasks.load_model(args.model_dir) net = model.net # Construct bi-temporal inputs inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)] analyze(model.net, inputs)