From 53723251c5aff1e6009843514b0e1cb13098c22d Mon Sep 17 00:00:00 2001 From: Archermmt Date: Thu, 30 Nov 2023 08:23:51 +0800 Subject: [PATCH 1/3] add tool and pruner --- .../tvm/contrib/msc/core/codegen/sources.py | 56 +- python/tvm/contrib/msc/core/ir/graph.py | 166 ++- python/tvm/contrib/msc/core/runtime/runner.py | 93 +- python/tvm/contrib/msc/core/tools/__init__.py | 21 + python/tvm/contrib/msc/core/tools/execute.py | 386 ++++++ .../contrib/msc/core/tools/prune/__init__.py | 20 + .../contrib/msc/core/tools/prune/method.py | 118 ++ .../contrib/msc/core/tools/prune/pruner.py | 546 ++++++++ python/tvm/contrib/msc/core/tools/tool.py | 1231 +++++++++++++++++ python/tvm/contrib/msc/core/utils/dataset.py | 51 +- python/tvm/contrib/msc/core/utils/file.py | 53 +- python/tvm/contrib/msc/core/utils/info.py | 113 +- .../tvm/contrib/msc/core/utils/namespace.py | 12 +- python/tvm/contrib/msc/core/utils/register.py | 134 +- .../framework/tensorflow/runtime/runner.py | 3 +- .../framework/tensorflow/tools/__init__.py | 19 + .../tensorflow/tools/prune/__init__.py | 19 + .../tensorflow/tools/prune/pruner.py | 42 + .../msc/framework/tensorrt/codegen/codegen.py | 40 +- .../msc/framework/tensorrt/runtime/runner.py | 2 + .../msc/framework/tensorrt/tools/__init__.py | 19 + .../tensorrt/tools/prune/__init__.py | 19 + .../framework/tensorrt/tools/prune/pruner.py | 42 + .../msc/framework/torch/runtime/runner.py | 2 + .../msc/framework/torch/tools/__init__.py | 19 + .../framework/torch/tools/prune/__init__.py | 19 + .../msc/framework/torch/tools/prune/pruner.py | 42 + .../msc/framework/tvm/runtime/runner.py | 7 +- .../msc/framework/tvm/tools/__init__.py | 19 + .../msc/framework/tvm/tools/prune/__init__.py | 19 + .../msc/framework/tvm/tools/prune/pruner.py | 42 + python/tvm/contrib/msc/pipeline/manager.py | 105 +- src/contrib/msc/core/codegen/base_codegen.h | 8 - src/contrib/msc/core/codegen/codegen_utils.h | 3 + src/contrib/msc/core/codegen/cpp_codegen.h | 15 +- src/contrib/msc/core/ir/graph.cc | 395 +++++- src/contrib/msc/core/ir/graph.h | 197 ++- src/contrib/msc/core/utils.cc | 20 +- src/contrib/msc/framework/tensorrt/codegen.cc | 97 +- src/contrib/msc/framework/tensorrt/codegen.h | 6 + .../msc/framework/tensorrt/codegen_utils.h | 9 + .../msc/framework/tensorrt/tensorrt_opcode.cc | 15 +- src/contrib/msc/framework/tvm/relax_opcode.cc | 6 +- src/contrib/msc/framework/tvm/relax_opcode.h | 2 +- src/runtime/contrib/msc/tensorrt_runtime.cc | 71 +- tests/python/contrib/test_msc/test_manager.py | 16 +- tests/python/contrib/test_msc/test_tools.py | 191 +++ 47 files changed, 4297 insertions(+), 233 deletions(-) create mode 100644 python/tvm/contrib/msc/core/tools/__init__.py create mode 100644 python/tvm/contrib/msc/core/tools/execute.py create mode 100644 python/tvm/contrib/msc/core/tools/prune/__init__.py create mode 100644 python/tvm/contrib/msc/core/tools/prune/method.py create mode 100644 python/tvm/contrib/msc/core/tools/prune/pruner.py create mode 100644 python/tvm/contrib/msc/core/tools/tool.py create mode 100644 python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py create mode 100644 python/tvm/contrib/msc/framework/torch/tools/__init__.py create mode 100644 python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py create mode 100644 python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py create mode 100644 python/tvm/contrib/msc/framework/tvm/tools/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py create mode 100644 python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py create mode 100644 tests/python/contrib/test_msc/test_tools.py diff --git a/python/tvm/contrib/msc/core/codegen/sources.py b/python/tvm/contrib/msc/core/codegen/sources.py index cf1aed7b4764..825ec390f895 100644 --- a/python/tvm/contrib/msc/core/codegen/sources.py +++ b/python/tvm/contrib/msc/core/codegen/sources.py @@ -34,6 +34,7 @@ def get_base_h_code() -> str: #include #include #include +#include #include #include @@ -78,11 +79,19 @@ class DatasetReader { bool ReadNext(void* buffers[], int num_datas = -1); + const std::vector GetTensorNames() { return tensor_names_; } + + size_t GetTensorSize(const std::string& name); + + const std::string GetSaveName(const std::string& name); + private: std::string folder_; size_t max_size_; size_t cur_cnt_; - std::vector> tensor_info_; + std::vector tensor_names_; + std::unordered_map save_names_; + std::unordered_map tensor_sizes_; }; } // namespace msc @@ -102,10 +111,10 @@ def get_base_cc_code() -> str: The base cc source. """ - return """#include -#include + return """#include "base.h" -#include "base.h" +#include +#include namespace tvm { namespace contrib { @@ -122,23 +131,31 @@ def get_base_cc_code() -> str: DatasetReader::DatasetReader(const std::string& folder, int max_size) { folder_ = folder; - const std::string info_file = folder_ + "/tensor_info"; + const std::string info_file = folder_ + "/datas_info.txt"; std::ifstream input(info_file, std::ios::binary); assert(input.is_open() && ("Failed to open file " + info_file).c_str()); std::string line; while (getline(input, line)) { + // define name int pos = line.find(" "); assert(pos > 0 && ("Can not find space in line " + line).c_str()); const auto& name = line.substr(0, pos); - const auto& byte_size = line.substr(pos + 1, line.size()); - tensor_info_.push_back(std::make_pair(name, static_cast(std::stoi(byte_size)))); + tensor_names_.push_back(name); + const auto& left = line.substr(pos + 1, line.size()); + // define save_name + pos = left.find(" "); + assert(pos > 0 && ("Can not find space in left " + left).c_str()); + save_names_[name] = left.substr(0, pos); + // define size + const auto& byte_size = left.substr(pos + 1, left.size()); + tensor_sizes_[name] = static_cast(std::stoi(byte_size)); } size_t file_cnt = 0; while (true) { bool all_exists = true; - for (const auto& pair : tensor_info_) { + for (const auto& pair : save_names_) { const auto& d_file = - folder_ + "/" + pair.first + "/batch_" + std::to_string(file_cnt) + ".bin"; + folder_ + "/" + pair.second + "/batch_" + std::to_string(file_cnt) + ".bin"; if (!FileUtils::FileExist(d_file)) { all_exists = false; break; @@ -160,12 +177,13 @@ def get_base_cc_code() -> str: if (cur_cnt_ >= max_size_) { return false; } - size_t max_num = num_datas > 0 ? static_cast(num_datas) : tensor_info_.size(); - max_num = std::min(max_num, tensor_info_.size()); + size_t max_num = num_datas > 0 ? static_cast(num_datas) : tensor_names_.size(); + max_num = std::min(max_num, tensor_names_.size()); for (size_t i = 0; i < max_num; i++) { - const auto& pair = tensor_info_[i]; - const auto& d_file = folder_ + "/" + pair.first + "/batch_" + std::to_string(cur_cnt_) + ".bin"; - if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], pair.second)) { + const auto& name = tensor_names_[i]; + const auto& d_file = + folder_ + "/" + GetSaveName(name) + "/batch_" + std::to_string(cur_cnt_) + ".bin"; + if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], GetTensorSize(name))) { return false; } } @@ -173,6 +191,16 @@ def get_base_cc_code() -> str: return true; } +size_t DatasetReader::GetTensorSize(const std::string& name) { + assert(tensor_sizes_.count(name)); + return tensor_sizes_[name]; +} + +const std::string DatasetReader::GetSaveName(const std::string& name) { + assert(save_names_.count(name)); + return save_names_[name]; +} + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 154703d332ce..8fabed30acfe 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -68,6 +68,9 @@ def dim_at(self, axis: Union[int, str]) -> int: return int(self.shape[axis]) return int(_ffi_api.MSCTensorDimAt(self, axis)) + def layout_of(self, axis: str) -> int: + return self.layout.index_of(axis) + def set_alias(self, alias: str): """Set alis for the tensor @@ -162,7 +165,6 @@ def __init__( outputs: List[MSCTensor], weights: Dict[str, MSCTensor], ): - parents = [i[0] for i in inputs] out_indices = [i[1] for i in inputs] self.__init_handle_by_constructor__( @@ -350,10 +352,12 @@ class WeightJoint(BaseJoint): The optype of the node. wtype: string The weight type of the node. + strategy: string + The prune strategy of the node. + weight: MSCTensor + The weight of the node. attrs: dict The attributes of the node. - weight: MSCTensor, - The weight of the node. parents: list The parents of the node. friends: list @@ -367,12 +371,12 @@ def __init__( shared_ref: str, optype: str, wtype: str, - attrs: Dict[str, str], + strategy: str, weight: MSCTensor, + attrs: Dict[str, str], parents: List[BaseJoint], friends: List[BaseJoint], ): - self.__init_handle_by_constructor__( _ffi_api.WeightJoint, index, @@ -380,12 +384,58 @@ def __init__( shared_ref, optype, wtype, - attrs, + strategy, weight, + attrs, parents, friends, ) + def get_attrs(self) -> Dict[str, str]: + """Get all the attributes from node + + Returns + ------- + attributes: dict + The attributes of node. + """ + + return _ffi_api.WeightJointGetAttrs(self) + + def get_attr(self, key: str, default: Optional[Any] = None) -> str: + """Get the attribute of key from node + + Parameters + ------- + key: str + The key of the attribute. + default: Any + The default value when key is missing. + + Returns + ------- + attribute: str + The attributes of node. + """ + + return self.get_attrs().get(key, default) + + def has_attr(self, key: str) -> bool: + """Check if key in attributes + + Parameters + ------- + key: str + The key of the attribute. + + Returns + ------- + has_attr: bool + Whether the key in the attributes. + """ + + return bool(_ffi_api.WeightJointHasAttr(self, key)) + class BaseGraph(Object): """Base class of all MSC Graphs.""" @@ -727,6 +777,110 @@ def __init__( nodes, ) + def has_node(self, name: str) -> bool: + """Check if weight node in the graph. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + has_node: bool + Whether the node is in the graph + """ + + return bool(_ffi_api.WeightGraphHasNode(self, name)) + + def find_node(self, name: str) -> WeightJoint: + """Find weight node by name. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + node: MSCJoint + The found node. + """ + + return _ffi_api.WeightGraphFindNode(self, name) + + def get_nodes(self) -> Iterable[WeightJoint]: + """Get all the weight nodes in the graph. + + Returns + ------- + nodes: generator + The generator of nodes. + """ + + for n in self.node_names: + yield self.find_node(n) + + def to_json(self) -> str: + """Dump the graph to json. + + Returns + ------- + graph_json: string + The graph in json format. + """ + + return _ffi_api.WeightGraphToJson(self) + + def inspect(self) -> dict: + """Extract important info of the graph. + + Returns + ------- + graph_des: dict + The graph description in json format. + """ + + graph_des = { + "nodes": {"total": 0}, + } + for node in self.get_nodes(): + graph_des["nodes"]["total"] += 1 + if node.weight_type not in graph_des["nodes"]: + graph_des["nodes"][node.weight_type] = 1 + else: + graph_des["nodes"][node.weight_type] += 1 + return graph_des + + @classmethod + def from_json(cls, json_str: str) -> BaseGraph: + """Load the graph from json. + + Parameters + ---------- + json_str: string + The file_path or json string. + + Returns + ------- + graph: WeightGraph + The graph. + """ + + dict_obj = msc_utils.load_dict(json_str) + return _ffi_api.WeightGraphFromJson(msc_utils.dump_dict(dict_obj)) + + def clone(self) -> BaseGraph: + """Clone the graph. + + Returns + ------- + new_graph: MSCGraph + The cloned graph. + """ + + return MSCGraph.from_json(self.to_json()) + def visualize(self, path: Optional[str] = None) -> str: """Dump the graph to prototxt format. diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index cc4b56eae432..12e2dec02a30 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -26,7 +26,9 @@ import tvm from tvm.contrib.msc.core.ir import MSCGraph from tvm.contrib.msc.core.frontend import from_relax +from tvm.contrib.msc.core.tools import BaseTool, ToolType, create_tool from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.core import _ffi_api @@ -54,6 +56,8 @@ class BaseRunner(object): The device of the model, cpu| cuda| cuda:0|... is_training: bool Whether use model in training + debug_level: int + The debug level. logger: logging.Logger The logger """ @@ -68,6 +72,7 @@ def __init__( name: str = "main", device: str = "cpu", is_training: bool = False, + debug_level: int = 0, logger: logging.Logger = None, ): self._mod = mod @@ -78,6 +83,7 @@ def __init__( self._name = name self._device = device if self._device_enabled(device) else "cpu" self._is_training = is_training + self._debug_level = debug_level self._logger = logger or msc_utils.get_global_logger() self._logger.info( msc_utils.msg_block( @@ -111,6 +117,7 @@ def setup(self) -> dict: "name": self._name, "device": self._device, "is_training": self._is_training, + "debug_level": self._debug_level, } def change_stage(self, stage: str): @@ -143,7 +150,10 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph: bool = Fa # Create tools if self._tools_config: - raise NotImplementedError("Tools is not supported") + for t_type, config in self._tools_config.items(): + self._tools[t_type] = create_tool( + self.framework, t_type, self._name, stage=self._stage, **config + ) # Load graphs from cache if cache_info.get("graphs"): @@ -157,6 +167,12 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph: bool = Fa self._graphs, self._weights = self._translate() self._logger.debug("Translate {} graphs from module".format(len(self._graphs))) + # reset graph for tools + for tool in self._tools.values(): + self._graphs, self._weights = tool.reset( + self._graphs, self._weights, cache_dir=cache_dir + ) + if cache_info.get("model") and not build_graph: # Load model from cache self._model = self._load_model(cache_dir, cache_info["model"]) @@ -172,7 +188,8 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph: bool = Fa ) ) self._model_info = self._inspect_model() - self._logger.debug(msc_utils.msg_block("MODEL_INFO", self._model_info)) + if self._debug_level >= 3: + self._logger.debug(msc_utils.msg_block("RUNNER.MODEL_INFO", self._model_info)) if cache_info.get("runnable") and not build_graph: # Load runnable from cache @@ -201,11 +218,14 @@ def save_cache(self, cache_dir: msc_utils.MSCDirectory): "model": self._save_model(cache_dir), "runnable": self._save_runnable(cache_dir), } + for tool in self._tools.values(): + cache_info.update(tool.save_cache(cache_dir)) with open(cache_dir.relpath("cache_info.json"), "w") as f: f.write(json.dumps(cache_info, indent=2)) - self._logger.debug( - msc_utils.msg_block("CACHE_INFO", {"folder": cache_dir, "info": cache_info}) - ) + if self._debug_level >= 3: + self._logger.debug( + msc_utils.msg_block("RUNNER.CACHE_INFO", {"folder": cache_dir, "info": cache_info}) + ) def run( self, inputs: Union[List[np.ndarray], Dict[str, np.ndarray]], ret_type="dict" @@ -263,6 +283,51 @@ def run( outputs = [msc_utils.cast_array(data) for data in outputs] return outputs + def get_tool(self, tool_type: str) -> BaseTool: + """Get tool by type + + Parameters + ------- + tool_type: str + The type of the tool prune| quantize| distill... + + Returns + ------- + tool: BaseTool + The saved tool. + """ + + return self._tools.get(tool_type) + + def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict: + """Execute tool and get plan + + Parameters + ------- + tool_type: str + The tool type, should be in ToolType + data_loader: + The data loader + """ + + assert tool_type in self._tools, "Can not find tool " + str(tool_type) + if tool_type == ToolType.PRUNER: + pruner = self.get_tool(ToolType.PRUNER) + if not pruner.finalize(): + assert data_loader, "data_loader should be given to plan prune" + for inputs in data_loader(): + self.run(inputs) + break + plan = pruner.finalize() + else: + plan = self.get_tool(tool_type).finalize() + assert plan, "Failed to create plan for {}".format(tool_type) + plan_file = self._tools_config[tool_type]["plan_file"] + with open(plan_file, "w") as f: + f.write(json.dumps(plan, indent=2)) + self._logger.info("Save %s plan -> %s", tool_type, plan_file) + return plan + def visualize(self, visual_dir: msc_utils.MSCDirectory): """Visualize MSCGraphs @@ -638,7 +703,7 @@ def _generate_model( graphs or self._graphs[0], weights or self._weights[0], codegen_config=self._generate_config.get("codegen"), - print_config=self._generate_config.get("build"), + print_config=self._generate_config.get("print"), build_folder=self._generate_config["build_folder"], ) @@ -799,13 +864,16 @@ def _generate_model( graph_infos = list(zip(graphs or self._graphs, weights or self._weights)) extra_option = self._generate_config.get("extra_option", {}) - extra_option["tool_tag"] = self._name + if self._stage == MSCStage.COMPILE and not self.get_tool(ToolType.TRACKER): + extra_option["tool_tag"] = "" + else: + extra_option["tool_tag"] = self._name return self.codegen_func( self._byoc_mod, graph_infos, - codegen_config=self._generate_config.get("codegen"), - print_config=self._generate_config.get("build"), - extra_option=extra_option, + codegen_configs=self._generate_config.get("codegen"), + print_configs=self._generate_config.get("print"), + extra_options=extra_option, build_folder=self._generate_config["build_folder"], output_folder=self._generate_config.get("output_folder", msc_utils.get_output_dir()), ) @@ -886,6 +954,11 @@ def _inspect_model(self) -> dict: The inspected model info """ + if self._debug_level >= 3: + for idx, graph in enumerate(self._graphs): + self._logger.debug( + msc_utils.msg_block("RUNNER.GRAPH[{}].INFO".format(idx), graph.inspect()) + ) return self._byoc_graph.inspect() def _device_enabled(self, device: str) -> bool: diff --git a/python/tvm/contrib/msc/core/tools/__init__.py b/python/tvm/contrib/msc/core/tools/__init__.py new file mode 100644 index 000000000000..3d60ce22c63a --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools""" + +from .tool import * +from .execute import * +from .prune import * diff --git a/python/tvm/contrib/msc/core/tools/execute.py b/python/tvm/contrib/msc/core/tools/execute.py new file mode 100644 index 000000000000..7623de109e08 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/execute.py @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.execute""" + +from functools import wraps +from typing import List, Iterable, Any, Dict + +import tvm +from tvm.contrib.msc.core.utils.namespace import MSCMap, MSCKey +from tvm.contrib.msc.core import utils as msc_utils +from .tool import ToolType, BaseTool + + +def _get_tool_key(tool_type: str) -> str: + """Get the key according to tool_type + + Parameters + ------- + tool_type: str + The type of the tool prune| quantize| distill... + + Returns + ------- + tool_key: str + The tool key. + """ + + if tool_type == ToolType.PRUNER: + return MSCKey.PRUNERS + if tool_type == ToolType.QUANTIZER: + return MSCKey.QUANTIZERS + if tool_type == ToolType.DISTILLER: + return MSCKey.DISTILLERS + if tool_type == ToolType.TRACKER: + return MSCKey.TRACKERS + raise TypeError("Unexpected tool type " + str(tool_type)) + + +def add_tool(tool: BaseTool, tool_type: str, tag: str = "main"): + """Add tool by type and tag + + Parameters + ------- + tool: BaseTool + The tool. + tool_type: str + The type of the tool prune| quantize| distill... + tag: str + The tag of the tool. + """ + + tool_key = _get_tool_key(tool_type) + tools = MSCMap.get(tool_key, {}) + tools[tag] = tool + MSCMap.set(tool_key, tools) + return tool + + +def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> BaseTool: + """Create tool by type, config and tag + + Parameters + ------- + framework: str + The framework for implement + tool_type: str + The type of the tool prune| quantize| distill... + tag: str + The tag of the tool. + config: dict + The config of tool. + """ + + tool_style = config.pop("tool_style") if "tool_style" in config else "default" + tool_cls = msc_utils.get_registered_tool_cls(framework, tool_type, tool_style) + assert tool_cls, "Can not find tool class for {}:{} @ {}".format( + tool_type, tool_style, framework + ) + return add_tool(tool_cls(**config), tool_type, tag) + + +def get_tool(tool_type: str, tag: str = "main") -> BaseTool: + """Get tool by type and tag + + Parameters + ------- + tool_type: str + The type of the tool prune| quantize| distill... + tag: str + The tag of the tool. + + Returns + ------- + tool: BaseTool + The saved tool. + """ + + tool_key = _get_tool_key(tool_type) + tools = MSCMap.get(tool_key, {}) + return tools.get(tag) + + +def get_tools(tag: str = "main") -> Iterable[BaseTool]: + """Get all saved tools by tag + + Parameters + ------- + tag: str + The tag of the tool. + + Returns + ------- + tools: iterable + The saved tools. + """ + + for t_type in ToolType.all_types(): + tool = get_tool(t_type, tag) + if tool: + yield tool + + +def remove_tool(tool_type: str, tag: str = "main"): + """Remove tool by type and tag + + Parameters + ------- + tool_type: str + The type of the tool prune| quantize| distill... + tag: str + The tag of the tool. + """ + + tool_key = _get_tool_key(tool_type) + tools = MSCMap.get(tool_key, {}) + if tag in tools: + tools.pop(tag) + MSCMap.set(tool_key, tools) + + +def remove_tools(tag: str = "main"): + """Remove all saved tools by tag + + Parameters + ------- + tag: str + The tag of the tool. + + Returns + ------- + tools: iterable + The saved tools. + """ + + for t_type in ToolType.all_types(): + remove_tool(t_type, tag) + + +def process_tensor(tensor: Any, name: str, consumer: str, scope: str, tag: str = "main") -> Any: + """Process tensor with tools + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scope: str + The scope mark teacher| student| null + tag: str + The tag of the tool. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + for tool in get_tools(tag): + tensor = tool.process_tensor(tensor, name, consumer, scope) + return tensor + + +@tvm.register_func("msc_tool.codegen_tensor") +def codegen_tensor( + tensor_ctx: Dict[str, str], name: str, consumer: str, scope: str, tag: str = "main" +) -> List[str]: + """Codegen processed tensor describe with tools + + Parameters + ------- + tensor_ctx: dict + Tensor describe items. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scope: str + The scope mark teacher| student| null + tag: str + The tag of the tool. + + Returns + ------- + processed: list + The tensor describe for processed tensor. + """ + + tensor_ctx = {**dict(tensor_ctx), "processed": []} + tensor_ctx = process_tensor(dict(tensor_ctx), name, consumer, scope, tag) + return tensor_ctx["processed"] + + +def wrap_step(step: str, tag: str = "main") -> callable: + """Wrapper for tool execution + + Parameters + ------- + step: str + The step for tool execution build| forward + tag: str + The tag of the tool. + + Returns + ------- + decorate: callable + The decorate. + """ + + def decorate(func): + @wraps(func) + def wrapper(*args, **kwargs): + for tool in get_tools(tag): + if step == "build": + tool.execute_before_build(*args, **kwargs) + elif step == "forward": + tool.execute_before_forward(*args, **kwargs) + else: + raise TypeError("Unexpected step " + str(step)) + output = func(*args, **kwargs) + for tool in get_tools(tag): + if step == "build": + output = tool.execute_after_build(output) + elif step == "forward": + output = tool.execute_after_forward(output) + else: + raise TypeError("Unexpected step " + str(step)) + return output + + return wrapper + + return decorate + + +def execute_step(step: str, *args, **kwargs): + """Execute tools for a step + + Parameters + ------- + step: str + The step for tool execution build| forward + args: list + The arguments for model build. + kwargs: dict + The key word arguments for model build. + """ + + if step in ("before_build", "before_forward"): + output = None + else: + assert ( + len(args) == 1 and not kwargs + ), "after step only accept 1 argument, get args {}, kwargs {}".format(args, kwargs) + output = args[0] + tag = kwargs.pop("tag") if "tag" in kwargs else "main" + for tool in get_tools(tag): + if step == "before_build": + tool.execute_before_build(*args, **kwargs) + elif step == "before_forward": + tool.execute_before_forward(*args, **kwargs) + elif step == "after_build": + output = tool.execute_after_build(output) + elif step == "after_forward": + output = tool.execute_after_forward(output) + else: + raise TypeError("Unexpected step " + str(step)) + return output + + +def _execute_step_with_context( + step_ctx: Dict[str, Any], step: str, graph_name: str, tag: str = "main" +) -> Dict[str, Any]: + """Execute step with contect + + Parameters + ------- + step_ctx: dict + The step context. + step: str + The step for tool execution build| forward + graph_name: str + The graph name. + tag: str + The tag of the tool. + + Returns + ------- + step_ctx: dict + The processed step context. + """ + + for tool in get_tools(tag): + if step == "before_build": + tool.execute_before_build(step_ctx, graph_name=graph_name) + elif step == "before_forward": + tool.execute_before_forward(step_ctx, graph_name=graph_name) + elif step == "after_build": + step_ctx = tool.execute_after_build(step_ctx) + elif step == "after_forward": + step_ctx = tool.execute_after_forward(step_ctx) + else: + raise TypeError("Unexpected step " + str(step)) + return step_ctx + + +@tvm.register_func("msc_tool.codegen_step") +def codegen_step( + step_ctx: Dict[str, str], step: str, graph_name: str, tag: str = "main" +) -> List[str]: + """Codegen step codes + + Parameters + ------- + step_ctx: dict + The step describe items. + step: str + The step for tool execution build| forward + graph_name: str + The graph name. + tag: str + The tag of the tool. + + Returns + ------- + processed: list + The tensor describe for processed tensor. + """ + + step_ctx = {**dict(step_ctx), "processed": []} + step_ctx = _execute_step_with_context(step_ctx, step, graph_name, tag) + return step_ctx["processed"] + + +@tvm.register_func("msc_tool.callback_step") +def callback_step(step_ctx: Dict[str, Any], step: str, graph_name: str = "main", tag: str = "main"): + """Execute tools for a step + + Parameters + ------- + step_ctx: dict + The step context. + step: str + The step for tool execution build| forward + graph_name: str + The graph name. + tag: str + The tag of the tool. + """ + + _execute_step_with_context(step_ctx, step, graph_name, tag) diff --git a/python/tvm/contrib/msc/core/tools/prune/__init__.py b/python/tvm/contrib/msc/core/tools/prune/__init__.py new file mode 100644 index 000000000000..8317d52ac12b --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/prune/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.prune""" + +from .pruner import * +from .method import * diff --git a/python/tvm/contrib/msc/core/tools/prune/method.py b/python/tvm/contrib/msc/core/tools/prune/method.py new file mode 100644 index 000000000000..6e294dc15c86 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/prune/method.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.tools.prune.method""" + +from typing import List +import numpy as np + +from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class PruneMethod(object): + """Default prune method""" + + @classmethod + def prune_axis(cls, data: np.ndarray, axis: int, indices: List[int]) -> np.ndarray: + """Delete indices on axis + + Parameters + ---------- + data: np.ndarray + The source data. + axis: int + The axis to prune + indices: list + The indices to be pruned + + Returns + ------- + data: np.ndarray + The pruned data. + """ + + left_datas = [ + d for idx, d in enumerate(np.split(data, data.shape[axis], axis)) if idx in indices + ] + return np.concatenate(left_datas, axis=axis) + + @classmethod + def per_channel( + cls, + pruner: BaseTool, + data: np.ndarray, + name: str, + consumer: str, + in_axis: int, + out_axis: int, + in_indices: List[int], + density: float, + stride: int = 8, + ) -> np.ndarray: + """Prune the data + + Parameters + ---------- + pruner: BasePruner + The pruner + data: np.ndarray + The source data. + name: str + The name of the weight. + consumer: str + The name of the consumer. + in_axis: int + The input axis + out_axis: int + The output axis + in_indices: list + The input indices to be pruned + density: float + The density to prune + stride: int + The prune stride + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + config = {"in_indices": in_indices, "out_indices": []} + if density == 1: + return config + if len(in_indices) > 0: + data = cls.prune_axis(data, in_axis, in_indices) + weight = pruner.find_tensor(name) + left_num = int(((density * weight.dim_at(out_axis) + stride) // stride) * stride) + axis_sum = [np.abs(d).sum() for d in np.split(data, data.shape[out_axis], out_axis)] + rank = np.argsort(np.array(axis_sum)) + config["out_indices"] = rank[-left_num:].tolist() + return config + + @classmethod + def framework(cls): + return MSCFramework.MSC + + @classmethod + def tool_type(cls): + return ToolType.PRUNER + + +msc_utils.register_tool_method(PruneMethod) diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py new file mode 100644 index 000000000000..f761a37c56d0 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -0,0 +1,546 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.prune.pruner""" + +from typing import List, Dict, Iterable, Tuple, Any + +import tvm +from tvm.contrib.msc.core.ir import MSCGraph, WeightGraph, WeightJoint, MSCTensor +from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, Strategy +from tvm.contrib.msc.core import _ffi_api +from tvm.contrib.msc.core import utils as msc_utils +from .method import PruneMethod + + +class BasePruner(BaseTool): + """Base pruner for all""" + + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + # Build weight graphs + if "prunable_types" in self._options: + self._prunable_types = self._options["prunable_types"] + else: + self._prunable_types = { + "constant": ["const"], + "nn.conv2d": ["weight"], + "msc.conv2d_bias": ["weight"], + "msc.linear": ["weight"], + "msc.linear_bias": ["weight"], + } + + if "relation_types" in self._options: + self._relation_types = self._options["relation_types"] + else: + self._relation_types = { + "concatenate": "multi_inputs", + "reshape": "reshape", + "add": "passby", + "substract": "passby", + "multiply": "passby", + "divide": "passby", + } + + return super().setup() + + def reset( + self, + graphs: List[MSCGraph], + weights: List[Dict[str, tvm.nd.array]], + cache_dir: msc_utils.MSCDirectory = None, + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Reset the tool with graphs and weights + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + cache_dir: MSCDirectory + cache path for save/load info + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + self._unpruned_tensors = {} + res = super().reset(graphs, weights, cache_dir) + if self.on_debug(3): + for idx, graph in enumerate(self._weight_graphs): + self._logger.debug( + msc_utils.msg_block("PRUNER.WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect()) + ) + return res + + def load_graphs( + self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Load the graphs and weights + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + as_cache: bool + Whether the graphs and weights are loaded from cache + + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + self._weight_graphs = [ + _ffi_api.WeightGraph(graph, self._prunable_types, self._relation_types) + for graph in graphs + ] + if not self._plan: + return graphs, weights + return self.prune_graphs(graphs, weights) + + def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): + """Save runner to cache + + Parameters + ------- + cache_dir: MSCDirectory + cache path for save/load info + cache_info: dict + The cache_info + """ + + assert ( + "weight_graphs" in cache_info + ), "weight_graphs should be given in cache_info, get " + str(cache_info) + self._weight_graphs = [ + WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"] + ] + + def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: + """Save runner to cache + + Parameters + ------- + cache_dir: MSCDirectory + cache path for save/load info + + Returns + ------- + cache_info: dict + The cache_info. + """ + + cache_info = {"weight_graphs": [g.name + "_graph.json" for g in self._weight_graphs]} + with cache_dir: + for graph, f_path in zip(self._weight_graphs, cache_info["weight_graphs"]): + with open(f_path, "w") as f_graph: + f_graph.write(graph.to_json()) + return cache_info + + def _parse_strategys(self, strategy_list: dict) -> Dict[str, Strategy]: + """Parse the strategy to get valid strategy + + Parameters + ------- + strategy_list: dict + The given strategy + + Returns + ------- + strategys: dict + The parsed strategy. + """ + + def _update_stages(strategy): + if "stages" not in strategy: + strategy["stages"] = [msc_utils.MSCStage.PRUNE] + return strategy + + return super()._parse_strategys([_update_stages(s) for s in strategy_list]) + + def _check_tensor(self, name: str, consumer: str) -> bool: + """Check if the tensor should be processed + + Parameters + ------- + name: str + The name of the tensor. + consumer: str + The name of the consumer. + + Returns + ------- + vaild: bool + Whether to process the tensor. + """ + + if not self.has_w_node(name): + return False + strategy = self._get_tensor_strategy(name, consumer) + if not strategy: + return False + if strategy.get_config("density", 1.0) == 1.0: + return False + return True + + def _process_tensor( + self, tensor: Any, name: str, consumer: str, strategys: List[Strategy] + ) -> Any: + """Process tensor + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + if name in self._plan: + return tensor + + assert len(strategys) == 1, "pruner should only has 1 strategy, get " + str(strategys) + strategy = strategys[0] + + def _get_in_indices(w_node: WeightJoint) -> List[int]: + """Get input indices for weight node""" + if not w_node.parents: + return [] + if w_node.name in self._plan and "in_indices" in self._plan[w_node.name]: + return self._plan[w_node.name]["in_indices"] + assert all( + p.name in self._plan for p in w_node.parents + ), "Missing some parents in runtime config " + str(w_node) + if len(w_node.parents) == 1: + return self._plan[w_node.parents[0].name]["out_indices"] + if w_node.parents[0].friends: + return self._plan[w_node.parents[0].friends[0].name]["out_indices"] + raise Exception("Unexpected w_node " + str(w_node)) + + def _prunable(w_node: WeightJoint) -> bool: + """Check if weight node is prunable""" + if w_node.get_attr("prune_strategy") != "prune": + return False + if not w_node.children: + return False + childrens = list(w_node.children) + while childrens: + current = childrens.pop(0) + prune_strategy = current.get_attr("prune_strategy") + if prune_strategy == "prune": + return True + childrens.extend(list(current.children)) + return False + + w_node = self.find_w_node(name) + in_axis, out_axis = self._get_io_axes(w_node) + if w_node.weight.dim_at(in_axis) == 1: + in_indices = [] + else: + in_indices = _get_in_indices(w_node) + self._plan[w_node.name] = {"in_indices": in_indices} + if w_node.friends and w_node != w_node.friends[0]: + lead_name = w_node.friends[0].name + if lead_name not in self._plan: + self._unpruned_tensors[name] = { + "lead_name": lead_name, + "tensor": tensor, + "consumer": consumer, + } + self._plan.pop(w_node.name) + return tensor + self._plan[w_node.name]["out_indices"] = self._plan[lead_name]["out_indices"] + elif _prunable(w_node): + self._plan[w_node.name] = strategy( + self, + self.get_data(w_node.name), + w_node.name, + consumer, + in_axis=in_axis, + out_axis=out_axis, + in_indices=in_indices, + ) + elif w_node.get_attr("prune_strategy") == "follow": + self._plan[w_node.name]["out_indices"] = [] + elif w_node.get_attr("prune_strategy") == "passby": + self._plan[w_node.name]["out_indices"] = in_indices + else: + self._plan[w_node.name]["out_indices"] = [] + lazy_pruned = set() + for lazy_name, info in self._unpruned_tensors.items(): + if info["lead_name"] in self._plan: + strategys = self._get_tensor_strategys(lazy_name, info["consumer"]) + lazy_tensor = self._process_tensor( + info["tensor"], lazy_name, info["consumer"], strategys + ) + strategy_mark = ".".join([s.get_executor().name for s in strategys]) + self.debug_tensor( + lazy_tensor, lazy_name, consumer, "lazy processed({})".format(strategy_mark) + ) + lazy_pruned.add(lazy_name) + if lazy_pruned: + self._unpruned_tensors = { + k: v for k, v in self._unpruned_tensors.items() if k not in lazy_pruned + } + return tensor + + def prune_graphs( + self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Reset the tool + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + def _prune_by_shape(tensor: MSCTensor, shape: List[int]): + return MSCTensor(tensor.name, tensor.dtype, tensor.layout.name, shape, tensor.alias) + + def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): + shape = tensor.get_shape() + if channel_axis is None: + channel_axis = tensor.layout_of("C") + shape[channel_axis] = dim + return _prune_by_shape(tensor, shape) + + new_graphs, new_weights = [], [] + pruned_weights_cnt = 0 + for graph, sub_weights in zip(graphs, weights): + pruned_tensors, pruned_weights = {}, {} + for node in graph.get_nodes(): + for weight in node.get_weights().values(): + w_name = weight.name + if w_name in self._plan: + data = msc_utils.cast_array(sub_weights[w_name]) + in_axis, out_axis = self._get_io_axes(self.find_w_node(w_name)) + w_config = self._plan[w_name] + if w_config["in_indices"]: + data = PruneMethod.prune_axis(data, in_axis, w_config["in_indices"]) + if w_config["out_indices"]: + data = PruneMethod.prune_axis(data, out_axis, w_config["out_indices"]) + pruned_tensors[w_name] = _prune_by_shape(weight, data.shape) + pruned_weights[w_name] = tvm.nd.array(data) + pruned_weights_cnt += 1 + else: + pruned_weights[w_name] = sub_weights[w_name] + if node.optype == "constant" and node.weight_at("const").name in pruned_tensors: + ref_tensor = pruned_tensors[node.weight_at("const").name] + pruned_tensors[node.output_at(0).name] = MSCTensor( + node.output_at(0).name, + ref_tensor.dtype, + ref_tensor.layout.name, + ref_tensor.get_shape(), + ref_tensor.alias, + ) + elif ( + node.optype in ("nn.conv2d", "msc.conv2d_bias", "msc.linear", "msc.linear_bias") + and node.weight_at("weight").name in pruned_tensors + ): + out = node.output_at(0) + if node.optype in ("msc.linear", "msc.linear_bias"): + channel_axis = out.ndim - 1 + else: + channel_axis = out.layout_of("C") + pruned_tensors[out.name] = _prune_by_channel( + out, + pruned_tensors[node.weight_at("weight").name].dim_at("O"), + channel_axis, + ) + else: + for out in node.get_outputs(): + if out.name in self._plan: + pruned_tensors[out.name] = _prune_by_channel( + out, len(self._plan[out.name]["out_indices"]) + ) + elif ( + node.get_inputs() + and node.input_at(0).name in pruned_tensors + and node.input_at(0).layout_of("C") >= 0 + and out.layout_of("C") >= 0 + ): + pruned_tensors[out.name] = _prune_by_channel( + out, pruned_tensors[node.input_at(0).name].dim_at("C") + ) + if self.on_debug(3): + self._logger.debug(msc_utils.msg_block("Pruned Tensors", pruned_tensors)) + pruned_graph = _ffi_api.PruneWeights(graph, pruned_tensors) + new_graphs.append(pruned_graph) + new_weights.append(pruned_weights) + + # log compress rate + def _flatten_size(weights): + weight_size = 0 + for sub_weights in weights: + for w_data in sub_weights.values(): + weight_size += w_data.asnumpy().size + return weight_size + + raw_size = _flatten_size(weights) + new_size = _flatten_size(new_weights) + self._logger.info( + "{} weights pruned, compress to {:g}%".format( + pruned_weights_cnt, new_size * 100 / raw_size + ) + ) + return new_graphs, new_weights + + def visualize(self, visual_dir: msc_utils.MSCDirectory): + """Visualize MSCGraphs + + Parameters + ------- + visual_dir: MSCDirectory + Visualize path for saving graph + """ + + for w_graph in self._weight_graphs: + w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt")) + + def finalize(self) -> dict: + """Get the plan""" + + assert not self._unpruned_tensors, "Some tensors are not pruned " + str( + self._unpruned_tensors + ) + self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} + return super().finalize() + + def get_w_nodes(self) -> Iterable[WeightJoint]: + """Get all the weight nodes in the weight_graphs. + + Returns + ------- + nodes: generator + The generator of weight nodes. + """ + + for g in self._weight_graphs: + for n in g.get_nodes(): + yield n + + def has_w_node(self, name: str) -> bool: + """Check if name in weight_graphs. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + has_node: bool + Whether node in weight_graphs. + """ + + for g in self._weight_graphs: + if g.has_node(name): + return True + return False + + def find_w_node(self, name: str) -> WeightJoint: + """Find weight node by name. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + node: WeightJoint + The found node. + """ + + for g in self._weight_graphs: + if g.has_node(name): + return g.find_node(name) + raise Exception("Can not find node {} from graphs".format(name)) + + def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: + """Get the input output axes + + Parameters + ---------- + w_node: WeightJoint + The weight node. + + Returns + ------- + axes: (int, int) + The input output axis. + """ + + if w_node.weight.ndim == 1: + return 0, 0 + if w_node.has_attr("in_axis") and w_node.has_attr("out_axis"): + return int(w_node.get_attr("in_axis")), int(w_node.get_attr("out_axis")) + in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") + if in_axis >= 0 and out_axis >= 0: + return in_axis, out_axis + if w_node.weight.layout_of("C") >= 0: + return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") + raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) + + @classmethod + def tool_type(cls): + return ToolType.PRUNER + + +class DefaultPruner(BasePruner): + @classmethod + def tool_style(cls): + return "default" + + +msc_utils.register_tool_cls(DefaultPruner) diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py new file mode 100644 index 000000000000..0032394a6506 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -0,0 +1,1231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.tools.base_tool""" + +import os +import copy +import logging +from itertools import product +from typing import List, Iterable, Any, Tuple, Dict +import numpy as np + +import tvm +from tvm.contrib.msc.core.ir import MSCGraph, MSCJoint, MSCTensor +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class ToolType(object): + """Enum all msc tool types""" + + BASE = "base" + PRUNER = "pruner" + QUANTIZER = "quantizer" + DISTILLER = "distiller" + TRACKER = "tracker" + ALL = [PRUNER, QUANTIZER, DISTILLER, TRACKER] + + @classmethod + def all_types(cls) -> List[str]: + return cls.ALL + + +class ToolScope(object): + """Enum all msc tool scope""" + + TEACHER = "teacher" + STUDENT = "student" + + +class Executor(object): + """Executor for process the tensor + + Parameters + ---------- + name: str + The name. + method: str + The method for execute. + config: dict + The config for execute + """ + + def __init__(self, name: str, method: callable, config: dict = None): + self._name = name + self._method = method + self._config = config or {} + + def __str__(self): + return "{}({})".format(self._name, self._config) + + def execute(self, *args, **kwargs) -> Any: + """execute the method + + Parameters + ---------- + args: list + The arguments for run method. + kwargs: dict + The key word arguments for run method. + + Returns + ------- + plan or tensor: + The plan generated by method or processed tensor. + """ + + kwargs.update({k: v for k, v in self._config.items() if k not in kwargs}) + return self._method(*args, **kwargs) + + def get_config(self, key: str, default: Any) -> Any: + """Get the value in config""" + + return self._config.get(key, default) + + def copy(self, name: str = None, method: callable = None, config: dict = None): + """Copy a executor + + Parameters + ---------- + name: str + The name for new executor. + method: str + The method for new execute. + config: dict + The config for new execute + + Returns + ------- + new_strategy: Strategy + The copied strategy + """ + + new_config = config or {} + new_config.update({k: v for k, v in self._config.items() if k not in new_config}) + return Executor(name or self._name, method or self._method, new_config) + + @property + def name(self): + return self._name + + +class Strategy(object): + """Strategy for process tensor + + Parameters + ---------- + name: str + The name. + tensor_type: str + The tensor type. + stage: str + The init stage + """ + + def __init__(self, name: str, tensor_type: str, stage: str = "default"): + self._name = name + self._tensor_type = tensor_type + self._stage = stage + self._executors = {} + + def __str__(self): + return "{}({} @ {}) ".format(self._name, self._tensor_type, self._stage) + "; ".join( + ["{}:{}".format(k, v) for k, v in self._executors.items()] + ) + + def inspect(self) -> dict: + """Get inspect of strategy + + Returns + ------- + inspect: dict + The inspect of the strategy. + """ + + return {"{}({})".format(s, self._tensor_type): str(e) for s, e in self._executors.items()} + + def __call__(self, *args, **kwargs) -> Any: + return self.apply(*args, **kwargs) + + def apply(self, *args, **kwargs) -> Any: + """Apply the strategy + + Parameters + ---------- + args: list + The arguments for run method. + kwargs: dict + The key word arguments for run method. + + Returns + ------- + plan or tensot: + The plan generated by method or processed tensor. + """ + + return self.get_executor().execute(*args, **kwargs) + + def change_stage(self, stage: str): + """Change the stage of strategy""" + + self._stage = stage + + def add_executor(self, stage: str, executor: Executor): + """Add a executor to strategy + + Parameters + ---------- + stage: str + The mark of the executor. + executor: Executor + The executor to process tensor. + """ + + self._executors[stage] = executor + if not self._stage: + self._stage = stage + + def get_executor(self) -> Tuple[callable, dict]: + """Get executor of current stage + + Returns + ------- + executor: tuple + The method and config to execute strategy + """ + + if self._stage in self._executors: + return self._executors[self._stage] + return self._executors["default"] + + def get_config(self, key: str, default: Any) -> Any: + """Get the value in config""" + + return self.get_executor().get_config(key, default) + + def support_stage(self, stage: str) -> bool: + """Check if the strategy support a stage + + Parameters + ---------- + stage: str + The mark of the executor + + Returns + ------- + support: bool + Whether the strategy support the strategy + """ + + return stage in self._executors or "default" in self._executors + + def copy( + self, + name: str = None, + tensor_type: str = None, + stage: str = None, + configs: Dict[str, dict] = None, + ): + """Copy a strategy + + Parameters + ---------- + name: str + The name for new strategy + tensor_type: + The tensor type for new strategy + stage: str + The init stage for new strategy + configs: dict + The method config of new executors. + + Returns + ------- + new_strategy: Strategy + The copied strategy + """ + + configs = configs or {} + strategy = Strategy( + name or self._name, tensor_type or self._tensor_type, stage or self._stage + ) + for st_name, executor in self._executors.items(): + new_executor = executor.copy(config=configs.get(st_name, {})) + strategy.add_executor(st_name, new_executor) + return strategy + + +class BaseTool(object): + """Basic tool of MSC + + Parameters + ---------- + stage: str + The stage of tool + plan_file: str + The plan file path. + strategys: list[dict] + The strategys of the tool. + cache_processed: bool + Whether to cache processed tensor. + options: dict + The extra options for the tool + debug_level: int + The debug level. + verbose_step: int + The verbose interval step. + logger: logging.Logger + The logger + """ + + def __init__( + self, + stage: str, + plan_file: str, + strategys: dict, + cache_processed: bool = True, + options: dict = None, + debug_level: int = 0, + verbose_step: int = 50, + logger: logging.Logger = None, + ): + self._stage = stage + if os.path.isfile(plan_file): + self._plan = msc_utils.load_dict(plan_file) + else: + self._plan = {} + self._strategys = self._parse_strategys(msc_utils.copy_dict(strategys)) + self._cache_processed = cache_processed + self._options = options or {} + self._debug_level = debug_level + self._verbose_step = verbose_step + self._logger = logger or msc_utils.get_global_logger() + title = "{}.SETUP({} @ {})".format(self.tool_type().upper(), self._stage, self.framework()) + self._logger.info(msc_utils.msg_block(title, self.setup(), width=0)) + if self._debug_level >= 3 and self._plan: + self._logger.debug( + msc_utils.msg_block("{}.PLAN".format(self.tool_type().upper()), self._plan) + ) + + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + self._tensor_cache = {} + self._enabled, self._is_training = True, True + self._graphs, self._weights = [], {} + self._graph_id, self._forward_cnt = 0, 0 + self._processed_tensor = {} + return { + "style": self.tool_style(), + "strategys": {k: v.inspect() for k, v in self._strategys.items()}, + "cache_processed": self._cache_processed, + "options": self._options, + "planed_num": len(self._plan), + "verbose_step": self._verbose_step, + "debug_level": self._debug_level, + } + + def reset( + self, + graphs: List[MSCGraph], + weights: List[Dict[str, tvm.nd.array]], + cache_dir: msc_utils.MSCDirectory = None, + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Reset the tool with graphs and weights + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + cache_dir: MSCDirectory + cache path for save/load info + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + self._forward_cnt = 0 + self._tensor_cache = {} + if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")): + cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json")) + self.load_cache(cache_dir, cache_info) + else: + graphs, weights = self.load_graphs(graphs, weights) + self._graphs, self._weights = graphs, {} + for sub_weights in weights: + self._weights.update(sub_weights) + self._logger.debug( + "%s load %d graphs and %d weights", + self.tool_type().upper(), + len(self._graphs), + len(self._weights), + ) + return self._graphs, weights + + def change_stage(self, stage: str): + """Change the stage of tools and strategy""" + + self._stage = stage + for strategy in self._strategys.values(): + strategy.change_stage(stage) + + def destory(self): + """Destory tool""" + + self._graphs, self._weights = [], {} + + def load_graphs( + self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Load the graphs and weights + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + return graphs, weights + + def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): + """Save runner to cache + + Parameters + ------- + cache_dir: MSCDirectory + cache path for save/load info + cache_info: dict + The cache_info + """ + + return None + + def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: + """Save runner to cache + + Parameters + ------- + cache_dir: MSCDirectory + cache path for save/load info + + Returns + ------- + cache_info: dict + The cache_info. + """ + + return {} + + def execute_before_build(self, *args, **kwargs): + """Execute before model build + + Parameters + ---------- + args: list + The arguments for model build. + kwargs: dict + The key word arguments for model build. + """ + + if self._enabled: + self._graph_id = self._infer_graph_id(kwargs) + self._processed_tensor = {} + self._logger.debug("%sStart Build", self.msg_mark(in_forward=False)) + self._execute_before_build(*args, **kwargs) + + def _execute_before_build(self, *args, **kwargs): + """Execute before model build + + Parameters + ---------- + args: list + The arguments for model build. + kwargs: dict + The key word arguments for model build. + """ + + return None + + def execute_after_build(self, output: Any) -> Any: + """Execute after model build + + Parameters + ---------- + output: Any + The output reference of the model. + + Returns + ------- + output: Any + The modified output reference. + """ + + if self._enabled: + output = self._execute_after_build(output) + self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False)) + return output + + def _execute_after_build(self, output: Any) -> Any: + """Execute after model build + + Parameters + ---------- + output: Any + The output reference of the model. + + Returns + ------- + output: Any + The modified output reference. + """ + + return output + + def execute_before_forward(self, *args, **kwargs): + """Execute before model forward + + Parameters + ---------- + args: list + The arguments for model forward. + kwargs: dict + The key word arguments for model forward. + """ + + if self._enabled: + self._graph_id = self._infer_graph_id(kwargs) + self._processed_tensor = {} + if self.on_debug(2): + self._logger.debug("%sStart Forward", self.msg_mark()) + self._execute_before_forward(*args, **kwargs) + + def _execute_before_forward(self, *args, **kwargs): + """Execute before model forward + + Parameters + ---------- + args: list + The arguments for model forward. + kwargs: dict + The key word arguments for model forward. + """ + + return None + + def execute_after_forward(self, output: Any) -> Any: + """Execute after model forward + + Parameters + ---------- + output: Any + The output reference of the model. + + Returns + ------- + output: Any + The modified output reference. + """ + + if self._enabled: + output = self._execute_after_forward(output) + if self.on_debug(2): + self._logger.debug( + "%sEnd Forward, process %d tensors", + self.msg_mark(), + len(self._processed_tensor), + ) + self._forward_cnt += 1 + return output + + def _execute_after_forward(self, output: Any) -> Any: + """Execute after model forward + + Parameters + ---------- + output: Any + The output reference of the model. + + Returns + ------- + output: Any + The modified output reference. + """ + + return output + + def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> Any: + """Process tensor + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scope: str + The scope mark teacher| student| null + + Returns + ------- + tensor: Any + The processed tensor. + """ + + if not self._support_scope(scope): + return tensor + strategys = self._get_tensor_strategys(name, consumer) + strategy_mark = ".".join([s.get_executor().name for s in strategys]) + cached_tensor = self._get_processed(name, consumer, strategy_mark) + if cached_tensor is not None: + self.debug_tensor(cached_tensor, name, consumer, "cached({})".format(strategy_mark)) + return cached_tensor + process = self._get_tensor_cache(name, consumer, "process") + if process is None: + process = self._check_tensor(name, consumer) + self._save_tensor_cache(name, consumer, "process", process) + if process and self.on_debug(3): + self._logger.debug("%sprocess tensor %s-%s", self.msg_mark(), name, consumer) + if not process: + return tensor + tensor = self._process_tensor(tensor, name, consumer, strategys) + self._save_processed(name, consumer, tensor, strategy_mark) + self.debug_tensor(tensor, name, consumer, "processed({})".format(strategy_mark)) + return tensor + + def _support_scope(self, scope: str) -> bool: + """Check if the scope si supported + + Parameters + ------- + scope: str + The scope mark, should be null or ToolScope + + Returns + ------- + vaild: bool + Whether to process the tensor. + """ + + if not scope: + return True + return scope != ToolScope.TEACHER + + def _get_processed(self, name: str, consumer: str, strategy_mark: str) -> Any: + """Get cached processed tensor + + Parameters + ------- + name: str + The name of the tensor. + consumer: str + The name of the consumer. + strategy_mark: str + The sstrategy mark. + + Returns + ------- + processed_tensor + The cached processed tensor. + """ + + if self._cache_processed: + return self._processed_tensor.get(name + "." + strategy_mark) + return None + + def _save_processed(self, name: str, consumer: str, tensor: Any, strategy_mark: str): + """Save cached processed tensor + + Parameters + ------- + name: str + The name of the tensor. + consumer: str + The name of the consumer. + tensor: Any + The processed tensor + strategy_mark: str + The sstrategy mark. + """ + + if self._cache_processed: + self._processed_tensor[name + "." + strategy_mark] = tensor + else: + self._processed_tensor[self.to_tensor_id(name, consumer)] = None + + def _check_tensor(self, name: str, consumer: str) -> bool: + """Check if the tensor should be processed + + Parameters + ------- + name: str + The name of the tensor. + consumer: str + The name of the consumer. + + Returns + ------- + vaild: bool + Whether to process the tensor. + """ + + strategys = self._get_tensor_strategys(name, consumer) + return len(strategys) > 0 + + def _process_tensor( + self, tensor: Any, name: str, consumer: str, strategys: List[Strategy] + ) -> Any: + """Process tensor + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + return tensor + + def visualize(self, visual_dir: msc_utils.MSCDirectory): + """Visualize MSCGraphs + + Parameters + ------- + visual_dir: MSCDirectory + Visualize path for saving graph + """ + + return None + + def update_plan(self, plan: dict): + """Update the plan + + Parameters + ---------- + plan: dict + The new plan. + """ + + self._plan.update(plan) + + def get_plan(self, name: str) -> dict: + """Get the plan for name + + Parameters + ---------- + name: str + The plan name. + + Returns + ------- + plan: dict + The plan of the name. + """ + + return self._plan.get(name, {}) + + def finalize(self) -> dict: + """Get the plan""" + + return self._plan + + def enable(self): + """Enable the tool""" + + self._enabled = True + + def disable(self): + """Disable the tool""" + + self._enabled = False + + def train(self): + """Set the tool to train mode""" + + self._is_training = True + + def eval(self): + """Set the tool to eval mode""" + + self._is_training = False + + def to_tensor_id(self, name: str, consumer: str) -> str: + """Concat name to unique id + + Parameters + ---------- + name: str + The name of tensor. + consumer: str + The name of consumer. + + Returns + ------- + tensor_id: str + The unique name of edge. + """ + + return "{}-c-{}".format(name, consumer) + + def from_tensor_id(self, tensor_id: str) -> Tuple[str]: + """Split name from unique id + + Parameters + ---------- + tensor_id: str + The unique name of edge. + + Returns + ------- + name: str + The name of tensor. + consumer: str + The name of consumer. + """ + + return tensor_id.split("-c-") + + def is_weight(self, name: str) -> bool: + """Check if the tensor is weight + + Parameters + ---------- + name: str + The name of tensor. + + Returns + ------- + is_weight: bool + Whether the name is weight. + """ + + return name in self._weights + + def on_debug(self, debug_level: int = 1) -> bool: + """Check if should log + + Parameters + ------- + debug_level: int + The given debug_level. + + Returns + ------- + on_debug: bool + Whether to log debug info. + """ + + if self._forward_cnt % self._verbose_step != 0: + return False + return self._debug_level >= debug_level + + def msg_mark(self, in_forward=True) -> str: + """Get the debug title + + Returns + ------- + msg_mark: str + Get the debug title. + """ + + title = "{}.G[{}]".format(self.tool_type().upper(), self._graph_id) + if in_forward: + title += ".F[{}]".format(self._forward_cnt) + title += "({}) ".format(self._stage) + return title + + def debug_tensor( + self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 2 + ) -> str: + """Get the debug tensor info + + Parameters + ------- + tensor: array_like + The tensor + name: str + The name of tensor. + consumer: str + The name of consumer. + t_mark: str + The mark of tensor. + debug_level: int + The given debug_level. + """ + + if self.on_debug(debug_level): + self._logger.debug( + "%s%s %s-%s: %s", + self.msg_mark(), + t_mark, + name, + consumer, + msc_utils.inspect_array(tensor), + ) + + def _infer_graph_id(self, kwargs: dict) -> int: + """Infer graph id from kwargs + + Parameters + ---------- + kwargs: dict + The kwargs for execute. + """ + + if "graph_id" in kwargs: + return kwargs.pop("graph_id") + if "graph_name" in kwargs: + name = kwargs.pop("graph_name") + for idx, g in enumerate(self._graphs): + if g.name == name: + return idx + return 0 + + def get_nodes(self) -> Iterable[MSCJoint]: + """Get all the nodes in the graphs. + + Returns + ------- + nodes: generator + The generator of nodes. + """ + + for g in self._graphs: + for n in g.get_nodes(): + yield n + + def find_node(self, name: str) -> MSCJoint: + """Find node by name. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + node: MSCJoint + The found node. + """ + + for g in self._graphs: + if g.has_node(name): + return g.find_node(name) + raise Exception("Can not find node {} from {} graphs".format(name, len(self._graphs))) + + def find_tensor(self, name: str) -> MSCTensor: + """Find tensor by name. + + Parameters + ---------- + name: string + The name of the tensor. + + Returns + ------- + node: MSCTensor + The found tensor. + """ + + for g in self._graphs: + if g.has_tensor(name): + return g.find_tensor(name) + raise Exception("Can not find tensor {} from {} graphs".format(name, len(self._graphs))) + + def find_producer(self, name: str) -> MSCJoint: + """Find producer by tensor_name . + + Parameters + ---------- + name: string + The name of the tensor. + + Returns + ------- + node: MSCJoint + The found prducer. + """ + + for g in self._graphs: + if g.has_tensor(name): + return g.find_producer(name) + raise Exception( + "Can not find producer of {} from {} graphs".format(name, len(self._graphs)) + ) + + def find_consumers(self, name: str) -> List[MSCJoint]: + """Find consumers by tensor_name. + + Parameters + ---------- + name: string + The name of the tensor. + + Returns + ------- + node: list + The found consumers. + """ + + for g in self._graphs: + if g.has_tensor(name): + return g.find_consumers(name) + raise Exception( + "Can not find consumers of {} from {} graphs".format(name, len(self._graphs)) + ) + + def get_data(self, name: str) -> np.ndarray: + """Get the data by name + + Parameters + ------- + name: str + The tensor name + + Returns + ------- + data: np.ndarray + The data. + """ + + if name in self._weights: + return msc_utils.cast_array(self._weights[name]) + raise Exception("Can not find data {} from {} weights".format(name, len(self._weights))) + + def _parse_strategys(self, strategy_list: dict) -> Dict[str, Strategy]: + """Parse the strategy to get valid strategy + + Parameters + ------- + strategy_list: dict + The given strategy + + Returns + ------- + strategys: dict + The parsed strategy. + """ + + strategys = {} + assert isinstance(strategy_list, list) and all( + isinstance(s, dict) for s in strategy_list + ), "Strategy should be given as list of dict" + for stra in strategy_list: + method_cls_name = stra.pop("method_cls") if "method_cls" in stra else "default" + method_cls = msc_utils.get_registered_tool_method( + self.framework(), self.tool_type(), method_cls_name + ) + method_name = stra.pop("method") if "method" in stra else "default" + if hasattr(method_cls, method_name): + method = getattr(method_cls, method_name) + else: + default_cls = msc_utils.get_registered_tool_method( + MSCFramework.MSC, self.tool_type(), method_cls_name + ) + assert hasattr( + default_cls, method_name + ), "Can not find method {} from neighter {} nor {}".format( + method_name, method_cls, default_cls + ) + method = getattr(default_cls, method_name) + tensor_types = stra.pop("tensor_types") if "tensor_types" in stra else ["all"] + if "op_types" in stra: + op_types = stra.pop("op_types") + marks = [("{}.{}".format(s, t), t) for s, t in product(op_types, tensor_types)] + elif "op_names" in stra: + op_names = stra.pop("op_names") + marks = [("{}.{}".format(s, t), t) for s, t in product(op_names, tensor_types)] + else: + marks = [("default", "all")] + stages = stra.pop("stages") if "stages" in stra else ["default"] + for mark, t_type in marks: + if mark not in strategys: + strategys[mark] = Strategy(mark, t_type, self._stage) + for stage in stages: + strategys[mark].add_executor( + stage, Executor(method_name, method, copy.deepcopy(stra)) + ) + return strategys + + def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any): + """Save the data to tensor cache + + Parameters + ------- + name: str + The tensor name. + consumer: str + The name of the consumer. + key: str + The data key. + value: any + The value to cache. + """ + + tensor_id = self.to_tensor_id(name, consumer) + if tensor_id not in self._tensor_cache: + self._tensor_cache[tensor_id] = {} + self._tensor_cache[tensor_id][key] = value + + def _get_tensor_cache(self, name: str, consumer: str, key: str) -> Any: + """Get the cached tensor data + + Parameters + ------- + name: str + The tensor name. + consumer: str + The name of the consumer. + key: str + The data key. + + Returns + ------- + value: any + The cached value. + """ + + tensor_id = self.to_tensor_id(name, consumer) + if tensor_id not in self._tensor_cache: + return None + return self._tensor_cache[tensor_id].get(key) + + def _get_tensor_strategys(self, name: str, consumer: str) -> List[Strategy]: + """Get the strategys by name and consumer + + Parameters + ------- + name: str + The tensor name. + consumer: str + The name of the consumer. + + Returns + ------- + strategys: list + The strategys for the tensor. + """ + + tensor_id = self.to_tensor_id(name, consumer) + mark = "strategy.{}".format(self._stage) + if mark not in self._tensor_cache.get(tensor_id, {}): + if self.is_weight(name): + consumer = self.find_node(consumer) + name_refs = [ + consumer.name + ".weight", + consumer.optype + ".weight", + consumer.optype + ".all", + ] + elif consumer == "exit": + producer = self.find_producer(name) + name_refs = [ + producer.name + ".output", + producer.optype + ".output", + producer.optype + ".all", + ] + else: + consumer = self.find_node(consumer) + producer = self.find_producer(name) + name_refs = [ + producer.name + ".output", + producer.optype + ".output", + producer.optype + ".all", + consumer.name + ".input", + consumer.optype + ".input", + consumer.optype + ".all", + ] + strategys = [] + for n in name_refs: + if n in self._strategys and self._strategys[n].support_stage(self._stage): + strategys.append(self._strategys[n]) + d_strategy = self._strategys.get("default") + if not strategys and d_strategy and d_strategy.support_stage(self._stage): + strategys.append(d_strategy) + self._save_tensor_cache(name, consumer, mark, strategys) + return self._get_tensor_cache(name, consumer, mark) + + def _get_tensor_strategy(self, name: str, consumer: str) -> Strategy: + """Get the unique strategy by name and consumer + + Parameters + ------- + name: str + The tensor name. + consumer: str + The name of the consumer. + + Returns + ------- + strategy: Strategy + The unique strategy for the tensor. + """ + + strategys = self._get_tensor_strategys(name, consumer) + if not strategys: + return None + assert len(strategys) == 1, "{} should only has 1 strategy, get {}".format( + self._stage, strategys + ) + return strategys[0] + + def get_graph(self): + return self._graphs[self._graph_id] + + @classmethod + def tool_type(cls): + return ToolType.BASE + + @classmethod + def framework(cls): + return MSCFramework.MSC + + @classmethod + def tool_style(cls): + return "base" diff --git a/python/tvm/contrib/msc/core/utils/dataset.py b/python/tvm/contrib/msc/core/utils/dataset.py index 7835eb346e03..a96369b320f1 100644 --- a/python/tvm/contrib/msc/core/utils/dataset.py +++ b/python/tvm/contrib/msc/core/utils/dataset.py @@ -51,7 +51,7 @@ def __init__(self, folder: str, start: int = 0, end: int = -1): self._end = min(end, self._info["num_datas"]) def __str__(self): - return "<{}> @ {}".format(self._class__.__name__, self._folder) + return "<{}> @ {}".format(self.__class__.__name__, self._folder) def __getitem__(self, idx): if idx + self._start >= self._end: @@ -167,6 +167,10 @@ def _data_info(self, name: str) -> dict: raise NotImplementedError("_data_info is not implemented for BaseDataLoader") + @property + def folder(self): + return self._folder + @property def info(self): return self._info @@ -290,6 +294,9 @@ def __init__( def setup(self, options: dict): return {"num_datas": 0} + def __str__(self): + return "<{}> @ {}".format(self.__class__.__name__, self._folder) + def __enter__(self): return self @@ -298,9 +305,22 @@ def __exit__(self, exception_type, exception_value, traceback): self.finalize() def finalize(self): + """Finalize the saver""" + with open(os.path.join(self._folder, "datas_info.json"), "w") as f: f.write(json.dumps(self._info, indent=2)) + def is_finalized(self) -> bool: + """Check if the saver is finalized + + Returns + ------- + is_finalized: bool + Whether the saver is finalized. + """ + + return os.path.isfile(os.path.join(self._folder, "datas_info.json")) + def reset(self): self._current = 0 @@ -353,6 +373,10 @@ def _save_batch(self, *args, **kwargs) -> dict: raise NotImplementedError("_save_batch is not implemented for BaseDataSaver") + @property + def folder(self): + return self._folder + @property def info(self): return self._info @@ -400,6 +424,31 @@ def setup(self, options: dict): self._output_names = options.get("output_names", []) return {"inputs": {}, "outputs": {}, "num_datas": 0} + def finalize(self): + """Finalize the saver""" + + super().finalize() + with open(os.path.join(self._folder, "datas_info.txt"), "w") as f: + for name in self._input_names: + info = self._info["inputs"][name] + f.write("{} {} {}\n".format(name, info.get("save_name", name), info["bytes"])) + for name in self._output_names: + info = self._info["outputs"][name] + f.write("{} {} {}\n".format(name, info.get("save_name", name), info["bytes"])) + + def is_finalized(self) -> bool: + """Check if the saver is finalized + + Returns + ------- + is_finalized: bool + Whether the saver is finalized. + """ + + if not super().is_finalized(): + return False + return os.path.isfile(os.path.join(self._folder, "datas_info.txt")) + def save_batch( self, inputs: Union[Dict[str, np.ndarray], List[np.ndarray]], diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index f59295640cda..4936bf28e0a4 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -109,17 +109,15 @@ def add_file(self, name: str, contains: str) -> str: f.write(contains) return file_path - def move_file(self, src_file: str, dst_folder: Any, dst_file: str = None): - """Move a file to another folder + def move(self, src_path: str, dst_path: str = None): + """Move a file or folder to another folder Parameters ---------- - src_file: str - The name of the source file. - dst_folder: MSCDirectory - The target folder. - dst_file: str - The target file name. + src_path: str + The name of the source file or folder. + dst_path: str + The target file name or folder path. Returns ------- @@ -127,23 +125,25 @@ def move_file(self, src_file: str, dst_folder: Any, dst_file: str = None): The abs file path. """ - src_path = os.path.join(self.relpath(src_file)) - assert os.path.isfile(src_path), "Source file {} not exist".format(src_path) - dst_path = dst_folder.relpath(dst_file or src_file) + if src_path != os.path.abspath(src_path): + src_path = os.path.join(self.relpath(src_path)) + assert os.path.isfile(src_path), "Source path {} not exist".format(src_path) + if not dst_path: + dst_path = self.relpath(os.path.basename(src_path)) + if dst_path != os.path.abspath(dst_path): + dst_path = self.relpath(dst_path) os.rename(src_path, dst_path) return dst_path - def copy_file(self, src_file: str, dst_folder: Any, dst_file: str = None): + def copy(self, src_path: str, dst_path: str = None): """Copy a file to another folder Parameters ---------- - src_file: str - The name of the source file. - dst_folder: MSCDirectory - The target folder. - dst_file: str - The target file name. + src_path: str + The name of the source file or folder. + dst_path: str + The target file name or folder path. Returns ------- @@ -151,10 +151,19 @@ def copy_file(self, src_file: str, dst_folder: Any, dst_file: str = None): The abs file path. """ - src_path = os.path.join(self.relpath(src_file)) - assert os.path.isfile(src_path), "Source file {} not exist".format(src_path) - dst_path = dst_folder.relpath(dst_file or src_file) - shutil.copy2(src_path, dst_path) + if src_path != os.path.abspath(src_path): + src_path = os.path.join(self.relpath(src_path)) + assert os.path.exists(src_path), "Source path {} not exist".format(src_path) + if not dst_path: + dst_path = self.relpath(os.path.basename(src_path)) + if dst_path != os.path.abspath(dst_path): + dst_path = self.relpath(dst_path) + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + else: + if os.path.isdir(dst_path): + os.remove(dst_path) + shutil.copytree(src_path, dst_path) return dst_path def create_dir(self, name: str, keep_history: bool = True, cleanup: bool = False) -> Any: diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 5d8d4fdd5a87..782be2604950 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -19,37 +19,15 @@ import os import json import copy -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Dict, Any, Union from distutils.version import LooseVersion import numpy as np import tvm +from tvm.contrib.msc.core import _ffi_api from .namespace import MSCFramework -def inspect_array(data: np.ndarray) -> Dict[str, Any]: - """Inspect the array - - Parameters - ---------- - data: np.ndarray - The data to inspect - - Returns - ------- - info: dict - The data info. - """ - - return { - "shape": list(data.shape), - "dtype": data.dtype.name, - "max": float(data.max()), - "min": float(data.min()), - "avg": float(data.sum() / data.size), - } - - class MSCArray(object): """MSC wrapper for array like object @@ -67,11 +45,14 @@ def __str__(self): def _analysis(self, data: Any) -> Tuple[str, np.ndarray]: if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): - return "np", np.array(data) + return "list", np.array(data) if isinstance(data, np.ndarray): return "np", data if isinstance(data, tvm.runtime.NDArray): return "tvm", data.asnumpy() + if isinstance(data, tvm.relax.Var): + shape = [int(s) for s in data.struct_info.shape] + return "var", np.zeros(shape, dtype=data.struct_info.dtype) try: import torch # pylint: disable=import-outside-toplevel @@ -84,6 +65,7 @@ def _analysis(self, data: Any) -> Tuple[str, np.ndarray]: def abstract(self) -> str: """Get abstract describe of the data""" + return "[S:{},D:{}] Max {:g}, Min {:g}, Avg {:g}".format( ";".join([str(s) for s in self._data.shape]), self._data.dtype.name, @@ -92,6 +74,36 @@ def abstract(self) -> str: self._data.sum() / self._data.size, ) + @classmethod + def is_array(cls, data: Any) -> bool: + """Check if the data is array like + + Parameters + ---------- + data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... + The data object. + + Returns + ------- + is_array: bool + Whether the data is array like. + """ + + normal_types = (np.ndarray, tvm.runtime.NDArray, tvm.relax.Var) + if isinstance(data, normal_types): + return True + if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): + return True + try: + import torch # pylint: disable=import-outside-toplevel + + if isinstance(data, torch.Tensor): + return True + except: # pylint: disable=bare-except + pass + + return False + @property def type(self): return self._type @@ -115,9 +127,40 @@ def cast_array(data: Any) -> np.ndarray: The output as numpy array. """ + assert MSCArray.is_array(data), "{} is not array like".format(data) return MSCArray(data).data +def inspect_array(data: Any, as_str: bool = True) -> Union[Dict[str, Any], str]: + """Inspect the array + + Parameters + ---------- + data: array like + The data to inspect + as_str: bool + Whether inspect the array as string. + + Returns + ------- + info: dict + The data info. + """ + + if not MSCArray.is_array(data): + return str(data) + if as_str: + return str(MSCArray(data)) + data = cast_array(data) + return { + "shape": list(data.shape), + "dtype": data.dtype.name, + "max": float(data.max()), + "min": float(data.min()), + "avg": float(data.sum() / data.size), + } + + def compare_arrays( golden: Dict[str, np.ndarray], datas: Dict[str, np.ndarray], @@ -333,3 +376,23 @@ def get_version(framework: str) -> List[int]: raw_version = "1.0.0" return LooseVersion(raw_version).version + + +def compare_version(given_version: List[int], target_version: List[int]) -> int: + """Compare version + + Parameters + ---------- + given_version: list + The version in . + + target_version: list + The version in . + + Returns + ------- + compare_res: int + The compare result: 0 for same version, 1 for greater version, -1 for less version + """ + + return int(_ffi_api.CompareVersion(given_version, target_version)) diff --git a/python/tvm/contrib/msc/core/utils/namespace.py b/python/tvm/contrib/msc/core/utils/namespace.py index 63a4365c5419..6744548ddfc4 100644 --- a/python/tvm/contrib/msc/core/utils/namespace.py +++ b/python/tvm/contrib/msc/core/utils/namespace.py @@ -47,6 +47,10 @@ def delete(cls, key: str): def contains(cls, key: str): return key in cls.MAP + @classmethod + def reset(cls): + cls.MAP = {} + class MSCKey: """Keys for the MSCMap""" @@ -54,12 +58,14 @@ class MSCKey: WORKSPACE = "workspace" VERBOSE = "verbose" GLOBALE_LOGGER = "global_logger" - REGISTERED_FUNCS = "registered_funcs" - REGISTERED_TOOLS = "registered_tools" - MSC_STAGE = "msc_stage" TIME_STAMPS = "time_stamps" + PRUNERS = "pruners" + QUANTIZERS = "quantizers" + DISTILLERS = "distillers" + TRACKERS = "trackers" + FUSED_CNT = "fused_cnt" diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index 2b921d33e004..31ae8942a106 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -16,7 +16,40 @@ # under the License. """tvm.contrib.msc.core.utils.register""" -from .namespace import MSCMap, MSCKey, MSCFramework +from typing import Any, Optional +from .namespace import MSCFramework + + +class MSCRegistery: + """The registery for MSC""" + + REGISTERY = {} + MSC_FUNCS = "msc_funcs" + MSC_TOOLS_CLS = "msc_tools_cls" + MSC_TOOLS_METHOD = "msc_tools_method" + + @classmethod + def register(cls, key: str, value: Any): + cls.REGISTERY[key] = value + return value + + @classmethod + def unregister(cls, key: str): + if key in cls.REGISTERY: + return cls.REGISTERY.pop(key) + return None + + @classmethod + def get(cls, key: str, default: Optional[Any] = None) -> Any: + return cls.REGISTERY.get(key, default) + + @classmethod + def contains(cls, key: str): + return key in cls.REGISTERY + + @classmethod + def reset(cls): + cls.REGISTERY = {} def register_func(name: str, func: callable, framework: str = MSCFramework.MSC): @@ -32,11 +65,11 @@ def register_func(name: str, func: callable, framework: str = MSCFramework.MSC): Should be from MSCFramework. """ - funcs = MSCMap.get(MSCKey.REGISTERED_FUNCS, {}) + funcs = MSCRegistery.get(MSCRegistery.MSC_FUNCS, {}) if framework not in funcs: funcs[framework] = {} funcs[framework][name] = func - MSCMap.set(MSCKey.REGISTERED_FUNCS, funcs) + MSCRegistery.register(MSCRegistery.MSC_FUNCS, funcs) def get_registered_func(name: str, framework: str = MSCFramework.MSC): @@ -55,7 +88,100 @@ def get_registered_func(name: str, framework: str = MSCFramework.MSC): The registered function. """ - funcs = MSCMap.get(MSCKey.REGISTERED_FUNCS, {}) + funcs = MSCRegistery.get(MSCRegistery.MSC_FUNCS, {}) if framework not in funcs: return None return funcs[framework].get(name) + + +def register_tool_cls(tool_cls: Any): + """Register a tool class. + + Parameters + ---------- + tool_cls: class + The tool class to be registered. + """ + + tools_cls = MSCRegistery.get(MSCRegistery.MSC_TOOLS_CLS, {}) + for key in ["framework", "tool_type", "tool_style"]: + assert hasattr(tool_cls, key), "{} should be given to register tool class".format(key) + if tool_cls.framework() not in tools_cls: + tools_cls[tool_cls.framework()] = {} + framework_tools = tools_cls[tool_cls.framework()] + if tool_cls.tool_type() not in framework_tools: + framework_tools[tool_cls.tool_type()] = {} + tools = framework_tools[tool_cls.tool_type()] + tools[tool_cls.tool_style()] = tool_cls + MSCRegistery.register(MSCRegistery.MSC_TOOLS_CLS, tools_cls) + + +def get_registered_tool_cls(framework: str, tool_type: str, tool_style: str) -> Any: + """Get the registered tool class. + + Parameters + ---------- + framework: string + Should be from MSCFramework. + tool_type: string + The type of the tool prune| quantize| distill| debug. + tool_style: string + The style of the tool. + + Returns + ------- + tool_cls: class + The registered tool class. + """ + + tools_cls = MSCRegistery.get(MSCRegistery.MSC_TOOLS_CLS, {}) + if tool_style == "all": + return tools_cls.get(framework, {}).get(tool_type, {}) + return tools_cls.get(framework, {}).get(tool_type, {}).get(tool_style) + + +def register_tool_method(method_cls: Any, method_style: str = "default"): + """Register a tool method. + + Parameters + ---------- + method_cls: class + The method class. + method_style: string + The style of the method. + """ + + tools_method = MSCRegistery.get(MSCRegistery.MSC_TOOLS_METHOD, {}) + assert hasattr(method_cls, "framework") and hasattr( + method_cls, "tool_type" + ), "framework and tool_type should be given to register tool method" + if method_cls.framework() not in tools_method: + tools_method[method_cls.framework()] = {} + register_name = "{}.{}".format(method_cls.tool_type(), method_style) + tools_method[method_cls.framework()][register_name] = method_cls + MSCRegistery.register(MSCRegistery.MSC_TOOLS_METHOD, tools_method) + + +def get_registered_tool_method( + framework: str, tool_type: str, method_style: str = "default" +) -> Any: + """Get the registered tool method. + + Parameters + ---------- + framework: string + Should be from MSCFramework. + tool_type: string + The type of the tool prune| quantize| distill| debug. + method_style: string + The style of the method. + + Returns + ------- + method_cls: class + The method class. + """ + + tools_method = MSCRegistery.get(MSCRegistery.MSC_TOOLS_METHOD, {}) + register_name = "{}.{}".format(tool_type, method_style) + return tools_method.get(framework, {}).get(register_name) diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py index 6fd26e04f17f..c686647bfefa 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=not-context-manager +# pylint: disable=not-context-manager,unused-import """tvm.contrib.msc.framework.tensorflow.runtime.runner""" import time @@ -30,6 +30,7 @@ from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.framework.tensorflow.codegen import to_tensorflow from tvm.contrib.msc.framework.tensorflow import tf_v1 +from tvm.contrib.msc.framework.tensorflow import tools class WrapSession(tf_v1.Session): diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py new file mode 100644 index 000000000000..94e19b7e074f --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorflow.tools""" + +from .prune import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py new file mode 100644 index 000000000000..8bdd61d3aa12 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorflow.tools.prune""" + +from .pruner import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py new file mode 100644 index 000000000000..b59865c9ce0f --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorflow.tools.prune.pruner""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.prune import BasePruner +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TensorflowPrunerFactory(object): + """Pruner factory for tensorflow""" + + def create(self, base_cls: BasePruner): + class Pruner(base_cls): + """Adaptive pruner for tensorflow""" + + @classmethod + def framework(cls): + return MSCFramework.TENSORFLOW + + return Pruner + + +factory = TensorflowPrunerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py index 1ff74f27b9c4..318334d33302 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py @@ -18,7 +18,7 @@ import os import subprocess -from typing import Dict, Optional, Tuple, List +from typing import Dict, Optional, Tuple, List, Union import numpy as np import tvm @@ -68,6 +68,11 @@ def to_sub_tensorrt( codegen_config["tensorrt_root"] = _ffi_api.GetTensorRTRoot() build_folder = build_folder or msc_utils.msc_dir(keep_history=False, cleanup=True) output_folder = output_folder or msc_utils.msc_dir("msc_output") + depends = {} + if "range_file" in codegen_config: + range_file = codegen_config["range_file"] + codegen_config["range_file"] = os.path.basename(range_file) + depends[codegen_config["range_file"]] = {"src": range_file, "copy_back": True} def _create_depends(folder: msc_utils.MSCDirectory) -> str: if weights: @@ -90,6 +95,10 @@ def _create_depends(folder: msc_utils.MSCDirectory) -> str: with folder.create_dir("utils") as utils_folder: for name, source in get_trt_sources().items(): utils_folder.add_file(name, source) + # copy depends + for path, info in depends.items(): + if os.path.exists(info["src"]): + folder.copy(info["src"], path) def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: with open("engine.log", "w") as log_f: @@ -100,7 +109,10 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: ), "Failed to test engine {} under {}, check engine.log for detail".format( engine_name, os.getcwd() ) - return folder.move_file(engine_name + ".trt", output_folder.create_dir(graph.name)) + for path, info in depends.items(): + if info.get("copy_back", False) and os.path.exists(path): + folder.copy(path, info["src"]) + return folder.move(engine_name + ".trt", output_folder.relpath(engine_name + ".trt")) codegen = CodeGen( graph, @@ -121,9 +133,9 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: def to_tensorrt( mod: tvm.IRModule, graph_infos: List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]], - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - extra_option: Optional[Dict[str, str]] = None, + codegen_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, + print_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, + extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, build_folder: msc_utils.MSCDirectory = None, output_folder: msc_utils.MSCDirectory = None, ) -> Dict[str, str]: @@ -135,9 +147,9 @@ def to_tensorrt( The IRModule of relax. graph_infos: list The translated graph. - codegen_config: dict + codegen_configs: dict or list The config for codegen. - print_config: dict + print_configs: dict ot list The config for print. extra_option: dict The extra option for sub engine. @@ -153,12 +165,18 @@ def to_tensorrt( """ target_options = {} - for graph, weights in graph_infos: + if not isinstance(codegen_configs, (list, tuple)): + codegen_configs = [codegen_configs] * len(graph_infos) + if not isinstance(print_configs, (list, tuple)): + print_configs = [print_configs] * len(graph_infos) + if not isinstance(extra_options, (list, tuple)): + extra_options = [extra_options] * len(graph_infos) + for idx, (graph, weights) in enumerate(graph_infos): options = to_sub_tensorrt( - graph, weights, codegen_config, print_config, build_folder, output_folder + graph, weights, codegen_configs[idx], print_configs[idx], build_folder, output_folder ) - if extra_option: - options.update(extra_option) + if extra_options[idx]: + options.update(extra_options[idx]) target_options[graph.name] = msc_utils.dump_dict(options) mod = tvm.transform.Sequential( [ diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index 615cc4ba31e3..15a42b2cf967 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-import """tvm.contrib.msc.framework.tensorrt.runtime.runner""" import tvm @@ -24,6 +25,7 @@ transform_for_tensorrt, ) from tvm.contrib.msc.framework.tensorrt.codegen import to_tensorrt +from tvm.contrib.msc.framework.tensorrt import tools class TensorRTRunner(BYOCRunner): diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py new file mode 100644 index 000000000000..0247da2642f7 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.tools""" + +from .prune import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py new file mode 100644 index 000000000000..24ef6a62b24b --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.tools.prune""" + +from .pruner import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py new file mode 100644 index 000000000000..de7ccb0747be --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.tools.prune.pruner""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.prune import BasePruner +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TensorRTPrunerFactory(object): + """Pruner factory for tensorrt""" + + def create(self, base_cls: BasePruner): + class Pruner(base_cls): + """Adaptive pruner for tensorrt""" + + @classmethod + def framework(cls): + return MSCFramework.TENSORRT + + return Pruner + + +factory = TensorRTPrunerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py index a4f65b4fe1f1..9401e6047fa5 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-import """tvm.contrib.msc.framework.torch.runtime.runner""" import time @@ -28,6 +29,7 @@ from tvm.contrib.msc.framework.torch.codegen import to_torch from tvm.contrib.msc.framework.torch.frontend import set_weight_alias from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.framework.torch import tools class TorchRunner(ModelRunner): diff --git a/python/tvm/contrib/msc/framework/torch/tools/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/__init__.py new file mode 100644 index 000000000000..ff26491f54af --- /dev/null +++ b/python/tvm/contrib/msc/framework/torch/tools/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.torch.tools""" + +from .prune import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py new file mode 100644 index 000000000000..6364e14945aa --- /dev/null +++ b/python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.torch.tools.prune""" + +from .pruner import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py new file mode 100644 index 000000000000..171c639ceaa3 --- /dev/null +++ b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.torch.tools.prune.pruner""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.prune import BasePruner +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TorchPrunerFactory(object): + """Pruner factory for torch""" + + def create(self, base_cls: BasePruner): + class Pruner(base_cls): + """Adaptive pruner for torch""" + + @classmethod + def framework(cls): + return MSCFramework.TORCH + + return Pruner + + +factory = TorchPrunerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py index b3a3f3bf7045..1e2b5257576b 100644 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-import """tvm.contrib.msc.framework.runtime.tvm.runner""" from typing import Dict, List, Union, Any @@ -21,8 +22,10 @@ import tvm from tvm.contrib.msc.core.runtime import ModelRunner +from tvm.contrib.msc.core.tools import execute_step from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.framework.tvm.codegen import to_relax +from tvm.contrib.msc.framework.tvm import tools class WrapRunnable(object): @@ -41,7 +44,9 @@ def __init__(self, runnable: tvm.relax.VirtualMachine, entry: str = "main"): self._entry = entry def __call__(self, *inputs) -> List[tvm.nd.array]: - return self._runnable[self._entry](*inputs) + execute_step("before_forward", *inputs) + output = self._runnable[self._entry](*inputs) + return execute_step("after_forward", output) class TVMRunner(ModelRunner): diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py new file mode 100644 index 000000000000..91f07fd58149 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tvm.tools""" + +from .prune import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py new file mode 100644 index 000000000000..47c1478611e8 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tvm.tools.prune""" + +from .pruner import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py new file mode 100644 index 000000000000..788f1090bd79 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tvm.tools.prune.pruner""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.prune import BasePruner +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TVMPrunerFactory(object): + """Pruner factory for tvm""" + + def create(self, base_cls: BasePruner): + class Pruner(base_cls): + """Adaptive pruner for tvm""" + + @classmethod + def framework(cls): + return MSCFramework.TVM + + return Pruner + + +factory = TVMPrunerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index f571884860ea..4f31eeacfaee 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -25,7 +25,8 @@ import tvm from tvm.contrib.msc.core.runtime import BaseRunner -from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils @@ -45,10 +46,11 @@ def __init__(self, model, config): # check config for stage in ["inputs", "outputs", "dataset", "prepare", "compile"]: assert stage in config, "{} should be given to run the pipeline".format(stage) + MSCMap.reset() self._model = model self._workspace = msc_utils.set_workspace(config.get("workspace")) log_path = config.get("log_path") or self._workspace.relpath("MSC_LOG", keep_history=False) - if config.get("debug", False) and "verbose" not in config: + if config.get("debug_level", 0) > 0 and "verbose" not in config: verbose = "debug" else: verbose = config.get("verbose", "info") @@ -70,7 +72,7 @@ def setup(self, config: dict) -> dict: The setup info. """ - self._config, self._debug_config = self.update_config(config) + self._config, self._debug_levels = self.update_config(config) self._tools_config = {} self._relax_mod, self._runner = None, None self._data_loader, self._sample_inputs = None, None @@ -102,7 +104,7 @@ def update_config(self, config: dict) -> dict: # update prepare and parse assert "inputs" in config, "inputs should be given to run manager" assert "outputs" in config, "outputs should be given to run manager" - config = msc_utils.copy_dict(config) + config, debug_levels = msc_utils.copy_dict(config), {} for stage in ["prepare", "parse"]: if stage not in config: config[stage] = {} @@ -111,25 +113,36 @@ def update_config(self, config: dict) -> dict: for stage in ["baseline", "optimize", "compile"]: config = self._update_runner_config(config, stage) config = self._update_tool_config(config) - debug_config = {} - def _set_debug(stage, stage_config, default=None): - if "debug" in stage_config: - debug_config[stage] = stage_config.pop("debug") + def _get_tool_stage(tool_type: str) -> str: + if tool_type == ToolType.PRUNER: + return MSCStage.PRUNE + if tool_type == ToolType.QUANTIZER: + return MSCStage.QUANTIZE + if tool_type == ToolType.DISTILLER: + return MSCStage.DISTILL + return tool_type + + def _set_debug_level(stage: str, stage_config: dict, default: int = None) -> dict: + if "debug_level" in stage_config: + debug_levels[stage] = stage_config["debug_level"] elif default is not None: - debug_config[stage] = default - return debug_config + debug_levels[stage] = default + stage_config["debug_level"] = default + return debug_levels - if "debug" in config: - for stage in ["baseline", "optimize", "compile"]: - if stage not in config: - continue - debug_config = _set_debug(stage, config[stage], config["debug"]) - else: - for stage in ["baseline", "optimize", "compile"]: - if stage not in config: + debug_level = config.get("debug_level") + for stage in ["baseline", "optimize", "compile"]: + if stage not in config: + continue + debug_levels = _set_debug_level(stage, config[stage]["run_config"], debug_level) + if "optimize" in config: + for t_type in ToolType.all_types(): + if t_type not in config["optimize"]: continue - debug_config = _set_debug(stage, config[stage]) + debug_levels = _set_debug_level( + _get_tool_stage(t_type), config["optimize"][t_type], debug_level + ) ordered_keys = [ "model_type", "inputs", @@ -141,7 +154,7 @@ def _set_debug(stage, stage_config, default=None): "optimize", "compile", ] - return {k: config[k] for k in ordered_keys if k in config}, debug_config + return {k: config[k] for k in ordered_keys if k in config}, debug_levels def run_pipe(self) -> dict: """Run the pipeline and return object. @@ -365,6 +378,17 @@ def optimize(self, stage_config: dict, use_cache: bool = False) -> BaseRunner: The runner. """ + # run prune + if ToolType.PRUNER in stage_config: + self._tools_config[ToolType.PRUNER] = stage_config[ToolType.PRUNER] + plan_file = stage_config[ToolType.PRUNER]["plan_file"] + if os.path.isfile(plan_file): + self._logger.info("Skip %s with plan_file %s", ToolType.PRUNER, plan_file) + else: + msc_utils.time_stamp(MSCStage.PRUNE) + runner = self._create_tool_runner(MSCStage.PRUNE, stage_config) + runner.apply_tool(ToolType.PRUNER, self._data_loader) + # optimize and get the runner msc_utils.time_stamp(MSCStage.OPTIMIZE) return self._create_runner( @@ -425,7 +449,6 @@ def destory(self, keep_workspace: bool = False): Whether to keep workspace. """ - MSCMap.delete(MSCKey.TIME_STAMPS) if self._runner: self._runner.destory() if not keep_workspace: @@ -465,7 +488,7 @@ def _create_runner( if self._runner: self._runner.destory() - on_debug = self._debug_config.get(stage, False) + debug_level = self._debug_levels.get(stage, 0) cache_dir = msc_utils.get_cache_dir().create_dir(stage) if use_cache else None tools_config = tools_config or {} msc_utils.time_stamp(stage + ".build", False) @@ -475,7 +498,9 @@ def _create_runner( run_config["generate_config"] = {} run_config["generate_config"].update( { - "build_folder": msc_utils.get_build_dir().create_dir(stage, cleanup=not on_debug), + "build_folder": msc_utils.get_build_dir().create_dir( + stage, cleanup=(debug_level == 0) + ), } ) self._logger.debug("Create runner(%s) by %s(%s)", stage, runner_cls.__name__, run_config) @@ -550,10 +575,15 @@ def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: if check_config: loader = msc_utils.IODataLoader(msc_utils.get_dataset_dir().relpath("Golden")) total, passed = 0, 0 - acc_report = {} + acc_report = {"config": check_config} for idx, (inputs, outputs) in enumerate(loader): results = runner.run(inputs) - iter_report = msc_utils.compare_arrays(outputs, results) + iter_report = msc_utils.compare_arrays( + outputs, + results, + atol=check_config.get("atol", 1e-2), + rtol=check_config.get("rtol", 1e-2), + ) total += iter_report["total"] passed += iter_report["passed"] acc_report["iter_" + str(idx)] = iter_report["info"] @@ -562,15 +592,17 @@ def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: title = "Check({}) pass {}".format(stage, report["accuracy"]) self._logger.debug(msc_utils.msg_block(title, acc_report)) msg += " acc {} iters -> {}".format(len(loader), report["accuracy"]) - required_err, err_rate = check_config.get("err_rate", 0), (1 - pass_rate) - if err_rate > required_err >= 0: - raise Exception( - "Failed to profile the runner({}), err_rate {} > required {}".format( - stage, err_rate, required_err + if runner.get_tool(ToolType.PRUNER) or runner.get_tool(ToolType.QUANTIZER): + self._logger.debug("Disable accuracy check(%s) by tools", stage) + else: + required_err, err_rate = check_config.get("err_rate", 0), (1 - pass_rate) + if err_rate > required_err >= 0: + raise Exception( + "Failed to profile the runner({}), err_rate {} > required {}".format( + stage, err_rate, required_err + ) ) - ) - # benchmark model benchmark_config = profile_config.get("benchmark", {}) if benchmark_config: for _ in range(benchmark_config.get("warm_up", 10)): @@ -721,6 +753,15 @@ def _update_tool_config(self, config: dict) -> dict: if "optimize" not in config: return config + for tool_type in ToolType.all_types(): + if tool_type not in config["optimize"]: + continue + tool_config = config["optimize"][tool_type] + if "plan_file" not in tool_config: + tool_config["plan_file"] = "msc_{}.json".format(tool_type) + tool_config["plan_file"] = msc_utils.to_abs_path( + tool_config["plan_file"], msc_utils.get_config_dir() + ) return config def get_runnable(self, ret_type: str = "runner") -> Any: diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index 26c9de5d8b8b..0b5756df7e63 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -169,14 +169,6 @@ class BaseCodeGen { } } - /*! - * \brief Compare version with version in config - * 0 for same version, 1 for greater version, -1 for less version - */ - int CompareVersion(size_t major, size_t minor, size_t patch) { - return CommonUtils::CompareVersion({major, minor, patch}, config_->version); - } - /*! \brief Get the docs for the op*/ virtual const Array GetOpCodes(const MSCJoint& node) = 0; diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index 126e9847d690..bd5d543dc2b1 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -93,6 +93,9 @@ using namespace tvm::script::printer; return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ } \ const String Comment(const MSCJoint& node) { return helper_.Comment(node, config()->prefix); } \ + int CompareVersion(size_t major, size_t minor, size_t patch) { \ + return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ + } \ \ private: \ std::shared_ptr config_; \ diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h index 97b5c221f586..2c07aeb4c741 100644 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -140,7 +140,8 @@ class CppCodeGen : public BaseCodeGen { } } - virtual Map GetTensorCtx(const MSCTensor& tensor) { + /*! \brief Get the tensor context for codegen_tensor*/ + virtual const Map GetTensorCtx(const MSCTensor& tensor) { Map tensor_ctx; MSCJoint producer; if (this->graph()->weight_holders.count(tensor->name)) { @@ -162,6 +163,18 @@ class CppCodeGen : public BaseCodeGen { return tensor_ctx; } + /*! \brief Get the step context for codegen_step*/ + virtual const Map GetStepCtx() { + Map step_ctx; + std::string version = ""; + for (size_t i = 0; i < this->config()->version.size(); i++) { + version += std::to_string(this->config()->version[i]) + + (i < this->config()->version.size() - 1 ? "." : ""); + } + step_ctx.Set("version", version); + return step_ctx; + } + void StartNamespace() { this->stack_.line("namespace tvm {").line("namespace contrib {").line("namespace msc {").line(); } diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index abd581c79737..97e91ca1839b 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -23,6 +23,9 @@ #include "graph.h" +#include +#include +#include #include #include "../printer/prototxt_printer.h" @@ -100,6 +103,10 @@ const Integer MSCTensorNode::DimAt(const String& axis) const { return DimAt(index); } +int32_t MSCTensorNode::LayoutOf(const String& axis) const { + return layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); +} + const Integer MSCTensorNode::GetSize() const { Integer size = Integer(1); for (const auto& s : shape) { @@ -454,15 +461,14 @@ const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor } WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref, - const String& optype, const String& wtype, - const Map& attrs, const MSCTensor& weight, - const Array parents, const Array& friends) { + const String& weight_type, const MSCTensor& weight, + const Array parents, const Map& attrs, + const Array& friends) { ObjectPtr n = make_object(); n->index = index; n->name = std::move(name); n->shared_ref = std::move(shared_ref); - n->optype = std::move(optype); - n->wtype = std::move(wtype); + n->weight_type = std::move(weight_type); n->attrs = std::move(attrs); n->weight = std::move(weight); for (const auto& p : parents) { @@ -472,6 +478,62 @@ WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref data_ = std::move(n); } +WeightJoint::WeightJoint(const JsonWeightJoint& j_joint, const Map& nodes) { + ObjectPtr n = make_object(); + n->FromJson(j_joint, nodes); + data_ = std::move(n); +} + +WeightJoint::WeightJoint(const std::string& json_str, const Map& nodes) { + ObjectPtr n = make_object(); + n->FromJson(json_str, nodes); + data_ = std::move(n); +} + +const JsonWeightJoint WeightJointNode::ToJson() const { + JsonWeightJoint j_joint; + j_joint.index = index; + j_joint.name = name; + j_joint.shared_ref = shared_ref; + j_joint.weight_type = weight_type; + j_joint.weight = weight->ToJson(); + for (const auto& pair : attrs) { + j_joint.attrs[pair.first] = pair.second; + } + for (const auto& p : parents) { + j_joint.parents.push_back(Downcast(p)->name); + } + for (const auto& f : friends) { + j_joint.friends.push_back(Downcast(f)->name); + } + + return j_joint; +} + +void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, + const Map& nodes) { + index = j_joint.index; + name = j_joint.name; + shared_ref = j_joint.shared_ref; + weight_type = j_joint.weight_type; + weight = MSCTensor(j_joint.weight); + for (const auto& pair : j_joint.attrs) { + attrs.Set(pair.first, pair.second); + } + for (const auto& p_name : j_joint.parents) { + ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; + parents.push_back(nodes[p_name]); + } +} + +void WeightJointNode::FromJson(const std::string& json_str, const Map& nodes) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonWeightJoint j_joint; + reader.Read(&j_joint); + FromJson(j_joint, nodes); +} + const WeightJoint WeightJointNode::ParentAt(int index) const { size_t v_index = CommonUtils::GetIndex(index, parents.size()); return Downcast(parents[v_index]); @@ -482,6 +544,10 @@ const WeightJoint WeightJointNode::ChildAt(int index) const { return Downcast(children[v_index]); } +const bool BaseGraphNode::HasNode(const String& name) const { + return nodes.count(name) ? true : false; +} + MSCGraph::MSCGraph(const String& name, const Array& nodes, const Array& input_names, const Array& output_names) { ObjectPtr n = make_object(); @@ -588,10 +654,6 @@ const String MSCGraphNode::ToPrototxt() const { return printer.GetString(); } -const bool MSCGraphNode::HasNode(const String& name) const { - return nodes.count(name) ? true : false; -} - const MSCJoint MSCGraphNode::FindNode(const String& name) const { ICHECK(nodes.count(name)) << "Can not find node " << name; return Downcast(nodes[name]); @@ -779,16 +841,204 @@ void MSCGraphNode::AnalysisGraph() { } } -WeightGraph::WeightGraph(const String& name, const Array& nodes) { +WeightGraph::WeightGraph(const MSCGraph& graph, const Map>& prunable_types, + const Map& relation_types) { ObjectPtr n = make_object(); - n->name = std::move(name); - for (const auto& node : nodes) { - n->node_names.push_back(node->name); - n->nodes.Set(node->name, node); - } + n->name = graph->name + "_weights"; + n->Build(graph, prunable_types, relation_types); data_ = std::move(n); } +WeightGraph::WeightGraph(const JsonWeightGraph& j_graph) { + ObjectPtr n = make_object(); + n->FromJson(j_graph); + data_ = std::move(n); +} + +WeightGraph::WeightGraph(const std::string& json_str) { + ObjectPtr n = make_object(); + n->FromJson(json_str); + data_ = std::move(n); +} + +void WeightGraphNode::Build(const MSCGraph& graph, const Map>& prunable_types, + const Map& relation_types) { + auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) { + return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index; + }; + + auto find_parents = [this, &prunable_types, &relation_types, &sort_nodes](const MSCJoint& node) { + std::vector parents; + std::queue frontier; + std::set explored; + for (const auto& p : node->parents) { + frontier.push(Downcast(p)); + } + while (!frontier.empty()) { + const auto& current = frontier.front(); + if (explored.count(current)) { + frontier.pop(); + continue; + } + explored.insert(current); + if (prunable_types.count(current->optype)) { + for (const auto& t_type : prunable_types[current->optype]) { + if (current->weights.count(t_type)) { + parents.push_back(FindNode(current->WeightAt(t_type)->name)); + } + } + } else if (relation_types.count(current->optype)) { + parents.push_back(FindNode(current->OutputAt(0)->name)); + } else { + for (const auto& p : current->parents) { + const auto& new_parent = Downcast(p); + if (!explored.count(new_parent)) { + frontier.push(new_parent); + } + } + } + frontier.pop(); + } + Array parents_array; + if (parents.size() > 1) { + std::sort(parents.begin(), parents.end(), sort_nodes); + } + for (const auto& p : parents) { + parents_array.push_back(p); + } + return parents_array; + }; + + for (const auto& n : graph->node_names) { + const auto& node = graph->FindNode(n); + if (node->shared_ref.size() > 0) { + continue; + } + if (prunable_types.count(node->optype) || relation_types.count(node->optype) || + node->weights.size() > 0) { + const auto& w_parents = find_parents(node); + bool bind_friends = true; + if (relation_types.count(node->optype) && relation_types[node->optype] == "multi_inputs") { + bind_friends = false; + } + if (w_parents.size() > 1 && bind_friends) { + for (const auto& p : w_parents) { + Downcast(p)->friends = w_parents; + } + } + if (prunable_types.count(node->optype)) { + for (const auto& wtype : prunable_types[node->optype]) { + if (node->weights.count(wtype)) { + const auto& weight = node->WeightAt(wtype); + Map attrs; + attrs.Set("producer_type", node->optype); + attrs.Set("prune_strategy", "prune"); + const auto& w_node = + WeightJoint(node_names.size(), weight->name, "", wtype, weight, w_parents, attrs); + for (const auto& p : w_parents) { + p->AddChild(w_node); + } + nodes.Set(weight->name, w_node); + node_names.push_back(weight->name); + } + } + const BaseJoint& head = FindNode(node_names[node_names.size() - 1]); + for (const auto& pair : node->weights) { + if (!nodes.count(pair.second->name)) { + Map attrs; + attrs.Set("producer_type", node->optype); + attrs.Set("prune_strategy", "follow"); + const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, + pair.second, {head}, attrs); + head->AddChild(w_node); + nodes.Set(pair.second->name, w_node); + node_names.push_back(pair.second->name); + } + } + } else if (relation_types.count(node->optype)) { + const auto& tensor = node->OutputAt(0); + Map attrs; + attrs.Set("producer_type", node->optype); + if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 && + node->OutputAt(0)->LayoutOf("C") >= 0 && + node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) { + attrs.Set("prune_strategy", "passby"); + } else { + attrs.Set("prune_strategy", relation_types[node->optype]); + } + const auto& t_node = + WeightJoint(node_names.size(), tensor->name, "", "output", tensor, w_parents, attrs); + for (const auto& p : w_parents) { + p->AddChild(t_node); + } + nodes.Set(tensor->name, t_node); + node_names.push_back(tensor->name); + } else if (node->weights.size() > 0) { + for (const auto& pair : node->weights) { + if (!nodes.count(pair.second->name)) { + Map attrs; + attrs.Set("producer_type", node->optype); + attrs.Set("prune_strategy", "follow"); + const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, + pair.second, w_parents, attrs); + for (const auto& p : w_parents) { + p->AddChild(w_node); + } + nodes.Set(pair.second->name, w_node); + node_names.push_back(pair.second->name); + } + } + } + } + } +} + +const WeightJoint WeightGraphNode::FindNode(const String& name) const { + ICHECK(nodes.count(name)) << "Can not find node " << name; + return Downcast(nodes[name]); +} + +const JsonWeightGraph WeightGraphNode::ToJson() const { + JsonWeightGraph j_graph; + j_graph.name = name; + for (const auto& n : node_names) { + const auto& node = FindNode(n); + j_graph.nodes.push_back(node->ToJson()); + } + return j_graph; +} + +void WeightGraphNode::FromJson(const JsonWeightGraph& j_graph) { + name = j_graph.name; + Map loaded_nodes; + for (const auto& n : j_graph.nodes) { + const auto& node = WeightJoint(n, loaded_nodes); + loaded_nodes.Set(node->name, node); + for (const auto& p : node->parents) { + Downcast(p)->AddChild(node); + } + node_names.push_back(node->name); + nodes.Set(node->name, node); + } + // set friends + for (const auto& j_joint : j_graph.nodes) { + name = j_joint.name; + const auto& node = Downcast(nodes[name]); + for (const auto& f_name : j_joint.friends) { + ICHECK(nodes.count(f_name)) << "Can not find friend " << f_name; + node->friends.push_back(nodes[f_name]); + } + } +} + +void WeightGraphNode::FromJson(const std::string& json_str) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonWeightGraph j_graph; + reader.Read(&j_graph); + FromJson(j_graph); +} + const String WeightGraphNode::ToPrototxt() const { PrototxtPrinter printer; printer.Append(Map{{"name", name}}); @@ -797,8 +1047,7 @@ const String WeightGraphNode::ToPrototxt() const { // define layer std::vector> layer; layer.push_back(std::make_pair("name", node->name)); - layer.push_back( - std::make_pair("type", StringUtils::Replace(node->optype, ".", "_") + "_" + node->wtype)); + layer.push_back(std::make_pair("type", node->weight_type)); layer.push_back(std::make_pair("top", node->name)); for (const auto& p : node->parents) { layer.push_back(std::make_pair("bottom", Downcast(p)->name)); @@ -808,7 +1057,10 @@ const String WeightGraphNode::ToPrototxt() const { param.Set("idx", Integer(node->index)); param.Set("weight", node->weight); for (size_t i = 0; i < node->friends.size(); i++) { - param.Set("friend_" + std::to_string(i), Downcast(node->friends[i])); + param.Set("friend_" + std::to_string(i), Downcast(node->friends[i])); + } + for (const auto& pair : node->attrs) { + param.Set(pair.first, pair.second); } layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); // Append the layer Map @@ -817,9 +1069,49 @@ const String WeightGraphNode::ToPrototxt() const { return printer.GetString(); } -const WeightJoint WeightGraphNode::FindNode(const String& name) const { - ICHECK(nodes.count(name)) << "Can not find node " << name; - return Downcast(nodes[name]); +MSCGraph PruneWeights(const MSCGraph& graph, const Map& pruned_tensors) { + Array nodes; + std::unordered_map> inputs_map; + for (const auto& name : graph->node_names) { + const auto& node = graph->FindNode(name); + // define inputs + std::vector> inputs; + for (const auto& input : node->GetInputs()) { + ICHECK(inputs_map.count(input->name)) << "Can not find input " << input; + inputs.push_back(inputs_map[input->name]); + } + // define outputs + Array outputs; + for (const auto& out : node->outputs) { + const auto& output = pruned_tensors.count(out->name) ? pruned_tensors[out->name] : out; + outputs.push_back(output); + } + // define weights + Map weights; + for (const auto& pair : node->weights) { + const auto& weight = + pruned_tensors.count(pair.second->name) ? pruned_tensors[pair.second->name] : pair.second; + weights.Set(pair.first, weight); + } + // define attributes + Map attrs = node->attrs; + if (node->optype == "reshape" && attrs.count("shape") && + pruned_tensors.count(node->OutputAt(0)->name)) { + const auto& new_shape = pruned_tensors[node->OutputAt(0)->name]->shape; + attrs.Set("shape", StringUtils::ToString(new_shape)); + } + // create new node + const auto& new_node = MSCJoint(static_cast(nodes.size()), node->name, node->shared_ref, + node->optype, attrs, node->scope, inputs, outputs, weights); + nodes.push_back(new_node); + for (size_t i = 0; i < new_node->outputs.size(); i++) { + inputs_map[new_node->OutputAt(i)->name] = std::make_pair(new_node, i); + } + for (const auto& p : new_node->parents) { + Downcast(p)->AddChild(new_node); + } + } + return MSCGraph(graph->name, nodes, graph->input_names, graph->output_names); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -910,13 +1202,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << (i == joint->friends.size() - 1 ? "\n" : ","); } } - p->stream << " OPTYPE: " << joint->optype; - p->stream << "\n WEIGHT_TYPE: " << joint->wtype; + p->stream << " WEIGHT_TYPE: " << joint->weight_type; p->stream << "\n WEIGHT: " << joint->weight; if (joint->attrs.size() > 0) { p->stream << "\n ATTRS: "; for (const auto& pair : joint->attrs) { - p->stream << pair.first << " = " << pair.second << " "; + p->stream << pair.first << "=" << pair.second << " "; } } p->stream << "\n"; @@ -981,9 +1272,9 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint") TVM_REGISTER_GLOBAL("msc.core.WeightJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, - const String& optype, const String& wtype, const Map& attrs, - const MSCTensor& weight, const Array parents, - const Array& friends) -> WeightJoint { + const String& weight_type, const MSCTensor& weight, + const Array parents, const Map& attrs, + const Array& friends) -> WeightJoint { Array b_parents, b_friends; for (const auto& p : parents) { b_parents.push_back(p); @@ -991,7 +1282,7 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJoint") for (const auto& f : friends) { b_friends.push_back(f); } - return WeightJoint(index->value, name, shared_ref, optype, wtype, attrs, weight, b_parents, + return WeightJoint(index->value, name, shared_ref, weight_type, weight, b_parents, attrs, b_friends); }); @@ -1003,11 +1294,12 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraph") }); TVM_REGISTER_GLOBAL("msc.core.WeightGraph") - .set_body_typed([](const String& name, const Array& nodes) -> WeightGraph { - return WeightGraph(name, nodes); + .set_body_typed([](const MSCGraph& graph, const Map>& prunable_types, + const Map& relation_types) -> WeightGraph { + return WeightGraph(graph, prunable_types, relation_types); }); -// Graph APIS +// MSC Graph APIS TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasNode(name)); @@ -1068,6 +1360,31 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraphFromJson") TVM_REGISTER_GLOBAL("msc.core.MSCGraphToPrototxt") .set_body_typed([](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); +// Weight Graph APIS +TVM_REGISTER_GLOBAL("msc.core.WeightGraphHasNode") + .set_body_typed([](const WeightGraph& graph, const String& name) -> Bool { + return Bool(graph->HasNode(name)); + }); + +TVM_REGISTER_GLOBAL("msc.core.WeightGraphFindNode") + .set_body_typed([](const WeightGraph& graph, const String& name) -> WeightJoint { + return graph->FindNode(name); + }); + +TVM_REGISTER_GLOBAL("msc.core.WeightGraphToJson") + .set_body_typed([](const WeightGraph& graph) -> String { + const auto& graph_json = graph->ToJson(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + graph_json.Save(&writer); + return os.str(); + }); + +TVM_REGISTER_GLOBAL("msc.core.WeightGraphFromJson") + .set_body_typed([](const String& graph_json) -> WeightGraph { + return WeightGraph(graph_json); + }); + TVM_REGISTER_GLOBAL("msc.core.WeightGraphToPrototxt") .set_body_typed([](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }); @@ -1103,6 +1420,14 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJointHasAttr") TVM_REGISTER_GLOBAL("msc.core.MSCJointGetAttrs") .set_body_typed([](const MSCJoint& node) -> Map { return node->attrs; }); +TVM_REGISTER_GLOBAL("msc.core.WeightJointHasAttr") + .set_body_typed([](const WeightJoint& node, const String& key) -> Bool { + return Bool(node->HasAttr(key)); + }); + +TVM_REGISTER_GLOBAL("msc.core.WeightJointGetAttrs") + .set_body_typed([](const WeightJoint& node) -> Map { return node->attrs; }); + TVM_REGISTER_GLOBAL("msc.core.MSCTensorDTypeName") .set_body_typed([](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }); @@ -1115,8 +1440,12 @@ TVM_REGISTER_GLOBAL("msc.core.MSCTensorGetSize") .set_body_typed([](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }); TVM_REGISTER_GLOBAL("msc.core.MSCTensorSetAlias") - .set_body_typed([](const MSCTensor& tensor, const String& alias) { - return tensor->alias = alias; + .set_body_typed([](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }); + +TVM_REGISTER_GLOBAL("msc.core.PruneWeights") + .set_body_typed([](const MSCGraph& graph, + const Map& pruned_tensors) -> MSCGraph { + return PruneWeights(graph, pruned_tensors); }); } // namespace msc diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index c8b69b08fc26..85880841d4d8 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -147,6 +147,65 @@ struct JsonMSCJoint { } }; +/*! + * \brief Json serialize and deserialize for WeightJoint. + * WeightJoint is node in WeightGraph with name, wtype and attrbutes. + * WeightJoint has MSCTensor as weight. + */ +struct JsonWeightJoint { + size_t index; + std::string name; + std::string shared_ref; + std::string weight_type; + JsonMSCTensor weight; + std::vector parents; + std::vector friends; + std::unordered_map attrs; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("index", index); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("shared_ref", shared_ref); + writer->WriteObjectKeyValue("weight_type", weight_type); + writer->WriteObjectKeyValue("weight", weight); + writer->WriteObjectKeyValue("parents", parents); + writer->WriteObjectKeyValue("friends", friends); + writer->WriteObjectKeyValue("attrs", attrs); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "index") { + reader->Read(&index); + bitmask |= 1; + } else if (key == "name") { + reader->Read(&name); + bitmask |= 2; + } else if (key == "shared_ref") { + reader->Read(&shared_ref); + } else if (key == "weight_type") { + reader->Read(&weight_type); + bitmask |= 4; + } else if (key == "weight") { + reader->Read(&weight); + bitmask |= 8; + } else if (key == "parents") { + reader->Read(&parents); + } else if (key == "friends") { + reader->Read(&friends); + } else if (key == "attrs") { + reader->Read(&attrs); + } + } + ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "index, name, weight_type and weight should be given"; + } +}; + /*! * \brief Json serialize and deserialize for MSCGraph. * MSCGraph is core of MSC. @@ -190,6 +249,39 @@ struct JsonMSCGraph { } }; +/*! + * \brief Json serialize and deserialize for WeightGraph. + * WeightGraph is core of MSC.prune. + * WeightGraph contains WeightJoints as nodes. + */ +struct JsonWeightGraph { + std::string name; + std::vector nodes; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("nodes", nodes); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "name") { + reader->Read(&name); + bitmask |= 1; + } else if (key == "nodes") { + reader->Read(&nodes); + bitmask |= 2; + } + } + ICHECK_EQ(bitmask, 1 | 2) << "name and nodes should be given"; + } +}; + /*! * \brief Tensor in MSCGraph. */ @@ -217,6 +309,8 @@ class MSCTensorNode : public Object { const Integer DimAt(int index) const; /*! \brief Get dim at given axis. */ const Integer DimAt(const String& axis) const; + /*! \brief Get layout index of given axis. */ + int32_t LayoutOf(const String& axis) const; /*! \brief Get size of the tensor. */ const Integer GetSize() const; /*! \brief Get name of the dtype. */ @@ -290,8 +384,6 @@ class BaseJointNode : public Object { String name; /*! \brief The shared_ref of node, can be changed. */ String shared_ref; - /*! \brief The op type of node. */ - String optype; /*! \brief The attributes of node. */ Map attrs; /*! \brief The parents of node. */ @@ -333,15 +425,13 @@ class BaseJointNode : public Object { v->Visit("index", &index); v->Visit("name", &name); v->Visit("shared_ref", &shared_ref); - v->Visit("optype", &optype); v->Visit("attrs", &attrs); v->Visit("parents", &parents); - v->Visit("childern", &children); + v->Visit("children", &children); } bool SEqualReduce(const BaseJointNode* other, SEqualReducer equal) const { - return equal(name, other->name) && - equal(shared_ref, other->shared_ref) & equal(optype, other->optype) && + return equal(name, other->name) && equal(shared_ref, other->shared_ref) && equal(attrs, other->attrs) && equal(parents, other->parents) && equal(children, other->children); } @@ -349,7 +439,6 @@ class BaseJointNode : public Object { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(name); hash_reduce(shared_ref); - hash_reduce(optype); hash_reduce(attrs); hash_reduce(parents); hash_reduce(children); @@ -377,6 +466,8 @@ class BaseJoint : public ObjectRef { class MSCJoint; class MSCJointNode : public BaseJointNode { public: + /*! \brief The op type of node. */ + String optype; /*! \brief The scope of node. */ Array scope; /*! \brief The inputs of node, can be changed. */ @@ -416,6 +507,7 @@ class MSCJointNode : public BaseJointNode { void VisitAttrs(AttrVisitor* v) { BaseJointNode::VisitAttrs(v); + v->Visit("optype", &optype); v->Visit("scope", &scope); v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); @@ -423,13 +515,14 @@ class MSCJointNode : public BaseJointNode { } bool SEqualReduce(const MSCJointNode* other, SEqualReducer equal) const { - return BaseJointNode::SEqualReduce(other, equal) && equal(scope, other->scope) && - equal(inputs, other->inputs) && equal(outputs, other->outputs) && - equal(weights, other->weights); + return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype) && + equal(scope, other->scope) && equal(inputs, other->inputs) && + equal(outputs, other->outputs) && equal(weights, other->weights); } void SHashReduce(SHashReducer hash_reduce) const { BaseJointNode::SHashReduce(hash_reduce); + hash_reduce(optype); hash_reduce(scope); hash_reduce(inputs); hash_reduce(outputs); @@ -488,11 +581,17 @@ class WeightJoint; class WeightJointNode : public BaseJointNode { public: /*! \brief The weight reference of weight node. */ - String wtype; + String weight_type; /*! \brief The weight of weight node. */ MSCTensor weight; /*! \brief The friends of weight node. */ - Array friends; + mutable Array friends; + /*! \brief Export node to json. */ + const JsonWeightJoint ToJson() const; + /*! \brief Load node from json struct. */ + void FromJson(const JsonWeightJoint& j_joint, const Map& nodes); + /*! \brief Load node from json string. */ + void FromJson(const std::string& json_str, const Map& nodes); /*! \brief Get parent from the node. */ const WeightJoint ParentAt(int index) const; /*! \brief Get child from the node. */ @@ -500,19 +599,19 @@ class WeightJointNode : public BaseJointNode { void VisitAttrs(AttrVisitor* v) { BaseJointNode::VisitAttrs(v); - v->Visit("wtype", &wtype); + v->Visit("weight_type", &weight_type); v->Visit("weight", &weight); v->Visit("friends", &friends); } bool SEqualReduce(const WeightJointNode* other, SEqualReducer equal) const { - return BaseJointNode::SEqualReduce(other, equal) && equal(wtype, other->wtype) && + return BaseJointNode::SEqualReduce(other, equal) && equal(weight_type, other->weight_type) && equal(weight, other->weight) && equal(friends, other->friends); } void SHashReduce(SHashReducer hash_reduce) const { BaseJointNode::SHashReduce(hash_reduce); - hash_reduce(wtype); + hash_reduce(weight_type); hash_reduce(weight); hash_reduce(friends); } @@ -532,17 +631,29 @@ class WeightJoint : public BaseJoint { * \param index The index of the node. * \param name The name of the node. * \param shared_ref The shared_ref of the node. - * \param optype The optype of the node. - * \param wtype The weight type of the node. - * \param attrs The attributes of the node. + * \param weight_type The weight type of the node. * \param weight The weight tensor of the node. * \param parents The parents of the node. + * \param attrs The attributes of the node. * \param friends The friends of the node. */ - TVM_DLL WeightJoint(int index, const String& name, const String& shared_ref, const String& optype, - const String& wtype, const Map& attrs, - const MSCTensor& weight, const Array parents, - const Array& friends); + TVM_DLL WeightJoint(int index, const String& name, const String& shared_ref, + const String& weight_type, const MSCTensor& weight, + const Array parents, + const Map& attrs = Map(), + const Array& friends = Array()); + + /*! + * \brief The json constructor. + * \param j_joint The json describe of the node. + */ + TVM_DLL WeightJoint(const JsonWeightJoint& j_joint, const Map& nodes); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the node. + */ + TVM_DLL WeightJoint(const std::string& json_str, const Map& nodes); TVM_DEFINE_OBJECT_REF_METHODS(WeightJoint, BaseJoint, WeightJointNode); }; @@ -558,6 +669,8 @@ class BaseGraphNode : public Object { Array node_names; /*! \brief The nodes in graph, can be changed. */ Map nodes; + /*! \brief Check if node in the graph. */ + const bool HasNode(const String& name) const; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -614,8 +727,6 @@ class MSCGraphNode : public BaseGraphNode { void FromJson(const std::string& json_str); /*! \brief Export graph to prototxt. */ const String ToPrototxt() const; - /*! \brief Check if node in the graph. */ - const bool HasNode(const String& name) const; /*! \brief Find node in graph. */ const MSCJoint FindNode(const String& name) const; /*! \brief Get input from the graph. */ @@ -715,10 +826,19 @@ class MSCGraph : public BaseGraph { */ class WeightGraphNode : public BaseGraphNode { public: - /*! \brief Export graph to prototxt. */ - const String ToPrototxt() const; + /*! \brief build from MSCGraph. */ + void Build(const MSCGraph& graph, const Map>& prunable_types, + const Map& relation_types); /*! \brief Find node in graph. */ const WeightJoint FindNode(const String& name) const; + /*! \brief Export graph to json. */ + const JsonWeightGraph ToJson() const; + /*! \brief Load graph from json. */ + void FromJson(const JsonWeightGraph& json_str); + /*! \brief Load graph from json string. */ + void FromJson(const std::string& json_str); + /*! \brief Export graph to prototxt. */ + const String ToPrototxt() const; void VisitAttrs(AttrVisitor* v) { BaseGraphNode::VisitAttrs(v); } @@ -739,16 +859,31 @@ class WeightGraphNode : public BaseGraphNode { class WeightGraph : public BaseGraph { public: /*! - * \brief The constructor. - * \param name The name of the node. - * \param node_names The node names in the graph - * \param nodes The nodes in the graph. + * \brief The constructor based on MSCGraph. + * \param graph The msc graph. + * \param prunable_types The prunable types. + * \param relation_types The relation types. */ - TVM_DLL WeightGraph(const String& name, const Array& node_names); + TVM_DLL WeightGraph(const MSCGraph& graph, const Map>& prunable_types, + const Map& relation_types); + + /*! + * \brief The json constructor. + * \param j_graph The json describe of the graph. + */ + TVM_DLL WeightGraph(const JsonWeightGraph& j_graph); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the graph. + */ + TVM_DLL WeightGraph(const std::string& json_str); TVM_DEFINE_OBJECT_REF_METHODS(WeightGraph, BaseGraph, WeightGraphNode); }; +MSCGraph PruneWeights(const MSCGraph& graph, const Map& pruned_tensors); + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index e8a79535add9..663de50b0b8b 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -56,9 +56,9 @@ int CommonUtils::CompareVersion(const std::vector& given_version, ICHECK_EQ(target_version.size(), 3) << "Target version should be in format major,minor,patch"; for (size_t i = 0; i < 3; i++) { if (given_version[i] > target_version[i]) { - return -1; - } else if (given_version[i] < target_version[i]) { return 1; + } else if (given_version[i] < target_version[i]) { + return -1; } } return 0; @@ -402,6 +402,22 @@ TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); +TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") + .set_body_typed([](const String& key, const String& value) -> Span { + return SpanUtils::SetAttr(Span(), key, value); + }); + +TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr") + .set_body_typed([](const Span& span, const String& key, const String& value) -> Span { + return SpanUtils::SetAttr(span, key, value); + }); + +TVM_REGISTER_GLOBAL("msc.core.CompareVersion") + .set_body_typed([](const Array& given_version, + const Array& target_version) -> Integer { + return Integer(CommonUtils::CompareVersion(given_version, target_version)); + }); + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 4f73767b2e76..6efc9b26afb5 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -39,10 +39,11 @@ void TensorRTCodeGen::CodeGenClassDeclare() { stack_.line("#include \"NvInfer.h\"") .line("#include \"NvInferRuntimeCommon.h\"") .line("#include \"utils/base.h\"") - .line("#include \"utils/trt_common.h\"") - .line() - .line("using namespace nvinfer1;") - .line(); + .line("#include \"utils/trt_common.h\""); + if (config()->precision == "int8") { + stack_.line("#include \"utils/trt_quantize.h\""); + } + stack_.line().line("using namespace nvinfer1;").line(); StartNamespace(); // start class declare stack_.class_def(graph()->name).class_start().scope_start("public:"); @@ -111,11 +112,6 @@ void TensorRTCodeGen::CodeGenClassDefine() { for (const auto& n : graph()->node_names) { const auto& node = graph()->FindNode(n); CodeGenNode(node, config()->use_tools); - /* - for (const auto& d : GetOpCodes(node)) { - stack_.line(d); - } - */ } // mark outputs stack_.comment("Mark outputs"); @@ -146,6 +142,48 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtrDoc("builder")) .call_arg(config()->max_workspace); } + // set data type + if (config()->precision == "float16") { + stack_.comment("Set network precision") + .cond_if("!builder->platformHasFastFp16()") + .func_call("logger.log") + .call_arg("ILogger::Severity::kINTERNAL_ERROR") + .call_arg(DocUtils::ToStrDoc("platform do not support float16, fallback to float32")) + .cond_else() + .func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("BuilderFlag::kFP16"); + if (config()->precision_mode == "strict") { + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("BuilderFlag::kSTRICT_TYPES"); + } + stack_.func_call("logger.log") + .call_arg("ILogger::Severity::kINFO") + .call_arg(DocUtils::ToStrDoc("use float16 to build the engine")) + .cond_end(); + } else if (config()->precision == "int8") { + stack_.comment("Set network precision") + .cond_if("!builder->platformHasFastInt8()") + .func_call("logger.log") + .call_arg("ILogger::Severity::kINTERNAL_ERROR") + .call_arg(DocUtils::ToStrDoc("platform do not support int8, fallback to float32")) + .cond_else() + .func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("BuilderFlag::kINT8"); + if (config()->precision_mode == "strict") { + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("BuilderFlag::kSTRICT_TYPES"); + } else if (config()->precision_mode == "prefer") { + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("BuilderFlag::kPREFER_PRECISION_CONSTRAINTS"); + } else if (config()->precision_mode == "obey") { + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .call_arg("BuilderFlag::kOBEY_PRECISION_CONSTRAINTS"); + } + stack_.func_call("logger.log") + .call_arg("ILogger::Severity::kINFO") + .call_arg(DocUtils::ToStrDoc("use int8 to build the engine")) + .cond_end(); + } // end define build method stack_.func_end("true"); // start define test function @@ -301,6 +339,16 @@ void TensorRTCodeGen::CodeGenMain() { .func_call("createBuilderConfig", NullOpt, DocUtils::ToPtrDoc("builder")) .pop_nest(); ReturnOnFail("config", "Failed to create config"); + // codegen before build + if (config()->use_tools) { + const auto* pf = runtime::Registry::Get("msc_tool.codegen_step"); + ICHECK(pf != nullptr) << "Cannot find msc_tool.codegen_step func."; + const Array& lines = + (*pf)(GetStepCtx(), "before_build", graph()->name, config()->tools_tag); + for (const auto& l : lines) { + stack_.line(l); + } + } // build model stack_.comment("Build model") .declare(graph()->name, "model") @@ -312,6 +360,16 @@ void TensorRTCodeGen::CodeGenMain() { } stack_.call_arg("logger"); ReturnOnFail("pass", "Failed to build model"); + // codegen after build + if (config()->use_tools) { + const auto* pf = runtime::Registry::Get("msc_tool.codegen_step"); + ICHECK(pf != nullptr) << "Cannot find msc_tool.codegen_step func."; + const Array& lines = + (*pf)(GetStepCtx(), "after_build", graph()->name, config()->tools_tag); + for (const auto& l : lines) { + stack_.line(l); + } + } // Set profile flag stack_.comment("Set profile flag") .declare("ProfilingVerbosity", "profile_verbose") @@ -471,6 +529,27 @@ const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { } } +const Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) { + Map tensor_ctx; + tensor_ctx.Set("ctx", "network"); + for (const auto& pair : + CppCodeGen::GetTensorCtx(tensor)) { + tensor_ctx.Set(pair.first, pair.second); + } + return tensor_ctx; +} + +const Map TensorRTCodeGen::GetStepCtx() { + Map step_ctx; + step_ctx.Set("network", "network"); + step_ctx.Set("config", "config"); + step_ctx.Set("builder", "builder"); + for (const auto& pair : CppCodeGen::GetStepCtx()) { + step_ctx.Set(pair.first, pair.second); + } + return step_ctx; +} + TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String print_config) -> Map { diff --git a/src/contrib/msc/framework/tensorrt/codegen.h b/src/contrib/msc/framework/tensorrt/codegen.h index 28d69d3a4f5c..21b556d1cecc 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.h +++ b/src/contrib/msc/framework/tensorrt/codegen.h @@ -62,6 +62,12 @@ class TensorRTCodeGen : public CppCodeGen GetOpCodes(const MSCJoint& node) final; + /*! \brief Get the tensor context for codegen_tensor*/ + const Map GetTensorCtx(const MSCTensor& tensor) final; + + /*! \brief Get the step context for codegen_step*/ + const Map GetStepCtx() final; + /*! \brief Generate return on fail codes*/ void ReturnOnFail(const String& flag, const String& err); diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h index eab2ec616ddb..bfaecb8d3dc8 100644 --- a/src/contrib/msc/framework/tensorrt/codegen_utils.h +++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h @@ -85,6 +85,9 @@ struct TensorRTCodeGenConfig { size_t max_workspace{1 << 20}; std::string cmake_version{"3.5"}; std::string dataset{"Dataset"}; + std::string range_file{""}; + std::string precision{"float32"}; + std::string precision_mode{"strict"}; std::string tensorrt_root{"/usr/local/cuda"}; CODEGEN_CONFIG_MEMBERS void Load(dmlc::JSONReader* reader) { @@ -103,6 +106,12 @@ struct TensorRTCodeGenConfig { reader->Read(&cmake_version); } else if (key == "dataset") { reader->Read(&dataset); + } else if (key == "range_file") { + reader->Read(&range_file); + } else if (key == "precision") { + reader->Read(&precision); + } else if (key == "precision_mode") { + reader->Read(&precision_mode); } else if (key == "tensorrt_root") { reader->Read(&tensorrt_root); } else { diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index df5d4f343c88..fc6217c31de1 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -230,11 +230,12 @@ class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { stride.push_back(in_sizes[i] / out_sizes[i]); kernel.push_back((in_sizes[i] - (out_sizes[i] - 1) * stride[i])); } + const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; stack_.op_call() .op_input_arg() .call_arg("PoolingType::k" + symbol_) .call_arg(ToDims(kernel, false)); - SetLayerByDimsValue("Stride", stride, false); + SetLayerByDimsValue("Stride" + suffix, stride, false); } }; @@ -339,8 +340,9 @@ class TensorRTConvCodeGen : public TensorRTOpCode { } else { stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); } - SetLayerByDimsAttr("Stride", "strides", false); - SetLayerByDimsAttr("Dilation", "dilation", false); + const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + SetLayerByDimsAttr("Stride" + suffix, "strides", false); + SetLayerByDimsAttr("Dilation" + suffix, "dilation", false); SetLayerByAttr("NbGroups", "groups"); SetPadding(); } @@ -472,7 +474,7 @@ class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { class TensorRTPool2dCodeGen : public TensorRTOpCode { public: - explicit TensorRTPool2dCodeGen(const String& symbol) : TensorRTOpCode("Pooling") { + explicit TensorRTPool2dCodeGen(const String& symbol) : TensorRTOpCode("PoolingNd") { symbol_ = symbol; } @@ -482,7 +484,8 @@ class TensorRTPool2dCodeGen : public TensorRTOpCode { .op_input_arg() .call_arg("PoolingType::k" + symbol_) .call_arg(AttrToDims("pool_size", false)); - SetLayerByDimsAttr("Stride", "strides", false); + const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + SetLayerByDimsAttr("Stride" + suffix, "strides", false); if (node()->GetTypeAttr("ceil_mode")) { SetLayerByValue("PaddingMode", "PaddingMode::kEXPLICIT_ROUND_UP"); } @@ -777,7 +780,7 @@ GetTensorRTOpCodes() { // nn ops map->emplace("nn.adaptive_avg_pool2d", - std::make_shared("Pooling", "AVERAGE")); + std::make_shared("PoolingNd", "AVERAGE")); map->emplace("nn.avg_pool2d", std::make_shared("AVERAGE")); map->emplace("nn.batch_matmul", std::make_shared("MatrixMultiply")); map->emplace("nn.conv2d", std::make_shared("ConvolutionNd", false)); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 796025374156..3bd40cbfd79f 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -53,7 +53,11 @@ void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { } } -const ExprDoc RelaxOpCode::GetOutDtype(const String& key) { +const ExprDoc RelaxOpCode::GetOutDtype(const String& key, int input_idx) { + if (config()->use_tools && input_idx >= 0 && + node()->inputs.size() > static_cast(input_idx)) { + return DocUtils::ToDoc(IdxInput(input_idx) + ".struct_info.dtype"); + } std::string out_dtype; if (!node()->GetAttr(key, &out_dtype) && config()->from_relay) { return DocUtils::ToStrDoc(node()->OutputAt(0)->DTypeName()); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.h b/src/contrib/msc/framework/tvm/relax_opcode.h index 32e3e1926f7a..e5914149184e 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.h +++ b/src/contrib/msc/framework/tvm/relax_opcode.h @@ -65,7 +65,7 @@ class RelaxOpCode : public BaseOpCode { void BuilderEmit(const String& ret, const String& name = ""); /*! \brief Get the out_dtype attribute*/ - const ExprDoc GetOutDtype(const String& key = "out_dtype"); + const ExprDoc GetOutDtype(const String& key = "out_dtype", int input_idx = 0); /*! \brief Get the axes attribute*/ const std::vector GetAxes(const String& key = "axes"); diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 8a4f5fe4bae0..8b7ba0c4b5b8 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -90,6 +90,12 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalOptions(); + for (size_t nid = 0; nid < nodes_.size(); nid++) { + for (size_t oid = 0; oid < nodes_[nid].GetNumOutput(); oid++) { + const auto& t_name = nodes_[nid].GetOpName() + ":" + std::to_string(oid); + tensor_ids_[t_name] = std::make_pair(nid, oid); + } + } LoadEngine(engine_file_); } @@ -99,6 +105,12 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { for (size_t i = 0; i < nodes_.size(); ++i) { if (nodes_[i].HasAttr("msc_global_options_num")) { engine_file_ = nodes_[i].GetAttr>("msc_global_engine")[0]; + graph_name_ = nodes_[i].GetAttr>("msc_global_graph_name")[0]; + if (nodes_[i].HasAttr("msc_global_tool_tag")) { + tool_tag_ = nodes_[i].GetAttr>("msc_global_tool_tag")[0]; + } else { + tool_tag_ = ""; + } } } } @@ -106,6 +118,18 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { #ifdef TVM_GRAPH_EXECUTOR_TENSORRT void Run() override { SetInputOutputBinds(); + if (tool_tag_.size() > 0) { + const auto* pf = runtime::Registry::Get("msc_tool.callback_step"); + ICHECK(pf != nullptr) << "Cannot find msc_tool.callback_step func."; + Map input_datas; + for (const auto& pair : input_bindings_) { + const auto& tensor_name = engine_->getBindingName(pair.first); + input_datas.Set(tensor_name, device_buffers_[pair.first]); + } + Map> context; + context.Set("datas", input_datas); + (*pf)(context, "before_forward", graph_name_, tool_tag_); + } auto tvm_stream = CUDAThreadEntry::ThreadLocal()->stream; #if TRT_VERSION_GE(6, 0, 1) ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) @@ -120,11 +144,26 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); int binding_index = engine_->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); - if (data_entry_[eid]->device.device_type != kDLCUDA) { + if (data_entry_[eid]->device.device_type != kDLCUDA || tool_tag_.size() > 0) { auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); device_buffer.CopyTo(const_cast(data_entry_[eid])); } } + if (tool_tag_.size() > 0) { + const auto* pf = runtime::Registry::Get("msc_tool.callback_step"); + ICHECK(pf != nullptr) << "Cannot find msc_tool.callback_step func."; + Map output_datas; + for (int bid = 0; bid < engine_->getNbBindings(); bid++) { + if (input_bindings_.count(bid)) { + continue; + } + const auto& tensor_name = engine_->getBindingName(bid); + output_datas.Set(tensor_name, device_buffers_[bid]); + } + Map> context; + context.Set("datas", output_datas); + (*pf)(context, "after_forward", graph_name_, tool_tag_); + } } bool LoadEngine(const String& engine_file) { @@ -190,6 +229,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { void SetInputOutputBinds() { // Setup input bindings + std::set binded; for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; if (nodes_[nid].GetOpType() == "input") { @@ -203,7 +243,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { data_entry_[eid]->shape + data_entry_[eid]->ndim); ICHECK(context_->setBindingDimensions(binding_index, VectorToTrtDims(shape))); #endif - if (data_entry_[eid]->device.device_type == kDLCUDA) { + if (data_entry_[eid]->device.device_type == kDLCUDA && tool_tag_.size() == 0) { bindings_[binding_index] = data_entry_[eid]->data; } else { auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); @@ -214,6 +254,8 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { int num_elements = 1; for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i]; binding_sizes_[binding_index] = num_elements; + input_bindings_[binding_index] = eid; + binded.insert(binding_index); } } } @@ -224,12 +266,30 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); int binding_index = engine_->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); - if (data_entry_[eid]->device.device_type == kDLCUDA) { + if (data_entry_[eid]->device.device_type == kDLCUDA && tool_tag_.size() == 0) { bindings_[binding_index] = data_entry_[eid]->data; } else { auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); bindings_[binding_index] = device_buffer->data; } + output_bindings_[binding_index] = eid; + binded.insert(binding_index); + } + // Setup tool bindings + for (int bid = 0; bid < engine_->getNbBindings(); bid++) { + if (binded.count(bid)) { + continue; + } + if (!device_buffers_.count(bid)) { + const auto& tensor_name = engine_->getBindingName(bid); + ICHECK(tensor_ids_.count(tensor_name)) << "Can not find tensor_name " << tensor_name; + const auto& pair = tensor_ids_[tensor_name]; + auto shape = nodes_[pair.first].GetOpShape()[pair.second]; + auto dtype = nodes_[pair.first].GetOpDataType()[pair.second]; + device_buffers_[bid] = runtime::NDArray::Empty(shape, dtype, {kDLCUDA, 0}); + } + bindings_[bid] = device_buffers_[bid]->data; + binded.insert(bid); } } @@ -267,10 +327,15 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { private: String engine_file_; + String tool_tag_; + String graph_name_; + std::unordered_map> tensor_ids_; #ifdef TVM_GRAPH_EXECUTOR_TENSORRT TensorRTLogger logger_; ICudaEngine* engine_{nullptr}; IExecutionContext* context_{nullptr}; + std::unordered_map input_bindings_; + std::unordered_map output_bindings_; std::vector bindings_; std::vector binding_sizes_; std::unordered_map device_buffers_; diff --git a/tests/python/contrib/test_msc/test_manager.py b/tests/python/contrib/test_msc/test_manager.py index de846c10eb89..393c8decbc79 100644 --- a/tests/python/contrib/test_msc/test_manager.py +++ b/tests/python/contrib/test_msc/test_manager.py @@ -31,7 +31,7 @@ ) -def _get_config(model_type, deploy_type, inputs, outputs, atol=1e-2, rtol=1e-2): +def _get_config(model_type, compile_type, inputs, outputs, atol=1e-2, rtol=1e-2): """Get msc config""" return { "model_type": model_type, @@ -44,7 +44,7 @@ def _get_config(model_type, deploy_type, inputs, outputs, atol=1e-2, rtol=1e-2): "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, }, "compile": { - "run_type": deploy_type, + "run_type": compile_type, "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, }, } @@ -88,14 +88,14 @@ def _get_tf_graph(): return None -def _test_from_torch(deploy_type, expected_info, is_training=False, atol=1e-2, rtol=1e-2): +def _test_from_torch(compile_type, expected_info, is_training=False, atol=1e-2, rtol=1e-2): torch_model = _get_torch_model("resnet50", is_training) if torch_model: if torch.cuda.is_available(): torch_model = torch_model.to(torch.device("cuda:0")) config = _get_config( MSCFramework.TORCH, - deploy_type, + compile_type, inputs=[["input_0", [1, 3, 224, 224], "float32"]], outputs=["output"], atol=atol, @@ -103,7 +103,7 @@ def _test_from_torch(deploy_type, expected_info, is_training=False, atol=1e-2, r ) manager = MSCManager(torch_model, config) report = manager.run_pipe() - assert report["success"], "Failed to run pipe for torch -> {}".format(deploy_type) + assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) model_info = manager.runner.model_info assert msc_utils.dict_equal( model_info, expected_info @@ -111,12 +111,12 @@ def _test_from_torch(deploy_type, expected_info, is_training=False, atol=1e-2, r manager.destory() -def _test_from_tf(deploy_type, expected_info, atol=1e-2, rtol=1e-2): +def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2): graphdef = _get_tf_graph() if graphdef: config = _get_config( MSCFramework.TENSORFLOW, - deploy_type, + compile_type, inputs=[["input", [1, 224, 224, 3], "float32"]], outputs=["MobilenetV2/Predictions/Reshape_1:0"], atol=atol, @@ -125,7 +125,7 @@ def _test_from_tf(deploy_type, expected_info, atol=1e-2, rtol=1e-2): config["compile"]["profile"]["check"]["err_rate"] = -1 manager = MSCManager(graphdef, config) report = manager.run_pipe() - assert report["success"], "Failed to run pipe for tensorflow -> {}".format(deploy_type) + assert report["success"], "Failed to run pipe for tensorflow -> {}".format(compile_type) model_info = manager.runner.model_info assert msc_utils.dict_equal( model_info, expected_info diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py new file mode 100644 index 000000000000..037507cf692f --- /dev/null +++ b/tests/python/contrib/test_msc/test_tools.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test Tools in MSC. """ + +import os +import pytest + +import torch + +import tvm.testing +from tvm.contrib.msc.pipeline import MSCManager +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + +requires_tensorrt = pytest.mark.skipif( + tvm.get_global_func("relax.ext.tensorrt", True) is None, + reason="TENSORRT is not enabled", +) + + +def _get_config( + model_type, + compile_type, + tools_config, + inputs, + outputs, + atol=1e-2, + rtol=1e-2, + optimize_type=None, +): + """Get msc config""" + return { + "model_type": model_type, + "inputs": inputs, + "outputs": outputs, + "debug_level": 0, + "dataset": {"loader": "from_random", "max_iter": 5}, + "prepare": {"profile": {"benchmark": {"repeat": 10}}}, + "baseline": { + "run_type": model_type, + "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, + }, + "optimize": { + "run_type": optimize_type or model_type, + "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, + **tools_config, + }, + "compile": { + "run_type": compile_type, + "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, + }, + } + + +def get_tool_config(tool_type): + config = {} + if tool_type == ToolType.PRUNER: + config = { + "plan_file": "msc_pruner.json", + "strategys": [{"method": "per_channel", "density": 0.8}], + } + return {tool_type: config} + + +def _get_torch_model(name, is_training=False): + """Get model from torch vision""" + # pylint: disable=import-outside-toplevel + try: + import torchvision + + model = getattr(torchvision.models, name)(pretrained=True) + if is_training: + model = model.train() + else: + model = model.eval() + return model + except: # pylint: disable=bare-except + print("please install torchvision package") + return None + + +def _test_from_torch( + compile_type, + tools_config, + expected_info, + is_training=False, + atol=1e-2, + rtol=1e-2, + optimize_type=None, +): + torch_model = _get_torch_model("resnet50", is_training) + if torch_model: + if torch.cuda.is_available(): + torch_model = torch_model.to(torch.device("cuda:0")) + config = _get_config( + MSCFramework.TORCH, + compile_type, + tools_config, + inputs=[["input_0", [1, 3, 224, 224], "float32"]], + outputs=["output"], + atol=atol, + rtol=rtol, + optimize_type=optimize_type, + ) + manager = MSCManager(torch_model, config) + report = manager.run_pipe() + assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) + for t_type, config in tools_config.items(): + assert os.path.isfile( + msc_utils.get_config_dir().relpath(config["plan_file"]) + ), "Failed to find plan of " + str(t_type) + model_info = manager.runner.model_info + assert msc_utils.dict_equal( + model_info, expected_info + ), "Model info {} mismatch with expected {}".format(model_info, expected_info) + manager.destory() + + +@pytest.mark.parametrize("tool_type", [ToolType.PRUNER]) +def test_tvm_tools(tool_type): + """Test tools for tvm""" + + model_info = { + "inputs": [ + {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "nodes": { + "total": 229, + "input": 1, + "nn.conv2d": 53, + "nn.batch_norm": 53, + "get_item": 53, + "nn.relu": 49, + "nn.max_pool2d": 1, + "add": 16, + "nn.adaptive_avg_pool2d": 1, + "reshape": 1, + "msc.linear_bias": 1, + }, + } + tool_config = get_tool_config(tool_type) + _test_from_torch(MSCFramework.TVM, tool_config, model_info, is_training=True) + + +@requires_tensorrt +@pytest.mark.parametrize( + "tool_type,use_native", + [(ToolType.PRUNER, False)], +) +def test_tensorrt_tools(tool_type, use_native): + """Test tools for tensorrt""" + + model_info = { + "inputs": [ + {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}], + "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, + } + tool_config = get_tool_config(tool_type) + if tool_type == ToolType.QUANTIZER and use_native: + tool_config[ToolType.QUANTIZER]["strategys"] = [] + optimize_type = MSCFramework.TENSORRT if use_native else None + _test_from_torch( + MSCFramework.TENSORRT, + tool_config, + model_info, + is_training=False, + optimize_type=optimize_type, + ) + + +if __name__ == "__main__": + tvm.testing.main() From a23f84ca0e8e9cd8a6eb7ad3f021e53d5c8a622e Mon Sep 17 00:00:00 2001 From: Archermmt Date: Thu, 30 Nov 2023 21:52:48 +0800 Subject: [PATCH 2/3] update pruner --- .../contrib/msc/core/tools/prune/pruner.py | 329 +++++----------- python/tvm/contrib/msc/core/tools/tool.py | 370 ++++++++++++++---- python/tvm/contrib/msc/core/utils/file.py | 7 +- .../tensorflow/tools/prune/pruner.py | 15 +- .../framework/tensorrt/tools/prune/pruner.py | 15 +- .../msc/framework/torch/tools/prune/pruner.py | 15 +- .../msc/framework/tvm/tools/prune/pruner.py | 15 +- src/contrib/msc/core/ir/graph.cc | 44 +-- 8 files changed, 494 insertions(+), 316 deletions(-) diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index f761a37c56d0..d5dc2ee5a7a9 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -16,33 +16,34 @@ # under the License. """tvm.contrib.msc.core.tools.prune.pruner""" -from typing import List, Dict, Iterable, Tuple, Any +from typing import List, Dict, Tuple, Any import tvm -from tvm.contrib.msc.core.ir import MSCGraph, WeightGraph, WeightJoint, MSCTensor -from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, Strategy +from tvm.contrib.msc.core.ir import MSCGraph, WeightJoint, MSCTensor +from tvm.contrib.msc.core.tools.tool import ToolType, WeightTool, Strategy from tvm.contrib.msc.core import _ffi_api from tvm.contrib.msc.core import utils as msc_utils from .method import PruneMethod -class BasePruner(BaseTool): +class BasePruner(WeightTool): """Base pruner for all""" - def setup(self) -> dict: - """Setup the tool + def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: + """Get the weight types from options Returns ------- - info: dict - The setup info. + main_wtypes: dict> + The main weight types. + relation_wtypes: dict + The relation weight types """ - # Build weight graphs - if "prunable_types" in self._options: - self._prunable_types = self._options["prunable_types"] + if "main_wtypes" in self._options: + main_wtypes = self._options["main_wtypes"] else: - self._prunable_types = { + main_wtypes = { "constant": ["const"], "nn.conv2d": ["weight"], "msc.conv2d_bias": ["weight"], @@ -50,10 +51,10 @@ def setup(self) -> dict: "msc.linear_bias": ["weight"], } - if "relation_types" in self._options: - self._relation_types = self._options["relation_types"] + if "relation_wtypes" in self._options: + relation_wtypes = self._options["relation_wtypes"] else: - self._relation_types = { + relation_wtypes = { "concatenate": "multi_inputs", "reshape": "reshape", "add": "passby", @@ -61,42 +62,28 @@ def setup(self) -> dict: "multiply": "passby", "divide": "passby", } + return main_wtypes, relation_wtypes - return super().setup() - - def reset( - self, - graphs: List[MSCGraph], - weights: List[Dict[str, tvm.nd.array]], - cache_dir: msc_utils.MSCDirectory = None, - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: - """Reset the tool with graphs and weights + def _parse_strategys(self, strategy_list: dict) -> Dict[str, Strategy]: + """Parse the strategy to get valid strategy Parameters - ---------- - graphs: list - The msc graphs. - weights: list> - The weights - cache_dir: MSCDirectory - cache path for save/load info + ------- + strategy_list: dict + The given strategy Returns ------- - graphs: list - The msc graphs. - weights: list> - The weights + strategys: dict + The parsed strategy. """ - self._unpruned_tensors = {} - res = super().reset(graphs, weights, cache_dir) - if self.on_debug(3): - for idx, graph in enumerate(self._weight_graphs): - self._logger.debug( - msc_utils.msg_block("PRUNER.WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect()) - ) - return res + def _update_stages(strategy): + if "stages" not in strategy: + strategy["stages"] = [msc_utils.MSCStage.PRUNE] + return strategy + + return super()._parse_strategys([_update_stages(s) for s in strategy_list]) def load_graphs( self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] @@ -107,87 +94,54 @@ def load_graphs( ---------- graphs: list The msc graphs. - weights: list> + weights: list> The weights - as_cache: bool - Whether the graphs and weights are loaded from cache - Returns ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights """ - self._weight_graphs = [ - _ffi_api.WeightGraph(graph, self._prunable_types, self._relation_types) - for graph in graphs - ] + graphs, weights = super().load_graphs(graphs, weights) if not self._plan: return graphs, weights return self.prune_graphs(graphs, weights) - def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): - """Save runner to cache + def _execute_before_build(self, *args, **kwargs): + """Execute before model build Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache_info - """ - - assert ( - "weight_graphs" in cache_info - ), "weight_graphs should be given in cache_info, get " + str(cache_info) - self._weight_graphs = [ - WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"] - ] - - def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save runner to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache_info. + ---------- + args: list + The arguments for model build. + kwargs: dict + The key word arguments for model build. """ - cache_info = {"weight_graphs": [g.name + "_graph.json" for g in self._weight_graphs]} - with cache_dir: - for graph, f_path in zip(self._weight_graphs, cache_info["weight_graphs"]): - with open(f_path, "w") as f_graph: - f_graph.write(graph.to_json()) - return cache_info + self._unpruned_tensors = {} + super()._execute_before_build(*args, **kwargs) - def _parse_strategys(self, strategy_list: dict) -> Dict[str, Strategy]: - """Parse the strategy to get valid strategy + def _execute_after_build(self, output: Any) -> Any: + """Execute after model build Parameters - ------- - strategy_list: dict - The given strategy + ---------- + output: Any + The output reference of the model. Returns ------- - strategys: dict - The parsed strategy. + output: Any + The modified output reference. """ - def _update_stages(strategy): - if "stages" not in strategy: - strategy["stages"] = [msc_utils.MSCStage.PRUNE] - return strategy - - return super()._parse_strategys([_update_stages(s) for s in strategy_list]) + assert not self._unpruned_tensors, "Some tensors are not pruned " + str( + self._unpruned_tensors + ) + return super()._execute_after_build(output) def _check_tensor(self, name: str, consumer: str) -> bool: """Check if the tensor should be processed @@ -215,7 +169,7 @@ def _check_tensor(self, name: str, consumer: str) -> bool: return True def _process_tensor( - self, tensor: Any, name: str, consumer: str, strategys: List[Strategy] + self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[Strategy] ) -> Any: """Process tensor @@ -227,6 +181,8 @@ def _process_tensor( The name of the tensor. consumer: str The name of the consumer. + scope: str + The scope mark teacher| student| null. strategys: list The strategys for the tensor. @@ -239,6 +195,41 @@ def _process_tensor( if name in self._plan: return tensor + self._prune_tensor(name, consumer, strategys) + lazy_pruned = set() + for lazy_name, info in self._unpruned_tensors.items(): + if info["lead_name"] in self._plan: + strategys = self._get_tensor_strategys(lazy_name, info["consumer"]) + self._prune_tensor(lazy_name, info["consumer"], strategys) + t_mark = ".".join([s.get_executor().name for s in strategys]) + self.debug_tensor( + self.find_tensor(lazy_name), + lazy_name, + consumer, + "lazy processed({})".format(t_mark), + ) + lazy_pruned.add(lazy_name) + if lazy_pruned: + self._unpruned_tensors = { + k: v for k, v in self._unpruned_tensors.items() if k not in lazy_pruned + } + return tensor + + def _prune_tensor(self, name: str, consumer: str, strategys: List[Strategy]) -> Any: + """Prune tensor + + Parameters + ------- + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scope: str + The scope mark teacher| student| null. + strategys: list + The strategys for the tensor. + """ + assert len(strategys) == 1, "pruner should only has 1 strategy, get " + str(strategys) strategy = strategys[0] @@ -259,15 +250,15 @@ def _get_in_indices(w_node: WeightJoint) -> List[int]: def _prunable(w_node: WeightJoint) -> bool: """Check if weight node is prunable""" - if w_node.get_attr("prune_strategy") != "prune": + if w_node.get_attr("weight_strategy") != "main": return False if not w_node.children: return False childrens = list(w_node.children) while childrens: current = childrens.pop(0) - prune_strategy = current.get_attr("prune_strategy") - if prune_strategy == "prune": + weight_strategy = current.get_attr("weight_strategy") + if weight_strategy == "main": return True childrens.extend(list(current.children)) return False @@ -284,11 +275,10 @@ def _prunable(w_node: WeightJoint) -> bool: if lead_name not in self._plan: self._unpruned_tensors[name] = { "lead_name": lead_name, - "tensor": tensor, "consumer": consumer, } self._plan.pop(w_node.name) - return tensor + return None self._plan[w_node.name]["out_indices"] = self._plan[lead_name]["out_indices"] elif _prunable(w_node): self._plan[w_node.name] = strategy( @@ -300,29 +290,12 @@ def _prunable(w_node: WeightJoint) -> bool: out_axis=out_axis, in_indices=in_indices, ) - elif w_node.get_attr("prune_strategy") == "follow": + elif w_node.get_attr("weight_strategy") == "follow": self._plan[w_node.name]["out_indices"] = [] - elif w_node.get_attr("prune_strategy") == "passby": + elif w_node.get_attr("weight_strategy") == "passby": self._plan[w_node.name]["out_indices"] = in_indices else: self._plan[w_node.name]["out_indices"] = [] - lazy_pruned = set() - for lazy_name, info in self._unpruned_tensors.items(): - if info["lead_name"] in self._plan: - strategys = self._get_tensor_strategys(lazy_name, info["consumer"]) - lazy_tensor = self._process_tensor( - info["tensor"], lazy_name, info["consumer"], strategys - ) - strategy_mark = ".".join([s.get_executor().name for s in strategys]) - self.debug_tensor( - lazy_tensor, lazy_name, consumer, "lazy processed({})".format(strategy_mark) - ) - lazy_pruned.add(lazy_name) - if lazy_pruned: - self._unpruned_tensors = { - k: v for k, v in self._unpruned_tensors.items() if k not in lazy_pruned - } - return tensor def prune_graphs( self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] @@ -333,14 +306,14 @@ def prune_graphs( ---------- graphs: list The msc graphs. - weights: list> + weights: list> The weights Returns ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights """ @@ -361,7 +334,7 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): for node in graph.get_nodes(): for weight in node.get_weights().values(): w_name = weight.name - if w_name in self._plan: + if w_name in self._plan and not self._plan[w_name].get("pruned", False): data = msc_utils.cast_array(sub_weights[w_name]) in_axis, out_axis = self._get_io_axes(self.find_w_node(w_name)) w_config = self._plan[w_name] @@ -371,6 +344,7 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): data = PruneMethod.prune_axis(data, out_axis, w_config["out_indices"]) pruned_tensors[w_name] = _prune_by_shape(weight, data.shape) pruned_weights[w_name] = tvm.nd.array(data) + self._plan[w_name]["pruned"] = True pruned_weights_cnt += 1 else: pruned_weights[w_name] = sub_weights[w_name] @@ -399,10 +373,11 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): ) else: for out in node.get_outputs(): - if out.name in self._plan: + if out.name in self._plan and not self._plan[out.name].get("pruned", False): pruned_tensors[out.name] = _prune_by_channel( out, len(self._plan[out.name]["out_indices"]) ) + self._plan[out.name]["pruned"] = True elif ( node.get_inputs() and node.input_at(0).name in pruned_tensors @@ -429,109 +404,21 @@ def _flatten_size(weights): raw_size = _flatten_size(weights) new_size = _flatten_size(new_weights) self._logger.info( - "{} weights pruned, compress to {:g}%".format( - pruned_weights_cnt, new_size * 100 / raw_size + "{} weights pruned, compress to {:g}% ({:g} M->{:g} M)".format( + pruned_weights_cnt, + new_size * 100 / raw_size, + raw_size / 2**20, + new_size / 2**20, ) ) return new_graphs, new_weights - def visualize(self, visual_dir: msc_utils.MSCDirectory): - """Visualize MSCGraphs - - Parameters - ------- - visual_dir: MSCDirectory - Visualize path for saving graph - """ - - for w_graph in self._weight_graphs: - w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt")) - def finalize(self) -> dict: """Get the plan""" - assert not self._unpruned_tensors, "Some tensors are not pruned " + str( - self._unpruned_tensors - ) self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} return super().finalize() - def get_w_nodes(self) -> Iterable[WeightJoint]: - """Get all the weight nodes in the weight_graphs. - - Returns - ------- - nodes: generator - The generator of weight nodes. - """ - - for g in self._weight_graphs: - for n in g.get_nodes(): - yield n - - def has_w_node(self, name: str) -> bool: - """Check if name in weight_graphs. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - has_node: bool - Whether node in weight_graphs. - """ - - for g in self._weight_graphs: - if g.has_node(name): - return True - return False - - def find_w_node(self, name: str) -> WeightJoint: - """Find weight node by name. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - node: WeightJoint - The found node. - """ - - for g in self._weight_graphs: - if g.has_node(name): - return g.find_node(name) - raise Exception("Can not find node {} from graphs".format(name)) - - def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: - """Get the input output axes - - Parameters - ---------- - w_node: WeightJoint - The weight node. - - Returns - ------- - axes: (int, int) - The input output axis. - """ - - if w_node.weight.ndim == 1: - return 0, 0 - if w_node.has_attr("in_axis") and w_node.has_attr("out_axis"): - return int(w_node.get_attr("in_axis")), int(w_node.get_attr("out_axis")) - in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") - if in_axis >= 0 and out_axis >= 0: - return in_axis, out_axis - if w_node.weight.layout_of("C") >= 0: - return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") - raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) - @classmethod def tool_type(cls): return ToolType.PRUNER diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 0032394a6506..dad96d5b9eb2 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -25,15 +25,17 @@ import numpy as np import tvm -from tvm.contrib.msc.core.ir import MSCGraph, MSCJoint, MSCTensor +from tvm.contrib.msc.core.ir import MSCGraph, WeightGraph, MSCJoint, WeightJoint, MSCTensor from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core import _ffi_api class ToolType(object): """Enum all msc tool types""" BASE = "base" + WEIGHT = "weight" PRUNER = "pruner" QUANTIZER = "quantizer" DISTILLER = "distiller" @@ -346,6 +348,61 @@ def setup(self) -> dict: "debug_level": self._debug_level, } + def _parse_strategys(self, strategy_list: dict) -> Dict[str, Strategy]: + """Parse the strategy to get valid strategy + + Parameters + ------- + strategy_list: dict + The given strategy + + Returns + ------- + strategys: dict + The parsed strategy. + """ + + strategys = {} + assert isinstance(strategy_list, list) and all( + isinstance(s, dict) for s in strategy_list + ), "Strategy should be given as list of dict" + for stra in strategy_list: + method_cls_name = stra.pop("method_cls") if "method_cls" in stra else "default" + method_cls = msc_utils.get_registered_tool_method( + self.framework(), self.tool_type(), method_cls_name + ) + method_name = stra.pop("method") if "method" in stra else "default" + if hasattr(method_cls, method_name): + method = getattr(method_cls, method_name) + else: + default_cls = msc_utils.get_registered_tool_method( + MSCFramework.MSC, self.tool_type(), method_cls_name + ) + assert hasattr( + default_cls, method_name + ), "Can not find method {} from neighter {} nor {}".format( + method_name, method_cls, default_cls + ) + method = getattr(default_cls, method_name) + tensor_types = stra.pop("tensor_types") if "tensor_types" in stra else ["all"] + if "op_types" in stra: + op_types = stra.pop("op_types") + marks = [("{}.{}".format(s, t), t) for s, t in product(op_types, tensor_types)] + elif "op_names" in stra: + op_names = stra.pop("op_names") + marks = [("{}.{}".format(s, t), t) for s, t in product(op_names, tensor_types)] + else: + marks = [("default", "all")] + stages = stra.pop("stages") if "stages" in stra else ["default"] + for mark, t_type in marks: + if mark not in strategys: + strategys[mark] = Strategy(mark, t_type, self._stage) + for stage in stages: + strategys[mark].add_executor( + stage, Executor(method_name, method, copy.deepcopy(stra)) + ) + return strategys + def reset( self, graphs: List[MSCGraph], @@ -358,7 +415,7 @@ def reset( ---------- graphs: list The msc graphs. - weights: list> + weights: list> The weights cache_dir: MSCDirectory cache path for save/load info @@ -367,7 +424,7 @@ def reset( ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights """ @@ -410,14 +467,14 @@ def load_graphs( ---------- graphs: list The msc graphs. - weights: list> + weights: list> The weights Returns ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights """ @@ -612,10 +669,12 @@ def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> A if not self._support_scope(scope): return tensor strategys = self._get_tensor_strategys(name, consumer) - strategy_mark = ".".join([s.get_executor().name for s in strategys]) - cached_tensor = self._get_processed(name, consumer, strategy_mark) + t_mark = ".".join([s.get_executor().name for s in strategys]) + if scope: + t_mark += "." + scope + cached_tensor = self._get_processed(name, consumer, t_mark) if cached_tensor is not None: - self.debug_tensor(cached_tensor, name, consumer, "cached({})".format(strategy_mark)) + self.debug_tensor(cached_tensor, name, consumer, "cached({})".format(t_mark)) return cached_tensor process = self._get_tensor_cache(name, consumer, "process") if process is None: @@ -625,9 +684,9 @@ def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> A self._logger.debug("%sprocess tensor %s-%s", self.msg_mark(), name, consumer) if not process: return tensor - tensor = self._process_tensor(tensor, name, consumer, strategys) - self._save_processed(name, consumer, tensor, strategy_mark) - self.debug_tensor(tensor, name, consumer, "processed({})".format(strategy_mark)) + tensor = self._process_tensor(tensor, name, consumer, scope, strategys) + self._save_processed(name, consumer, tensor, t_mark) + self.debug_tensor(tensor, name, consumer, "processed({})".format(t_mark)) return tensor def _support_scope(self, scope: str) -> bool: @@ -710,7 +769,7 @@ def _check_tensor(self, name: str, consumer: str) -> bool: return len(strategys) > 0 def _process_tensor( - self, tensor: Any, name: str, consumer: str, strategys: List[Strategy] + self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[Strategy] ) -> Any: """Process tensor @@ -722,6 +781,8 @@ def _process_tensor( The name of the tensor. consumer: str The name of the consumer. + scope: str + The scope mark teacher| student| null. strategys: list The strategys for the tensor. @@ -733,6 +794,22 @@ def _process_tensor( return tensor + def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]: + """Update the generate configs + + Parameters + ---------- + generate_config: dict + The generate_config. + + Returns + ------- + generate_config: dict + The updated generate_config. + """ + + return generate_config + def visualize(self, visual_dir: msc_utils.MSCDirectory): """Visualize MSCGraphs @@ -1039,61 +1116,6 @@ def get_data(self, name: str) -> np.ndarray: return msc_utils.cast_array(self._weights[name]) raise Exception("Can not find data {} from {} weights".format(name, len(self._weights))) - def _parse_strategys(self, strategy_list: dict) -> Dict[str, Strategy]: - """Parse the strategy to get valid strategy - - Parameters - ------- - strategy_list: dict - The given strategy - - Returns - ------- - strategys: dict - The parsed strategy. - """ - - strategys = {} - assert isinstance(strategy_list, list) and all( - isinstance(s, dict) for s in strategy_list - ), "Strategy should be given as list of dict" - for stra in strategy_list: - method_cls_name = stra.pop("method_cls") if "method_cls" in stra else "default" - method_cls = msc_utils.get_registered_tool_method( - self.framework(), self.tool_type(), method_cls_name - ) - method_name = stra.pop("method") if "method" in stra else "default" - if hasattr(method_cls, method_name): - method = getattr(method_cls, method_name) - else: - default_cls = msc_utils.get_registered_tool_method( - MSCFramework.MSC, self.tool_type(), method_cls_name - ) - assert hasattr( - default_cls, method_name - ), "Can not find method {} from neighter {} nor {}".format( - method_name, method_cls, default_cls - ) - method = getattr(default_cls, method_name) - tensor_types = stra.pop("tensor_types") if "tensor_types" in stra else ["all"] - if "op_types" in stra: - op_types = stra.pop("op_types") - marks = [("{}.{}".format(s, t), t) for s, t in product(op_types, tensor_types)] - elif "op_names" in stra: - op_names = stra.pop("op_names") - marks = [("{}.{}".format(s, t), t) for s, t in product(op_names, tensor_types)] - else: - marks = [("default", "all")] - stages = stra.pop("stages") if "stages" in stra else ["default"] - for mark, t_type in marks: - if mark not in strategys: - strategys[mark] = Strategy(mark, t_type, self._stage) - for stage in stages: - strategys[mark].add_executor( - stage, Executor(method_name, method, copy.deepcopy(stra)) - ) - return strategys - def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any): """Save the data to tensor cache @@ -1229,3 +1251,217 @@ def framework(cls): @classmethod def tool_style(cls): return "base" + + +class WeightTool(BaseTool): + """Basic tool with weight graphs""" + + def reset( + self, + graphs: List[MSCGraph], + weights: List[Dict[str, tvm.nd.array]], + cache_dir: msc_utils.MSCDirectory = None, + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Reset the tool with graphs and weights + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + cache_dir: MSCDirectory + cache path for save/load info + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + graphs, weights = super().reset(graphs, weights, cache_dir) + assert len(graphs) == len( + self._weight_graphs + ), "Graphs {} mismatch with weight graphs {}".format(len(graphs), len(self._weight_graphs)) + if self.on_debug(3): + for idx, graph in enumerate(self._weight_graphs): + self._logger.debug( + msc_utils.msg_block("PRUNER.WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect()) + ) + return graphs, weights + + def load_graphs( + self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Load the graphs and weights + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + as_cache: bool + Whether the graphs and weights are loaded from cache + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + graphs, weights = super().load_graphs(graphs, weights) + main_wtypes, relation_wtypes = self._get_wtypes() + assert main_wtypes, "main_wtypes should be given to build weight graphs" + self._weight_graphs = [ + _ffi_api.WeightGraph(graph, main_wtypes, relation_wtypes) for graph in graphs + ] + return graphs, weights + + def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: + """Get the weight types from options + + Returns + ------- + main_wtypes: dict> + The main weight types. + relation_wtypes: dict + The relation weight types + """ + + raise NotImplementedError("_get_wtypes is not implemented in WeightTool") + + def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): + """Save runner to cache + + Parameters + ------- + cache_dir: MSCDirectory + cache path for save/load info + cache_info: dict + The cache_info + """ + + assert ( + "weight_graphs" in cache_info + ), "weight_graphs should be given in cache_info, get " + str(cache_info) + self._weight_graphs = [ + WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"] + ] + + def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: + """Save runner to cache + + Parameters + ------- + cache_dir: MSCDirectory + cache path for save/load info + + Returns + ------- + cache_info: dict + The cache_info. + """ + + cache_info = {"weight_graphs": [g.name + "_graph.json" for g in self._weight_graphs]} + with cache_dir: + for graph, f_path in zip(self._weight_graphs, cache_info["weight_graphs"]): + with open(f_path, "w") as f_graph: + f_graph.write(graph.to_json()) + return cache_info + + def visualize(self, visual_dir: msc_utils.MSCDirectory): + """Visualize MSCGraphs + + Parameters + ------- + visual_dir: MSCDirectory + Visualize path for saving graph + """ + + for w_graph in self._weight_graphs: + w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt")) + + def get_w_nodes(self) -> Iterable[WeightJoint]: + """Get all the weight nodes in the weight_graphs. + + Returns + ------- + nodes: generator + The generator of weight nodes. + """ + + for g in self._weight_graphs: + for n in g.get_nodes(): + yield n + + def has_w_node(self, name: str) -> bool: + """Check if name in weight_graphs. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + has_node: bool + Whether node in weight_graphs. + """ + + for g in self._weight_graphs: + if g.has_node(name): + return True + return False + + def find_w_node(self, name: str) -> WeightJoint: + """Find weight node by name. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + node: WeightJoint + The found node. + """ + + for g in self._weight_graphs: + if g.has_node(name): + return g.find_node(name) + raise Exception("Can not find node {} from graphs".format(name)) + + def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: + """Get the input output axes + + Parameters + ---------- + w_node: WeightJoint + The weight node. + + Returns + ------- + axes: (int, int) + The input output axis. + """ + + if w_node.weight.ndim == 1: + return 0, 0 + if w_node.has_attr("in_axis") and w_node.has_attr("out_axis"): + return int(w_node.get_attr("in_axis")), int(w_node.get_attr("out_axis")) + in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") + if in_axis >= 0 and out_axis >= 0: + return in_axis, out_axis + if w_node.weight.layout_of("C") >= 0: + return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") + raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) + + @classmethod + def tool_type(cls): + return ToolType.WEIGHT diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index 4936bf28e0a4..a56abc7b6a7e 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -21,7 +21,7 @@ import tempfile import types from functools import partial -from typing import List, Any +from typing import List, Any, Union from importlib.machinery import SourceFileLoader from .namespace import MSCMap, MSCKey, MSCFramework @@ -257,7 +257,7 @@ def msc_dir(path: str = None, keep_history: bool = True, cleanup: bool = False) def set_workspace( - path: str = None, keep_history: bool = True, cleanup: bool = False + path: Union[str, MSCDirectory] = None, keep_history: bool = True, cleanup: bool = False ) -> MSCDirectory: """Create MSCDirectory as worksapce and set to map @@ -276,6 +276,9 @@ def set_workspace( The created dir. """ + if isinstance(path, MSCDirectory): + MSCMap.set(MSCKey.WORKSPACE, path) + return path path = path or "msc_workspace" workspace = MSCDirectory(path, keep_history, cleanup) MSCMap.set(MSCKey.WORKSPACE, workspace) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py index b59865c9ce0f..9b3d9d4326db 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py @@ -25,7 +25,20 @@ class TensorflowPrunerFactory(object): """Pruner factory for tensorflow""" - def create(self, base_cls: BasePruner): + def create(self, base_cls: BasePruner) -> BasePruner: + """Create adaptive pruner + + Parameters + ---------- + base_cls: BasePruner + The base pruner class + + Returns + ------- + pruner_cls: BasePruner + The pruner class. + """ + class Pruner(base_cls): """Adaptive pruner for tensorflow""" diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py index de7ccb0747be..da591d9cebb6 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py @@ -25,7 +25,20 @@ class TensorRTPrunerFactory(object): """Pruner factory for tensorrt""" - def create(self, base_cls: BasePruner): + def create(self, base_cls: BasePruner) -> BasePruner: + """Create adaptive pruner + + Parameters + ---------- + base_cls: BasePruner + The base pruner class + + Returns + ------- + pruner_cls: BasePruner + The pruner class. + """ + class Pruner(base_cls): """Adaptive pruner for tensorrt""" diff --git a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py index 171c639ceaa3..4dfcf21dca55 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py @@ -25,7 +25,20 @@ class TorchPrunerFactory(object): """Pruner factory for torch""" - def create(self, base_cls: BasePruner): + def create(self, base_cls: BasePruner) -> BasePruner: + """Create adaptive pruner + + Parameters + ---------- + base_cls: BasePruner + The base pruner class + + Returns + ------- + pruner_cls: BasePruner + The pruner class. + """ + class Pruner(base_cls): """Adaptive pruner for torch""" diff --git a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py index 788f1090bd79..198a6985466a 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py @@ -25,7 +25,20 @@ class TVMPrunerFactory(object): """Pruner factory for tvm""" - def create(self, base_cls: BasePruner): + def create(self, base_cls: BasePruner) -> BasePruner: + """Create adaptive pruner + + Parameters + ---------- + base_cls: BasePruner + The base pruner class + + Returns + ------- + pruner_cls: BasePruner + The pruner class. + """ + class Pruner(base_cls): """Adaptive pruner for tvm""" diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 97e91ca1839b..acfa7cbd2e05 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -841,11 +841,11 @@ void MSCGraphNode::AnalysisGraph() { } } -WeightGraph::WeightGraph(const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types) { +WeightGraph::WeightGraph(const MSCGraph& graph, const Map>& main_wtypes, + const Map& relation_wtypes) { ObjectPtr n = make_object(); n->name = graph->name + "_weights"; - n->Build(graph, prunable_types, relation_types); + n->Build(graph, main_wtypes, relation_wtypes); data_ = std::move(n); } @@ -861,13 +861,13 @@ WeightGraph::WeightGraph(const std::string& json_str) { data_ = std::move(n); } -void WeightGraphNode::Build(const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types) { +void WeightGraphNode::Build(const MSCGraph& graph, const Map>& main_wtypes, + const Map& relation_wtypes) { auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) { return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index; }; - auto find_parents = [this, &prunable_types, &relation_types, &sort_nodes](const MSCJoint& node) { + auto find_parents = [this, &main_wtypes, &relation_wtypes, &sort_nodes](const MSCJoint& node) { std::vector parents; std::queue frontier; std::set explored; @@ -881,13 +881,13 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapoptype)) { - for (const auto& t_type : prunable_types[current->optype]) { + if (main_wtypes.count(current->optype)) { + for (const auto& t_type : main_wtypes[current->optype]) { if (current->weights.count(t_type)) { parents.push_back(FindNode(current->WeightAt(t_type)->name)); } } - } else if (relation_types.count(current->optype)) { + } else if (relation_wtypes.count(current->optype)) { parents.push_back(FindNode(current->OutputAt(0)->name)); } else { for (const auto& p : current->parents) { @@ -914,11 +914,11 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapshared_ref.size() > 0) { continue; } - if (prunable_types.count(node->optype) || relation_types.count(node->optype) || + if (main_wtypes.count(node->optype) || relation_wtypes.count(node->optype) || node->weights.size() > 0) { const auto& w_parents = find_parents(node); bool bind_friends = true; - if (relation_types.count(node->optype) && relation_types[node->optype] == "multi_inputs") { + if (relation_wtypes.count(node->optype) && relation_wtypes[node->optype] == "multi_inputs") { bind_friends = false; } if (w_parents.size() > 1 && bind_friends) { @@ -926,13 +926,13 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map(p)->friends = w_parents; } } - if (prunable_types.count(node->optype)) { - for (const auto& wtype : prunable_types[node->optype]) { + if (main_wtypes.count(node->optype)) { + for (const auto& wtype : main_wtypes[node->optype]) { if (node->weights.count(wtype)) { const auto& weight = node->WeightAt(wtype); Map attrs; attrs.Set("producer_type", node->optype); - attrs.Set("prune_strategy", "prune"); + attrs.Set("weight_strategy", "main"); const auto& w_node = WeightJoint(node_names.size(), weight->name, "", wtype, weight, w_parents, attrs); for (const auto& p : w_parents) { @@ -947,7 +947,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapname)) { Map attrs; attrs.Set("producer_type", node->optype); - attrs.Set("prune_strategy", "follow"); + attrs.Set("weight_strategy", "follow"); const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, pair.second, {head}, attrs); head->AddChild(w_node); @@ -955,16 +955,16 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapname); } } - } else if (relation_types.count(node->optype)) { + } else if (relation_wtypes.count(node->optype)) { const auto& tensor = node->OutputAt(0); Map attrs; attrs.Set("producer_type", node->optype); if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 && node->OutputAt(0)->LayoutOf("C") >= 0 && node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) { - attrs.Set("prune_strategy", "passby"); + attrs.Set("weight_strategy", "passby"); } else { - attrs.Set("prune_strategy", relation_types[node->optype]); + attrs.Set("weight_strategy", relation_wtypes[node->optype]); } const auto& t_node = WeightJoint(node_names.size(), tensor->name, "", "output", tensor, w_parents, attrs); @@ -978,7 +978,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapname)) { Map attrs; attrs.Set("producer_type", node->optype); - attrs.Set("prune_strategy", "follow"); + attrs.Set("weight_strategy", "follow"); const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, pair.second, w_parents, attrs); for (const auto& p : w_parents) { @@ -1294,9 +1294,9 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraph") }); TVM_REGISTER_GLOBAL("msc.core.WeightGraph") - .set_body_typed([](const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types) -> WeightGraph { - return WeightGraph(graph, prunable_types, relation_types); + .set_body_typed([](const MSCGraph& graph, const Map>& main_wtypes, + const Map& relation_wtypes) -> WeightGraph { + return WeightGraph(graph, main_wtypes, relation_wtypes); }); // MSC Graph APIS From e51b1d5d47e3a88b3702625405fd0deefb5177f2 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 1 Dec 2023 18:19:31 +0800 Subject: [PATCH 3/3] change debug level --- .../contrib/msc/core/tools/prune/pruner.py | 4 +- python/tvm/contrib/msc/core/tools/tool.py | 78 +++++++++---------- python/tvm/contrib/msc/core/utils/file.py | 1 + src/contrib/msc/core/ir/graph.cc | 35 ++++++--- src/contrib/msc/core/ir/graph.h | 2 +- 5 files changed, 66 insertions(+), 54 deletions(-) diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index d5dc2ee5a7a9..5fa625c2adc7 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -387,7 +387,7 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): pruned_tensors[out.name] = _prune_by_channel( out, pruned_tensors[node.input_at(0).name].dim_at("C") ) - if self.on_debug(3): + if self.on_debug(3, in_forward=False): self._logger.debug(msc_utils.msg_block("Pruned Tensors", pruned_tensors)) pruned_graph = _ffi_api.PruneWeights(graph, pruned_tensors) new_graphs.append(pruned_graph) @@ -404,7 +404,7 @@ def _flatten_size(weights): raw_size = _flatten_size(weights) new_size = _flatten_size(new_weights) self._logger.info( - "{} weights pruned, compress to {:g}% ({:g} M->{:g} M)".format( + "Prune {} weights, compress to {:g}% ({:g} M->{:g} M)".format( pruned_weights_cnt, new_size * 100 / raw_size, raw_size / 2**20, diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index dad96d5b9eb2..c37ec3db974a 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -432,7 +432,10 @@ def reset( self._tensor_cache = {} if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")): cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json")) - self.load_cache(cache_dir, cache_info) + else: + cache_info = {} + if self.tool_type() in cache_info: + self.load_cache(cache_dir, cache_info[self.tool_type()]) else: graphs, weights = self.load_graphs(graphs, weights) self._graphs, self._weights = graphs, {} @@ -440,12 +443,18 @@ def reset( self._weights.update(sub_weights) self._logger.debug( "%s load %d graphs and %d weights", - self.tool_type().upper(), + self.tool_type(), len(self._graphs), len(self._weights), ) + self._reset() return self._graphs, weights + def _reset(self): + """Extra reset for tool""" + + return None + def change_stage(self, stage: str): """Change the stage of tools and strategy""" @@ -523,7 +532,8 @@ def execute_before_build(self, *args, **kwargs): if self._enabled: self._graph_id = self._infer_graph_id(kwargs) self._processed_tensor = {} - self._logger.debug("%sStart Build", self.msg_mark(in_forward=False)) + if self.on_debug(3, in_forward=False): + self._logger.debug("%sStart Build", self.msg_mark(in_forward=False)) self._execute_before_build(*args, **kwargs) def _execute_before_build(self, *args, **kwargs): @@ -555,7 +565,8 @@ def execute_after_build(self, output: Any) -> Any: if self._enabled: output = self._execute_after_build(output) - self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False)) + if self.on_debug(3, in_forward=False): + self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False)) return output def _execute_after_build(self, output: Any) -> Any: @@ -588,7 +599,7 @@ def execute_before_forward(self, *args, **kwargs): if self._enabled: self._graph_id = self._infer_graph_id(kwargs) self._processed_tensor = {} - if self.on_debug(2): + if self.on_debug(3): self._logger.debug("%sStart Forward", self.msg_mark()) self._execute_before_forward(*args, **kwargs) @@ -621,7 +632,7 @@ def execute_after_forward(self, output: Any) -> Any: if self._enabled: output = self._execute_after_forward(output) - if self.on_debug(2): + if self.on_debug(3): self._logger.debug( "%sEnd Forward, process %d tensors", self.msg_mark(), @@ -925,13 +936,15 @@ def is_weight(self, name: str) -> bool: return name in self._weights - def on_debug(self, debug_level: int = 1) -> bool: + def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool: """Check if should log Parameters ------- debug_level: int The given debug_level. + in_forward: bool + Whether to check forward_cnt. Returns ------- @@ -939,13 +952,18 @@ def on_debug(self, debug_level: int = 1) -> bool: Whether to log debug info. """ - if self._forward_cnt % self._verbose_step != 0: + if in_forward and self._forward_cnt % self._verbose_step != 0: return False return self._debug_level >= debug_level - def msg_mark(self, in_forward=True) -> str: + def msg_mark(self, in_forward: bool = True) -> str: """Get the debug title + Parameters + ------- + in_forward: bool + Whether to add forward mark. + Returns ------- msg_mark: str @@ -959,7 +977,7 @@ def msg_mark(self, in_forward=True) -> str: return title def debug_tensor( - self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 2 + self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 3 ) -> str: """Get the debug tensor info @@ -1256,41 +1274,21 @@ def tool_style(cls): class WeightTool(BaseTool): """Basic tool with weight graphs""" - def reset( - self, - graphs: List[MSCGraph], - weights: List[Dict[str, tvm.nd.array]], - cache_dir: msc_utils.MSCDirectory = None, - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: - """Reset the tool with graphs and weights + def _reset(self): + """Extra reset for tool""" - Parameters - ---------- - graphs: list - The msc graphs. - weights: list> - The weights - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - graphs: list - The msc graphs. - weights: list> - The weights - """ - - graphs, weights = super().reset(graphs, weights, cache_dir) - assert len(graphs) == len( + super()._reset() + assert len(self._graphs) == len( self._weight_graphs - ), "Graphs {} mismatch with weight graphs {}".format(len(graphs), len(self._weight_graphs)) - if self.on_debug(3): + ), "Graphs {} mismatch with weight graphs {}".format( + len(self._graphs), len(self._weight_graphs) + ) + self._logger.debug("%s load %d weight graphs", self.tool_type(), len(self._weight_graphs)) + if self.on_debug(2, in_forward=False): for idx, graph in enumerate(self._weight_graphs): self._logger.debug( - msc_utils.msg_block("PRUNER.WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect()) + msc_utils.msg_block("WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect()) ) - return graphs, weights def load_graphs( self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index a56abc7b6a7e..146cfaf50434 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -352,3 +352,4 @@ def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = T get_dataset_dir = partial(get_workspace_subdir, name="Dataset") get_output_dir = partial(get_workspace_subdir, name="Output") get_visual_dir = partial(get_workspace_subdir, name="Visual") +get_weights_dir = partial(get_workspace_subdir, name="Weights") diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index acfa7cbd2e05..c563ccb4101c 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -708,20 +708,21 @@ const Array MSCGraphNode::GetExits() const { } const bool MSCGraphNode::HasTensor(const String& name) const { - if (weight_holders.count(name)) { + const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + if (weight_holders.count(tensor_name)) { return true; } - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); return nodes.count(host) > 0 ? true : false; } const MSCTensor MSCGraphNode::FindTensor(const String& name) const { - if (weight_holders.count(name)) { - const auto& node = FindNode(weight_holders[name][0]); + const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + if (weight_holders.count(tensor_name)) { + const auto& node = FindNode(weight_holders[tensor_name][0]); for (const auto& pair : node->weights) { - if (pair.second->name == name) { + if (pair.second->name == tensor_name) { return pair.second; } } @@ -732,8 +733,9 @@ const MSCTensor MSCGraphNode::FindTensor(const String& name) const { } const MSCJoint MSCGraphNode::FindProducer(const String& name) const { - if (weight_holders.count(name)) { - return FindNode(weight_holders[name][0]); + const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + if (weight_holders.count(tensor_name)) { + return FindNode(weight_holders[tensor_name][0]); } const auto& pair = FindProducerAndIdx(name); return pair.first; @@ -744,8 +746,8 @@ const MSCJoint MSCGraphNode::FindProducer(const MSCTensor& tensor) const { } const std::pair MSCGraphNode::FindProducerAndIdx(const String& name) const { - ICHECK(!weight_holders.count(name)) << "Weight " << name << " has no producer with index"; const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + ICHECK(!weight_holders.count(tensor_name)) << "Weight " << name << " has no producer with index"; String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); if (index.size() == 0) { @@ -762,8 +764,9 @@ const std::pair MSCGraphNode::FindProducerAndIdx(const MSCTens const Array MSCGraphNode::FindConsumers(const String& name) const { Array consumers; - if (weight_holders.count(name)) { - for (const auto& h : weight_holders[name]) { + const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + if (weight_holders.count(tensor_name)) { + for (const auto& h : weight_holders[tensor_name]) { consumers.push_back(FindNode(h)); } } else { @@ -781,7 +784,8 @@ const Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const const std::vector> MSCGraphNode::FindConsumersAndIndices( const String& name) const { - ICHECK(!weight_holders.count(name)) << "Weight has no index"; + const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index"; std::vector> consumers; for (const auto& c : FindConsumers(name)) { bool find_tensor = false; @@ -836,6 +840,9 @@ void MSCGraphNode::AnalysisGraph() { weight_holders.Set(w_name, holders); } else { weight_holders.Set(w_name, Array({n})); + if (pair.second->alias.size() > 0) { + tensor_alias.Set(pair.second->alias, pair.second->name); + } } } } @@ -1320,6 +1327,12 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") return graph->FindTensor(name); }); +TVM_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias") + .set_body_typed([](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { + tensor->alias = alias; + graph->tensor_alias.Set(alias, tensor->name); + }); + TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { return graph->FindProducer(name); diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 85880841d4d8..67855deb97ab 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -716,7 +716,7 @@ class MSCGraphNode : public BaseGraphNode { /*! \brief The output names of graph. */ Array output_names; /*! \brief The tensor alias in graph, get by AnalysisGraph. */ - Map tensor_alias; + mutable Map tensor_alias; /*! \brief The weights in graph, get by AnalysisGraph. */ Map> weight_holders; /*! \brief Export graph to json. */