Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix _graph_utils import (#2675)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Jul 17, 2020
1 parent 0f33bc7 commit bccda3d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import logging
import torch
from nni._graph_utils import build_module_graph
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
Expand Down Expand Up @@ -51,6 +50,8 @@ def __init__(self, model, dummy_input, masks_file, map_location=None):
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
from nni._graph_utils import build_module_graph

self.bound_model = model
self.masks = torch.load(masks_file, map_location)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
Expand Down
4 changes: 2 additions & 2 deletions src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import csv
import logging

from nni._graph_utils import TorchModuleGraph

__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency']

CONV_TYPE = 'aten::_convolution'
Expand All @@ -19,6 +17,8 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Build the graph for the model.
"""
from nni._graph_utils import TorchModuleGraph

# check if the input is legal
if traced_model is None:
# user should provide model & dummy_input to trace
Expand Down

0 comments on commit bccda3d

Please sign in to comment.