diff --git a/gallery/how_to/work_with_msc/using_tools.py b/gallery/how_to/work_with_msc/using_tools.py index 28cbc4c198bd..c8187d218d9b 100644 --- a/gallery/how_to/work_with_msc/using_tools.py +++ b/gallery/how_to/work_with_msc/using_tools.py @@ -57,11 +57,12 @@ parser.add_argument("--test_iter", type=int, default=100, help="The iter for test") parser.add_argument("--calibrate_iter", type=int, default=100, help="The iter for calibration") parser.add_argument("--train_batch", type=int, default=32, help="The batch size for train") -parser.add_argument("--train_iter", type=int, default=200, help="The iter for train") -parser.add_argument("--train_epoch", type=int, default=100, help="The epoch for train") +parser.add_argument("--train_iter", type=int, default=100, help="The iter for train") +parser.add_argument("--train_epoch", type=int, default=5, help="The epoch for train") parser.add_argument( "--verbose", type=str, default="info", help="The verbose level, info|debug:1,2,3|critical" ) +parser.add_argument("--dynamic", action="store_true", help="Whether to use dynamic wrapper") args = parser.parse_args() @@ -88,8 +89,8 @@ def get_config(calib_loader, train_loader): compile_type=args.compile_type, dataset=dataset, tools=tools, - skip_config={"all": "check"}, verbose=args.verbose, + dynamic=args.dynamic, ) @@ -100,13 +101,13 @@ def _get_calib_datas(): for i, (inputs, _) in enumerate(testloader, 0): if i >= args.calibrate_iter > 0: break - yield {"input": inputs} + yield inputs if args.dynamic else {"input": inputs} def _get_train_datas(): for i, (inputs, _) in enumerate(trainloader, 0): if i >= args.train_iter > 0: break - yield {"input": inputs} + yield inputs if args.dynamic else {"input": inputs} model = resnet50(pretrained=args.checkpoint) if torch.cuda.is_available(): diff --git a/python/tvm/contrib/msc/core/gym/environment/method.py b/python/tvm/contrib/msc/core/gym/environment/method.py index 405318c447d9..296688eceace 100644 --- a/python/tvm/contrib/msc/core/gym/environment/method.py +++ b/python/tvm/contrib/msc/core/gym/environment/method.py @@ -105,7 +105,7 @@ def _get_loss(golden, result): outputs = runner.run(inputs) baseline = loader[idx] for name, data in outputs.items(): - loss += _get_loss(baseline[name], data) + loss += _get_loss(baseline[name], msc_utils.cast_array(data)) return {"loss": loss / len(loader)} @classmethod diff --git a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py index 72dee8e5de67..fcedcf5f7f88 100644 --- a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py +++ b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py @@ -70,7 +70,7 @@ def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str] continue info.update(strategys[name].get_executor(msc_utils.MSCStage.QUANTIZE).config) summary_file = msc_utils.get_cache_dir().relpath("gym_summary.json") - return msc_utils.dump_dict(plan, summary_file) + return msc_utils.save_dict(plan, summary_file) @classmethod def role_type(cls): diff --git a/python/tvm/contrib/msc/core/runtime/__init__.py b/python/tvm/contrib/msc/core/runtime/__init__.py index a0ccca5b2bc4..6eb9f6df5ffd 100644 --- a/python/tvm/contrib/msc/core/runtime/__init__.py +++ b/python/tvm/contrib/msc/core/runtime/__init__.py @@ -17,3 +17,4 @@ """tvm.contrib.msc.core.runtime""" from .runner import * +from .jit import * diff --git a/python/tvm/contrib/msc/core/runtime/jit.py b/python/tvm/contrib/msc/core/runtime/jit.py new file mode 100644 index 000000000000..5b1d9a8c3c02 --- /dev/null +++ b/python/tvm/contrib/msc/core/runtime/jit.py @@ -0,0 +1,365 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.runtime.jit_model""" + +import logging +from typing import Any, List, Tuple, Union, Dict + +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from .runner import BaseRunner + + +class BaseJIT(object): + """Base Just-In-Time compile for msc + + Parameters + ---------- + model: + The model to be jit compile. + inputs: list + The input names. + outputs: list + The output names. + device: str + The device to build runnable. + training: bool + Whether compile model to trainable. + hooks: dict + The hooks for runners. + logger: logging.Logger + The logger + """ + + def __init__( + self, + model: Any, + inputs: List[str], + outputs: List[str], + device: str = "cpu", + training: bool = False, + hooks: dict = None, + logger: logging.Logger = None, + ): + self._model = model + self._jit_model = model + self._inputs = inputs + self._outputs = outputs + self._device = device if self.support_device(device) else "cpu" + self._training, self._trained = training, training + self._hooks = hooks or {} + self._runner_ctxs = {} + self._logger = logger or msc_utils.get_global_logger() + self._logger.info(msc_utils.msg_block(self.jit_mark("SETUP"), self.setup())) + + def setup(self) -> dict: + """Setup the jit + + Returns + ------- + info: dict + The setup info. + """ + + return { + "inputs": self._inputs, + "outputs": self._outputs, + "device": self._device, + "training": self._training, + "hooks": self._hooks, + } + + def run( + self, inputs: Union[List[Any], Dict[str, Any]], ret_type="native" + ) -> Union[List[Any], Dict[str, Any]]: + """Run the jit to get outputs + + Parameters + ------- + inputs: list or dict + The inputs in list or dict. + ret_type: str + The return type list| dict + + Returns + ------- + outputs: dict + The outputs in dict. + """ + + inputs = msc_utils.format_datas(inputs, self._inputs, style="dict") + outputs = self._call_jit(inputs) + if ret_type == "native": + return outputs + return msc_utils.format_datas(outputs, self._outputs, style=ret_type) + + def _call_jit(self, inputs: Dict[str, Any]) -> Any: + """Run the jit model + + Parameters + ---------- + inputs: + The inputs of model. + """ + + raise NotImplementedError("_call_jit is not implemented in " + str(self.__class__)) + + def set_runner(self, runner_name: str, runner: BaseRunner): + """Set runner in runner ctx + + Parameters + ---------- + runner_name: str + The runner name. + runner: BaseRunner + The runner. + """ + + self.get_runner_ctx(runner_name)["runner"] = runner + + def build(self): + """Build the jit model""" + + self._jit_model = self._build(self._model) + + def _build(self, model: Any) -> Any: + """Build the jit model + + Parameters + ---------- + model: + The model. + + Returns + ------- + jit_model: + The jit model. + """ + + raise NotImplementedError("_build is not implemented in " + str(self.__class__)) + + def make_plan(self, tool_type: str, data_loader: Any = None) -> str: + """Execute tool and get plan + + Parameters + ------- + tool_type: str + The tool type, should be in ToolType + data_loader: + The data loader. + + Returns + ------- + plan_file: str + The saved plan file. + """ + + tools = {n: r["runner"].get_tool(tool_type) for n, r in self._runner_ctxs.items()} + + def _finalize_tool( + checker: callable, post_batch: callable = None, post_iter: callable = None + ): + while any(not checker(t) for t in tools.values()): + assert data_loader, "data_loader should be given to make plan for " + tool_type + for inputs in data_loader(): + outputs = self.run(inputs, ret_type="native") + if post_batch: + for t in tools.values(): + post_batch(t, outputs) + if all(checker(t) for t in tools.values()): + break + if post_iter: + for t in tools.values(): + post_iter(t) + return {n: t.finalize() for n, t in tools.items()} + + if tool_type == ToolType.PRUNER: + plans = _finalize_tool(lambda t: t.pruned) + elif tool_type == ToolType.QUANTIZER: + plans = _finalize_tool(lambda t: t.calibrated, post_iter=lambda t: t.calibrate()) + elif tool_type == ToolType.DISTILLER: + plans = _finalize_tool( + lambda t: t.distilled, + post_batch=lambda t, outputs: t.learn(outputs), + post_iter=lambda t: t.distill(), + ) + elif tool_type == ToolType.TRACKER: + plans = _finalize_tool(lambda t: t.tracked) + else: + plans = {n: t.finalize() for n, t in tools.items()} + plans_info = ", ".join(["{}({})".format(n, len(p)) for n, p in plans.items()]) + self._logger.debug("Made %s plans for %s", plans_info, tool_type) + + def _redirect_run(self, *args, runner_name: str = "worker", **kwargs) -> Any: + """Redirect forward of model + + Parameters + ---------- + args: + The arguments. + runner_name: str + The runner name. + kwargs: + The kwargs. + + Returns + ------- + outputs: + The outputs. + """ + + assert runner_name in self._runner_ctxs, "Failed to create runner " + runner_name + inputs = self._to_msc_inputs(runner_name, *args, **kwargs) + for hook in self._hooks.get("pre_forward", []): + hook(runner_name, inputs) + outputs = self._run_ctx(self.get_runner_ctx(runner_name), inputs) + for hook in self._hooks.get("post_forward", []): + outputs = hook(runner_name, outputs) + return self._from_msc_outputs(runner_name, outputs) + + def _to_msc_inputs(self, runner_name: str, *args, **kwargs) -> List[Tuple[str, Any]]: + """Change inputs to msc format + + Parameters + ---------- + runner_name: str + The runner name. + args: + The arguments. + kwargs: + The kwargs. + + Returns + ------- + inputs: + The msc format inputs. + """ + + raise NotImplementedError("_to_msc_inputs is not implemented in " + str(self.__class__)) + + def _from_msc_outputs(self, runner_name: str, outputs: List[Tuple[str, Any]]) -> Any: + """Change inputs from msc format + + Parameters + ---------- + runner_name: str + The runner name. + outputs: list<(str, tensor)> + The msc format outputs. + + Returns + ------- + outputs: + The framework outputs. + """ + + raise NotImplementedError("_from_msc_outputs is not implemented in " + str(self.__class__)) + + def _run_ctx(self, runner_ctx: dict, inputs: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: + """Forward by runner context + + Parameters + ---------- + runner_ctx: dict + The runner context + inputs: list<(str, tensor)> + The inputs. + + Returns + ------- + outputs: list<(str, tensor)> + The outputs. + """ + + raise NotImplementedError("_run_ctx is not implemented in " + str(self.__class__)) + + def get_runner_ctx(self, runner_name: str) -> dict: + """Get the runner context + + Parameters + ---------- + runner_name: str + The runner name + + Returns + ------- + runner_cts: dict + The runner context. + """ + + assert runner_name in self._runner_ctxs, "Can not finc runner_context " + str(runner_name) + return self._runner_ctxs[runner_name] + + def train(self): + """Change status to train""" + + if not self._training: + self._training = True + for runner_ctx in self._runner_ctxs.values(): + if "runner" in runner_ctx: + runner_ctx["runner"].train() + + def eval(self): + """Change status to eval""" + + if self._training: + self._training, self._trained = False, True + for runner_ctx in self._runner_ctxs.values(): + if "runner" in runner_ctx: + runner_ctx["runner"].eval() + + def jit_mark(self, msg: str): + """Mark the message with jit info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "JIT({}) {}".format(self.framework, msg) + + @property + def trained(self): + return self._trained + + @property + def jit_model(self): + return self._jit_model + + @property + def framework(self): + return MSCFramework.MSC + + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + return True diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index e4a9aaa1d39b..8b0646b1d927 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -55,7 +55,7 @@ class BaseRunner(object): device: str The device to build runnable. training: bool - Whether compile model to trainable + Whether compile model to trainable. stage: str The stage of runner. plugin: PluginManager @@ -94,7 +94,7 @@ def __init__( self._translate_config = msc_utils.copy_dict(translate_config) self._generate_config = msc_utils.copy_dict(generate_config) self._build_config = msc_utils.copy_dict(build_config) - self._device = device if self._device_enabled(device) else "cpu" + self._device = device if self.support_device(device) else "cpu" self._stage = stage self._plugin = plugin self._name = name @@ -274,7 +274,7 @@ def _build_scope_model(scope: str, apply_hooks: bool): build_msg += "runnable({}, {}) on {}".format( self.framework, "train" if self._training else "eval", self._device ) - self._logger.info(build_msg) + self._logger.info(self.runner_mark(build_msg)) return self._runnable def run( @@ -295,45 +295,13 @@ def run( The outputs in dict. """ - model_inputs = self.get_inputs() - model_outputs = self.get_outputs() - if isinstance(inputs, (list, tuple)): - assert len(inputs) == len( - model_inputs - ), "inputs({}) mismatch with model inputs {}".format(len(inputs), model_inputs) - inputs = {info["name"]: data for info, data in zip(model_inputs, inputs)} - assert isinstance(inputs, dict), "Expect inputs as list or dict, get {}({})".format( - inputs, type(inputs) - ) - assert all( - msc_utils.is_array(data) for data in inputs.values() - ), "Expected all inputs as array like" - inputs = {i["name"]: inputs[i["name"]] for i in model_inputs} + in_names = [i["name"] for i in self.get_inputs()] + inputs = msc_utils.format_datas(inputs, in_names, style="dict") outputs = self._call_runnable(self._runnable, inputs, self._device) if ret_type == "native": return outputs - if ret_type == "dict": - if isinstance(outputs, (list, tuple, tvm.ir.container.Array)): - assert len(outputs) == len( - model_outputs - ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) - outputs = {info["name"]: data for info, data in zip(model_outputs, outputs)} - if not isinstance(outputs, dict): - assert len(model_outputs) == 1, "Expect model_outputs with len 1, get " + str( - model_outputs - ) - outputs = {model_outputs[0]["name"]: outputs} - return {name: msc_utils.cast_array(data) for name, data in outputs.items()} - if ret_type == "list": - if isinstance(outputs, dict): - assert len(outputs) == len( - model_outputs - ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) - outputs = [outputs[o["name"]] for o in model_outputs] - if not isinstance(outputs, (list, tuple)): - outputs = [outputs] - return [msc_utils.cast_array(data) for data in outputs] - return outputs + out_names = [o["name"] for o in self.get_outputs()] + return msc_utils.format_datas(outputs, out_names, style=ret_type) def save_cache( self, @@ -548,7 +516,7 @@ def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: The exported module """ - raise NotImplementedError("export_module is not supported in BaseRunner") + raise NotImplementedError("export_module is not implemented for " + str(self.__class__)) def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: """Export the runnable @@ -564,7 +532,23 @@ def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: The runnable info. """ - raise NotImplementedError("export_runnable is not supported in BaseRunner") + raise NotImplementedError("export_runnable is not implemented for " + str(self.__class__)) + + def export_graphs(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the graphs + + Parameters + ------- + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The graphs info. + """ + + raise NotImplementedError("export_graphs is not implemented for " + str(self.__class__)) def train(self): """Change status to train""" @@ -584,8 +568,7 @@ def eval(self): """Change status to eval""" if self._training: - self._trained = True - self._training = False + self._training, self._trained = False, True for tool in self.get_tools(): tool.eval() self._eval() @@ -657,47 +640,42 @@ def make_plan(self, tool_type: str, data_loader: Any = None) -> str: The saved plan file. """ + def _finalize_tool( + checker: callable, post_batch: callable = None, post_iter: callable = None + ): + tool = self.get_tool(tool_type) + while not checker(tool): + assert data_loader, "data_loader should be given to make plan for " + tool_type + for inputs in data_loader(): + outputs = self.run(inputs, ret_type="native") + if post_batch: + post_batch(tool, outputs) + if checker(tool): + break + if post_iter: + post_iter(tool) + return tool.finalize() + assert tool_type in self._tools, "Can not find tool " + str(tool_type) if tool_type == ToolType.PRUNER: - pruner = self.get_tool(ToolType.PRUNER) - if not pruner.pruned: - assert data_loader, "data_loader should be given to plan prune" - for inputs in data_loader(): - self.run(inputs, ret_type="native") - break - plan = pruner.finalize() + plan = _finalize_tool(lambda t: t.pruned) elif tool_type == ToolType.QUANTIZER: - quantizer = self.get_tool(ToolType.QUANTIZER) - while not quantizer.calibrated: - assert data_loader, "data_loader should be given to plan prune" - for inputs in data_loader(): - self.run(inputs, ret_type="native") - quantizer.calibrate() - plan = quantizer.finalize() + plan = _finalize_tool(lambda t: t.calibrated, post_iter=lambda t: t.calibrate()) elif tool_type == ToolType.DISTILLER: - distiller = self.get_tool(ToolType.DISTILLER) - while not distiller.distilled: - assert data_loader, "data_loader should be given to plan prune" - for inputs in data_loader(): - loss = self.run(inputs, ret_type="native") - distiller.learn(loss) - distiller.distill() - plan = distiller.finalize() + plan = _finalize_tool( + lambda t: t.distilled, + post_batch=lambda t, outputs: t.learn(outputs), + post_iter=lambda t: t.distill(), + ) elif tool_type == ToolType.TRACKER: - tracker = self.get_tool(ToolType.TRACKER) - if not tracker.tracked: - assert data_loader, "data_loader should be given to plan prune" - for inputs in data_loader(): - self.run(inputs, ret_type="native") - if tracker.tracked: - break - plan = tracker.finalize() + plan = _finalize_tool(lambda t: t.tracked) else: plan = self.get_tool(tool_type).finalize() self._logger.debug("Made %d plan for %s", len(plan), tool_type) plan_file = self._tools_config[tool_type]["plan_file"] - with open(plan_file, "w") as f: - f.write(json.dumps(plan, indent=2)) + if plan: + with open(plan_file, "w") as f: + f.write(json.dumps(plan, indent=2)) return plan_file def _apply_hook(self, desc: str, hook_def: dict, *args, **kwargs) -> Any: @@ -744,17 +722,22 @@ def _update_codegen(self, config: Dict[str, Any]): else: raise TypeError("Unexpecet codegen config " + str(codegen)) - def visualize(self, visual_dir: msc_utils.MSCDirectory): + def visualize(self, visual_dir: msc_utils.MSCDirectory, export_graph: bool = False): """Visualize MSCGraphs Parameters ------- visual_dir: MSCDirectory Visualize path for saving graph + export_graph: bool + Whether to export the graph """ for graph in self._graphs: graph.visualize(visual_dir.relpath(graph.name + ".prototxt")) + if export_graph: + with open(visual_dir.relpath(graph.name + "_graph.json"), "w") as f_graph: + f_graph.write(graph.to_json()) for tool in self._tools.values(): tool.visualize(visual_dir) @@ -976,17 +959,6 @@ def _call_runnable( raise NotImplementedError("_call_runnable is not implemented for " + str(self.__class__)) - def _device_enabled(self, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - return True - def runner_mark(self, msg: Any) -> str: """Mark the message with runner info @@ -1001,7 +973,7 @@ def runner_mark(self, msg: Any) -> str: The message with mark. """ - return "RUNNER({} @ {}) {}".format(self.framework, self._stage, msg) + return "RUNNER[{}]({} @ {}) {}".format(self._name, self.framework, self._stage, msg) @property def stage(self): @@ -1011,6 +983,10 @@ def stage(self): def debug_level(self): return self._debug_level + @property + def trained(self): + return self._trained + @property def model(self): return self._model @@ -1058,6 +1034,66 @@ def load_native(cls, model: Any, config: dict) -> Tuple[Any, str, bool]: return model, "cpu", False + @classmethod + def run_native( + cls, + model: Any, + inputs: Dict[str, np.ndarray], + input_names: List[str], + output_names: List[str], + warm_up: int = 10, + repeat: int = 0, + ) -> Tuple[Dict[str, np.ndarray], float]: + """Run the datas and get outputs + + Parameters + ------- + model: + The nativate model. + inputs: dict + The inputs in dict. + input_names: list + The input names. + output_names: list + The outut names. + warm_up: int + The warm_up num for profile. + repeat: int + The repeat num for profile. + + Returns + ------- + outputs: dict + The outputs in dict. + avg_time: float + The average time. + """ + + raise NotImplementedError("run_native is not implemented for " + str(cls)) + + @classmethod + def dump_nativate( + cls, model: Any, folder: msc_utils.MSCDirectory, dump_config: dict = None + ) -> str: + """Dump the nativate model + + Parameters + ------- + model: + The native model. + folder: MSCDirectory + The export folder. + dump_config: dict + The dump config. + + Returns + ------- + export_path: str + The exported path + """ + + raise NotImplementedError("dump_nativate is not implemented for " + str(cls)) + @classmethod def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: """Update the config for parse @@ -1094,6 +1130,18 @@ def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: config[stage]["run_config"] = run_config return config + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + return True + class ModelRunner(BaseRunner): """Model runner of MSC""" @@ -1218,6 +1266,25 @@ def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: ) return module + def export_graphs(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the graphs + + Parameters + ------- + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The graphs info. + """ + + graphs = {"main": folder.relpath(self._graphs[0].name + "_graph.json")} + with open(graphs["main"], "w") as f_graph: + f_graph.write(self._graphs[0].to_json()) + return graphs + class BYOCRunner(BaseRunner): """BYOC runner of MSC""" @@ -1235,17 +1302,22 @@ def setup(self) -> dict: self._executable = None return super().setup() - def visualize(self, visual_dir: msc_utils.MSCDirectory): + def visualize(self, visual_dir: msc_utils.MSCDirectory, export_graph: bool = False): """Visualize MSCGraphs Parameters ------- visual_dir: MSCDirectory Visualize path for saving graph + export_graph: bool + Whether to export the graph """ super().visualize(visual_dir) self._byoc_graph.visualize(visual_dir.relpath(self._byoc_graph.name + ".prototxt")) + if export_graph: + with open(visual_dir.relpath(self._byoc_graph.name + "_graph.json"), "w") as f_graph: + f_graph.write(self._byoc_graph.to_json()) def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Translate IRModule to MSCgraphs @@ -1350,10 +1422,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra """ extra_option = self._generate_config.get("extra_option", {}) - if self._stage == MSCStage.COMPILE and not self.get_tool(ToolType.TRACKER): - extra_option["tool_tag"] = "" - else: - extra_option["tool_tag"] = self._name + extra_option["tool_tag"] = "" if self._stage == MSCStage.COMPILE else self._name return self.codegen_func( self._byoc_mod, graphs, @@ -1438,24 +1507,31 @@ def _inspect_model(self) -> dict: self._logger.debug(msc_utils.msg_block(title, sub_graphs)) return self._byoc_graph.inspect() - def _device_enabled(self, device: str) -> bool: - """Check if the device is enabled + def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the runnable + + Parameters + ------- + folder: MSCDirectory + The export folder. Returns ------- - enabled: bool - Whether the device is enabled. + info: dict + The runnable info. """ - if device == "cpu": - return True - if device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - return tvm.cuda(dev_id).exist - return False + export_lib = folder.relpath("lib.so") + self._executable.export_library(export_lib) + return { + "lib": export_lib, + "device": self.device, + "model_type": self.framework, + "abstract": self.model_info, + } - def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the runnable + def export_graphs(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the graphs Parameters ------- @@ -1465,13 +1541,37 @@ def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: Returns ------- info: dict - The runnable info. + The graphs info. """ - export_path = folder.relpath("model.so") - self._executable.export_library(export_path) - return {"model": export_path} + graphs = { + "byoc_graph": folder.relpath(self._byoc_graph.name + "_graph.json"), + "sub_graphs": {g.name: folder.relpath(g.name + "_graph.json") for g in self._graphs}, + } + with open(graphs["byoc_graph"], "w") as f: + f.write(self._byoc_graph.to_json()) + for graph in self._graphs: + with open(graphs["sub_graphs"][graph.name], "w") as f: + f.write(graph.to_json()) + return graphs @property def partition_func(self): raise NotImplementedError("partition_func is not implemented for " + str(self.__class__)) + + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + if device == "cpu": + return True + if device.startswith("cuda"): + dev_id = int(device.split(":")[1]) if ":" in device else 0 + return tvm.cuda(dev_id).exist + return False diff --git a/python/tvm/contrib/msc/core/tools/configer.py b/python/tvm/contrib/msc/core/tools/configer.py index 2c6789591721..1dffd1b10fef 100644 --- a/python/tvm/contrib/msc/core/tools/configer.py +++ b/python/tvm/contrib/msc/core/tools/configer.py @@ -93,7 +93,7 @@ def config_gym(self, gym_config: Union[dict, str]) -> dict: raise NotImplementedError("config_gym is not implemented in ToolConfiger") def config_apply(self) -> dict: - """Get the config fro apply + """Get the config for apply Returns ------- diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py b/python/tvm/contrib/msc/core/tools/distill/distiller.py index 39e06b701bbe..55b7947a6e20 100644 --- a/python/tvm/contrib/msc/core/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py @@ -37,7 +37,7 @@ def setup(self) -> dict: The setup info. """ - self._max_iter = self._options.get("max_iter", 5) + self._max_iter = self._options.get("max_iter", 1) self._save_step = self._options.get("save_step", 50) if "weights_folder" in self._options: self._weights_folder = msc_utils.msc_dir(self._options["weights_folder"]) @@ -72,7 +72,8 @@ def _reset( with open(self._weights_path, "rb") as f: distilled_weights = tvm.runtime.load_param_dict(f.read()) weights.update({k: v for k, v in distilled_weights.items() if k in weights}) - self._logger.info("Update %d distilled weights", len(distilled_weights)) + msg = "Update {} distilled weights".format(len(distilled_weights)) + self._logger.info(self.tool_mark(msg)) return super()._reset(graphs, weights) def build_model(self, teacher: Any, student: Any) -> Any: @@ -103,7 +104,8 @@ def learn(self, loss: Any): """ if self.on_debug(3, in_forward=False): - self._logger.debug("%s start learn[%d]", self.tool_type(), self._current_iter) + msg = "Start learn[{}]".format(self._current_iter) + self._logger.debug(self.tool_mark(msg)) self._total_loss += float(self._learn(loss)) def _learn(self, loss: Any): @@ -134,9 +136,10 @@ def distill(self) -> Dict[str, Any]: if self._current_iter >= self._max_iter: self._distilled = True self._plan = {n: msc_utils.inspect_array(d, False) for n, d in weights.items()} - self._logger.info( - "Distill[%d] loss(%d batch) %f", self._current_iter, self._forward_cnt, self._total_loss + msg = "Distill[{}] loss({} batch) {}".format( + self._current_iter, self._forward_cnt, self._total_loss ) + self._logger.info(self.tool_mark(msg)) self._current_iter += 1 self._total_loss, self._forward_cnt = 0, 0 return weights @@ -165,8 +168,9 @@ def _save_weights(self, weights: Dict[str, Any]): weights_path = self._weights_folder.relpath("distill_{}.bin".format(self._current_iter)) with open(weights_path, "wb") as f_params: f_params.write(tvm.runtime.save_param_dict(weights)) - if self.on_debug(2, in_forward=False): - self._logger.debug("Save weights[%d] to %s", self._current_iter, weights_path) + if self._debug_level >= 2: + msg = "Save weights[{}] to {}".format(self._current_iter, weights_path) + self._logger.debug(self.tool_mark(msg)) def _support_scope(self, scope: str) -> bool: """Check if the scope si supported @@ -244,24 +248,6 @@ def _distill_tensor( self._plan[name][scope] = plan return tensor - def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: - """Export the config for tool - - Parameters - ------- - config: dict - The source config. - folder: MSCDirectory - The export folder. - - Returns - ------- - config: dict - The exported config. - """ - - return {} - @property def distilled(self): return self._distilled @@ -270,6 +256,10 @@ def distilled(self): def tool_type(cls): return ToolType.DISTILLER + @classmethod + def exportable(cls): + return False + @msc_utils.register_tool class DefaultDistiller(BaseDistiller): diff --git a/python/tvm/contrib/msc/core/tools/execute.py b/python/tvm/contrib/msc/core/tools/execute.py index 22cb52a60b6d..2a47d755619e 100644 --- a/python/tvm/contrib/msc/core/tools/execute.py +++ b/python/tvm/contrib/msc/core/tools/execute.py @@ -70,8 +70,8 @@ def add_tool(tool: BaseTool, tool_type: str, tag: str = "main"): return tool -def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> BaseTool: - """Create tool by type, config and tag +def get_tool_cls(framework: str, tool_type: str, config: dict) -> BaseTool: + """Get the tool class Parameters ------- @@ -79,8 +79,6 @@ def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> The framework for implement tool_type: str The type of the tool prune| quantize| distill... - tag: str - The tag of the tool. config: dict The config of tool. """ @@ -90,7 +88,26 @@ def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> assert tool_cls, "Can not find tool class for {}:{} @ {}".format( tool_type, tool_style, framework ) - return add_tool(tool_cls(**config), tool_type, tag) + return tool_cls + + +def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> BaseTool: + """Create tool by type, config and tag + + Parameters + ------- + framework: str + The framework for implement + tool_type: str + The type of the tool prune| quantize| distill... + tag: str + The tag of the tool. + config: dict + The config of tool. + """ + + tool_cls = get_tool_cls(framework, tool_type, config) + return add_tool(tool_cls(tag, **config), tool_type, tag) def get_tool(tool_type: str, tag: str = "main") -> BaseTool: diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 9f20240cf218..90273e25416b 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -123,6 +123,7 @@ def _reset( The weights. """ + self._unpruned_tensors = {} self._meta_weights = weights graphs, weights = super()._reset(graphs, weights) if self._plan and self._enabled: @@ -423,7 +424,9 @@ def _is_pruned(tensor: MSCTensor, graph: MSCGraph) -> bool: pruned_tensors = {k: v for k, v in pruned_tensors.items() if _is_pruned(v, graph)} if self.on_debug(3, in_forward=False): - self._logger.debug(msc_utils.msg_block("Pruned Tensors", pruned_tensors)) + self._logger.debug( + msc_utils.msg_block(self.tool_mark("Pruned Tensors"), pruned_tensors) + ) if pruned_tensors: pruned_graph = _ffi_api.PruneWeights(graph, pruned_tensors) @@ -439,15 +442,12 @@ def _flatten_size(weights): # log compress rate if pruned_cnt > 0: new_size = _flatten_size(pruned_weights) - self._logger.info( - "Prune %d weights, compress to %.2f%% (%.4f M->%.4f M)", - pruned_cnt, - new_size * 100 / raw_size, - raw_size, - new_size, + msg = "Prune {} weights, compress to {:.2f}% ({:.4f} M->{:.4f} M)".format( + pruned_cnt, new_size * 100 / raw_size, raw_size, new_size ) else: - self._logger.info("No weights pruned, size %.4f M", raw_size) + msg = "No weights pruned, size {:.4f} M".format(raw_size) + self._logger.info(self.tool_mark(msg)) return pruned_graphs, pruned_weights def get_meta_data(self, name: str) -> np.ndarray: @@ -514,24 +514,6 @@ def finalize(self) -> dict: self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} return super().finalize() - def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: - """Export the config for tool - - Parameters - ------- - config: dict - The source config. - folder: MSCDirectory - The export folder. - - Returns - ------- - config: dict - The exported config. - """ - - return {} - @property def pruned(self): return len(self._plan) > 0 @@ -540,6 +522,10 @@ def pruned(self): def tool_type(cls): return ToolType.PRUNER + @classmethod + def exportable(cls): + return False + @msc_utils.register_tool class DefaultPruner(BasePruner): diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py index 3d706002d6c6..bb6567810c90 100644 --- a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py @@ -76,9 +76,8 @@ def calibrate(self) -> dict: self._plan[name] = {k: v for k, v in plan.items() if k not in ("calibrated")} self.change_stage(MSCStage.QUANTIZE) calib_type = "calibrate" if self._calibrated else "gather" - self._logger.info( - "Quantizer %s %d plan after %d batch", calib_type, len(new_plan), self._forward_cnt - ) + msg = "{} {} plan after {} batch".format(calib_type, len(new_plan), self._forward_cnt) + self._logger.info(self.tool_mark(msg)) self._forward_cnt = 0 return new_plan diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 7cd0742c0753..626ae312bcf4 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -21,7 +21,7 @@ import copy import logging from itertools import product -from typing import List, Iterable, Any, Tuple, Dict +from typing import List, Iterable, Any, Tuple, Dict, Union import numpy as np import tvm @@ -288,8 +288,10 @@ class BaseTool(object): Parameters ---------- + tag: str + The tag of tool. stage: str - The stage of tool + The stage of tool. plan_file: str The plan file path. strategys: list[dict] @@ -310,6 +312,7 @@ class BaseTool(object): def __init__( self, + tag: str, stage: str, plan_file: str, strategys: List[dict], @@ -320,6 +323,7 @@ def __init__( verbose_step: int = 50, logger: logging.Logger = None, ): + self._tag = tag self._stage = stage self._plan_file = plan_file if os.path.isfile(plan_file): @@ -334,7 +338,13 @@ def __init__( self._verbose_step = verbose_step self._logger = logger or msc_utils.get_global_logger() title = self.tool_mark("APPLY_PLAN" if self._plan else "MAKE_PLAN") - self._logger.info(msc_utils.msg_block(title, self.setup(), width=0)) + self._logger.info(msc_utils.msg_block(title, self.setup())) + + def __str__(self): + msg = "forward[{}] {} graphs, {} weights".format( + self._forward_cnt, len(self._graphs), len(self._weights) + ) + return self.tool_mark(msg) def setup(self) -> dict: """Setup the tool @@ -554,11 +564,10 @@ def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: The exported config. """ - config = msc_utils.copy_dict(config) plan_file = msc_utils.to_abs_path(config["plan_file"], msc_utils.get_config_dir()) if os.path.isfile(plan_file): - config["plan_file"] = folder.create_dir("tools").copy(plan_file) - return config + return {"plan_file": folder.create_dir("tools").copy(plan_file)} + return {} def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): """Save runner to cache @@ -755,8 +764,7 @@ def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> A t_mark += "." + scope cached_tensor = self._get_processed(name, consumer, t_mark) if cached_tensor is not None: - if msc_utils.is_array(cached_tensor): - self.debug_tensors(name, consumer, t_mark, {"cached": cached_tensor}) + self.debug_tensors(name, consumer, t_mark, {"cached": cached_tensor}) return cached_tensor process = self._get_tensor_cache(name, consumer, "process") if process is None: @@ -764,10 +772,20 @@ def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> A self._save_tensor_cache(name, consumer, "process", process) if not process: return tensor - new_tensor = self._process_tensor(tensor, name, consumer, scope, strategys) + if isinstance(tensor, dict): + new_tensor = self._process_tensor( + msc_utils.copy_dict(tensor), name, consumer, scope, strategys + ) + else: + new_tensor = self._process_tensor(tensor, name, consumer, scope, strategys) self._save_processed(name, consumer, new_tensor, t_mark) if msc_utils.is_array(tensor) and id(new_tensor) != id(tensor): - tensors = {"pre": tensor, "post": new_tensor, "diff": tensor - new_tensor} + tensors = {"org": tensor, "new": new_tensor, "dif": tensor - new_tensor} + self.debug_tensors(name, consumer, t_mark, tensors) + elif isinstance(tensor, dict) and len(tensor.get("processed", [])) != len( + new_tensor.get("processed", []) + ): + tensors = {"org": tensor, "new": new_tensor} self.debug_tensors(name, consumer, t_mark, tensors) return new_tensor @@ -1016,7 +1034,7 @@ def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool: return False return self._debug_level >= debug_level - def tool_mark(self, msg: Any) -> dict: + def tool_mark(self, msg: Any) -> str: """Mark the message with tool info Parameters @@ -1030,7 +1048,9 @@ def tool_mark(self, msg: Any) -> dict: The message with mark. """ - return "{}({} @ {}) {}".format(self.tool_type().upper(), self.framework(), self._stage, msg) + return "{}[{}]({} @ {}) {}".format( + self.tool_type().upper(), self._tag, self.framework(), self._stage, msg + ) def msg_mark(self, msg: Any, in_forward: bool = True) -> str: """Mark the message with debug info @@ -1048,11 +1068,12 @@ def msg_mark(self, msg: Any, in_forward: bool = True) -> str: The message with mark. """ - mark = "{}.G[{}]".format(self.tool_type().upper(), self._graph_id) + mark = "{}({} @ {}) G[{}]".format( + self.tool_type().upper(), self._tag, self._stage, self._graph_id + ) if in_forward: mark += ".F[{}]".format(self._forward_cnt) - mark += "({}) ".format(self._stage) - return mark + str(msg) + return mark + " " + str(msg) def debug_tensors( self, name: str, consumer: str, t_mark: str, tensors: Dict[str, Any], debug_level: int = 3 @@ -1074,10 +1095,18 @@ def debug_tensors( """ if self.on_debug(debug_level): + + def _t_info(tensor): + if msc_utils.is_array(tensor): + return msc_utils.inspect_array(tensor) + if isinstance(tensor, dict) and "processed" in tensor: + return "{}({} processed)".format( + self.find_tensor(name), len(tensor["processed"]) + ) + return str(tensor) + msg = "{}-{}({})".format(name, consumer, t_mark) - tensor_des = "\n ".join( - ["{:6s}:{}".format(k, msc_utils.inspect_array(v)) for k, v in tensors.items()] - ) + tensor_des = "\n ".join(["{:6s}:{}".format(k, _t_info(v)) for k, v in tensors.items()]) self._logger.debug("%s\n %s", self.msg_mark(msg), tensor_des) def _infer_graph_id(self, kwargs: dict) -> int: @@ -1136,7 +1165,7 @@ def get_tensors(self) -> Iterable[MSCTensor]: Returns ------- tensors: generator - The generator of nodes. + The generator of tensors. """ for graph in self._graphs: @@ -1149,7 +1178,7 @@ def get_tensor_ids(self) -> Iterable[MSCTensor]: Returns ------- tensors: generator - The generator of nodes. + The generator of tensor ids. """ for graph in self._graphs: @@ -1159,13 +1188,13 @@ def get_tensor_ids(self) -> Iterable[MSCTensor]: for weight in node.get_weights().values(): yield self.to_tensor_id(weight.name, node.name) - def find_tensor(self, name: str) -> MSCTensor: - """Find tensor by name. + def find_tensor(self, t_ref: Union[str, MSCTensor]) -> MSCTensor: + """Find tensor by tensor ref. Parameters ---------- - name: string - The name of the tensor. + t_ref: string| MSCTensor + The name of the tensor or tensor. Returns ------- @@ -1173,18 +1202,19 @@ def find_tensor(self, name: str) -> MSCTensor: The found tensor. """ + t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref for g in self._graphs: - if g.has_tensor(name): - return g.find_tensor(name) - raise Exception("Can not find tensor {} from {} graphs".format(name, len(self._graphs))) + if g.has_tensor(t_name): + return g.find_tensor(t_name) + raise Exception("Can not find tensor {} from {} graphs".format(t_name, len(self._graphs))) - def find_producer(self, name: str) -> MSCJoint: - """Find producer by tensor_name . + def find_producer(self, t_ref: Union[str, MSCTensor]) -> MSCJoint: + """Find producer by tensor ref. Parameters ---------- - name: string - The name of the tensor. + t_ref: string| MSCTensor + The name of the tensor or tensor. Returns ------- @@ -1192,20 +1222,21 @@ def find_producer(self, name: str) -> MSCJoint: The found prducer. """ + t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref for g in self._graphs: - if g.has_tensor(name): - return g.find_producer(name) + if g.has_tensor(t_name): + return g.find_producer(t_name) raise Exception( - "Can not find producer of {} from {} graphs".format(name, len(self._graphs)) + "Can not find producer of {} from {} graphs".format(t_name, len(self._graphs)) ) - def find_consumers(self, name: str) -> List[MSCJoint]: - """Find consumers by tensor_name. + def find_consumers(self, t_ref: Union[str, MSCTensor]) -> List[MSCJoint]: + """Find consumers by tensor ref. Parameters ---------- - name: string - The name of the tensor. + t_ref: string| MSCTensor + The name of the tensor or tensor. Returns ------- @@ -1213,11 +1244,12 @@ def find_consumers(self, name: str) -> List[MSCJoint]: The found consumers. """ + t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref for g in self._graphs: - if g.has_tensor(name): - return g.find_consumers(name) + if g.has_tensor(t_name): + return g.find_consumers(t_name) raise Exception( - "Can not find consumers of {} from {} graphs".format(name, len(self._graphs)) + "Can not find consumers of {} from {} graphs".format(t_name, len(self._graphs)) ) def get_data(self, name: str) -> np.ndarray: @@ -1383,6 +1415,14 @@ def framework(cls): def tool_style(cls): return "base" + @classmethod + def apply_once(cls): + return False + + @classmethod + def exportable(cls): + return True + class WeightTool(BaseTool): """Basic tool with weight graphs""" @@ -1433,9 +1473,8 @@ def _reset( _ffi_api.WeightGraph(graph, self._main_wtypes, self._relation_wtypes) for graph in graphs ] - self._logger.debug( - "%s build %d weight graphs", self.tool_type(), len(self._weight_graphs) - ) + msg = "build {} weight graphs".format(len(self._weight_graphs)) + self._logger.debug(self.tool_mark(msg)) if self.on_debug(2, in_forward=False): weight_graphs = {g.name: g.inspect() for g in self._weight_graphs} title = self.tool_mark("WEIGHT_GRAPHS({})".format(len(weight_graphs))) @@ -1472,12 +1511,8 @@ def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): self._weight_graphs = [ WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"] ] - self._logger.debug( - "%s load %d weight graphs from %s", - self.tool_type(), - len(self._weight_graphs), - cache_dir, - ) + msg = "load {} weight graphs from {}".format(len(self._weight_graphs), cache_dir) + self._logger.debug(self.tool_mark(msg)) def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: """Save runner to cache @@ -1511,6 +1546,7 @@ def visualize(self, visual_dir: msc_utils.MSCDirectory): for w_graph in self._weight_graphs: w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt")) + super().visualize(visual_dir) def get_w_nodes(self) -> Iterable[WeightJoint]: """Get all the weight nodes in the weight_graphs. diff --git a/python/tvm/contrib/msc/core/tools/track/configer.py b/python/tvm/contrib/msc/core/tools/track/configer.py index ef9c18c3f72e..82ab634e92b5 100644 --- a/python/tvm/contrib/msc/core/tools/track/configer.py +++ b/python/tvm/contrib/msc/core/tools/track/configer.py @@ -25,19 +25,6 @@ class TrackConfiger(ToolConfiger): """Configer for track""" - def config_apply(self) -> dict: - """Get the config fro apply - - Returns - ------- - config: dict - The apply config. - """ - - config = super().config_apply() - config.update({"apply_once": True}) - return config - @classmethod def tool_type(cls): return ToolType.TRACKER diff --git a/python/tvm/contrib/msc/core/tools/track/tracker.py b/python/tvm/contrib/msc/core/tools/track/tracker.py index 510153a5c4e5..3c36d80bd200 100644 --- a/python/tvm/contrib/msc/core/tools/track/tracker.py +++ b/python/tvm/contrib/msc/core/tools/track/tracker.py @@ -87,7 +87,7 @@ def _execute_after_forward(self, output: Any) -> Any: msg += "; ".join( ["{}: {}/{}".format(s, i["passed"], i["total"]) for s, i in passed.items()] ) - self._logger.info(msg) + self._logger.info(self.msg_mark(msg, in_forward=False)) else: self._tracked = True return output @@ -184,6 +184,10 @@ def tracked(self): def tool_type(cls): return ToolType.TRACKER + @classmethod + def apply_once(cls): + return True + @msc_utils.register_tool class DefaultTracker(BaseTracker): diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index fe8882f7f296..c6d7113f44f5 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -22,6 +22,7 @@ import tvm from tvm.relax.transform import _ffi_api as relax_api from tvm.relay.transform import _ffi_api as relay_api +from tvm.contrib.msc.core import utils as msc_utils def SetExprName( @@ -49,12 +50,8 @@ def SetExprName( """ if as_relax: - - def _get_name(name): - return name.replace("/", "_").replace(".", "_").strip("_") - var_names = var_names or {} - var_names = {k: _get_name(v) for k, v in var_names.items()} + var_names = {k: msc_utils.legalize_expr_name(v) for k, v in var_names.items()} return relax_api.SetRelaxExprName(entry_name, target, var_names) # type: ignore return relay_api.SetRelayExprName(entry_name) # type: ignore diff --git a/python/tvm/contrib/msc/core/utils/arguments.py b/python/tvm/contrib/msc/core/utils/arguments.py index a1b8e918e8ac..f09c411648e3 100644 --- a/python/tvm/contrib/msc/core/utils/arguments.py +++ b/python/tvm/contrib/msc/core/utils/arguments.py @@ -77,7 +77,7 @@ def save_dict(dict_obj: Any, path: str, indent: int = 2) -> str: return path -def update_dict(src_dict: dict, new_dict: dict, soft_update: bool = True) -> dict: +def update_dict(src_dict: dict, new_dict: dict, soft_update: bool = False) -> dict: """Update src_dict with new_dict. Parameters @@ -95,14 +95,18 @@ def update_dict(src_dict: dict, new_dict: dict, soft_update: bool = True) -> dic The updated dict. """ + if not new_dict: + return src_dict assert isinstance(src_dict, dict) and isinstance( new_dict, dict ), "update_dict only support dict, get src {} and new {}".format(type(src_dict), type(new_dict)) for k, v in new_dict.items(): - if isinstance(v, dict): + if not src_dict.get(k): + src_dict[k] = v + elif isinstance(v, dict): v = update_dict(src_dict.get(k, {}), v, soft_update) src_dict[k] = v - elif not soft_update or k not in src_dict: + elif not soft_update: src_dict[k] = v return src_dict diff --git a/python/tvm/contrib/msc/core/utils/dataset.py b/python/tvm/contrib/msc/core/utils/dataset.py index 3da57abb4384..e6461d107941 100644 --- a/python/tvm/contrib/msc/core/utils/dataset.py +++ b/python/tvm/contrib/msc/core/utils/dataset.py @@ -23,8 +23,45 @@ from typing import List, Union, Dict, Any import numpy as np +import tvm from .arguments import load_dict -from .info import cast_array +from .info import cast_array, is_array + + +def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], style="dict") -> Any: + """Format datas to style format + + Parameters + ---------- + datas: + The source datas. + names: list + The data names. + style: str + The style of format, dict|list. + + Returns + ------- + datas: + The formated datas. + """ + + if isinstance(datas, (list, tuple, tvm.ir.container.Array)): + assert len(datas) == len(names), "datas({}) mismatch with names {}".format( + len(datas), names + ) + datas = dict(zip(names, datas)) + if not isinstance(datas, dict): + assert len(names) == 1, "Expect 1 names, get " + str(names) + datas = {names[0]: datas} + elif len(datas) > len(names): + datas = {n: datas[n] for n in datas} + assert all(is_array(d) for d in datas.values()), "Expected all tensors as array like" + if style == "dict": + return datas + if style == "list": + return [datas[n] for n in names] + raise TypeError("Unexpected style " + str(style)) class BaseDataLoader(object): @@ -168,6 +205,10 @@ def _data_info(self, name: str) -> dict: raise NotImplementedError("_data_info is not implemented for BaseDataLoader") + @property + def num_datas(self): + return self.info["num_datas"] + @property def folder(self): return self._folder @@ -302,12 +343,12 @@ def __enter__(self): return self def __exit__(self, exception_type, exception_value, traceback): - self._info["num_datas"] = self._current self.finalize() def finalize(self): """Finalize the saver""" + self._info["num_datas"] = self._current with open(os.path.join(self._folder, "datas_info.json"), "w") as f: f.write(json.dumps(self._info, indent=2)) @@ -375,6 +416,12 @@ def _save_batch(self, *args, **kwargs) -> dict: raise NotImplementedError("_save_batch is not implemented for BaseDataSaver") + @property + def num_datas(self): + if self.is_finalized(): + return self.info["num_datas"] + return self._current + @property def folder(self): return self._folder @@ -424,13 +471,19 @@ def setup(self, options: dict): assert "input_names" in options, "input_names should be given to setup IODataSaver" self._input_names = options["input_names"] self._output_names = options.get("output_names", []) - return {"inputs": {}, "outputs": {}, "num_datas": 0} + return { + "inputs": {}, + "outputs": {}, + "num_datas": 0, + "input_names": self._input_names, + "output_names": self._output_names, + } def finalize(self): """Finalize the saver""" super().finalize() - if "inputs" not in self._info: + if any(n not in self._info["inputs"] for n in self._input_names): return with open(os.path.join(self._folder, "datas_info.txt"), "w") as f: for name in self._input_names: @@ -475,29 +528,11 @@ def save_batch( The current batch cnt. """ - if isinstance(inputs, dict): - assert set(inputs.keys()) == set( - self._input_names - ), "Input names mismatch {} with {}".format(inputs.keys(), self._input_names) - elif isinstance(inputs, (tuple, list)): - assert len(inputs) == len( - self._input_names - ), "Inputs size {} mismatch with input_names {}".format(len(inputs), self._input_names) - inputs = dict(zip(self._input_names, inputs)) + inputs = format_datas(inputs, self._input_names, style="dict") for name, data in inputs.items(): self._save_data(self._current, name, data, "inputs") - if outputs: - if isinstance(outputs, dict): - assert set(outputs.keys()) == set( - self._output_names - ), "Output names mismatch {} with {}".format(outputs.keys(), self._output_names) - elif isinstance(outputs, (tuple, list)): - assert len(outputs) == len( - self._output_names - ), "Outputs size {} mismatch with input_names {}".format( - len(outputs), self._output_names - ) - outputs = dict(zip(self._output_names, outputs)) + if outputs is not None: + outputs = format_datas(outputs, self._output_names, style="dict") for name, data in outputs.items(): self._save_data(self._current, name, data, "outputs") self._current += 1 @@ -512,7 +547,9 @@ def is_io_dataset(folder: str) -> bool: if not os.path.isfile(os.path.join(folder, "datas_info.json")): return False data_info = load_dict(os.path.join(folder, "datas_info.json")) - return "inputs" in data_info and "outputs" in data_info + if any(key not in data_info for key in ["inputs", "outputs", "num_datas"]): + return False + return data_info["num_datas"] > 0 def is_simple_dataset(folder: str) -> bool: @@ -521,4 +558,6 @@ def is_simple_dataset(folder: str) -> bool: if not os.path.isfile(os.path.join(folder, "datas_info.json")): return False data_info = load_dict(os.path.join(folder, "datas_info.json")) - return "datas" in data_info + if any(key not in data_info for key in ["datas", "num_datas"]): + return False + return data_info["num_datas"] > 0 diff --git a/python/tvm/contrib/msc/core/utils/expr.py b/python/tvm/contrib/msc/core/utils/expr.py index b18e88888723..cc87976b801e 100644 --- a/python/tvm/contrib/msc/core/utils/expr.py +++ b/python/tvm/contrib/msc/core/utils/expr.py @@ -17,7 +17,7 @@ """tvm.contrib.msc.core.utils.expr""" import copy -from typing import Dict +from typing import Dict, List import tvm from tvm import relax @@ -25,6 +25,30 @@ from tvm.contrib.msc.core import _ffi_api +def legalize_expr_name(name: str, symbols: List[str] = None, dst: str = "_") -> str: + """Legalize expr name + + Parameters + ---------- + name: str + The source name. + symbols: list + The symbols to be replaced. + dst: str + The symbol for replace. + + Returns + ------- + name: str + The legialized name. + """ + + symbols = symbols or ["::", "/", "."] + for sym in symbols: + name = name.replace(sym, dst) + return name.strip(dst) + + def get_expr_name(expr: relax.Expr) -> str: """Get name hint for expr @@ -46,11 +70,11 @@ def get_expr_name(expr: relax.Expr) -> str: def make_span(kwargs: Dict[str, str], span: relax.Span = None) -> relax.Span: - """Change name to span + """Make a span from kwargs Parameters ---------- - kwargs: dict + kwargs: dict The attrs in span. span: relax.Span The source span. diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index b1eb8fa8bfa1..6b5400a40535 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -110,27 +110,6 @@ def __exit__(self, exception_type, exception_value, traceback): def __del__(self): self.clean_up() - def finalize(self): - """Finalize the directory""" - - if not os.path.isdir(self._path): - return self._path - - def _remove_empty(path: str): - sub_paths = [os.path.join(path, f) for f in os.listdir(path)] - for s_path in sub_paths: - if not os.path.isdir(s_path): - continue - if len(os.listdir(s_path)) == 0: - shutil.rmtree(s_path) - else: - _remove_empty(s_path) - if len(os.listdir(path)) == 0: - shutil.rmtree(path) - return path - - return _remove_empty(self._path) - def clean_up(self): """Clean up the dir""" @@ -187,7 +166,7 @@ def move(self, src_path: str, dst_path: str = None): os.rename(src_path, dst_path) return dst_path - def copy(self, src_path: str, dst_path: str = None): + def copy(self, src_path: str, dst_path: str = None) -> str: """Copy a file to another folder Parameters @@ -203,6 +182,8 @@ def copy(self, src_path: str, dst_path: str = None): The abs file path. """ + if not src_path: + return None if src_path != os.path.abspath(src_path): src_path = os.path.join(self.relpath(src_path)) assert os.path.exists(src_path), "Source path {} not exist".format(src_path) @@ -214,10 +195,26 @@ def copy(self, src_path: str, dst_path: str = None): shutil.copy2(src_path, dst_path) else: if os.path.isdir(dst_path): - os.remove(dst_path) + shutil.rmtree(dst_path) shutil.copytree(src_path, dst_path) return dst_path + def copy_to(self, dst_path: str): + """Copy dir to another folder + + Parameters + ---------- + dst_path: str + The target folder path. + + Returns + ------- + path: str + The abs file path. + """ + + return self.copy(self._path, dst_path) + def create_dir(self, name: str, keep_history: bool = True, cleanup: bool = False) -> Any: """Add a dir under the folder @@ -283,6 +280,27 @@ def listdir(self, as_abs: bool = False) -> List[str]: return [os.path.join(self._path, f) for f in os.listdir(self._path)] return os.listdir(self._path) + def finalize(self): + """Finalize the directory""" + + if not os.path.isdir(self._path): + return self._path + + def _remove_empty(path: str): + sub_paths = [os.path.join(path, f) for f in os.listdir(path)] + for s_path in sub_paths: + if not os.path.isdir(s_path): + continue + if len(os.listdir(s_path)) == 0: + shutil.rmtree(s_path) + else: + _remove_empty(s_path) + if len(os.listdir(path)) == 0: + shutil.rmtree(path) + return path + + return _remove_empty(self._path) + def destory(self): """Destory the dir.""" @@ -358,6 +376,38 @@ def get_workspace() -> MSCDirectory: return workspace +class ChangeWorkspace(object): + """Change the workspace + + Parameters + ---------- + new_workspace: MSCDirectory + The new workspace. + """ + + def __init__(self, new_workspace: MSCDirectory): + self._src_workspace = get_workspace() + self._new_workspace = new_workspace + + def __enter__(self): + set_workspace(self._new_workspace) + + def __exit__(self, exception_type, exception_value, traceback): + set_workspace(self._src_workspace) + + +def change_workspace(new_workspace: MSCDirectory): + """Change the workspace + + Parameters + ---------- + new_workspace: MSCDirectory + The new workspace. + """ + + return ChangeWorkspace(new_workspace) + + def get_workspace_subdir( name: str = None, keep_history: bool = True, cleanup: bool = False ) -> MSCDirectory: @@ -405,13 +455,50 @@ def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = T return root_dir.relpath(path, keep_history) -def pack_folder(path: str, style="tar.gz"): +def pack_folder(path: str, dst: str = None, style="tar.gz"): """Pack the folder Parameters ---------- path: str The path of the folder. + dst: str + The pakced path. + style: str + The pack style. + + Returns + ------- + pack_path: str + The packed path. + """ + + dst = dst or path + "." + style + root = os.path.dirname(path) + if style == "tar.gz": + cmd = "tar --exculde={0} -zcvf {0} {1} && rm -rf {1}".format(dst, path) + else: + raise NotImplementedError("Pack style {} is not supported".format(style)) + if root: + with msc_dir(root): + retcode = subprocess.call(cmd, shell=True) + else: + retcode = subprocess.call(cmd, shell=True) + assert retcode == 0, "Failed to pack the folder {}->{}({}): {}".format( + path, dst, style, retcode + ) + return dst + + +def unpack_folder(path: str, dst: str = None, style="tar.gz"): + """UnPack the folder + + Parameters + ---------- + path: str + The path of the folder. + dst: str + The pakced path. style: str The pack style. @@ -421,9 +508,10 @@ def pack_folder(path: str, style="tar.gz"): The packed path. """ + dst = dst or path.split(".")[0] root = os.path.dirname(path) if style == "tar.gz": - cmd = "tar --exculde={0}.tar.gz -zcvf {0}.tar.gz {0} && rm -rf {0}".format(path) + cmd = "tar -zxvf {} {}".format(path, dst) else: raise NotImplementedError("Pack style {} is not supported".format(style)) if root: @@ -431,8 +519,10 @@ def pack_folder(path: str, style="tar.gz"): retcode = subprocess.call(cmd, shell=True) else: retcode = subprocess.call(cmd, shell=True) - assert retcode == 0, "Failed to pack the folder {}({}): {}".format(path, style, retcode) - return path + "." + style + assert retcode == 0, "Failed to unpack the folder {}->{}({}): {}".format( + path, dst, style, retcode + ) + return dst get_build_dir = partial(get_workspace_subdir, name="Build") @@ -440,6 +530,7 @@ def pack_folder(path: str, style="tar.gz"): get_config_dir = partial(get_workspace_subdir, name="Config") get_dataset_dir = partial(get_workspace_subdir, name="Dataset") get_gym_dir = partial(get_workspace_subdir, name="Gym") +get_info_dir = partial(get_workspace_subdir, name="Info") get_output_dir = partial(get_workspace_subdir, name="Output") get_visual_dir = partial(get_workspace_subdir, name="Visual") get_weights_dir = partial(get_workspace_subdir, name="Weights") diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 26afedfa282d..4fea45f8fab2 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -72,14 +72,11 @@ def abstract(self) -> str: """Get abstract describe of the data""" data = self._to_ndarray() + prefix = "[{},{}]".format(";".join([str(s) for s in data.shape]), data.dtype.name) if data.size < 10: - return ",".join([str(i) for i in data.flatten()]) - return "[{},{}] Max {:g}, Min {:g}, Avg {:g}".format( - ";".join([str(s) for s in data.shape]), - data.dtype.name, - data.max(), - data.min(), - data.sum() / data.size, + return "{} {}".format(prefix, ",".join([str(i) for i in data.flatten()])) + return "{} Max {:g}, Min {:g}, Avg {:g}".format( + prefix, data.max(), data.min(), data.sum() / data.size ) def _to_ndarray(self) -> np.ndarray: @@ -299,23 +296,26 @@ def inspect_array(data: Any, as_str: bool = True) -> Union[Dict[str, Any], str]: def compare_arrays( - golden: Dict[str, np.ndarray], - datas: Dict[str, np.ndarray], + golden: Dict[str, Any], + datas: Dict[str, Any], atol: float = 1e-2, rtol: float = 1e-2, + report_detail: bool = False, ) -> dict: """Compare elements in array Parameters ---------- - golden: dict + golden: dict The golden datas. - datas: dict + datas: dict The datas to be compared. atol: float The atol for compare. rtol: float The rtol for compare. + report_detail: bool + Whether to report detail Returns ------- @@ -326,27 +326,53 @@ def compare_arrays( assert golden.keys() == datas.keys(), "golden {} and datas {} mismatch".format( golden.keys(), datas.keys() ) + golden = {k: cast_array(v) for k, v in golden.items()} + datas = {k: cast_array(v) for k, v in datas.items()} report = {"total": 0, "passed": 0, "info": {}} + + def _add_report(name: str, gol: Any, data: Any, passed: bool): + diff = MSCArray(gol - data) + if passed: + if report_detail: + report["info"][name] = { + "data": MSCArray(data).abstract(), + "d_pass": diff.abstract(), + } + else: + report["info"][name] = "d_pass: {}".format(diff.abstract()) + report["passed"] += 1 + else: + if report_detail: + report["info"][name] = { + "gold": MSCArray(gol).abstract(), + "data": MSCArray(data).abstract(), + "d_fail": diff.abstract(), + } + else: + report["info"][name] = "d_fail: {}".format(diff.abstract()) + for name, gol in golden.items(): report["total"] += 1 data = datas[name] if list(gol.shape) != list(data.shape): - report["info"][name] = " shape mismatch [G]{} vs [D]{}".format( + report["info"][name] = "fail: shape mismatch [G]{} vs [D]{}".format( gol.shape, data.shape ) continue if gol.dtype != data.dtype: - report["info"][name] = " dtype mismatch [G]{} vs [D]{}".format( + report["info"][name] = "fail: dtype mismatch [G]{} vs [D]{}".format( gol.dtype, data.dtype ) continue - diff = MSCArray(gol - data) + if gol.dtype.name in ("int32", "int64"): + passed = np.abs(gol - data), max() == 0 + _add_report(name, gol, data, passed) + continue try: np.testing.assert_allclose(gol, data, rtol=rtol, atol=atol, verbose=False) - report["info"][name] = " diff {}".format(diff.abstract()) - report["passed"] += 1 + _add_report(name, gol, data, True) except: # pylint: disable=bare-except - report["info"][name] = " diff {}".format(diff.abstract()) + _add_report(name, gol, data, False) return report diff --git a/python/tvm/contrib/msc/core/utils/log.py b/python/tvm/contrib/msc/core/utils/log.py index 1422ad9a1bd0..8847d1948dbc 100644 --- a/python/tvm/contrib/msc/core/utils/log.py +++ b/python/tvm/contrib/msc/core/utils/log.py @@ -137,9 +137,50 @@ def get_global_logger() -> logging.Logger: return MSCMap.get(MSCKey.GLOBALE_LOGGER) +def get_log_file(logger: logging.Logger) -> str: + """Get the log file from logger + + Parameters + ---------- + logger: logging.Logger + The logger. + + Returns + ------- + log_file: str + The log file. + """ + + for log_h in logger.handlers: + if isinstance(log_h, logging.FileHandler): + return log_h.baseFilename + return None + + def remove_loggers(): """Remove the logger handlers""" logger = MSCMap.get(MSCKey.GLOBALE_LOGGER) if logger: logger.handlers.clear() + + +def split_line(msg: str, symbol: str = "#", width: int = 100) -> str: + """Mark message to split line + + Parameters + ---------- + msg: str + The message. + symbol: str + The split symbol. + width: int + The line width. + + Returns + ------- + split_line: str + The split line with message. + """ + + return "\n{0}{1}{0}".format(20 * symbol, msg.center(width - 40)) diff --git a/python/tvm/contrib/msc/core/utils/message.py b/python/tvm/contrib/msc/core/utils/message.py index d7b64ee22ea3..57fce501fc0b 100644 --- a/python/tvm/contrib/msc/core/utils/message.py +++ b/python/tvm/contrib/msc/core/utils/message.py @@ -21,7 +21,7 @@ from typing import List, Tuple from .arguments import dump_dict, map_dict -from .log import get_global_logger +from .log import get_global_logger, split_line from .namespace import MSCMap, MSCKey @@ -69,7 +69,7 @@ def time_stamp(stage: str, log_stage: bool = True, logger: logging.Logger = None stage: str The stage name. log_stage: bool - Whether to log the stage + Whether to log the stage. logger: logging.Logger The logger. """ @@ -82,14 +82,14 @@ def time_stamp(stage: str, log_stage: bool = True, logger: logging.Logger = None if log_stage: last_stage = MSCMap.get(MSCKey.MSC_STAGE) if last_stage: - end_msg = "[MSC] End {}".format(last_stage.upper()) - logger.info("\n{0} {1} {0}\n".format("#" * 20, end_msg.center(40))) - start_msg = "[MSC] Start {}".format(stage.upper()) - logger.info("\n{0} {1} {0}".format("#" * 20, start_msg.center(40))) + end_msg = "End {}".format(last_stage.upper()) + logger.info("%s\n", split_line(end_msg)) + start_msg = "Start {}".format(stage.upper()) + logger.info(split_line(start_msg)) MSCMap.set(MSCKey.MSC_STAGE, stage.upper()) elif log_stage: start_msg = "Start {}".format(stage) - logger.debug("\n{0} {1} {0}".format("+" * 20, start_msg.center(40))) + logger.debug(split_line(start_msg, "+")) def get_duration() -> dict: @@ -163,7 +163,7 @@ def msg_block(title: str, msg: str, width: int = 100, symbol: str = "-"): if isinstance(msg, dict): msg = dump_dict(msg, "table:" + str(width)) - return "\n{0} {1} {0}\n{2}".format(symbol * 20, title.center(40), msg) + return "{}\n{}".format(split_line(title, symbol), msg) def current_stage(): diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py index 2fff6d1c75dc..2297b3e82523 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py @@ -159,22 +159,6 @@ def _call_runnable( feed_dict = {i + ":0": msc_utils.cast_array(inputs[i]) for i in input_names} return runnable.run(self._tf_outputs, feed_dict) - def _device_enabled(self, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - device_protos = device_lib.list_local_devices() - return any(dev.device_type == "GPU" for dev in device_protos) - return False - @property def codegen_func(self): return to_tensorflow @@ -217,40 +201,6 @@ def load_native(cls, model: Any, config: dict) -> Tuple[tf_v1.GraphDef, str, boo device = "cpu" return native_model, device, False - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = ModelRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage == MSCStage.PARSE: - config["parse"]["parser"] = from_tensorflow - parse_config = config["parse"].get("parse_config", {}) - parse_config.update( - { - "shape_dict": {i[0]: i[1] for i in config["inputs"]}, - "outputs": config["outputs"], - } - ) - config["parse"]["parse_config"] = parse_config - return config - @classmethod def run_native( cls, @@ -302,3 +252,54 @@ def run_native( avg_time = -1 outputs = dict(zip(output_names, outputs)) return outputs, avg_time + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = ModelRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage == MSCStage.PARSE: + config["parse"]["parser"] = from_tensorflow + parse_config = config["parse"].get("parse_config", {}) + parse_config.update( + { + "shape_dict": {i[0]: i[1] for i in config["inputs"]}, + "outputs": config["outputs"], + } + ) + config["parse"]["parse_config"] = parse_config + return config + + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + if device == "cpu": + return True + if device.startswith("cuda"): + device_protos = device_lib.list_local_devices() + return any(dev.device_type == "GPU" for dev in device_protos) + return False diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index e38c5d7482a4..3dd392c7d8ac 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -117,12 +117,13 @@ def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: The runnable info. """ - info = super().export_runnable(folder) - info["engines"] = {} - for graph in self._graphs: + def _get_engine(graph: MSCGraph) -> str: engine_file = msc_utils.get_output_dir().relpath(graph.name + ".trt") assert os.path.isfile(engine_file), "Missing engine file " + engine_file - info["engines"] = folder.copy(engine_file) + return engine_file + + info = super().export_runnable(folder) + info["engines"] = {g.name: _get_engine(g) for g in self._graphs} return info @classmethod diff --git a/python/tvm/contrib/msc/framework/torch/runtime/__init__.py b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py index 83a1830b29b6..1871e4847a25 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/__init__.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py @@ -17,3 +17,4 @@ """tvm.contrib.msc.framework.torch.runtime""" from .runner import * +from .jit import * diff --git a/python/tvm/contrib/msc/framework/torch/runtime/jit.py b/python/tvm/contrib/msc/framework/torch/runtime/jit.py new file mode 100644 index 000000000000..aefa4b459148 --- /dev/null +++ b/python/tvm/contrib/msc/framework/torch/runtime/jit.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""tvm.contrib.msc.framework.torch.runtime.jit_model""" + +from typing import Any, List, Tuple, Dict +from functools import partial + +import torch +from torch import fx +from torch import _dynamo as dynamo + +from tvm.contrib.msc.core.runtime import BaseJIT +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils +from .runner import TorchRunner + + +class TorchJIT(BaseJIT): + """JIT of Torch""" + + def _call_jit(self, inputs: Dict[str, Any]) -> Any: + """Run the jit model + + Parameters + ---------- + inputs: + The inputs of model. + """ + + torch_inputs = [ + msc_utils.cast_array(inputs[i], MSCFramework.TORCH, self._device) for i in self._inputs + ] + return self._jit_model(*torch_inputs) + + def _build(self, model: Any) -> Any: + """Build the jit model + + Parameters + ---------- + model: + The model. + + Returns + ------- + jit_model: + The jit model. + """ + + # pylint: disable=unused-argument + def _compile(graph_module: fx.GraphModule, example_inputs): + graph_module = graph_module.train() if self._training else graph_module.eval() + name = "jit_" + str(len(self._runner_ctxs)) + self._runner_ctxs[name] = {"model": graph_module} + return partial(self._redirect_run, runner_name=name) + + dynamo.reset() + return torch.compile(self._model, backend=_compile) + + def _to_msc_inputs(self, runner_name: str, *args, **kwargs) -> List[Tuple[str, Any]]: + """Change inputs to msc format + + Parameters + ---------- + runner_name: str + The runner name. + args: + The arguments. + kwargs: + The kwargs. + + Returns + ------- + inputs: + The msc format inputs. + """ + + assert not kwargs, "TorchJIT do not support kwargs" + return [("input_" + str(i), d) for i, d in enumerate(args)] + + def _from_msc_outputs(self, runner_name: str, outputs: List[Tuple[str, Any]]) -> Any: + """Change inputs from msc format + + Parameters + ---------- + runner_name: str + The runner name. + outputs: list<(str, tensor)> + The msc format outputs. + + Returns + ------- + outputs: + The framework outputs. + """ + + torch_outputs = [o[1] for o in outputs] + unpack_outputs = self.get_runner_ctx(runner_name).get("unpack_outputs", True) + if not unpack_outputs: + return torch_outputs + return torch_outputs[0] if len(torch_outputs) == 1 else torch_outputs + + def _run_ctx(self, runner_ctx: dict, inputs: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: + """Forward by runner context + + Parameters + ---------- + runner_ctx: dict + The runner context + inputs: list<(str, tensor)> + The inputs. + + Returns + ------- + outputs: list<(str, tensor)> + The outputs. + """ + + if "runner" in runner_ctx: + runner = runner_ctx["runner"] + if runner.framework == MSCFramework.TORCH: + outputs = runner.run({i[0]: i[1] for i in inputs}, ret_type="native") + else: + outputs = runner.run({i[0]: i[1] for i in inputs}, ret_type="list") + outputs = [ + msc_utils.cast_array(o, MSCFramework.TORCH, runner.device) for o in outputs + ] + else: + torch_inputs = [i[1] for i in inputs] + outputs = runner_ctx["model"](*torch_inputs) + if isinstance(outputs, (list, tuple)) and len(outputs) == 1: + runner_ctx["unpack_outputs"] = False + if isinstance(outputs, (list, tuple)): + return [("output_" + str(i), o) for i, o in enumerate(outputs)] + return [("output", outputs)] + + @property + def framework(self): + return MSCFramework.TORCH + + @classmethod + def load_native(cls, model: Any, config: dict) -> Tuple[torch.nn.Module, str, bool]: + """Load the native model + + Parameters + ------- + model: + The native model. + config: dict + The config for pipeline. + + Returns + ------- + model: torch.nn.Module + The loaded native model. + device: str + The device of the model. + training: + Whether the model is for training. + """ + + return TorchRunner.load_native(model, config) + + @classmethod + def dump_nativate( + cls, model: torch.nn.Module, folder: msc_utils.MSCDirectory, dump_config: dict = None + ) -> str: + """Dump the nativate model + + Parameters + ------- + model: torch.nn.Module + The runnable model. + folder: MSCDirectory + The export folder. + dump_config: dict + The dump config. + + Returns + ------- + export_path: str + The exported path + """ + + dump_config = dump_config or {} + assert dump_config.get("mode", "fx") == "fx", "TorchJIT only support dump nativate as fx" + return TorchRunner.dump_nativate(model, folder, dump_config) + + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + return TorchRunner.support_device(device) diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py index 67812e7e5219..27773cecdc6d 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py @@ -17,7 +17,6 @@ # pylint: disable=unused-import """tvm.contrib.msc.framework.torch.runtime.runner""" -import os import time from typing import Dict, List, Union, Tuple, Any import numpy as np @@ -130,21 +129,6 @@ def _get_runtime_params(self) -> Dict[str, tvm.nd.array]: ) return params - def _device_enabled(self, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - return torch.cuda.is_available() - return False - @property def codegen_func(self): return to_torch @@ -174,8 +158,8 @@ def load_native(cls, model: Any, config: dict) -> Tuple[torch.nn.Module, str, bo Whether the model is for training. """ - if isinstance(model, dict) and "model" in model: - native_model = msc_utils.load_callable(model["model"]) + if isinstance(model, str) and ":" in model: + native_model = msc_utils.load_callable(model) elif isinstance(model, torch.nn.Module): native_model = model else: @@ -193,42 +177,6 @@ def load_native(cls, model: Any, config: dict) -> Tuple[torch.nn.Module, str, bo device = "cpu" return native_model, device, model.training - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = ModelRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage == MSCStage.PARSE: - config["parse"]["parser"] = from_torch - parse_config = config["parse"].get("parse_config", {}) - parse_config.update( - { - "input_info": [ - [i[1], "float" if len(i) < 2 else i[2]] for i in config["inputs"] - ], - "input_names": [i[0] for i in config["inputs"]], - } - ) - config["parse"]["parse_config"] = parse_config - return config - @classmethod def run_native( cls, @@ -302,7 +250,12 @@ def _run_once(): return outputs, avg_time @classmethod - def dump_nativate(cls, model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> str: + def dump_nativate( + cls, + model: torch.nn.Module, + folder: msc_utils.MSCDirectory, + dump_config: dict = None, + ) -> str: """Dump the nativate model Parameters @@ -311,6 +264,8 @@ def dump_nativate(cls, model: torch.nn.Module, folder: msc_utils.MSCDirectory) - The runnable model. folder: MSCDirectory The export folder. + dump_config: dict + The dump config. Returns ------- @@ -318,7 +273,74 @@ def dump_nativate(cls, model: torch.nn.Module, folder: msc_utils.MSCDirectory) - The exported path """ - graph_model = torch.fx.symbolic_trace(model) - exp_path = folder.create_dir("model") - graph_model.to_folder(exp_path.path, "native_model") - return {"model": exp_path.relpath("module.py") + ":native_model"} + dump_config = dump_config or {} + mode = dump_config.get("mode", "fx") + if mode == "fx": + graph_model = torch.fx.symbolic_trace(model) + exp_path = folder.create_dir("model") + graph_model.to_folder(exp_path.path, "native_model") + return exp_path.relpath("module.py") + ":native_model" + if mode == "pt": + assert "inputs" in dump_config, "inputs are needed for torch.jit.trace" + parameters = list(model.parameters()) + device = parameters[0].device if parameters else torch.device("cpu") + datas = [np.random.rand(i[1]).astype(i[2]) for i in dump_config["inputs"]] + torch_datas = [torch.from_numpy(d).to(device) for d in datas] + with torch.no_grad(): + scriptde_model = torch.jit.trace(model, tuple(torch_datas)).eval() + exp_path = folder.relpath("model.pt") + torch.jit.save(scriptde_model, exp_path) + return exp_path + raise TypeError("Unexpeceted dump mode " + str(mode)) + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = ModelRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage == MSCStage.PARSE: + config["parse"]["parser"] = from_torch + parse_config = config["parse"].get("parse_config", {}) + parse_config.update( + { + "input_info": [ + [i[1], "float" if len(i) < 2 else i[2]] for i in config["inputs"] + ], + "input_names": [i[0] for i in config["inputs"]], + } + ) + config["parse"]["parse_config"] = parse_config + return config + + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + if device == "cpu": + return True + if device.startswith("cuda"): + return torch.cuda.is_available() + return False diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py index b4f052f08dfe..642a88c93386 100644 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py @@ -17,6 +17,7 @@ # pylint: disable=unused-import """tvm.contrib.msc.framework.runtime.tvm.runner""" +import os import time from typing import Dict, List, Union, Any, Tuple import numpy as np @@ -139,22 +140,6 @@ def _call_runnable( ] return runnable(*tvm_inputs) - def _device_enabled(self, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - return tvm.cuda(dev_id).exist - return False - def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: """Export the runnable @@ -169,9 +154,14 @@ def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: The runnable info. """ - export_path = folder.relpath("model.so") - self._executable.export_library(export_path) - return {"model": export_path} + export_lib = folder.relpath("lib.so") + self._executable.export_library(export_lib) + return { + "lib": export_lib, + "device": self.device, + "model_type": self.framework, + "abstract": self.model_info, + } @property def codegen_func(self): @@ -202,8 +192,8 @@ def load_native(cls, model: Any, config: dict) -> Tuple[tvm.IRModule, str, bool] Whether the model is for training. """ - if isinstance(model, dict) and "model" in model: - with open(model["model"], "r") as f: + if isinstance(model, str) and os.path.isfile(model): + with open(model, "r") as f: native_model = tvm.ir.load_json(f.read()) elif isinstance(model, tvm.IRModule): native_model = model @@ -217,36 +207,6 @@ def load_native(cls, model: Any, config: dict) -> Tuple[tvm.IRModule, str, bool] device = "cpu" return native_model, device, False - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = ModelRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage == MSCStage.PARSE: - # pylint: disable=unused-argument - def passby(mod, *args, **kwargs): - return mod, None - - config["parse"]["parser"] = passby - return config - @classmethod def run_native( cls, @@ -320,3 +280,50 @@ def _run_once(): o_name: msc_utils.cast_array(o_data) for o_name, o_data in zip(output_names, outputs) } return outputs, avg_time + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = ModelRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage == MSCStage.PARSE: + # pylint: disable=unused-argument + def passby(mod, *args, **kwargs): + return mod, None + + config["parse"]["parser"] = passby + return config + + @classmethod + def support_device(cls, device: str) -> bool: + """Check if the device is enabled + + Returns + ------- + enabled: bool + Whether the device is enabled. + """ + + if device == "cpu": + return True + if device.startswith("cuda"): + dev_id = int(device.split(":")[1]) if ":" in device else 0 + return tvm.cuda(dev_id).exist + return False diff --git a/python/tvm/contrib/msc/pipeline/dynamic.py b/python/tvm/contrib/msc/pipeline/dynamic.py new file mode 100644 index 000000000000..3e1e8b654a90 --- /dev/null +++ b/python/tvm/contrib/msc/pipeline/dynamic.py @@ -0,0 +1,492 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.pipeline.dynamic""" + +from typing import Tuple, Any, List + +from tvm.contrib.msc.core.runtime import BaseJIT +from tvm.contrib.msc.core.utils.message import MSCStage +from tvm.contrib.msc.core import utils as msc_utils +from .pipeline import BasePipeline +from .worker import MSCPipeWorker + + +class MSCDynamic(BasePipeline): + """Dynamic of Pipeline, process dynamic model""" + + def setup(self) -> dict: + """Setup the pipeline + + Returns + ------- + info: dict + The setup info. + """ + + self._jit, self._jit_caches = None, {} + self._worker_ctxs = {} + return super().setup() + + def change_stage(self, stage: str, log_stage: bool = True) -> str: + """Change stage + + Parameters + ---------- + stage: str + The stage name. + log_stage: bool + Whether to log the stage. + + Returns + ------- + stage: str + The stage name. + """ + + self._jit_caches = {} + return super().change_stage(stage, log_stage) + + def _prepare(self, data_loader: Any) -> Tuple[dict, dict]: + """Prepare datas for the pipeline. + + Parameters + ---------- + data_loader: + The data loader. + + Returns + ------- + info: dict + The info of prepare. + report: dict + The report of prepare. + """ + + hooks = {"pre_forward": [self.pre_forward], "post_forward": [self.post_forward]} + if isinstance(self._model, dict) and "model" in self._model: + worker_models = self._model["worker_models"] + self._model, device, training = self.jit_cls.load_native( + self._model["model"], self._config + ) + else: + worker_models = {} + self._model, device, training = self.jit_cls.load_native(self._model, self._config) + self._jit = self.jit_cls( + self._model, + inputs=[i[0] for i in self._config["inputs"]], + outputs=self._config["outputs"], + device=device, + training=training, + hooks=hooks, + logger=self._logger, + ) + self._jit.build() + assert MSCStage.PREPARE in self._config["dataset"], "prepare dataset is needed" + cnt, max_golden = 0, self._config["dataset"][MSCStage.PREPARE].get("max_golden", 5) + for inputs in data_loader(): + if cnt >= max_golden > 0: + break + self._jit.run(inputs) + cnt += 1 + + # create workers + def _get_worker_config(name: str, cache: dict): + saver = cache.get("saver") + assert saver, "Failed to record datas for " + name + saver.finalize() + + def _to_input(i_name): + i_info = saver.info["inputs"][i_name] + return (i_name, i_info["shape"], i_info["dtype"]) + + w_config = msc_utils.copy_dict(self._config) + w_config.update( + { + "inputs": [_to_input(i) for i in saver.info["input_names"]], + "outputs": saver.info["output_names"], + } + ) + w_config["dataset"]["golden"] = {"loader": saver.folder} + for tool in w_config.get("tools", []): + worker_config = tool.get("worker_configs", {}).get(name) + if worker_config: + tool["tool_config"] = msc_utils.update_dict(tool["tool_config"], worker_config) + return w_config + + info, report = {}, {} + for name, cache in self._jit_caches.items(): + runner_ctx = self._jit.get_runner_ctx(name) + w_model = worker_models.get(name, runner_ctx["model"]) + self._worker_ctxs[name] = { + "worker": self.create_worker(w_model, name, _get_worker_config(name, cache)), + "workspace": self._workspace.create_dir(name), + } + with msc_utils.change_workspace(self._worker_ctxs[name]["workspace"]): + info[name], report[name] = self._worker_ctxs[name]["worker"].prepare() + return info, report + + def _parse(self) -> Tuple[dict, dict]: + """Parse relax module for the pipeline. + + Returns + ------- + info: dict + The info of parse. + report: dict + The report of parse. + """ + + info, report = {}, {} + for name, w_ctx in self._worker_ctxs.items(): + with msc_utils.change_workspace(w_ctx["workspace"]): + info[name], report[name] = w_ctx["worker"].parse() + return info, report + + def _tool_applied(self, tool_type: str) -> bool: + """Check if the tool is applied + + Parameters + ---------- + tool_type: str + The tool type. + + Returns + ------- + applied: bool + Whether the tool is applied. + """ + + return all(w["worker"].tool_applied(tool_type) for w in self._worker_ctxs.values()) + + def _apply_tool( + self, tool_type: str, knowledge: dict = None, data_loader: Any = None + ) -> Tuple[dict, dict]: + """Apply tool with runner + + Parameters + ---------- + tool_type: str + The tool type to apply. + knowledge: dict + The pre knowledge. + data_loader: + The data loader. + + Returns + ------- + info: dict + The info of apply tool. + report: dict + The report of apply tool. + """ + + if knowledge: + raise NotImplementedError("Apply tool with knowledge is not supported") + + self._jit.make_plan(tool_type, data_loader) + info, report = {}, {} + for name, w_ctx in self._worker_ctxs.items(): + with msc_utils.change_workspace(w_ctx["workspace"]): + info[name], report[name] = w_ctx["worker"].apply_tool(tool_type) + return info, report + + def _create_runtime( + self, + stage: str, + tools: List[str] = None, + run_type: str = None, + run_config: dict = None, + visualize: bool = True, + profile: bool = True, + use_cache: bool = True, + ) -> Tuple[dict, dict]: + """Create runtime. + + Parameters + ---------- + stage: str + The pipeline stage. + tools: list + The tools to apply. + run_type: str + The type of runner. + run_config: dict + The config of runner. + visualize: bool + Whether to visualize the runner + profile: bool + Whether to profile the runner. + use_cache: bool + Whether to use cache. + + Returns + ------- + info: dict + The info of stage. + report: dict + The report of stage. + """ + + info, report = {}, {} + for name, w_ctx in self._worker_ctxs.items(): + with msc_utils.change_workspace(w_ctx["workspace"]): + info[name], report[name] = w_ctx["worker"].create_runner( + stage, tools, run_type, run_config, visualize, profile, use_cache + ) + self._jit.set_runner(name, w_ctx["worker"].runner) + return info, report + + def _export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: + """Export the model + + Parameters + ---------- + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. + + Returns + ------- + exported: + The exported model. + """ + + if dump: + model = self.jit_cls.dump_nativate(self._model, folder, self._config[MSCStage.EXPORT]) + else: + model = self._model + worker_models = { + n: w["worker"].export_model(stage, folder.create_dir(n), dump) + for n, w in self._worker_ctxs.items() + } + return {"model": model, "worker_models": worker_models} + + def _export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the tool + + Parameters + ---------- + tool_type: str + The tool type. + folder: MSCDirectory + The export folder. + + Returns + ------- + configs: dict + The exported tool configs. + """ + + configs = {} + for name, w_ctx in self._worker_ctxs.items(): + with msc_utils.change_workspace(w_ctx["workspace"]): + configs[name] = w_ctx["worker"].export_tool(tool_type, folder.create_dir(name)) + assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) + return msc_utils.update_dict(self._tools_config[tool_type], {"worker_configs": configs}) + + def _export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the info of pipeline + + Parameters + ---------- + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The info. + """ + + info = super()._export_info(stage, folder) + if stage in (MSCStage.OPTIMIZE, MSCStage.COMPILE): + info["worker_infos"] = {} + for name, w_ctx in self._worker_ctxs.items(): + with msc_utils.change_workspace(w_ctx["workspace"]): + info["worker_infos"][name] = w_ctx["worker"].export_info( + stage, folder.create_dir(name) + ) + return info + + def _destory(self): + """Destory the pipeline""" + + for w_ctx in self._worker_ctxs.values(): + w_ctx["worker"].destory() + + def get_runtime(self, ret_type: str = "runner") -> Any: + """Get the runtime of pipeline + + Parameters + ---------- + ret_type: str + The return type runner| runnable| model. + + Returns + ------- + runnable: + The runnable object. + """ + + if ret_type == "runner": + return self._jit + if ret_type in ("model", "runnable"): + return self._jit.jit_model + raise TypeError("Unexpect return type " + str(ret_type)) + + def pre_forward(self, runner_name: str, inputs: List[Tuple[str, Any]]) -> Any: + """pre forward hook for jit model + + Parameters + ---------- + runner_name: str + The runner name. + inputs: + The msc format inputs. + """ + + if self._current_stage == MSCStage.PREPARE: + cache = self._jit_caches.setdefault(runner_name, {}) + cache["inputs"] = inputs + self._pre_forward(runner_name, inputs) + + def _pre_forward(self, runner_name: str, inputs: List[Tuple[str, Any]]) -> Any: + """pre forward hook for jit model + + Parameters + ---------- + runner_name: str + The runner name. + inputs: + The msc format inputs. + """ + + return None + + def post_forward( + self, runner_name: str, outputs: List[Tuple[str, Any]] + ) -> List[Tuple[str, Any]]: + """pre forward hook for jit model + + Parameters + ---------- + runner_name: str + The runner name. + outputs: + The outputs. + + Returns + ------- + outputs: + The outputs. + """ + + if self._current_stage == MSCStage.PREPARE: + cache = self._jit_caches[runner_name] + assert "inputs" in cache, "Failed to record inputs" + if "saver" not in cache: + golden = ( + msc_utils.get_dataset_dir().create_dir(runner_name).relpath("Golden", False) + ) + saver_options = { + "input_names": [i[0] for i in cache["inputs"]], + "output_names": [o[0] for o in outputs], + } + cache["saver"] = msc_utils.IODataSaver(golden, saver_options) + cache["saver"].save_batch([i[1] for i in cache["inputs"]], [o[1] for o in outputs]) + return self._post_forward(runner_name, outputs) + + def _post_forward( + self, runner_name: str, outputs: List[Tuple[str, Any]] + ) -> List[Tuple[str, Any]]: + """pre forward hook for jit model + + Parameters + ---------- + runner_name: str + The runner name. + outputs: + The outputs. + + Returns + ------- + outputs: + The outputs. + """ + + return outputs + + def _record_stage(self, stage: str, info: dict = None, report: dict = None): + """Record the stage + + Parameters + ------- + stage: str + The compile stage + info: dict + The info of stage. + report: dict + The report of stage. + """ + + stage_report = {} + for name, w_report in report.items(): + for k, v in w_report.items(): + stage_report.setdefault(k, {})[name] = v + info = {k: v for k, v in info.items() if v} + super()._record_stage(stage, info, stage_report) + + def pipe_mark(self, msg: Any) -> str: + """Mark the message with pipeline info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "DYNAMIC " + str(msg) + + @property + def jit_cls(self): + return BaseJIT + + @property + def worker_cls(self): + return MSCPipeWorker + + +class TorchDynamic(MSCDynamic): + """Dynamic of Pipeline, process torch dynamo""" + + @property + def jit_cls(self): + # pylint: disable=import-outside-toplevel + from tvm.contrib.msc.framework.torch.runtime import TorchJIT + + return TorchJIT diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index e0f734af6cb5..54052dccc6cb 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -14,114 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-outside-toplevel """tvm.contrib.msc.pipeline.manager""" -import os -import time -import json -import logging -from typing import Dict, Any, Union, List -import traceback -import numpy as np +from typing import Any, List, Tuple -import tvm -from tvm.contrib.msc.core.runtime import BaseRunner -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey +from tvm.contrib.msc.core.gym.control import create_controller from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.control import create_controller -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.plugin.utils import export_plugins, load_plugins -from .config import support_tool - - -class BaseManager(object): - """Base Manager of MSC - - Parameters - ---------- - model: Any - The raw model in framwork. - config: dict - The config for pipeline. - plugins: dict - The plugins for pipeline. - root: str - The root path for files. - run_optimize: bool - Whether to run optimize. - run_compile: bool - Whether to run compile. - """ - - def __init__( - self, - model: Any, - config: dict, - plugins: dict = None, - root: str = None, - run_optimize: bool = True, - run_compile: bool = True, - ): - # change path to root path - if root: +from .pipeline import BasePipeline +from .worker import MSCPipeWorker - def _from_root_mark(val): - if isinstance(val, str) and MSCKey.ROOT_MARK in val: - return val.replace(MSCKey.ROOT_MARK, root) - return val - model = _from_root_mark(model) - config = msc_utils.map_dict(config, _from_root_mark) - plugins = msc_utils.map_dict(plugins, _from_root_mark) +class MSCManager(BasePipeline): + """Manager of Pipeline, process static model""" - # check stage - for stage in [ - "inputs", - "outputs", - "dataset", - MSCStage.PREPARE, - MSCStage.PARSE, - MSCStage.COMPILE, - MSCStage.EXPORT, - ]: - config.setdefault(stage, {}) - - MSCMap.reset() - use_cache = config.get("use_cache", True) - self._workspace = msc_utils.set_workspace(config.get("workspace"), use_cache) - self._model_type = config["model_type"] - runner_cls = self._get_runner_cls(self._model_type) - self._model, self._device, self._training = runner_cls.load_native(model, config) - self._plugins = load_plugins(plugins) if plugins else {} - self._verbose = config.get("verbose", "info") - if "logger" in config: - self._logger = config["logger"] - MSCMap.set(MSCKey.GLOBALE_LOGGER, self._logger) - else: - log_path = config.get("log_path") or self._workspace.relpath( - "MSC_LOG", keep_history=False - ) - self._logger = msc_utils.set_global_logger(self._verbose, log_path) - self._optimized, self._compiled = False, False - msc_utils.time_stamp(MSCStage.SETUP) - self._logger.info( - msc_utils.msg_block("SETUP", self.setup(config, run_optimize, run_compile)) - ) - - def setup(self, config: dict, run_optimize: bool = True, run_compile: bool = True) -> dict: - """Setup the manager - - Parameters - ---------- - config: dict - The config for manager. - run_optimize: bool - Whether to run optimize. - run_compile: bool - Whether to run compile. + def setup(self) -> dict: + """Setup the pipeline Returns ------- @@ -129,582 +37,103 @@ def setup(self, config: dict, run_optimize: bool = True, run_compile: bool = Tru The setup info. """ - self._meta_config = config - self._optimize_type = config.get(MSCStage.OPTIMIZE, {}).get("run_type", self._model_type) - self._compile_type = config.get(MSCStage.COMPILE, {}).get("run_type", self._model_type) - # register plugins - if self._plugins: - for t in [self._model_type, self._optimize_type, self._compile_type]: - assert t in self._plugins, "Missing plugin for {}".format(t) - for name, plugin in self._plugins[self._model_type].get_ops_info().items(): - _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) - self._config, self._debug_levels = self.update_config(config) - if not run_optimize and MSCStage.OPTIMIZE in self._config: - self._config.pop(MSCStage.OPTIMIZE) - if not run_compile and MSCStage.COMPILE in self._config: - self._config.pop(MSCStage.COMPILE) - self._tools_config = [] - self._relax_mod, self._runner = None, None - self._sample_inputs = None - self._report = { - "success": False, - "info": { - "workspace": self._workspace.path, - "model_type": "{}({})".format(self._model_type, self._device), - }, - "duration": {}, - "profile": {}, - } - return {"workspace": self._workspace.path, "plugins": self._plugins, "config": self._config} - - def update_config(self, config: dict) -> dict: - """Update config - - Parameters - ---------- - config: dict - The config for manager. - - Returns - ------- - config: dict - The updated config. - """ - - assert "inputs" in config, "inputs should be given to run manager" - assert "outputs" in config, "outputs should be given to run manager" - config, debug_levels = msc_utils.copy_dict(config), {} - config = self._get_runner_cls(self._model_type).update_config( - MSCStage.PARSE, config, self._model - ) - - # update runner config - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in config: - continue - if "run_type" not in config[stage]: - config[stage]["run_type"] = self._model_type - runner_cls = self._get_runner_cls(config[stage]["run_type"]) - config = runner_cls.update_config(stage, config, self._model) - - # update tool config - if config.get("tools"): - config["tools"] = self._update_tools_config(config["tools"]) - - # update export config - config[MSCStage.EXPORT].update({"inputs": config["inputs"], "outputs": config["outputs"]}) - - def _set_debug_level(stage: str, sub_config: dict, default: int = None) -> dict: - if "debug_level" in sub_config: - debug_levels[stage] = sub_config["debug_level"] - elif default is not None: - debug_levels[stage] = default - sub_config["debug_level"] = default - return debug_levels - - if self._verbose.startswith("debug:"): - debug_level = int(self._verbose.split(":")[1]) - else: - debug_level = 0 - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in config: - continue - debug_levels = _set_debug_level(stage, config[stage]["run_config"], debug_level) - for t_config in config.get("tools", []): - if not support_tool(t_config, stage, config[stage]["run_type"]): - continue - t_stage = stage + "." + self._get_tool_stage(t_config["tool_type"]) - debug_levels = _set_debug_level(t_stage, t_config["tool_config"], debug_level) - ordered_keys = [ - "model_type", - "inputs", - "outputs", - "dataset", - "tools", - MSCStage.PREPARE, - MSCStage.PARSE, - MSCStage.BASELINE, - MSCStage.OPTIMIZE, - MSCStage.COMPILE, - MSCStage.EXPORT, - ] - return {k: config[k] for k in ordered_keys if k in config}, debug_levels - - def run_pipe(self) -> dict: - """Run the pipeline and return object. - - Returns - ------- - report: - The pipeline report. - """ - - err_msg, err_info = None, None - try: - self.prepare() - self.parse() - if MSCStage.BASELINE in self._config: - self.baseline() - if MSCStage.OPTIMIZE in self._config: - self.optimize() - if MSCStage.COMPILE in self._config: - self.compile() - except Exception as exc: # pylint: disable=broad-exception-caught - err_msg = "Pipeline failed: " + str(exc) - err_info = traceback.format_exc() - self.summary(err_msg, err_info) - self._logger.info(msc_utils.msg_block("SUMMARY", self._report, 0)) - self._workspace.finalize() - return self._report + self._worker = self.create_worker(self._model, "main") + self._config = self._worker._config + return super().setup() - def prepare(self) -> Dict[str, np.ndarray]: + def _prepare(self, data_loader: Any) -> Tuple[dict, dict]: """Prepare datas for the pipeline. - Returns - ------- - dataloader: - The dataloader - sample_inputs: dict - The sample inputs. - """ - - msc_utils.time_stamp(MSCStage.PREPARE) - stage_config = self._config[MSCStage.PREPARE] - use_cache = self._config.get("use_cache", True) - runner_cls = self._get_runner_cls(self._model_type) - run_func = runner_cls.run_native if hasattr(runner_cls, "run_native") else None - input_names = [i[0] for i in self._config["inputs"]] - - # create golden - if "golden" in self._config["dataset"]: - golden_folder = self._config["dataset"]["golden"]["loader"] - else: - golden_folder = msc_utils.get_dataset_dir().relpath("Golden", use_cache) - report = {"golden_folder": golden_folder} - if msc_utils.is_io_dataset(golden_folder): - loader, source_type = msc_utils.IODataLoader(golden_folder), "Cache" - self._sample_inputs = loader[0][0] - report["datas_info"] = loader.info - self._logger.debug("Load %d golden from %s", len(loader), golden_folder) - elif run_func: - loader, source_type = self._get_loader(MSCStage.PREPARE), "Native" - saver_options = {"input_names": input_names, "output_names": self._config["outputs"]} - cnt, max_golden = 0, self._config["dataset"][MSCStage.PREPARE].get("max_golden", 5) - with msc_utils.IODataSaver(golden_folder, saver_options) as saver: - for inputs in loader(): - if cnt >= max_golden > 0: - break - if not self._sample_inputs: - self._sample_inputs = { - k: msc_utils.cast_array(v) for k, v in inputs.items() - } - outputs, _ = run_func(self._model, inputs, input_names, self._config["outputs"]) - cnt = saver.save_batch(inputs, outputs) - report["datas_info"] = saver.info - self._logger.debug("Saved %d golden to %s", cnt, golden_folder) - else: - raise Exception("golden_folder or runner should given to save golden") - self._config["dataset"]["golden"] = {"loader": golden_folder, "max_batch": -1} - - def _to_abstract(info: dict) -> dict: - def _to_tensor_str(info): - return "{},{}".format(";".join([str(s) for s in info["shape"]]), info["dtype"]) - - return { - "num_datas": info["num_datas"], - "inputs": {n: _to_tensor_str(i) for n, i in info["inputs"].items()}, - "outputs": {n: _to_tensor_str(o) for n, o in info["outputs"].items()}, - } - - report["datas_info"] = _to_abstract(report["datas_info"]) - report["sample_inputs"] = self._sample_inputs - self._logger.info(msc_utils.msg_block("GOLDEN({})".format(source_type), report)) - - # profile - if "profile" in stage_config and run_func: - benchmark = stage_config["profile"].get("benchmark", {}) - benchmark["repeat"] = self._get_repeat(benchmark) - self._logger.debug("Prepare profile with %s(%s)", run_func.__name__, benchmark) - _, avg_time = run_func( - self._model, self._sample_inputs, input_names, self._config["outputs"], **benchmark - ) - msg = "{:.2f} ms @ {}".format(avg_time, self._device) - self._report["profile"][MSCStage.PREPARE] = {"latency": msg} - self._logger.info("Profile(prepare) %d times -> %s", benchmark["repeat"], msg) - - return self._sample_inputs - - def parse(self) -> tvm.IRModule: - """Parse the model to IRModule. - - Returns - ------- - relax_mod: tvm.IRModule - The parsed module. - """ - - msc_utils.time_stamp(MSCStage.PARSE) - stage_config = self._config[MSCStage.PARSE] - if self._config.get("use_cache", True): - cache_path = ( - msc_utils.get_cache_dir().create_dir(MSCStage.PARSE).relpath("parsed_relax.json") - ) - else: - cache_path = None - if cache_path and os.path.isfile(cache_path): - with open(cache_path, "r") as f: - self._relax_mod = tvm.ir.load_json(f.read()) - self._logger.info("Load parsed mod from %s", cache_path) - else: - parse_config = msc_utils.copy_dict(stage_config.get("parse_config", {})) - parse_info = {"parser": stage_config["parser"], "config": parse_config} - self._logger.info(msc_utils.msg_block("PARSE", parse_info)) - parse_config["as_msc"] = False - if self._model_type in self._plugins: - plugin = self._plugins[self._model_type] - parse_config["custom_convert_map"] = plugin.get_convert_map() - self._relax_mod, _ = stage_config["parser"](self._model, **parse_config) - transformed = set() - for stage in [MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in self._config: - continue - run_type = self._config[stage]["run_type"] - if run_type in transformed: - continue - transformed.add(run_type) - runner_cls = self._get_runner_cls(run_type) - if hasattr(runner_cls, "target_transform"): - self._logger.info("Transform for %s(%s)", run_type, stage) - self._relax_mod = runner_cls.target_transform(self._relax_mod) - if cache_path: - with open(cache_path, "w") as f: - f.write(tvm.ir.save_json(self._relax_mod)) - self._logger.debug("Save parsed mod to %s", cache_path) - return self._relax_mod - - def _run_stage(self, stage: str) -> BaseRunner: - """Run the stage. - Parameters ---------- - stage: str - The compile stage. - - Returns - ------- - runner: BaseRunner - The runner. - """ - - msc_utils.time_stamp(stage) - self.apply_tools(stage) - self._runner = self._create_runner( - stage, - self._config[stage], - use_cache=self._config.get("use_cache", True), - ) - return self._runner - - def baseline(self) -> BaseRunner: - """Run the baseline. + data_loader: + The data loader. Returns ------- - runner: BaseRunner - The runner. - """ - - return self._run_stage(MSCStage.BASELINE) - - def optimize(self) -> BaseRunner: - """Run the optimize and return object. - - Returns - ------- - runner: BaseRunner - The runner. - """ - - runner = self._run_stage(MSCStage.OPTIMIZE) - self._optimized = True - return runner - - def compile(self) -> BaseRunner: - """Run the compile and return object. - - Returns - ------- - runner: BaseRunner - The runner. - """ - - runner = self._run_stage(MSCStage.COMPILE) - self._compiled = True - return runner - - def apply_tools(self, stage: str): - """Apply tools for a stage. - - Parameters - ---------- - stage: str - The compile stage. + info: dict + The info of prepare. + report: dict + The report of prepare. """ - self._tools_config = [] - for tool in self._config.get("tools", []): - run_type = tool.get("run_type", self._config[stage]["run_type"]) - if not support_tool(tool, stage, run_type): - continue - self._apply_tool(tool, stage) - if tool.get("apply_once", False): - self._logger.debug("Remove apply once tool %s", tool["tool_type"]) - self._tools_config = self._tools_config[:-1] - - def summary(self, err_msg=None, err_info: str = None): - """Summary the pipeline. + return self._worker.prepare(data_loader) - Parameters - ---------- - err_msg: str - The error message. - err_info: str - The error info. + def _parse(self) -> Tuple[dict, dict]: + """Parse relax module for the pipeline. Returns ------- + info: dict + The info of parse. report: dict - The report of the pipeline. + The report of parse. """ - msc_utils.time_stamp(MSCStage.SUMMARY, False) - if err_msg: - self._report.update({"success": False, "err_msg": err_msg, "err_info": err_info}) - else: - self._report["success"] = True - self._report["duration"] = msc_utils.get_duration() - return self._report + return self._worker.parse() - def export(self, path: str = None, dump: bool = True) -> Union[str, dict]: - """Export the pipeline + def _tool_applied(self, tool_type: str) -> bool: + """Check if the tool is applied Parameters ---------- - path: str - The export path. - dump: bool - Whether to dump the info. - - Returns - ------- - export_path/pipeline: str/dict - The exported path/pipeline info. - """ - - path = path or "msc_export" - if path.endswith(".tar.gz"): - folder, dump = msc_utils.msc_dir(path.replace(".tar.gz", ""), keep_history=False), True - else: - folder = msc_utils.msc_dir(path, keep_history=False) - - def _to_root_mark(val): - if isinstance(val, str) and folder.path != val and folder.path in val: - return val.replace(folder.path, MSCKey.ROOT_MARK) - return val - - # export compiled - if self._compiled: - if not dump: - return self._runner.runnable - model = self._runner.export_runnable(folder) - if self._plugins: - plugin = self._plugins[self.compile_type] - model["plugins"] = plugin.copy_libs(folder.create_dir("plugins")) - model.update( - { - "device": self._runner.device, - "model_type": self.compile_type, - "abstract": self._runner.model_info, - } - ) - # save golden - num_golden = self._config[MSCStage.EXPORT].get("num_golden", 0) - if num_golden > 0: - saver_options = { - "input_names": [i[0] for i in self._config["inputs"]], - "output_names": self._config["outputs"], - } - batch_cnt, model["golden"] = 0, folder.create_dir("golden").path - with msc_utils.IODataSaver(model["golden"], saver_options) as saver: - for inputs in self._get_loader()(): - if batch_cnt >= num_golden: - break - batch_cnt = saver.save_batch(inputs, self._runner.run(inputs)) - model = msc_utils.map_dict(model, _to_root_mark) - with open(folder.relpath("model.json"), "w") as f: - f.write(json.dumps(model, indent=2)) - else: - if dump: - plugins = export_plugins(self._plugins, folder.create_dir("plugins")) - else: - plugins = self._plugins - - pipeline = { - "model": self.export_model(folder.create_dir("model"), dump), - "config": self.export_config(folder, dump), - "plugins": plugins, - "root": folder.path, - } - pipeline = msc_utils.map_dict(pipeline, _to_root_mark) - if not dump: - return pipeline - with open(folder.relpath("pipeline.json"), "w") as f: - f.write(json.dumps(pipeline, indent=2)) - # copy common files - if self._optimized or self._compiled: - stage = MSCStage.COMPILE if self._compiled else MSCStage.OPTIMIZE - msc_utils.get_visual_dir().copy(stage, folder.relpath("visualize")) - for log_h in self._logger.handlers: - if isinstance(log_h, logging.FileHandler): - folder.copy(log_h.baseFilename) - with open(folder.relpath("report.json"), "w") as f: - f.write(json.dumps(self._report, indent=2)) - folder.finalize() - if path.endswith(".tar.gz"): - msc_utils.pack_folder(path.replace(".tar.gz", ""), "tar.gz") - return path - - def export_model(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: - """Export the model - - Parameters - ---------- - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. + tool_type: str + The tool type. Returns ------- - exported: - The exported model. + applied: bool + Whether the tool is applied. """ - if self._optimized: - module = self._runner.export_module(folder) - if not dump: - return module - path = folder.relpath("model.json") - with open(path, "w") as f: - f.write(tvm.ir.save_json(module)) - return {"model": path} - if not dump: - return self._model - return self._get_runner_cls(self._model_type).dump_nativate( - self._model, folder, **self._config[MSCStage.EXPORT] - ) + return self._worker.tool_applied(tool_type) - def export_config(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> dict: - """Export the config + def _apply_tool( + self, tool_type: str, knowledge: dict = None, data_loader: Any = None + ) -> Tuple[dict, dict]: + """Apply tool with runner Parameters ---------- - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. + tool_type: str + The tool type to apply. + knowledge: dict + The pre knowledge. + data_loader: + The data loader. Returns ------- - config: dict - The updated config. - """ - - # dump the dataloader - def _save_dataset(name, info, dump: bool): - loader, max_batch = info["loader"], info.get("max_batch", -1) - data_folder = folder.create_dir("dataset") - if isinstance(loader, str) and msc_utils.is_callable(loader): - path, func_name = loader.split(":") - exp_loader = data_folder.copy(path) + ":" + func_name - elif msc_utils.is_io_dataset(loader): - exp_loader = data_folder.copy(loader, name) - elif callable(loader) and dump: - saver_options = { - "input_names": [i[0] for i in self._config["inputs"]], - "output_names": self._config["outputs"], - } - batch_cnt = 0 - exp_loader = data_folder.create_dir(name).path - with msc_utils.IODataSaver(exp_loader, saver_options) as saver: - for inputs in loader(): - if batch_cnt >= max_batch > 0: - break - batch_cnt = saver.save_batch(inputs) - else: - exp_loader = loader - return {"loader": exp_loader, "max_batch": max_batch} - - config = msc_utils.copy_dict(self._meta_config) - config["dataset"] = { - k: _save_dataset(k, v, dump) for k, v in self._config["dataset"].items() - } - if self._optimized: - config["model_type"] = MSCFramework.TVM - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE]: - if stage in config: - config.pop(stage) - if "profile" in config[MSCStage.COMPILE]: - config[MSCStage.COMPILE]["profile"].setdefault("check", {})["err_rate"] = -1 - config["tools"] = [] - for tool in self._config.get("tools", []): - if not support_tool(tool, MSCStage.COMPILE, self._compile_type): - continue - run_tool = self.runner.get_tool(tool["tool_type"]) - tool["tool_config"] = run_tool.export_config(tool["tool_config"], folder) - if tool["tool_config"]: - config["tools"].append(tool) - else: - self._logger.info( - "Skip compile with tool %s as no config exported", tool["tool_type"] - ) - # remove not serializable items - if dump: - remove_keys = {"workspace", "logger"} - config = {k: v for k, v in config.items() if k not in remove_keys} - return config - - def destory(self, keep_workspace: bool = False): - """Destroy the manager - - Parameters - ---------- - keep_workspace: bool - Whether to keep workspace. + info: dict + The info of apply tool. + report: dict + The report of apply tool. """ - if self._runner: - self._runner.destory() - if not keep_workspace: - self._workspace.destory() - msc_utils.remove_loggers() + return self._worker.apply_tool(tool_type, knowledge, data_loader) - def _create_runner( + def _create_runtime( self, stage: str, - stage_config: dict, + tools: List[str] = None, + run_type: str = None, + run_config: dict = None, visualize: bool = True, profile: bool = True, use_cache: bool = True, - ) -> BaseRunner: - """Create runner. + ) -> Tuple[dict, dict]: + """Create runtime. Parameters ---------- stage: str - The stage name - stage_config: dict - The config of this stage. + The pipeline stage. + tools: list + The tools to apply. + run_type: str + The type of runner. + run_config: dict + The config of runner. visualize: bool Whether to visualize the runner profile: bool @@ -714,387 +143,145 @@ def _create_runner( Returns ------- - runner: BaseRunner - The runner. + info: dict + The info of stage. + report: dict + The report of stage. """ - if self._runner: - self._runner.destory() - cache_dir = msc_utils.get_cache_dir().create_dir(stage) if use_cache else None - msc_utils.time_stamp(stage + ".build", False) - runner_cls = self._get_runner_cls(stage_config["run_type"]) - run_config = msc_utils.copy_dict(stage_config.get("run_config")) - if "generate_config" not in run_config: - run_config["generate_config"] = {} - cleanup = self._debug_levels.get(stage, 0) == 0 - run_config["generate_config"]["build_folder"] = msc_utils.get_build_dir().create_dir( - stage, cleanup=cleanup + return self._worker.create_runner( + stage, tools, run_type, run_config, visualize, profile, use_cache ) - if "device" not in run_config: - run_config["device"] = self._device - if "training" not in run_config: - run_config["training"] = self._training - # Build runner - runner = runner_cls( - self._relax_mod, - tools_config=self._tools_config, - plugin=self._plugins.get(stage_config["run_type"]), - stage=stage, - logger=self._logger, - **run_config, - ) - runner.build(cache_dir=cache_dir) - self._report["info"][stage + "_type"] = "{}({})".format(runner.framework, runner.device) - if visualize: - runner.visualize(msc_utils.get_visual_dir().create_dir(stage.split(".")[0])) - if profile and "profile" in stage_config: - self._report["profile"][stage] = self._profile_runner(runner, stage_config) - if use_cache: - runner.save_cache(cache_dir) - return runner - def _apply_tool(self, tool: dict, stage: str) -> str: - """Apply tool with runner + def _run_gym(self, stage: str, config: dict, knowledge: dict, data_loader: Any) -> dict: + """Run gym. Parameters ---------- - tool: dict - The tool config. stage: str - The compile stage. + The pipeline stage. + config: dict + The gym config. + knowledge: dict + The pre knowledge. + data_loader: + The data loader. Returns ------- - plan_file: str - The plan_file path. + knowledge: dict + The learned knowledge. """ - self._tools_config.append(tool) - tool_type, tool_config = tool["tool_type"], tool["tool_config"] - tool_stage = self._get_tool_stage(tool_type) - plan_file = tool_config["plan_file"] - if os.path.isfile(plan_file): - self._logger.info("Skip %s with plan %s", tool_type, plan_file) - return plan_file - t_stage = stage + "." + tool_stage - msc_utils.time_stamp(t_stage) - stage_config = { - "run_type": tool.get("run_type", self._config[stage]["run_type"]), - "run_config": self._config[stage]["run_config"], + extra_config = { + "env": { + "runner": self._worker.runner, + "data_loader": data_loader, + "knowledge": knowledge, + }, + "verbose": self._verbose, } - runner = self._create_runner( - t_stage, stage_config, visualize=False, profile=False, use_cache=False - ) - if "gym_configs" in tool: - knowledge = None - for idx, config in enumerate(tool["gym_configs"]): - knowledge_file = msc_utils.get_config_dir().relpath( - "gym_knowledge_{}.json".format(idx) - ) - gym_mark = "GYM[{}/{}]({} @ {}) ".format( - idx, len(tool["gym_configs"]), runner.framework, t_stage - ) - if os.path.isfile(knowledge_file): - knowledge = knowledge_file - self._logger.info("%sLoad from %d", gym_mark, knowledge) - else: - msc_utils.time_stamp(t_stage + ".gym_{}".format(idx)) - self._logger.info("%sStart search", gym_mark) - extra_config = { - "env": { - "runner": runner, - "data_loader": self._get_loader(tool_stage), - "knowledge": knowledge, - }, - "verbose": self._verbose, - } - controller = create_controller(tool_stage, config, extra_config) - knowledge = controller.run() - msc_utils.save_dict(knowledge, knowledge_file) - plan = msc_utils.load_dict(knowledge) - self._logger.info("%sFound %d plan", gym_mark, len(plan)) - return msc_utils.save_dict(plan, plan_file) - msc_utils.time_stamp(t_stage + ".make_plan", False) - plan_file = runner.make_plan(tool_type, self._get_loader(tool_stage)) - if tool.get("visualize", False): - runner.visualize(msc_utils.get_visual_dir().create_dir(stage)) - return plan_file - - def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: - """Profile the runner. - - Parameters - ---------- - runner: BaseRunner - The runner to be profiled - stage_config: dict - The config of this stage. + controller = create_controller(stage, config, extra_config) + return controller.run() - Returns - ------- - report: dict - The profile report. - """ - - stage = runner.stage - msc_utils.time_stamp(stage + ".profile", False) - profile_config = stage_config["profile"] - msg, report = "Profile({})".format(stage), {} - - # check accuracy - check_config = profile_config.get("check", {}) - if check_config: - loader = msc_utils.IODataLoader(self._config["dataset"]["golden"]["loader"]) - total, passed = 0, 0 - acc_report = {"config": check_config} - for idx, (inputs, outputs) in enumerate(loader): - results = runner.run(inputs) - iter_report = msc_utils.compare_arrays( - outputs, - results, - atol=check_config.get("atol", 1e-2), - rtol=check_config.get("rtol", 1e-2), - ) - total += iter_report["total"] - passed += iter_report["passed"] - acc_report["iter_" + str(idx)] = iter_report["info"] - pass_rate = float(passed) / total - report["accuracy"] = "{}/{}({:.2f}%)".format(passed, total, pass_rate * 100) - title = "Check({}) pass {}".format(stage, report["accuracy"]) - self._logger.debug(msc_utils.msg_block(title, acc_report, width=0)) - msg += " acc {} iters -> {}".format(len(loader), report["accuracy"]) - if runner.get_tool(ToolType.PRUNER) or runner.get_tool(ToolType.QUANTIZER): - self._logger.debug("Disable accuracy check(%s) by tools", stage) - else: - required_err, err_rate = check_config.get("err_rate", 0), (1 - pass_rate) - if err_rate > required_err >= 0: - raise Exception( - "Failed to profile the runner({}), err_rate {} > required {}".format( - stage, err_rate, required_err - ) - ) - - # benchmark model - if runner.get_tool(ToolType.TRACKER): - benchmark_config = None - self._logger.debug("Disable benchmark(%s) by tools", stage) - else: - benchmark_config = profile_config.get("benchmark", {}) - if benchmark_config: - for _ in range(benchmark_config.get("warm_up", 10)): - runner.run(self._sample_inputs) - start = time.time() - repeat = self._get_repeat(benchmark_config, runner.device) - for _ in range(repeat): - runner.run(self._sample_inputs) - avg_time = (time.time() - start) * 1000 / repeat - report["latency"] = "{:.2f} ms @ {}".format(avg_time, runner.device) - msg += " latency {} times -> {}".format(repeat, report["latency"]) - self._logger.info(msg) - return report - - def _update_tools_config(self, tools: List[dict]) -> List[dict]: - """Update tool in stage config. + def _export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: + """Export the model Parameters ---------- - tools: list - The config of tools. + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. Returns ------- - tools: list - The updated config of tools. + exported: + The exported model. """ - for tool in tools: - tool_config = tool["tool_config"] - if "plan_file" not in tool_config: - tool_config["plan_file"] = "msc_{}.json".format(tool["tool_type"]) - tool_config["plan_file"] = msc_utils.to_abs_path( - tool_config["plan_file"], msc_utils.get_config_dir() - ) - return tools + return self._worker.export_model(stage, folder, dump) - def _get_tool_stage(self, tool_type: str) -> str: - """Map the stage according to tool_type + def _export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the tool Parameters ---------- tool_type: str The tool type. + folder: MSCDirectory + The export folder. Returns ------- - stage: str - The stage. - """ - - if tool_type == ToolType.PRUNER: - return MSCStage.PRUNE - if tool_type == ToolType.QUANTIZER: - return MSCStage.QUANTIZE - if tool_type == ToolType.DISTILLER: - return MSCStage.DISTILL - if tool_type == ToolType.TRACKER: - return MSCStage.TRACK - return tool_type - - def get_runnable(self, ret_type: str = "runner") -> Any: - """Return object by type. - - Parameters - ---------- - ret_type: str - The return type runner| model. - - Returns - ------- - runnable: - The runner or model. + config: dict + The exported tool config. """ - assert self._runner, "Failed to create runner, call run_pipe first" - if ret_type == "runner": - return self._runner - elif ret_type == "runnable": - return self._runner.runnable - elif ret_type == "model": - return self._runner.model - raise Exception("Unexpect return type " + str(ret_type)) + assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) + exp_config = {"tool_config": self._worker.export_tool(tool_type, folder)} + return msc_utils.update_dict(self._tools_config[tool_type], exp_config) - def _get_runner_cls(self, run_type: str) -> BaseRunner: - """Get the runner cls by type + def _export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the info of pipeline Parameters ---------- - run_type: str - The run type. + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. Returns ------- - runner_cls: class - The runner class. + info: dict + The info. """ - raise NotImplementedError("_get_runner_cls is not implemented for BaseManager") - - def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: - """Get the data loader""" - - config = self._config["dataset"].get(name, self._config["dataset"][MSCStage.PREPARE]) - source_loader = config.get("loader") - assert source_loader, "Dataset loader should be given for msc pipeline" - if source_loader == "from_random": - max_batch = config.get("max_batch", 5) + info = super()._export_info(stage, folder) + if stage in (MSCStage.OPTIMIZE, MSCStage.COMPILE): + info.update(self._worker.export_info(stage, folder)) + return info - def get_random(): - for _ in range(max_batch): - yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]} + def _destory(self): + """Destory the pipeline""" - loader, source_type = get_random, "Random" - elif msc_utils.is_io_dataset(source_loader): - max_batch = config.get("max_batch", -1) + self._worker.destory() - def load_datas(): - for inputs, _ in msc_utils.IODataLoader(source_loader, end=max_batch): - yield inputs - - loader, source_type = load_datas, "IOData" - elif callable(source_loader): - max_batch = config.get("max_batch", -1) - load_kwargs = config.get("load_kwargs", {}) - - def get_source(): - for idx, inputs in enumerate(source_loader(**load_kwargs)): - if idx >= max_batch > 0: - break - yield inputs - - loader, source_type = get_source, "Custom" - else: - raise TypeError( - "Unexpected source loader {}({})".format(source_loader, type(source_loader)) - ) - self._logger.debug("Create data loader(%s) %s(%s)", name, loader.__name__, source_type) - return loader - - def _get_repeat(self, benchmark: dict, device: str = None) -> int: - """Get the repeat number for benchmark + def get_runtime(self, ret_type: str = "runner") -> Any: + """Get the runtime of pipeline Parameters ---------- - benchmark: dict - The benchmark config. - device: str - The device name + ret_type: str + The return type runner| runnable| model. Returns ------- - repeat: int - The repeat number. + runnable: + The runnable object. """ - device = device or self._device - repeat = benchmark.get("repeat", -1) - if repeat == -1: - repeat = 500 if device.startswith("cuda") else 10 - return repeat + return self._worker.get_runnable(ret_type) - @property - def runner(self): - return self._runner - - @property - def report(self): - return self._report - - @property - def model_type(self): - return self._model_type - - @property - def optimize_type(self): - return self._optimize_type - - @property - def compile_type(self): - return self._compile_type - - -class MSCManager(BaseManager): - """Normal manager in MSC""" - - def _get_runner_cls(self, run_type: str) -> BaseRunner: - """Get the runner cls by type + def pipe_mark(self, msg: Any) -> str: + """Mark the message with pipeline info Parameters - ---------- - run_type: str - The run type. + ------- + msg: str + The message Returns ------- - runner_cls: class - The runner class. + msg: str + The message with mark. """ - if run_type == MSCFramework.TVM: - from tvm.contrib.msc.framework.tvm.runtime import TVMRunner - - runner_cls = TVMRunner - elif run_type == MSCFramework.TORCH: - from tvm.contrib.msc.framework.torch.runtime import TorchRunner + return "MANAGER " + str(msg) - runner_cls = TorchRunner - elif run_type == MSCFramework.TENSORFLOW: - from tvm.contrib.msc.framework.tensorflow.runtime import TensorflowRunner - - runner_cls = TensorflowRunner - elif run_type == MSCFramework.TENSORRT: - from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner - - runner_cls = TensorRTRunner - else: - raise Exception("Unexpect run_type " + str(run_type)) - return runner_cls + @property + def worker_cls(self): + return MSCPipeWorker diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py b/python/tvm/contrib/msc/pipeline/pipeline.py new file mode 100644 index 000000000000..f02503a113ca --- /dev/null +++ b/python/tvm/contrib/msc/pipeline/pipeline.py @@ -0,0 +1,845 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.pipeline.pipeline""" + +import os +import json +from typing import Any, Union, List, Tuple +import traceback +import numpy as np + +from tvm.contrib.msc.core.tools import get_tool_cls, BaseTool +from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey +from tvm.contrib.msc.core.utils.message import MSCStage +from tvm.contrib.msc.plugin.utils import export_plugins, load_plugins +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core import _ffi_api +from .utils import support_tool, get_tool_stage, map_tools +from .worker import BasePipeWorker + + +class BasePipeline(object): + """Base Pipeline of MSC + + Parameters + ---------- + model: Any + The raw model in framwork. + config: dict + The config for pipeline. + plugins: dict + The plugins for pipeline. + run_optimize: bool + Whether to run optimize. + run_compile: bool + Whether to run compile. + root: str + The root path for files. + """ + + def __init__( + self, + model: Any, + config: dict, + plugins: dict = None, + run_optimize: bool = True, + run_compile: bool = True, + root: str = None, + ): + # change path to root path + if root: + + def _from_root_mark(val): + if isinstance(val, str) and MSCKey.ROOT_MARK in val: + return val.replace(MSCKey.ROOT_MARK, root) + return val + + if isinstance(model, dict): + model = msc_utils.map_dict(model, _from_root_mark) + elif isinstance(model, str): + model = _from_root_mark(model) + config = msc_utils.map_dict(config, _from_root_mark) + plugins = msc_utils.map_dict(plugins, _from_root_mark) + + MSCMap.reset() + self._model, self._meta_config = model, config + self._config = msc_utils.copy_dict(config) + if not run_optimize and MSCStage.OPTIMIZE in self._config: + self._config.pop(MSCStage.OPTIMIZE) + if not run_compile and MSCStage.COMPILE in self._config: + self._config.pop(MSCStage.COMPILE) + for stage in [MSCStage.PREPARE, MSCStage.PARSE, MSCStage.EXPORT]: + self._config.setdefault(stage, {}) + self._verbose = self._config.get("verbose", "info") + use_cache = self._config.get("use_cache", True) + if "workspace" in self._config: + self._workspace = msc_utils.set_workspace(self._config.pop("workspace"), use_cache) + else: + self._workspace = msc_utils.set_workspace("msc_workspace", use_cache) + if "logger" in self._config: + self._logger = self._config.pop("logger") + MSCMap.set(MSCKey.GLOBALE_LOGGER, self._logger) + else: + if "log_file" in self._config: + log_file = self._config.pop("log_file") + else: + log_file = self._workspace.relpath("MSC_LOG", keep_history=False) + self._logger = msc_utils.set_global_logger(self._verbose, log_file) + self._plugins = load_plugins(plugins) if plugins else {} + self.change_stage(MSCStage.SETUP) + self._logger.info(msc_utils.msg_block(self.pipe_mark("SETUP"), self.setup())) + + def setup(self) -> dict: + """Setup the pipeline + + Returns + ------- + info: dict + The setup info. + """ + + # define run type + self._model_type = self._config["model_type"] + self._optimize_type = self._config.get(MSCStage.OPTIMIZE, {}).get( + "run_type", self._model_type + ) + self._compile_type = self._config.get(MSCStage.COMPILE, {}).get( + "run_type", self._model_type + ) + self._optimized, self._compiled = False, False + + # map tools + self._tools_config = map_tools(self._config.get("tools", [])) + + # register plugins + if self._plugins: + for t in [self._model_type, self._optimize_type, self._compile_type]: + assert t in self._plugins, "Missing plugin for {}".format(t) + for name, plugin in self._plugins[self._model_type].get_ops_info().items(): + _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) + + # status + self._current_stage = None + self._report = { + "success": False, + "info": {}, + "duration": {}, + } + return { + "workspace": self._workspace.path, + "log_file": msc_utils.get_log_file(self._logger), + "verbose": self._verbose, + "plugins": self._plugins, + "config": self._config, + } + + def run_pipe(self) -> dict: + """Run the pipeline and return object. + + Returns + ------- + report: + The pipeline report. + """ + + err_msg, err_info = None, None + try: + self.prepare() + self.parse() + if MSCStage.BASELINE in self._config: + self.baseline() + if MSCStage.OPTIMIZE in self._config: + self.optimize() + if MSCStage.COMPILE in self._config: + self.compile() + except Exception as exc: # pylint: disable=broad-exception-caught + err_msg = "Pipeline failed: " + str(exc) + err_info = traceback.format_exc() + self.summary(err_msg, err_info) + self._logger.info(msc_utils.msg_block(self.pipe_mark("SUMMARY"), self._report, 0)) + self._workspace.finalize() + return self._report + + def change_stage(self, stage: str, log_stage: bool = True) -> str: + """Change stage + + Parameters + ---------- + stage: str + The stage name. + log_stage: bool + Whether to log the stage. + + Returns + ------- + stage: str + The stage name. + """ + + self._current_stage = stage + msc_utils.time_stamp(stage, log_stage) + return stage + + def prepare(self): + """Prepare datas for the pipeline.""" + + self.change_stage(MSCStage.PREPARE) + info, report = self._prepare(self._get_loader(MSCStage.PREPARE)) + self._record_stage(MSCStage.PREPARE, info, report) + + def _prepare(self, data_loader: Any) -> Tuple[dict, dict]: + """Prepare datas for the pipeline. + + Parameters + ---------- + data_loader: + The data loader. + + Returns + ------- + info: dict + The info of prepare. + report: dict + The report of prepare. + """ + + raise NotImplementedError("_prepare is not implemented in " + str(self.__class__)) + + def parse(self): + """Parse relax module for the pipeline.""" + + self.change_stage(MSCStage.PARSE) + info, report = self._parse() + self._record_stage(MSCStage.PARSE, info, report) + + def _parse(self) -> Tuple[dict, dict]: + """Parse relax module for the pipeline. + + Returns + ------- + info: dict + The info of parse. + report: dict + The report of parse. + """ + + raise NotImplementedError("_parse is not implemented in " + str(self.__class__)) + + def baseline(self): + """Run the baseline.""" + + self._run_stage(MSCStage.BASELINE) + + def optimize(self) -> Tuple[dict, dict]: + """Run the optimize. + + Returns + ------- + info: dict + The info of stage. + report: dict + The report of stage. + """ + + self._run_stage(MSCStage.OPTIMIZE) + self._optimized = True + + def compile(self) -> Tuple[dict, dict]: + """Run the compile. + + Returns + ------- + info: dict + The info of stage. + report: dict + The report of stage. + """ + + self._run_stage(MSCStage.COMPILE) + self._compiled = True + + def _run_stage(self, stage: str) -> Tuple[dict, dict]: + """Run the stage. + + Parameters + ---------- + stage: str + The pipeline stage. + + Returns + ------- + info: dict + The info of stage. + report: dict + The report of stage. + """ + + self.change_stage(stage) + tools = [] + for tool in self._config.get("tools", []): + run_type = tool.get("run_type", self._config[stage]["run_type"]) + if not support_tool(tool, stage, run_type): + continue + tools.append(tool["tool_type"]) + tool_cls, tool_stage = self.get_tool_cls(tool, run_type), get_tool_stage( + tool["tool_type"] + ) + t_stage = self.change_stage(stage + "." + tool_stage) + if self._tool_applied(tool["tool_type"]): + if tool_cls.apply_once(): + msg = "Remove apply once tool " + str(tool["tool_type"]) + self._logger.info(self.pipe_mark(msg)) + tools = tools[:-1] + else: + self._logger.info(self.pipe_mark("Apply planed tool " + str(tool["tool_type"]))) + continue + self.change_stage(t_stage + ".build", False) + info, report = self._create_runtime( + t_stage, tools, run_type=run_type, visualize=False, profile=False, use_cache=False + ) + self._record_stage(t_stage, info, report) + knowledge, loader = None, self._get_loader(tool_stage) + if "gym_configs" in tool: + for idx, config in enumerate(tool["gym_configs"]): + knowledge_file = self._workspace.create_dir("Gym").relpath( + "knowledge_{}.json".format(idx) + ) + gym_mark = "GYM[{}/{}]({} @ {}) ".format( + idx, len(tool["gym_configs"]), self._config[stage]["run_type"], tool_stage + ) + if os.path.isfile(knowledge_file): + knowledge = knowledge_file + msg = "{}Load from {}".format(gym_mark, knowledge) + self._logger.info(self.pipe_mark(msg)) + else: + self.change_stage(tool_stage + ".gym_{}".format(idx)) + self._logger.info(self.pipe_mark(gym_mark + "Start search")) + knowledge = self._run_gym(tool_stage, config, knowledge, loader) + msc_utils.save_dict(knowledge, knowledge_file) + knowledge = msc_utils.load_dict(knowledge) + self.change_stage(t_stage + ".apply", False) + info, report = self._apply_tool(tool["tool_type"], knowledge, loader) + self._record_stage(t_stage, info, report) + if tool_cls.apply_once(): + msg = "Remove apply once tool " + str(tool["tool_type"]) + self._logger.info(self.pipe_mark(msg)) + tools = tools[:-1] + self.change_stage(stage + ".build", False) + info, report = self._create_runtime(stage, tools) + self._record_stage(stage, info, report) + + def _tool_applied(self, tool_type: str) -> bool: + """Check if the tool is applied + + Parameters + ---------- + tool_type: str + The tool type. + + Returns + ------- + applied: bool + Whether the tool is applied. + """ + + return False + + def _apply_tool( + self, tool_type: str, knowledge: dict = None, data_loader: Any = None + ) -> Tuple[dict, dict]: + """Apply tool with runner + + Parameters + ---------- + tool_type: str + The tool type to apply. + knowledge: dict + The pre knowledge. + data_loader: + The data loader. + + Returns + ------- + info: dict + The info of apply tool. + report: dict + The report of apply tool. + """ + + raise NotImplementedError("_apply_tool is not implemented in " + str(self.__class__)) + + def _create_runtime( + self, + stage: str, + tools: List[str] = None, + run_type: str = None, + run_config: dict = None, + visualize: bool = True, + profile: bool = True, + use_cache: bool = True, + ) -> Tuple[dict, dict]: + """Create runtime. + + Parameters + ---------- + stage: str + The pipeline stage. + tools: list + The tools to apply. + run_type: str + The type of runner. + run_config: dict + The config of runner. + visualize: bool + Whether to visualize the runner + profile: bool + Whether to profile the runner. + use_cache: bool + Whether to use cache. + + Returns + ------- + info: dict + The info of stage. + report: dict + The report of stage. + """ + + raise NotImplementedError("_create_runtime is not implemented in " + str(self.__class__)) + + def _run_gym(self, stage: str, config: dict, knowledge: dict, data_loader: Any) -> dict: + """Run gym. + + Parameters + ---------- + stage: str + The pipeline stage. + config: dict + The gym config. + knowledge: dict + The pre knowledge. + data_loader: + The data loader. + + Returns + ------- + knowledge: dict + The learned knowledge. + """ + + raise NotImplementedError("_run_gym is not implemented in " + str(self.__class__)) + + def summary(self, err_msg: str = None, err_info: str = None) -> dict: + """Summary the pipeline. + + Parameters + ---------- + err_msg: str + The error message. + err_info: str + The error info. + + Returns + ------- + report: dict + The report of the pipeline. + """ + + self.change_stage(MSCStage.SUMMARY, False) + if err_msg: + self._report.update({"success": False, "err_msg": err_msg, "err_info": err_info}) + else: + self._report["success"] = True + self._report["duration"] = msc_utils.get_duration() + return self._report + + def export(self, path: str = None, dump: bool = True) -> Union[str, dict]: + """Export the pipeline + + Parameters + ---------- + path: str + The export path. + dump: bool + Whether to dump the info. + + Returns + ------- + export_path/pipeline: str/dict + The exported path/pipeline info. + """ + + path = path or "msc_export" + if path.endswith(".tar.gz"): + folder, dump = msc_utils.msc_dir(path.replace(".tar.gz", ""), keep_history=False), True + else: + folder = msc_utils.msc_dir(path, keep_history=False) + + if self._compiled: + stage = MSCStage.COMPILE + elif self._optimized: + stage = MSCStage.OPTIMIZE + else: + stage = MSCStage.SETUP + + def _to_root_mark(val): + if isinstance(val, str) and folder.path != val and folder.path in val: + return val.replace(folder.path, MSCKey.ROOT_MARK) + return val + + def _export_plugins(folder: msc_utils.MSCDirectory): + if self._compiled: + if dump and self.compile_type in self._plugins: + return self._plugins[self.compile_type].copy_libs(folder) + return self._plugins.get(self.compile_type) + if dump: + return export_plugins(self._plugins, folder) + return self._plugins + + export = { + "logger": folder.copy(msc_utils.get_log_file(self._logger)), + "report": self._report, + "info": self._export_info(stage, folder.create_dir("info")), + "model": self._export_model(stage, folder.create_dir("model"), dump), + "plugins": _export_plugins(folder.create_dir("plugins")), + } + if self._compiled: + # save golden + num_golden = self._config[MSCStage.EXPORT].get("num_golden", 5) + if num_golden > 0: + saver_options = { + "input_names": [i[0] for i in self._config["inputs"]], + "output_names": self._config["outputs"], + } + batch_cnt, export["golden"] = 0, folder.create_dir("golden").path + with msc_utils.IODataSaver(export["golden"], saver_options) as saver: + for inputs in self._get_loader()(): + if batch_cnt >= num_golden: + break + batch_cnt = saver.save_batch(inputs, self.get_runtime().run(inputs)) + else: + export["config"] = self.export_config(folder, dump) + export = msc_utils.map_dict(export, _to_root_mark) + if not dump: + return export + with open(folder.relpath("export.json"), "w") as f: + f.write(json.dumps(export, indent=2)) + folder.finalize() + if path.endswith(".tar.gz"): + msc_utils.pack_folder(path.replace(".tar.gz", ""), "tar.gz") + return path + + def export_config(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> dict: + """Export the config + + Parameters + ---------- + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. + + Returns + ------- + config: dict + The updated config. + """ + + # dump the dataloader + def _export_dataset(name, info, dump: bool): + loader, max_batch = info["loader"], info.get("max_batch", -1) + data_folder = folder.create_dir("dataset") + if isinstance(loader, str) and msc_utils.is_callable(loader): + path, func_name = loader.split(":") + exp_loader = data_folder.copy(path) + ":" + func_name + elif msc_utils.is_io_dataset(loader): + exp_loader = data_folder.copy(loader, name) + elif callable(loader) and dump: + saver_options = {"input_names": [i[0] for i in self._config["inputs"]]} + batch_cnt, exp_loader = 0, data_folder.create_dir(name).path + with msc_utils.IODataSaver(exp_loader, saver_options) as saver: + for inputs in loader(): + if batch_cnt >= max_batch > 0: + break + batch_cnt = saver.save_batch(inputs) + else: + exp_loader = loader + return {"loader": exp_loader, "max_batch": max_batch} + + config = msc_utils.copy_dict(self._meta_config) + config["dataset"] = { + k: _export_dataset(k, v, dump) for k, v in self._config["dataset"].items() + } + if self._optimized: + config["model_type"] = MSCFramework.TVM + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE]: + if stage in config: + config.pop(stage) + if "profile" in config[MSCStage.COMPILE] and self.get_runtime().trained: + config[MSCStage.COMPILE]["profile"].setdefault("check", {})["err_rate"] = -1 + config["tools"] = [] + for tool in self._config.get("tools", []): + tool_type = tool["tool_type"] + skip_msg = "Skip export tool " + tool_type + if not support_tool(tool, MSCStage.COMPILE, self._compile_type): + self._logger.info(self.pipe_mark(skip_msg + "(unsupported)")) + continue + tool_cls = self.get_tool_cls(tool, self._optimize_type) + if not tool_cls.exportable(): + self._logger.info(self.pipe_mark(skip_msg + "(unexportable)")) + continue + config["tools"].append(self._export_tool(tool_type, folder)) + # remove not serializable items + if dump: + remove_keys = {"workspace", "logger"} + config = {k: v for k, v in config.items() if k not in remove_keys} + return config + + def _export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: + """Export the model + + Parameters + ---------- + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. + + Returns + ------- + exported: + The exported model. + """ + + raise NotImplementedError("_export_model is not implemented in " + str(self.__class__)) + + def _export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the tool + + Parameters + ---------- + tool_type: str + The tool type. + folder: MSCDirectory + The export folder. + + Returns + ------- + tool: dict + The exported tool. + """ + + raise NotImplementedError("_export_tool is not implemented in " + str(self.__class__)) + + def _export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the info of pipeline + + Parameters + ---------- + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The info. + """ + + return {} + + def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: + """Get the data loader""" + + config = self._config["dataset"].get(name, self._config["dataset"][MSCStage.PREPARE]) + source_loader = config.get("loader") + assert source_loader, "Dataset loader should be given for msc pipeline" + if source_loader == "from_random": + max_batch = config.get("max_batch", 5) + + def get_random(): + for _ in range(max_batch): + yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]} + + loader, source_type = get_random, "random" + elif msc_utils.is_io_dataset(source_loader): + max_batch = config.get("max_batch", -1) + + def load_datas(): + for inputs, _ in msc_utils.IODataLoader(source_loader, end=max_batch): + yield inputs + + loader, source_type = load_datas, "io_data" + elif callable(source_loader): + max_batch = config.get("max_batch", -1) + load_kwargs = config.get("load_kwargs", {}) + if max_batch == -1 and not load_kwargs: + loader, source_type = source_loader, "custom" + else: + + def get_source(): + for idx, inputs in enumerate(source_loader(**load_kwargs)): + if idx >= max_batch > 0: + break + yield inputs + + loader, source_type = get_source, "loaded_custom" + else: + raise TypeError( + "Unexpected source loader {}({})".format(source_loader, type(source_loader)) + ) + msg = "Create data loader({}) {}({})".format(name, loader.__name__, source_type) + self._logger.debug(self.pipe_mark(msg)) + return loader + + def _record_stage(self, stage: str, info: dict = None, report: dict = None): + """Record the stage + + Parameters + ------- + stage: str + The compile stage + info: dict + The info of stage. + report: dict + The report of stage. + """ + + if info: + self._logger.info(msc_utils.msg_block(self.pipe_mark(stage.upper()), info)) + if report: + self._report["info"].setdefault(stage, {}).update(report) + + def destory(self, keep_workspace: bool = False): + """Destroy the pipeline + + Parameters + ---------- + keep_workspace: bool + Whether to keep workspace. + """ + + self._destory() + if not keep_workspace: + self._workspace.destory() + msc_utils.remove_loggers() + + def _destory(self): + """Destroy the pipeline.""" + + raise NotImplementedError("_destory is not implemented in " + str(self.__class__)) + + def get_tool_cls(self, tool: dict, framework: str) -> BaseTool: + """Get the tool class from tool config + + Parameters + ---------- + tool: dict + The tool config. + framework: str + The framework. + + Returns + ------- + tool_cls: + The tool class. + """ + + return get_tool_cls(framework, tool["tool_type"], tool["tool_config"]) + + def get_runtime(self, ret_type: str = "runner") -> Any: + """Get the runtime of pipeline + + Parameters + ---------- + ret_type: str + The return type runner| runnable| model. + + Returns + ------- + runnable: + The runnable object. + """ + + raise NotImplementedError("get_runtime is not implemented in " + str(self.__class__)) + + def create_worker(self, model: Any, name: str, config: dict = None): + """Create pipe worker + + Parameters + ------- + model: Any + The raw model in framwork. + name: str + The name of worker. + worker_config: dict + The extra config for worker. + + Returns + ------- + worker: str + The message with mark. + """ + + return self.worker_cls( + model, + config or self._config, + self._workspace, + self._plugins, + self._logger, + name=name, + ) + + def pipe_mark(self, msg: Any) -> str: + """Mark the message with pipeline info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "PIPE " + str(msg) + + @property + def worker_cls(self): + return BasePipeWorker + + @property + def report(self): + return self._report + + @property + def model_type(self): + return self._model_type + + @property + def optimize_type(self): + return self._optimize_type + + @property + def compile_type(self): + return self._compile_type diff --git a/python/tvm/contrib/msc/pipeline/utils.py b/python/tvm/contrib/msc/pipeline/utils.py new file mode 100644 index 000000000000..e4d91ee14b62 --- /dev/null +++ b/python/tvm/contrib/msc/pipeline/utils.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.pipeline.config""" + +from typing import List, Union, Dict, Tuple + +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.message import MSCStage +from tvm.contrib.msc.core import utils as msc_utils + + +def get_tool_stage(tool_type: str) -> str: + """Map the stage according to tool_type + + Parameters + ---------- + tool_type: str + The tool type. + + Returns + ------- + stage: str + The stage. + """ + + if tool_type == ToolType.PRUNER: + return MSCStage.PRUNE + if tool_type == ToolType.QUANTIZER: + return MSCStage.QUANTIZE + if tool_type == ToolType.DISTILLER: + return MSCStage.DISTILL + if tool_type == ToolType.TRACKER: + return MSCStage.TRACK + return tool_type + + +def map_tools(tools: List[dict]) -> dict: + """Map tools from list + + Parameters + ---------- + tools: list + The tools config, + + Returns + ------- + tools: dict + The tools map. + """ + + tools_map = {t["tool_type"]: t for t in tools} + assert len(tools_map) == len(tools), "Duplicate tools: " + str([t["tool_type"] for t in tools]) + return tools_map + + +def support_tool(tool: dict, stage: str, run_type: str) -> bool: + """Check if the tool is supported + + Parameters + ---------- + tool: dict + The tool config, + stage: str + The pipeline stage. + run_type: str + The runtime type. + + Returns + ------- + supported: bool + Whether the tool is supported. + """ + + run_type = tool.get("run_type", run_type) + if stage == MSCStage.BASELINE: + return tool["tool_type"] == ToolType.TRACKER + return True + + +def config_tool(tool_type: str, raw_config: Union[dict, str]) -> dict: + """Config the tool + + Parameters + ---------- + tool_type: str + The tool type, + raw_config: str| dict + The tool config or style. + + Returns + ------- + config: dict + The config for tool. + """ + + if isinstance(raw_config, dict): + if "config_style" in raw_config: + config_style = raw_config.pop("config_style") + else: + config_style = "default" + else: + config_style, raw_config = raw_config, None + configer_cls = msc_utils.get_registered_tool_configer(tool_type, config_style) + assert configer_cls, "Can not find configer for {}:{}".format(tool_type, config_style) + return {"tool_type": tool_type, **configer_cls().config(raw_config)} + + +def create_config( + inputs: List[dict], + outputs: List[str], + model_type: str, + baseline_type: str = None, + optimize_type: str = None, + compile_type: str = None, + dataset: Dict[str, dict] = None, + tools: List[Tuple[str, Union[dict, str]]] = None, + dynamic: bool = False, + skip_config: Dict[str, str] = None, + **extra_config, +) -> dict: + """Create config for msc pipeline + + Parameters + ---------- + inputs: list + The inputs info, + outputs: list + The output names. + model_type: str + The model type. + baseline_type: str + The baseline type. + compile_type: str + The compile type. + optimize_type: str + The optimize type. + dataset: dict + The datasets for compile pipeline. + tools: list + The tools config. + dynamic: bool + Whether to config dyanmic mode. + skip_config: dict + The skip config for compile. + extra_config: dict + The extra config. + """ + + baseline_type = baseline_type or model_type + optimize_type = optimize_type or baseline_type + compile_type = compile_type or optimize_type + tools = tools or [] + tools = [config_tool(t_type, t_config) for t_type, t_config in tools] + # basic config + config = { + "model_type": model_type, + "dynamic": dynamic, + "inputs": inputs, + "outputs": outputs, + "dataset": dataset, + "tools": tools, + MSCStage.PREPARE: {"profile": {"benchmark": {"repeat": -1}}}, + MSCStage.BASELINE: { + "run_type": baseline_type, + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + }, + } + + # config optimize + opt_tools = [t for t in tools if support_tool(t, MSCStage.OPTIMIZE, optimize_type)] + if opt_tools: + config[MSCStage.OPTIMIZE] = { + "run_type": optimize_type, + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + } + + # config compile + config[MSCStage.COMPILE] = { + "run_type": compile_type, + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + } + + # update config + if extra_config: + config = msc_utils.update_dict(config, extra_config) + + # skip stages + skip_config = skip_config or {} + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in config: + continue + for key in ["all", stage]: + if key not in skip_config: + continue + if skip_config[key] == "stage": + config.pop(stage) + elif skip_config[key] == "profile": + config[stage].pop("profile") + elif skip_config[key] == "check": + config[stage]["profile"].pop("check") + elif skip_config[key] == "benchmark": + config[stage]["profile"].pop("benchmark") + else: + raise TypeError("Unexpected skip type " + str(skip_config[key])) + + return config diff --git a/python/tvm/contrib/msc/pipeline/worker.py b/python/tvm/contrib/msc/pipeline/worker.py new file mode 100644 index 000000000000..e22e52903f63 --- /dev/null +++ b/python/tvm/contrib/msc/pipeline/worker.py @@ -0,0 +1,786 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, unused-argument +"""tvm.contrib.msc.pipeline.worker""" + +import os +import time +import logging +from typing import Any, List, Tuple + +import tvm +from tvm.contrib.msc.core.runtime import BaseRunner +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core.utils.message import MSCStage +from tvm.contrib.msc.core import utils as msc_utils +from .utils import support_tool, get_tool_stage, map_tools + + +class BasePipeWorker(object): + """Base Worker of MSC pipeline + + Parameters + ---------- + model: Any + The raw model in framwork. + config: dict + The config for pipeline. + workspace: MSCDirectory + The workspace. + plugins: dict + The plugins for pipeline. + run_optimize: bool + Whether to run optimize. + run_compile: bool + Whether to run compile. + logger: logging.Logger + The logger. + name: str + The name of the worker. + """ + + def __init__( + self, + model: Any, + config: dict, + workspace: msc_utils.MSCDirectory, + plugins: dict = None, + logger: logging.Logger = None, + name: str = "main", + ): + # check/set default stage + for key in ["inputs", "outputs", "dataset"]: + assert key in config, "Missing {} in config".format(key) + + self._config = msc_utils.copy_dict(config) + self._workspace = workspace + self._plugins = plugins + self._model_type = config["model_type"] + self._optimize_type = config.get(MSCStage.OPTIMIZE, {}).get("run_type", self._model_type) + self._compile_type = config.get(MSCStage.COMPILE, {}).get("run_type", self._model_type) + runner_cls = self._get_runner_cls(self._model_type) + self._model, self._device, self._training = runner_cls.load_native(model, config) + self._verbose = config.get("verbose", "info") + self._logger = logger or msc_utils.get_global_logger() + self._name = name + self._optimized, self._compiled = False, False + self.setup() + + def setup(self) -> dict: + """Setup the manager + + Returns + ------- + config: dict + The updated config. + """ + + self._debug_levels = self.update_config() + self._tools_config = map_tools(self._config.get("tools", [])) + self._relax_mod, self._sample_inputs = None, None + self._runner = None + + def update_config(self) -> dict: + """Update config + + Returns + ------- + debug_levels: dict + The debug_levels. + """ + + debug_levels = {} + self._config = self._get_runner_cls(self._model_type).update_config( + MSCStage.PARSE, self._config, self._model + ) + + # update runner config + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in self._config: + continue + if "run_type" not in self._config[stage]: + self._config[stage]["run_type"] = self._model_type + runner_cls = self._get_runner_cls(self._config[stage]["run_type"]) + self._config = runner_cls.update_config(stage, self._config, self._model) + + # update tool config + if self._config.get("tools"): + self._config["tools"] = self._update_tools_config(self._config["tools"]) + + # update export config + self._config[MSCStage.EXPORT].update( + {"inputs": self._config["inputs"], "outputs": self._config["outputs"]} + ) + + def _set_debug_level(stage: str, sub_config: dict, default: int = None) -> dict: + if "debug_level" in sub_config: + debug_levels[stage] = sub_config["debug_level"] + elif default is not None: + debug_levels[stage] = default + sub_config["debug_level"] = default + return debug_levels + + if self._verbose.startswith("debug:"): + debug_level = int(self._verbose.split(":")[1]) + else: + debug_level = 0 + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in self._config: + continue + debug_levels = _set_debug_level(stage, self._config[stage]["run_config"], debug_level) + for t_config in self._config.get("tools", []): + if not support_tool(t_config, stage, self._config[stage]["run_type"]): + continue + t_stage = stage + "." + get_tool_stage(t_config["tool_type"]) + debug_levels = _set_debug_level(t_stage, t_config["tool_config"], debug_level) + ordered_keys = [ + "model_type", + "inputs", + "outputs", + "dataset", + "tools", + MSCStage.PREPARE, + MSCStage.PARSE, + MSCStage.BASELINE, + MSCStage.OPTIMIZE, + MSCStage.COMPILE, + MSCStage.EXPORT, + ] + self._config = {k: self._config[k] for k in ordered_keys if k in self._config} + return debug_levels + + def _update_tools_config(self, tools: List[dict]) -> List[dict]: + """Update tool in stage config. + + Parameters + ---------- + tools: list + The config of tools. + + Returns + ------- + tools: list + The updated config of tools. + """ + + for tool in tools: + tool_config = tool["tool_config"] + if "plan_file" not in tool_config: + tool_config["plan_file"] = "msc_{}.json".format(tool["tool_type"]) + tool_config["plan_file"] = msc_utils.to_abs_path( + tool_config["plan_file"], msc_utils.get_config_dir() + ) + return tools + + def prepare(self, data_loader: Any = None) -> Tuple[dict, dict]: + """Prepare datas for the pipeline. + + Parameters + ---------- + data_loader: + The data loader. + + Returns + ------- + info: dict + The info of prepare. + report: dict + The report of prepare. + """ + + stage_config = self._config[MSCStage.PREPARE] + use_cache = self._config.get("use_cache", True) + runner_cls = self._get_runner_cls(self._model_type) + run_func = runner_cls.run_native if hasattr(runner_cls, "run_native") else None + input_names = [i[0] for i in self._config["inputs"]] + + # create golden + if "golden" in self._config["dataset"]: + golden_folder = self._config["dataset"]["golden"]["loader"] + else: + golden_folder = msc_utils.get_dataset_dir().relpath("Golden", use_cache) + if msc_utils.is_io_dataset(golden_folder): + loader, source_type = msc_utils.IODataLoader(golden_folder), "cache" + self._sample_inputs = loader[0][0] + datas_info = loader.info + msg = "Load {} golden from {}".format(len(loader), golden_folder) + self._logger.debug(self.worker_mark(msg)) + elif run_func: + source_type = "native" + saver_options = {"input_names": input_names, "output_names": self._config["outputs"]} + cnt, max_golden = 0, self._config["dataset"][MSCStage.PREPARE].get("max_golden", 5) + with msc_utils.IODataSaver(golden_folder, saver_options) as saver: + for inputs in data_loader(): + if cnt >= max_golden > 0: + break + if not self._sample_inputs: + self._sample_inputs = { + k: msc_utils.cast_array(v) for k, v in inputs.items() + } + try: + outputs, _ = run_func( + self._model, inputs, input_names, self._config["outputs"] + ) + except Exception as exc: # pylint: disable=broad-exception-caught + if cnt == 0: + msg = "Failed to test native: {}".format(exc) + self._logger.warning(self.worker_mark(msg)) + outputs = None + cnt = saver.save_batch(inputs, outputs) + datas_info = saver.info + msg = "Save {} golden to {}".format(cnt, golden_folder) + self._logger.debug(self.worker_mark(msg)) + else: + raise Exception("golden_folder or runner should given to save golden") + self._config["dataset"]["golden"] = {"loader": golden_folder, "max_batch": -1} + + def _to_abstract(info: dict) -> dict: + def _to_tensor_str(info): + return "{},{}".format(";".join([str(s) for s in info["shape"]]), info["dtype"]) + + return { + "num_datas": info["num_datas"], + "inputs": {n: _to_tensor_str(i) for n, i in info["inputs"].items()}, + "outputs": {n: _to_tensor_str(o) for n, o in info["outputs"].items()}, + } + + info = { + "golden_folder({})".format(source_type): golden_folder, + "datas_info": _to_abstract(datas_info), + "smaple_inputs": self._sample_inputs, + } + + # profile + report = {} + if "profile" in stage_config and run_func: + benchmark = stage_config["profile"].get("benchmark", {}) + benchmark["repeat"] = self._get_repeat(benchmark) + try: + _, avg_time = run_func( + self._model, + self._sample_inputs, + input_names, + self._config["outputs"], + **benchmark, + ) + latency = "{:.2f} ms @ {}".format(avg_time, self._device) + info["latency"] = latency + " (X{})".format(benchmark["repeat"]) + report["profile"] = latency + except Exception as exc: # pylint: disable=broad-exception-caught + msg = "Failed to profile native: {}".format(exc) + self._logger.warning(self.worker_mark(msg)) + report["profile"] = "failed run native" + return info, report + + def parse(self) -> Tuple[dict, dict]: + """Parse the model to IRModule. + + Returns + ------- + info: dict + The info of parse. + report: dict + The report of parse. + """ + + stage_config = self._config[MSCStage.PARSE] + if self._config.get("use_cache", True): + cache_path = ( + msc_utils.get_cache_dir().create_dir(MSCStage.PARSE).relpath("parsed_relax.json") + ) + else: + cache_path = None + info = {} + if cache_path and os.path.isfile(cache_path): + with open(cache_path, "r") as f: + self._relax_mod = tvm.ir.load_json(f.read()) + info["cache"] = cache_path + else: + info = {"parser": stage_config["parser"], "config": stage_config.get("parse_config")} + parse_config = msc_utils.copy_dict(stage_config.get("parse_config", {})) + parse_config["as_msc"] = False + if self._model_type in self._plugins: + plugin = self._plugins[self._model_type] + parse_config["custom_convert_map"] = plugin.get_convert_map() + self._relax_mod, _ = stage_config["parser"](self._model, **parse_config) + transformed = set() + for stage in [MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in self._config: + continue + run_type = self._config[stage]["run_type"] + if run_type in transformed: + continue + transformed.add(run_type) + runner_cls = self._get_runner_cls(run_type) + if hasattr(runner_cls, "target_transform"): + msg = "Transform for {}({})".format(run_type, stage) + self._logger.info(self.worker_mark(msg)) + self._relax_mod = runner_cls.target_transform(self._relax_mod) + if cache_path: + with open(cache_path, "w") as f: + f.write(tvm.ir.save_json(self._relax_mod)) + msg = "Save parsed mod to " + cache_path + self._logger.debug(self.worker_mark(msg)) + return info, {} + + def get_tool_config(self, tool_type: str, key: str = "tool_config", default: Any = None) -> Any: + """Get the tool config + + Parameters + ---------- + tool_type: str + The tool type. + key: str + The config key + + Returns + ------- + config: + The tool config or info. + """ + + assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) + return self._tools_config[tool_type].get(key, default) + + def tool_applied(self, tool_type: str) -> bool: + """Check if the tool is applied + + Parameters + ---------- + tool_type: str + The tool type. + + Returns + ------- + applied: bool + Whether the tool is applied. + """ + + config = self.get_tool_config(tool_type) + return os.path.isfile(config["plan_file"]) + + def apply_tool( + self, tool_type: str, knowledge: dict = None, data_loader: Any = None + ) -> Tuple[dict, dict]: + """Apply tool with runner + + Parameters + ---------- + tool_type: str + The tool type to apply. + knowledge: dict + The pre knowledge. + data_loader: + The data loader. + + Returns + ------- + info: dict + The info of apply tool. + report: dict + The report of apply tool. + """ + + plan_file = self.get_tool_config(tool_type)["plan_file"] + if knowledge: + self._logger.info("Plan by %d knowledge for %s", len(knowledge), tool_type) + msc_utils.save_dict(knowledge, plan_file) + else: + self._runner.make_plan(tool_type, data_loader) + if self.get_tool_config(tool_type, "visualize", False): + self._runner.visualize( + msc_utils.get_visual_dir().create_dir(self._runner.stage.split(".")[0]) + ) + report = {} + if os.path.isfile(plan_file): + report["plan_num"] = len(msc_utils.load_dict(plan_file)) + return {}, report + + def create_runner( + self, + stage: str, + tools: List[str] = None, + run_type: str = None, + run_config: dict = None, + visualize: bool = True, + profile: bool = True, + use_cache: bool = True, + ) -> Tuple[dict, dict]: + """Create runner. + + Parameters + ---------- + stage: str + The stage name + tools: list + The tools to apply. + run_type: str + The type of runner. + run_config: dict + The config of runner. + visualize: bool + Whether to visualize the runner + profile: bool + Whether to profile the runner. + use_cache: bool + Whether to use cache. + + Returns + ------- + info: dict + The info of create runner. + report: dict + The report of create runner. + """ + + if self._runner: + self._runner.destory() + tools = tools or [] + assert all(t in self._tools_config for t in tools), "Missing some tools " + str(tools) + main_stage = stage.split(".")[0] + if not run_type: + run_type = self._config[main_stage]["run_type"] + if not run_config: + run_config = self._config[main_stage].get("run_config", {}) + runner_cls = self._get_runner_cls(run_type) + if "generate_config" not in run_config: + run_config["generate_config"] = {} + cleanup = self._debug_levels.get(stage, 0) == 0 + run_config["generate_config"]["build_folder"] = msc_utils.get_build_dir().create_dir( + stage, cleanup=cleanup + ) + if "device" not in run_config: + run_config["device"] = self._device + if "training" not in run_config: + run_config["training"] = self._training + # Build runner + runner = runner_cls( + self._relax_mod, + tools_config=[self._tools_config[t] for t in tools], + plugin=self._plugins.get(run_type), + stage=stage, + name=self._name, + logger=self._logger, + **run_config, + ) + cache_dir = msc_utils.get_cache_dir().create_dir(stage) if use_cache else None + runner.build(cache_dir=cache_dir) + if visualize: + runner.visualize(msc_utils.get_visual_dir().create_dir(main_stage)) + if use_cache: + runner.save_cache(cache_dir) + info, report = {}, {"runtime": "{} @ {}".format(runner.framework, runner.device)} + if profile and "profile" in self._config[main_stage]: + profile_config = self._config[main_stage]["profile"] + info["profile"], report["profile"] = self._profile_runner(runner, profile_config) + self._runner = runner + return info, report + + def _profile_runner(self, runner: BaseRunner, profile_config: dict) -> Tuple[dict, str]: + """Profile the runner. + + Parameters + ---------- + runner: BaseRunner + The runner to be profiled + profile_config: dict + The config of profile. + + Returns + ------- + info: dict + The info of profile. + report: str + The report of profile. + """ + + stage = runner.stage + info, report = {}, "" + + # check accuracy + check_config = profile_config.get("check", {}) + if check_config: + loader = msc_utils.IODataLoader(self._config["dataset"]["golden"]["loader"]) + acc_info = {"passed": ""} + total, passed = 0, 0 + for idx, (inputs, outputs) in enumerate(loader): + results = runner.run(inputs) + if outputs: + iter_info = msc_utils.compare_arrays( + outputs, + results, + atol=check_config.get("atol", 1e-2), + rtol=check_config.get("rtol", 1e-2), + report_detail=runner.debug_level >= 2, + ) + else: + iter_info = { + "total": len(results), + "passed": len(results), + "info": {k: msc_utils.MSCArray(v).abstract() for k, v in results.items()}, + } + total += iter_info["total"] + passed += iter_info["passed"] + acc_info["iter_" + str(idx)] = iter_info["info"] + pass_rate = float(passed) / total + accuracy = "{}/{}({:.2f}%)".format(passed, total, pass_rate * 100) + acc_info["passed"] = "{} {}".format(accuracy, check_config) + info["accuracy"] = acc_info if runner.debug_level >= 1 else accuracy + report = "pass " + accuracy + if runner.get_tool(ToolType.PRUNER) or runner.get_tool(ToolType.QUANTIZER): + disable_msg = "Disable accuracy check({}) by tools".format(stage) + self._logger.debug(self.worker_mark(disable_msg)) + else: + required_err, err_rate = check_config.get("err_rate", 0), (1 - pass_rate) + if err_rate > required_err >= 0: + self._logger.error(msc_utils.msg_block(self.worker_mark("ACCURACY"), acc_info)) + raise Exception( + "Failed to profile the runner({}), err_rate {} > required {}".format( + stage, err_rate, required_err + ) + ) + + # benchmark model + benchmark_config = profile_config.get("benchmark", {}) + if benchmark_config: + for _ in range(benchmark_config.get("warm_up", 10)): + runner.run(self._sample_inputs) + start = time.time() + repeat = self._get_repeat(benchmark_config, runner.device) + for _ in range(repeat): + runner.run(self._sample_inputs) + avg_time = (time.time() - start) * 1000 / repeat + latency = "{:.2f} ms @ {}".format(avg_time, runner.device) + info["latency"] = latency + " (X{})".format(repeat) + report += (", " if report else "") + latency + return info, report + + def export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: + """Export the model + + Parameters + ---------- + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. + + Returns + ------- + exported: + The exported model. + """ + + if stage == MSCStage.COMPILE: + if not dump: + return self._runner.runnable + return self._runner.export_runnable(folder) + + if stage == MSCStage.OPTIMIZE: + module = self._runner.export_module(folder) + if not dump: + return module + path = folder.relpath("model.json") + with open(path, "w") as f: + f.write(tvm.ir.save_json(module)) + return path + + if not dump: + return self._model + dump_func = self._get_runner_cls(self._model_type).dump_nativate + return dump_func(self._model, folder, self._config[MSCStage.EXPORT]) + + def export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the tool + + Parameters + ---------- + tool_type: str + The tool type. + folder: MSCDirectory + The export folder. + + Returns + ------- + config: dict + The exported tool config. + """ + + run_tool = self._runner.get_tool(tool_type) + assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) + return run_tool.export_config(self._tools_config[tool_type]["tool_config"], folder) + + def export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: + """Export the info of worker + + Parameters + ---------- + stage: str + The pipeline stage. + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The info. + """ + + return { + "visualize": msc_utils.get_visual_dir().copy_to(folder.relpath("visualize")), + "graphs": self._runner.export_graphs(folder.create_dir("graphs")), + } + + def get_runnable(self, ret_type: str = "runner") -> Any: + """Return object by type. + + Parameters + ---------- + ret_type: str + The return type runner| runnable| model. + + Returns + ------- + runnable: + The runner or model. + """ + + assert self._runner, "Failed to create runner, call run_pipe first" + if ret_type == "runner": + return self._runner + if ret_type == "runnable": + return self._runner.runnable + if ret_type == "model": + return self._runner.model + raise TypeError("Unexpect return type " + str(ret_type)) + + def _get_repeat(self, benchmark: dict, device: str = None) -> int: + """Get the repeat number for benchmark + + Parameters + ---------- + benchmark: dict + The benchmark config. + device: str + The device name + + Returns + ------- + repeat: int + The repeat number. + """ + + device = device or self._device + repeat = benchmark.get("repeat", -1) + if repeat == -1: + repeat = 500 if device.startswith("cuda") else 10 + return repeat + + def _get_runner_cls(self, run_type: str) -> BaseRunner: + """Get the runner cls by type + + Parameters + ---------- + run_type: str + The run type. + + Returns + ------- + runner_cls: class + The runner class. + """ + + raise NotImplementedError("_get_runner_cls is not implemented in " + str(self.__class__)) + + def destory(self): + """Destroy the worker""" + + if self._runner: + self._runner.destory() + + def worker_mark(self, msg: Any) -> str: + """Mark the message with worker info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "WORKER[{}] {}".format(self._name, msg) + + @property + def runner(self): + return self._runner + + @property + def model_type(self): + return self._model_type + + @property + def optimize_type(self): + return self._optimize_type + + @property + def compile_type(self): + return self._compile_type + + +class MSCPipeWorker(BasePipeWorker): + """Normal manager in MSC""" + + def _get_runner_cls(self, run_type: str) -> BaseRunner: + """Get the runner cls by type + + Parameters + ---------- + run_type: str + The run type. + + Returns + ------- + runner_cls: class + The runner class. + """ + + if run_type == MSCFramework.TVM: + from tvm.contrib.msc.framework.tvm.runtime import TVMRunner + + runner_cls = TVMRunner + elif run_type == MSCFramework.TORCH: + from tvm.contrib.msc.framework.torch.runtime import TorchRunner + + runner_cls = TorchRunner + elif run_type == MSCFramework.TENSORFLOW: + from tvm.contrib.msc.framework.tensorflow.runtime import TensorflowRunner + + runner_cls = TensorflowRunner + elif run_type == MSCFramework.TENSORRT: + from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner + + runner_cls = TensorRTRunner + else: + raise Exception("Unexpect run_type " + str(run_type)) + return runner_cls diff --git a/python/tvm/contrib/msc/pipeline/wrapper.py b/python/tvm/contrib/msc/pipeline/wrapper.py index 2b69034cab70..1332b3c79115 100644 --- a/python/tvm/contrib/msc/pipeline/wrapper.py +++ b/python/tvm/contrib/msc/pipeline/wrapper.py @@ -19,12 +19,12 @@ import shutil from typing import Any, Union, List -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils from .manager import MSCManager -from .config import create_config +from .dynamic import MSCDynamic, TorchDynamic +from .utils import create_config class BaseWrapper(object): @@ -41,22 +41,19 @@ class BaseWrapper(object): """ def __init__( - self, - model: Any, - config: dict, - workspace: str = "msc_workspace", - plugins: dict = None, + self, model: Any, config: dict, workspace: str = "msc_workspace", plugins: dict = None ): self._meta_model = model self._optimized_model, self._compiled_model = None, None self._config = config self._plugins = plugins + self._dynamic = self._config.get("dynamic", False) verbose = config.get("verbose", "info") self._debug = verbose.startswith("debug") self._workspace = msc_utils.msc_dir(workspace, keep_history=self._debug) log_path = self._workspace.relpath("MSC_LOG", keep_history=False) self._config["logger"] = msc_utils.create_file_logger(verbose, log_path) - self._manager = None + self._pipeline, self._report = None, None self.setup() def __str__(self): @@ -87,18 +84,18 @@ def optimize(self, workspace: str = "Optimize"): The workspace. """ - self.logger.info("[Wrapper] Start optimize model") + self.logger.info(msc_utils.split_line("Start optimize model", "*")) config = msc_utils.copy_dict(self._config) config["workspace"] = self._workspace.create_dir(workspace) if MSCStage.OPTIMIZE not in config: - config[MSCStage.OPTIMIZE] = { - "run_type": self.model_type(), - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - } - self._manager = MSCManager(self._meta_model, config, self._plugins, run_compile=False) - report = self._manager.run_pipe() - if report["success"]: - self._optimized_model = self._manager.get_runnable("runnable") + config[MSCStage.OPTIMIZE] = {"run_type": self.model_type()} + profile = config.get(MSCStage.BASELINE, {}).get("profile") + if profile: + config[MSCStage.OPTIMIZE]["profile"] = profile + self._pipeline = self.pipe_cls(self._meta_model, config, self._plugins, run_compile=False) + self._report = self._pipeline.run_pipe() + if self._report["success"]: + self._optimized_model = self._pipeline.get_runtime("runnable") return self def compile( @@ -117,27 +114,31 @@ def compile( """ if self._optimized_model: - self.logger.info("[Wrapper] Start compile checkpoint") + self.logger.info(msc_utils.split_line("Start compile checkpoint", "*")) ckpt_path = self._workspace.create_dir(ckpt_path).path - pipeline = self.export(ckpt_path, dump=dump) - pipeline["config"]["workspace"] = self._workspace.create_dir(workspace) - self._manager = MSCManager(**pipeline) - report = self._manager.run_pipe() - if report["success"]: - self._compiled_model = self._manager.get_runnable("runnable") + export = self.export(ckpt_path, dump=dump, keep_workspace=True) + export["config"]["workspace"] = self._workspace.create_dir(workspace) + self._pipeline = self.pipe_cls( + export["model"], export["config"], export["plugins"], root=ckpt_path + ) + self._report = self._pipeline.run_pipe() + if self._report["success"]: + self._compiled_model = self._pipeline.get_runtime("runnable") if not self._debug: shutil.rmtree(ckpt_path) else: - self.logger.info("[Wrapper] Start compile model") + self.logger.info(msc_utils.split_line("Start compile model", "*")) config = msc_utils.copy_dict(self._config) config["workspace"] = self._workspace.create_dir(workspace) - self._manager = MSCManager(self._meta_model, config, self._plugins) - report = self._manager.run_pipe() - if report["success"]: - self._compiled_model = self._manager.get_runnable("runnable") + self._pipeline = self.pipe_cls(self._meta_model, config, self._plugins) + self._report = self._pipeline.run_pipe() + if self._report["success"]: + self._compiled_model = self._pipeline.get_runtime("runnable") return self - def export(self, path: str = "msc_export", dump: bool = True) -> Union[str, dict]: + def export( + self, path: str = "msc_export", dump: bool = True, keep_workspace: bool = False + ) -> Union[str, dict]: """Export compile pipeline Parameters @@ -146,6 +147,8 @@ def export(self, path: str = "msc_export", dump: bool = True) -> Union[str, dict The export path. dump: bool Whether to dump the info. + keep_workspace: bool + Whether to keep workspace. Returns ------- @@ -153,66 +156,26 @@ def export(self, path: str = "msc_export", dump: bool = True) -> Union[str, dict The exported path/pipeline info. """ - if not self._manager: - self._manager = MSCManager(self._meta_model, self._config, self._plugins) - exported = self._manager.export(path, dump=dump) + if not self._pipeline: + self._pipeline = self.pipe_cls(self._meta_model, self._config, self._plugins) + exported = self._pipeline.export(path, dump=dump) if not self._debug: - self._manager.destory() + self._pipeline.destory() + if not keep_workspace: + self._workspace.destory() return exported - def get_tools(self, tool_types: List[str]) -> List[BaseTool]: - """Get the tools from manager - - Parameters - ---------- - tool_types: list - The tool types. - - Returns - ------- - tools: list - The tools. - """ - - if not self._manager: - return [] - tool_types = tool_types or ToolType.all_types() - tools = [] - for t in tool_types: - tool = self._manager.runner.get_tool(t) - if tool: - tools.append(tool) - return tools - - def disable_tools(self, tool_types: List[str]): - """Disable the tools - - Parameters - ---------- - tool_types: list - The tool types. - """ - - for tool in self.get_tools(tool_types): - tool.disable() - - def enable_tools(self, tool_types: List[str]): - """Enable the tools - - Parameters - ---------- - tool_types: list - The tool types. - """ - - for tool in self.get_tools(tool_types): - tool.enable() - def _get_model(self) -> Any: return self._compiled_model or self._optimized_model or self._meta_model def _get_framework(self) -> str: - return self._manager.runner.framework if self._manager else self.model_type() + return self._pipeline.get_runtime().framework if self._pipeline else self.model_type() + + @property + def pipe_cls(self): + if self._dynamic: + return MSCDynamic + return MSCManager @property def optimized(self): @@ -224,14 +187,18 @@ def compiled(self): @property def device(self): - if self._manager: - return self._manager.runner.device + if self._pipeline: + return self._pipeline.get_runtime().device return "cpu" @property def logger(self): return self._config["logger"] + @property + def report(self): + return self._report + @classmethod def create_config( cls, @@ -252,10 +219,10 @@ def create_config( The output names. baseline_type: str The baseline type. - compile_type: str - The compile type. optimize_type: str The optimize type. + compile_type: str + The compile type. kwargs: dict The config kwargs. """ @@ -281,28 +248,34 @@ def __call__(self, *inputs): return outputs if isinstance(outputs, (tuple, list)): return [msc_utils.cast_array(o, MSCFramework.TORCH, self.device) for o in outputs] - return msc_utils.cast_array(outputs, MSCFramework.TORCH) + return msc_utils.cast_array(outputs, MSCFramework.TORCH, self.device) def parameters(self): framework = self._get_framework() if framework == MSCFramework.TORCH: return self._get_model().parameters() - return self._manager.runner.get_weights(MSCFramework.TORCH) + return self._pipeline.get_runtime().get_weights(MSCFramework.TORCH) def train(self): - if self._manager: - self._manager.runner.train() + if self._pipeline: + self._pipeline.get_runtime().train() if self._get_framework() == MSCFramework.TORCH: return self._get_model().train() return self._get_model() def eval(self): - if self._manager: - self._manager.runner.eval() + if self._pipeline: + self._pipeline.get_runtime().eval() if self._get_framework() == MSCFramework.TORCH: return self._get_model().eval() return self._get_model() + @property + def pipe_cls(self): + if self._dynamic: + return TorchDynamic + return MSCManager + @classmethod def model_type(cls): return MSCFramework.TORCH diff --git a/tests/python/contrib/test_msc/test_manager.py b/tests/python/contrib/test_msc/test_pipeline.py similarity index 70% rename from tests/python/contrib/test_msc/test_manager.py rename to tests/python/contrib/test_msc/test_pipeline.py index bcd12b36b5a3..c7a26bf96efb 100644 --- a/tests/python/contrib/test_msc/test_manager.py +++ b/tests/python/contrib/test_msc/test_pipeline.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. -""" Test Managers in MSC. """ +""" Test Pipeline in MSC. """ import json import pytest import torch import tvm.testing -from tvm.contrib.msc.pipeline import MSCManager +from tvm.contrib.msc.pipeline import MSCManager, TorchDynamic from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils @@ -32,13 +32,13 @@ ) -def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1): +def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1): """Get msc config""" - path = "test_manager_{}_{}".format(model_type, compile_type) + path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static") return { "workspace": msc_utils.msc_dir(path), - "verbose": "critical", + "verbose": "info", "model_type": model_type, "inputs": inputs, "outputs": outputs, @@ -95,23 +95,29 @@ def _get_tf_graph(): return None -def _check_manager(manager, expected_info): - """Check the manager results""" +def _check_pipeline(pipeline, expected_info, dynamic=False): + """Check the pipeline results""" - model_info = manager.runner.model_info passed, err = True, "" - if not manager.report["success"]: + if not pipeline.report["success"]: passed = False - err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type) - if not msc_utils.dict_equal(model_info, expected_info): - passed = False - err = "Model info {} mismatch with expected {}".format(model_info, expected_info) - manager.destory() + err = "Failed to run pipe for {} -> {}".format(pipeline.model_type, pipeline.compile_type) + if not dynamic: + model_info = pipeline.get_runtime().model_info + if not msc_utils.dict_equal(model_info, expected_info): + passed = False + err = "Model info {} mismatch with expected {}".format(model_info, expected_info) + pipeline.destory() if not passed: - raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2))) + raise Exception("{}\nReport:{}".format(err, json.dumps(pipeline.report, indent=2))) + +def _test_from_torch( + compile_type, expected_info, training=False, dynamic=False, atol=1e-1, rtol=1e-1 +): + if dynamic and not hasattr(torch, "compile"): + return -def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rtol=1e-1): torch_model = _get_torch_model("resnet50", training) if torch_model: if torch.cuda.is_available(): @@ -121,12 +127,13 @@ def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rto compile_type, inputs=[["input_0", [1, 3, 224, 224], "float32"]], outputs=["output"], + dynamic=dynamic, atol=atol, rtol=rtol, ) - manager = MSCManager(torch_model, config) - manager.run_pipe() - _check_manager(manager, expected_info) + pipeline = TorchDynamic(torch_model, config) if dynamic else MSCManager(torch_model, config) + pipeline.run_pipe() + _check_pipeline(pipeline, expected_info, dynamic) def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2): @@ -143,11 +150,12 @@ def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2): config["compile"]["profile"]["check"]["err_rate"] = -1 manager = MSCManager(graphdef, config) manager.run_pipe() - _check_manager(manager, expected_info) + _check_pipeline(manager, expected_info) -def test_tvm_manager(): - """Test manager for tvm""" +@pytest.mark.parametrize("dynamic", [False, True]) +def test_tvm_pipeline(dynamic): + """Test pipeline for tvm""" model_info = { "inputs": [ @@ -168,40 +176,42 @@ def test_tvm_manager(): "msc.linear_bias": 1, }, } - _test_from_torch(MSCFramework.TVM, model_info, training=False) - - model_info = { - "inputs": [ - {"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"} - ], - "outputs": [ - { - "name": "MobilenetV2/Predictions/Reshape_1:0", - "shape": [1, 1001], - "dtype": "float32", - "layout": "NC", - } - ], - "nodes": { - "total": 138, - "input": 1, - "msc.conv2d_bias": 36, - "clip": 35, - "nn.conv2d": 17, - "nn.batch_norm": 17, - "get_item": 17, - "add": 10, - "nn.avg_pool2d": 1, - "squeeze": 1, - "reshape": 2, - "nn.softmax": 1, - }, - } - _test_from_tf(MSCFramework.TVM, model_info) - - -def test_torch_manager(): - """Test manager for torch""" + _test_from_torch(MSCFramework.TVM, model_info, training=False, dynamic=dynamic) + + if not dynamic: + model_info = { + "inputs": [ + {"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"} + ], + "outputs": [ + { + "name": "MobilenetV2/Predictions/Reshape_1:0", + "shape": [1, 1001], + "dtype": "float32", + "layout": "NC", + } + ], + "nodes": { + "total": 138, + "input": 1, + "msc.conv2d_bias": 36, + "clip": 35, + "nn.conv2d": 17, + "nn.batch_norm": 17, + "get_item": 17, + "add": 10, + "nn.avg_pool2d": 1, + "squeeze": 1, + "reshape": 2, + "nn.softmax": 1, + }, + } + _test_from_tf(MSCFramework.TVM, model_info) + + +@pytest.mark.parametrize("dynamic", [False, True]) +def test_torch_pipeline(dynamic): + """Test pipeline for torch""" model_info = { "inputs": [ @@ -222,10 +232,10 @@ def test_torch_manager(): "msc.linear_bias": 1, }, } - _test_from_torch(MSCFramework.TORCH, model_info, training=False) + _test_from_torch(MSCFramework.TORCH, model_info, training=False, dynamic=dynamic) -def test_tensorflow_manager(): +def test_tensorflow_pipeline(): """Test manager for tensorflow""" model_info = { @@ -259,8 +269,9 @@ def test_tensorflow_manager(): @requires_tensorrt -def test_tensorrt_manager(): - """Test manager for tensorrt""" +@pytest.mark.parametrize("dynamic", [False, True]) +def test_tensorrt_pipeline(dynamic): + """Test pipeline for tensorrt""" model_info = { "inputs": [ @@ -269,7 +280,7 @@ def test_tensorrt_manager(): "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, } - _test_from_torch(MSCFramework.TENSORRT, model_info, training=False) + _test_from_torch(MSCFramework.TENSORRT, model_info, training=False, dynamic=dynamic) if __name__ == "__main__": diff --git a/tests/python/contrib/test_msc/test_plugin.py b/tests/python/contrib/test_msc/test_plugin.py index e2d3b5fcd3d3..81adc2ab4ceb 100644 --- a/tests/python/contrib/test_msc/test_plugin.py +++ b/tests/python/contrib/test_msc/test_plugin.py @@ -313,7 +313,7 @@ def _test_with_manager(plugins, compile_type, expected_info): } manager = MSCManager(model, config, plugins=plugins) report = manager.run_pipe() - model_info = manager.runner.model_info + model_info = manager.get_runtime().model_info manager.destory() assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) assert msc_utils.dict_equal( diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index 3c88c8706a80..55fc9dd43e4f 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -100,7 +100,7 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): golden = [msc_utils.cast_array(golden)] workspace.destory() for gol_r, out_r in zip(golden, outputs): - tvm.testing.assert_allclose(gol_r, out_r, atol=atol, rtol=rtol) + tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol) def test_tvm_runner_cpu(): @@ -162,7 +162,7 @@ def test_tensorflow_runner(): outputs = runner.run([data], ret_type="list") workspace.destory() for gol_r, out_r in zip(golden, outputs): - tvm.testing.assert_allclose(gol_r, out_r, atol=1e-3, rtol=1e-3) + tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=1e-3, rtol=1e-3) if __name__ == "__main__": diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 3a56b255efdb..22354bb2c131 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -144,7 +144,7 @@ def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC): } ], } - tools.append({"tool_type": ToolType.TRACKER, "tool_config": config, "apply_once": True}) + tools.append({"tool_type": ToolType.TRACKER, "tool_config": config}) if use_distill: config = { "plan_file": "msc_distiller.json", @@ -180,7 +180,7 @@ def _get_torch_model(name, training=False): def _check_manager(manager, expected_info): """Check the manager results""" - model_info = manager.runner.model_info + model_info = manager.get_runtime().model_info passed, err = True, "" if not manager.report["success"]: passed = False