Skip to content

Commit

Permalink
change debug level
Browse files Browse the repository at this point in the history
  • Loading branch information
Archermmt committed Dec 1, 2023
1 parent a23f84c commit e51b1d5
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 54 deletions.
4 changes: 2 additions & 2 deletions python/tvm/contrib/msc/core/tools/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None):
pruned_tensors[out.name] = _prune_by_channel(
out, pruned_tensors[node.input_at(0).name].dim_at("C")
)
if self.on_debug(3):
if self.on_debug(3, in_forward=False):
self._logger.debug(msc_utils.msg_block("Pruned Tensors", pruned_tensors))
pruned_graph = _ffi_api.PruneWeights(graph, pruned_tensors)
new_graphs.append(pruned_graph)
Expand All @@ -404,7 +404,7 @@ def _flatten_size(weights):
raw_size = _flatten_size(weights)
new_size = _flatten_size(new_weights)
self._logger.info(
"{} weights pruned, compress to {:g}% ({:g} M->{:g} M)".format(
"Prune {} weights, compress to {:g}% ({:g} M->{:g} M)".format(
pruned_weights_cnt,
new_size * 100 / raw_size,
raw_size / 2**20,
Expand Down
78 changes: 38 additions & 40 deletions python/tvm/contrib/msc/core/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,20 +432,29 @@ def reset(
self._tensor_cache = {}
if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")):
cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json"))
self.load_cache(cache_dir, cache_info)
else:
cache_info = {}
if self.tool_type() in cache_info:
self.load_cache(cache_dir, cache_info[self.tool_type()])
else:
graphs, weights = self.load_graphs(graphs, weights)
self._graphs, self._weights = graphs, {}
for sub_weights in weights:
self._weights.update(sub_weights)
self._logger.debug(
"%s load %d graphs and %d weights",
self.tool_type().upper(),
self.tool_type(),
len(self._graphs),
len(self._weights),
)
self._reset()
return self._graphs, weights

def _reset(self):
"""Extra reset for tool"""

return None

def change_stage(self, stage: str):
"""Change the stage of tools and strategy"""

Expand Down Expand Up @@ -523,7 +532,8 @@ def execute_before_build(self, *args, **kwargs):
if self._enabled:
self._graph_id = self._infer_graph_id(kwargs)
self._processed_tensor = {}
self._logger.debug("%sStart Build", self.msg_mark(in_forward=False))
if self.on_debug(3, in_forward=False):
self._logger.debug("%sStart Build", self.msg_mark(in_forward=False))
self._execute_before_build(*args, **kwargs)

def _execute_before_build(self, *args, **kwargs):
Expand Down Expand Up @@ -555,7 +565,8 @@ def execute_after_build(self, output: Any) -> Any:

if self._enabled:
output = self._execute_after_build(output)
self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False))
if self.on_debug(3, in_forward=False):
self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False))
return output

def _execute_after_build(self, output: Any) -> Any:
Expand Down Expand Up @@ -588,7 +599,7 @@ def execute_before_forward(self, *args, **kwargs):
if self._enabled:
self._graph_id = self._infer_graph_id(kwargs)
self._processed_tensor = {}
if self.on_debug(2):
if self.on_debug(3):
self._logger.debug("%sStart Forward", self.msg_mark())
self._execute_before_forward(*args, **kwargs)

Expand Down Expand Up @@ -621,7 +632,7 @@ def execute_after_forward(self, output: Any) -> Any:

if self._enabled:
output = self._execute_after_forward(output)
if self.on_debug(2):
if self.on_debug(3):
self._logger.debug(
"%sEnd Forward, process %d tensors",
self.msg_mark(),
Expand Down Expand Up @@ -925,27 +936,34 @@ def is_weight(self, name: str) -> bool:

return name in self._weights

def on_debug(self, debug_level: int = 1) -> bool:
def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool:
"""Check if should log
Parameters
-------
debug_level: int
The given debug_level.
in_forward: bool
Whether to check forward_cnt.
Returns
-------
on_debug: bool
Whether to log debug info.
"""

if self._forward_cnt % self._verbose_step != 0:
if in_forward and self._forward_cnt % self._verbose_step != 0:
return False
return self._debug_level >= debug_level

def msg_mark(self, in_forward=True) -> str:
def msg_mark(self, in_forward: bool = True) -> str:
"""Get the debug title
Parameters
-------
in_forward: bool
Whether to add forward mark.
Returns
-------
msg_mark: str
Expand All @@ -959,7 +977,7 @@ def msg_mark(self, in_forward=True) -> str:
return title

def debug_tensor(
self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 2
self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 3
) -> str:
"""Get the debug tensor info
Expand Down Expand Up @@ -1256,41 +1274,21 @@ def tool_style(cls):
class WeightTool(BaseTool):
"""Basic tool with weight graphs"""

def reset(
self,
graphs: List[MSCGraph],
weights: List[Dict[str, tvm.nd.array]],
cache_dir: msc_utils.MSCDirectory = None,
) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
"""Reset the tool with graphs and weights
def _reset(self):
"""Extra reset for tool"""

Parameters
----------
graphs: list<MSCgraph>
The msc graphs.
weights: list<dict<str, tvm.nd.array>>
The weights
cache_dir: MSCDirectory
cache path for save/load info
Returns
-------
graphs: list<MSCgraph>
The msc graphs.
weights: list<dict<str, tvm.nd.array>>
The weights
"""

graphs, weights = super().reset(graphs, weights, cache_dir)
assert len(graphs) == len(
super()._reset()
assert len(self._graphs) == len(
self._weight_graphs
), "Graphs {} mismatch with weight graphs {}".format(len(graphs), len(self._weight_graphs))
if self.on_debug(3):
), "Graphs {} mismatch with weight graphs {}".format(
len(self._graphs), len(self._weight_graphs)
)
self._logger.debug("%s load %d weight graphs", self.tool_type(), len(self._weight_graphs))
if self.on_debug(2, in_forward=False):
for idx, graph in enumerate(self._weight_graphs):
self._logger.debug(
msc_utils.msg_block("PRUNER.WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect())
msc_utils.msg_block("WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect())
)
return graphs, weights

def load_graphs(
self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/contrib/msc/core/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,4 @@ def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = T
get_dataset_dir = partial(get_workspace_subdir, name="Dataset")
get_output_dir = partial(get_workspace_subdir, name="Output")
get_visual_dir = partial(get_workspace_subdir, name="Visual")
get_weights_dir = partial(get_workspace_subdir, name="Weights")
35 changes: 24 additions & 11 deletions src/contrib/msc/core/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,20 +708,21 @@ const Array<MSCJoint> MSCGraphNode::GetExits() const {
}

const bool MSCGraphNode::HasTensor(const String& name) const {
if (weight_holders.count(name)) {
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
if (weight_holders.count(tensor_name)) {
return true;
}
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
String host, index;
std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":");
return nodes.count(host) > 0 ? true : false;
}

const MSCTensor MSCGraphNode::FindTensor(const String& name) const {
if (weight_holders.count(name)) {
const auto& node = FindNode(weight_holders[name][0]);
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
if (weight_holders.count(tensor_name)) {
const auto& node = FindNode(weight_holders[tensor_name][0]);
for (const auto& pair : node->weights) {
if (pair.second->name == name) {
if (pair.second->name == tensor_name) {
return pair.second;
}
}
Expand All @@ -732,8 +733,9 @@ const MSCTensor MSCGraphNode::FindTensor(const String& name) const {
}

const MSCJoint MSCGraphNode::FindProducer(const String& name) const {
if (weight_holders.count(name)) {
return FindNode(weight_holders[name][0]);
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
if (weight_holders.count(tensor_name)) {
return FindNode(weight_holders[tensor_name][0]);
}
const auto& pair = FindProducerAndIdx(name);
return pair.first;
Expand All @@ -744,8 +746,8 @@ const MSCJoint MSCGraphNode::FindProducer(const MSCTensor& tensor) const {
}

const std::pair<MSCJoint, size_t> MSCGraphNode::FindProducerAndIdx(const String& name) const {
ICHECK(!weight_holders.count(name)) << "Weight " << name << " has no producer with index";
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
ICHECK(!weight_holders.count(tensor_name)) << "Weight " << name << " has no producer with index";
String host, index;
std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":");
if (index.size() == 0) {
Expand All @@ -762,8 +764,9 @@ const std::pair<MSCJoint, size_t> MSCGraphNode::FindProducerAndIdx(const MSCTens

const Array<MSCJoint> MSCGraphNode::FindConsumers(const String& name) const {
Array<MSCJoint> consumers;
if (weight_holders.count(name)) {
for (const auto& h : weight_holders[name]) {
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
if (weight_holders.count(tensor_name)) {
for (const auto& h : weight_holders[tensor_name]) {
consumers.push_back(FindNode(h));
}
} else {
Expand All @@ -781,7 +784,8 @@ const Array<MSCJoint> MSCGraphNode::FindConsumers(const MSCTensor& tensor) const

const std::vector<std::pair<MSCJoint, size_t>> MSCGraphNode::FindConsumersAndIndices(
const String& name) const {
ICHECK(!weight_holders.count(name)) << "Weight has no index";
const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name;
ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index";
std::vector<std::pair<MSCJoint, size_t>> consumers;
for (const auto& c : FindConsumers(name)) {
bool find_tensor = false;
Expand Down Expand Up @@ -836,6 +840,9 @@ void MSCGraphNode::AnalysisGraph() {
weight_holders.Set(w_name, holders);
} else {
weight_holders.Set(w_name, Array<String>({n}));
if (pair.second->alias.size() > 0) {
tensor_alias.Set(pair.second->alias, pair.second->name);
}
}
}
}
Expand Down Expand Up @@ -1320,6 +1327,12 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor")
return graph->FindTensor(name);
});

TVM_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias")
.set_body_typed([](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) {
tensor->alias = alias;
graph->tensor_alias.Set(alias, tensor->name);
});

TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer")
.set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint {
return graph->FindProducer(name);
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/core/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class MSCGraphNode : public BaseGraphNode {
/*! \brief The output names of graph. */
Array<String> output_names;
/*! \brief The tensor alias in graph, get by AnalysisGraph. */
Map<String, String> tensor_alias;
mutable Map<String, String> tensor_alias;
/*! \brief The weights in graph, get by AnalysisGraph. */
Map<String, Array<String>> weight_holders;
/*! \brief Export graph to json. */
Expand Down

0 comments on commit e51b1d5

Please sign in to comment.