Skip to content

Commit

Permalink
[Unity][MSC][M4.2][Step1] Enable plugin with manager, test plugins in…
Browse files Browse the repository at this point in the history
… compile pipeline (#16495)

* add plugin in manager

* remove wrapper
  • Loading branch information
Archermmt authored Feb 5, 2024
1 parent 5ebdd49 commit 343435a
Show file tree
Hide file tree
Showing 19 changed files with 374 additions and 74 deletions.
9 changes: 9 additions & 0 deletions python/tvm/contrib/msc/core/runtime/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class BaseRunner(object):
Whether compile model to trainable
stage: str
The stage of runner.
plugin: PluginManager
The plugin manager.
name: str
The name of the runner
debug_level: int
Expand All @@ -75,6 +77,7 @@ def __init__(
device: str = "cpu",
training: bool = False,
stage: str = "default",
plugin: Any = None,
name: str = "main",
debug_level: int = 0,
logger: logging.Logger = None,
Expand All @@ -86,6 +89,7 @@ def __init__(
self._build_config = msc_utils.copy_dict(build_config)
self._device = device if self._device_enabled(device) else "cpu"
self._stage = stage
self._plugin = plugin
self._name = name
self._debug_level = debug_level
self._training, self._trained = training, training
Expand Down Expand Up @@ -123,8 +127,11 @@ def setup(self) -> dict:
stage=self._stage,
**config,
)
if self._plugin:
self._update_codegen({"use_plugin": True})
return {
"tools": {k: v.tool_style() for k, v in self._tools.items()},
"plugin": self._plugin,
"translate_config": self._translate_config,
"generate_config": self._generate_config,
"build_config": self._build_config,
Expand Down Expand Up @@ -1069,6 +1076,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra
codegen_config=self._generate_config.get("codegen"),
print_config=self._generate_config.get("print"),
build_folder=self._generate_config["build_folder"],
plugin=self._plugin,
)

def _inspect_model(self) -> dict:
Expand Down Expand Up @@ -1226,6 +1234,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra
extra_options=extra_option,
build_folder=self._generate_config["build_folder"],
output_folder=self._generate_config.get("output_folder", msc_utils.get_output_dir()),
plugin=self._plugin,
)

def _build_runnable(self, model: Any) -> Any:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/contrib/msc/core/tools/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]:
def _update_stages(strategy):
if "stages" not in strategy:
strategy["stages"] = [msc_utils.MSCStage.PRUNE]
strategy["tensor_types"] = ["weight", "output"]
return strategy

return super()._parse_strategys([_update_stages(s) for s in strategy_list])
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/contrib/msc/core/tools/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def _check_tensor(self, name: str, consumer: str) -> bool:
Whether to process the tensor.
"""

if self._calibrated:
tensor_id = self.to_tensor_id(name, consumer)
if tensor_id not in self._plan:
return False
return self._plan.get(tensor_id, {}).get("nbits", 8) != -1
strategys = self._get_tensor_strategys(name, consumer)
if not strategys:
return False
Expand Down
47 changes: 26 additions & 21 deletions python/tvm/contrib/msc/core/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]
tensor_names = strategy.pop("tensor_names")
marks = [(n, "tensor") for n in tensor_names]
else:
marks = [("default", t) for t in ["input", "output", "weight"]]
marks = [("default." + str(t), t) for t in tensor_types]
stages = strategy.pop("stages") if "stages" in strategy else ["default"]
for mark, t_type in marks:
if mark not in strategys:
Expand Down Expand Up @@ -1212,33 +1212,38 @@ def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]:

tensor_id = self.to_tensor_id(name, consumer)
mark = "strategy.{}".format(self._stage)

def _check_strategy(s_ref):
return s_ref in self._strategys and self._strategys[s_ref].support_stage(self._stage)

if mark not in self._tensor_cache.get(tensor_id, {}):
if self.is_weight(name):
strategys = []
tensor_strategy = self._strategys.get(tensor_id)
if tensor_strategy and tensor_strategy.support_stage(self._stage):
strategys.append(tensor_strategy)
elif self.is_weight(name):
consumer = self.find_node(consumer)
name_refs = [consumer.name + ".weight", consumer.optype + ".weight"]
for ref in [consumer.name, consumer.optype, "default"]:
if _check_strategy(ref + ".weight"):
strategys.append(self._strategys[ref + ".weight"])
break
elif consumer == "exit":
producer = self.find_producer(name)
name_refs = [producer.name + ".output", producer.optype + ".output"]
for ref in [producer.name, producer.optype, "exit", "default"]:
if _check_strategy(ref + ".output"):
strategys.append(self._strategys[ref + ".output"])
break
else:
consumer = self.find_node(consumer)
for ref in [consumer.name, consumer.optype, "default"]:
if _check_strategy(ref + ".input"):
strategys.append(self._strategys[ref + ".input"])
break
producer = self.find_producer(name)
name_refs = [
producer.name + ".output",
producer.optype + ".output",
consumer.name + ".input",
consumer.optype + ".input",
]
strategys = []
tensor_strategy = self._strategys.get(tensor_id)
if tensor_strategy and tensor_strategy.support_stage(self._stage):
strategys.append(tensor_strategy)
if not strategys:
for n in name_refs:
if n in self._strategys and self._strategys[n].support_stage(self._stage):
strategys.append(self._strategys[n])
d_strategy = self._strategys.get("default")
if not strategys and d_strategy and d_strategy.support_stage(self._stage):
strategys.append(d_strategy)
for ref in [producer.name, producer.optype, "default"]:
if _check_strategy(ref + ".output"):
strategys.append(self._strategys[ref + ".output"])
break
self._save_tensor_cache(name, consumer, mark, strategys)
return self._get_tensor_cache(name, consumer, mark)

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.tensorflow.codegen.codegen"""

from typing import Dict, Optional
from typing import Dict, Optional, Any

import tvm
from tvm.contrib.msc.core.ir import MSCGraph
Expand All @@ -32,6 +32,7 @@ def to_tensorflow(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> tf_v1.Graph:
"""Change MSCGraph to tensorflow graph.
Expand All @@ -47,6 +48,8 @@ def to_tensorflow(
The config for print.
build_folder: MSCDirectory
The folder for saving scripts and datas.
plugin: PluginManager
The plugin manager.
Returns
-------
Expand All @@ -63,4 +66,7 @@ def _save_weights(folder: msc_utils.MSCDirectory):
codegen = CodeGen(
graph, _ffi_api.GetTensorflowSources, codegen_config, print_config, build_folder
)
return codegen.load(inputs + [weights], pre_load=_save_weights)
model_args = inputs + [weights]
if plugin:
model_args = model_args + [plugin]
return codegen.load(model_args, pre_load=_save_weights)
17 changes: 16 additions & 1 deletion python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os
import subprocess
from typing import Dict, Optional, List, Union
from typing import Dict, Optional, List, Union, Any
import numpy as np

import tvm
Expand All @@ -38,6 +38,7 @@ def to_sub_tensorrt(
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
output_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> str:
"""Change MSCGraph to TensorRT engine file.
Expand All @@ -55,6 +56,8 @@ def to_sub_tensorrt(
The folder for saving sources and datas.
export_folder: MSCDirectory
The folder for saving outputs.
plugin: PluginManager
The plugin manager.
Returns
-------
Expand Down Expand Up @@ -90,6 +93,10 @@ def _create_depends(folder: msc_utils.MSCDirectory) -> str:
f.write("{}\n".format(len(engine_wts)))
for name, data in engine_wts.items():
write_weight(name, msc_utils.cast_array(data), f)
# copy plugin
if plugin:
plugin.copy_libs("plugin_lib")
plugin.copy_includes("plugin")
# save utils sources
with folder.create_dir("utils") as utils_folder:
for name, source in get_trt_sources().items():
Expand All @@ -115,6 +122,10 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str:

with build_folder as folder:
sub_folder = folder.create_dir(graph.name)
if plugin:
codegen_config["extern_libs"] = [
sub_folder.create_dir("plugin_lib").relpath(f) for f in plugin.list_libs()
]
codegen = CodeGen(
graph,
_ffi_api.GetTensorRTSources,
Expand All @@ -140,6 +151,7 @@ def to_tensorrt(
extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,
build_folder: msc_utils.MSCDirectory = None,
output_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> Dict[str, str]:
"""Change all MSCGraphs to TensorRT engine files.
Expand All @@ -161,6 +173,8 @@ def to_tensorrt(
The folder for saving sources and datas.
export_folder: MSCDirectory
The folder for saving outputs.
plugin: PluginManager
The plugin manager.
Returns
-------
Expand All @@ -183,6 +197,7 @@ def to_tensorrt(
print_configs[idx],
build_folder,
output_folder,
plugin=plugin,
)
if extra_options[idx]:
options.update(extra_options[idx])
Expand Down
48 changes: 48 additions & 0 deletions python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from typing import Mapping, Tuple, List, Union, Callable, Dict
from functools import wraps, partial

import tvm
from tvm import relax
from tvm.relax.dpl import pattern
from tvm.relax.transform import PatternCheckContext, FusionPattern
from tvm.relax.backend.pattern_registry import register_patterns
from tvm.contrib.msc.core.transform import pattern as msc_pattern
from tvm.contrib.msc.core import _ffi_api


def basic_pattern(
Expand Down Expand Up @@ -234,6 +236,43 @@ def _take_check(context: PatternCheckContext) -> bool:
return _check_expr(context.annotated_expr["input_1"], ("int32"))


def _plugin_check(context: PatternCheckContext) -> bool:
"""Check if the plugin pattern is correct.
Returns
-------
pass: bool
Whether the pattern is correct.
"""

ext_func = context.annotated_expr["out"].args[0]
return bool(_ffi_api.IsPlugin(ext_func.global_symbol))


def plugin_attrs_getter(
annotated_expr: Dict[str, tvm.relax.Expr],
) -> Dict[str, str]:
"""Get attributes for plugin pattern
Parameters
----------
annotated_expr: dict<str,Expr>
The annotated exprs during fus pattern
anchor: str
The anchor key of expr
Returns
-------
attrs: dict<str,str>
The extra attributes for msc.
"""

attrs = msc_pattern.msc_attrs_getter(annotated_expr, anchor="out")
ext_func = annotated_expr["out"].args[0]
attrs[_ffi_api.ToAttrKey("optype")] = ext_func.global_symbol
return attrs


def wrap_basic_check(
func: Callable[[PatternCheckContext], bool]
) -> Callable[[PatternCheckContext], bool]:
Expand Down Expand Up @@ -410,6 +449,15 @@ def get_patterns(target) -> List[Pattern]:
),
]
)
# plugin ops
patterns.append(
(
target + ".plugin",
*basic_pattern("relax.call_dps_packed", ["input", "input"]),
_plugin_check,
plugin_attrs_getter,
)
)

return patterns

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/contrib/msc/framework/torch/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.torch.codegen.codegen"""

from typing import Dict, Optional
from typing import Dict, Optional, Any
import torch

import tvm
Expand All @@ -32,6 +32,7 @@ def to_torch(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> torch.nn.Module:
"""Change MSCGraph to torch nn.Module.
Expand All @@ -47,6 +48,8 @@ def to_torch(
The config for print.
build_folder: MSCDirectory
The folder for saving scripts and datas.
plugin: PluginManager
The plugin manager.
Returns
-------
Expand All @@ -73,4 +76,5 @@ def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> tor
return model

codegen = CodeGen(graph, _ffi_api.GetTorchSources, codegen_config, print_config, build_folder)
return codegen.load([], pre_load=_save_weights, post_load=_bind_weights)
model_args = [plugin] if plugin else []
return codegen.load(model_args, pre_load=_save_weights, post_load=_bind_weights)
10 changes: 8 additions & 2 deletions python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.tvm.codegen.codegen"""

from typing import Dict, Optional
from typing import Dict, Optional, Any

import tvm
from tvm.relax.transform import BindParams
Expand All @@ -32,6 +32,7 @@ def to_relax(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> tvm.IRModule:
"""Change MSCGraph to IRModule.
Expand All @@ -47,6 +48,8 @@ def to_relax(
The config for print.
build_folder: MSCDirectory
The folder for saving scripts and datas.
plugin: PluginManager
The plugin manager.
Returns
-------
Expand Down Expand Up @@ -81,4 +84,7 @@ def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModul
)(mod)

codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder)
return codegen.load(inputs, pre_load=_save_weights, post_load=_post_proc)
model_args = inputs
if plugin:
model_args = model_args + [plugin]
return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc)
Loading

0 comments on commit 343435a

Please sign in to comment.