diff --git a/tools/model/analyze_model.py b/tools/model/analyze_model.py index e273d7008f..16414dec26 100644 --- a/tools/model/analyze_model.py +++ b/tools/model/analyze_model.py @@ -21,14 +21,14 @@ import os import sys -import paddle import numpy as np +import paddle +from paddle.hapi.dynamic_flops import (count_io_info, count_parameters, + register_hooks) +from paddle.hapi.static_flops import Table -from paddleseg.cvlibs import Config +from paddleseg.cvlibs import Config, SegBuilder from paddleseg.utils import get_sys_env, logger, op_flops_funs -from paddle.hapi.dynamic_flops import (count_parameters, register_hooks, - count_io_info) -from paddle.hapi.static_flops import Table def parse_args(): @@ -140,10 +140,11 @@ def analyze(args): paddle.set_device('cpu') cfg = Config(args.config) + builder = SegBuilder(cfg) custom_ops = {paddle.nn.SyncBatchNorm: op_flops_funs.count_syncbn} inputs = paddle.randn(args.input_shape) - _dynamic_flops(cfg.model, inputs, custom_ops=custom_ops, print_detail=True) + _dynamic_flops(builder.model, inputs, custom_ops=custom_ops, print_detail=True) if __name__ == '__main__':