diff --git a/src/sdk/pynni/nni/compression/torch/speedup/compressor.py b/src/sdk/pynni/nni/compression/torch/speedup/compressor.py index 4b569d7e4f..b31acfe664 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/compressor.py @@ -3,7 +3,6 @@ import logging import torch -from nni._graph_utils import build_module_graph from nni.compression.torch.utils.mask_conflict import fix_mask_conflict from .compress_modules import replace_module from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape @@ -51,6 +50,8 @@ def __init__(self, model, dummy_input, masks_file, map_location=None): map_location : str the device on which masks are placed, same to map_location in ```torch.load``` """ + from nni._graph_utils import build_module_graph + self.bound_model = model self.masks = torch.load(masks_file, map_location) self.inferred_masks = dict() # key: module_name, value: ModuleMasks diff --git a/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py b/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py index 49aa32b7c9..d89e53d1a7 100644 --- a/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py +++ b/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py @@ -4,8 +4,6 @@ import csv import logging -from nni._graph_utils import TorchModuleGraph - __all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency'] CONV_TYPE = 'aten::_convolution' @@ -19,6 +17,8 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): """ Build the graph for the model. """ + from nni._graph_utils import TorchModuleGraph + # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace