From 89d37772d50050f9b56cf09b844033fac17f4308 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 16 Feb 2024 21:30:08 +0800 Subject: [PATCH] enable plugin with manager --- .../msc/framework/torch/frontend/translate.py | 9 ++- .../tvm/relax/frontend/torch/fx_translator.py | 33 ++++++++++- tests/python/contrib/test_msc/test_plugin.py | 58 +++++++++++++++++++ 3 files changed, 96 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py index 3ac1b81a2c73..2509f1abfcbe 100644 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py @@ -70,6 +70,7 @@ def from_torch( build_config: Optional[Dict[str, str]] = None, opt_config: Optional[Dict[str, str]] = None, as_msc: bool = True, + custom_convert_map: dict = None, ) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]: """Change torch nn.Module to MSCGraph. @@ -91,6 +92,8 @@ def from_torch( The config for optimize the relay before translate. as_msc: bool Set to to return msc graph, otherwise relax mod + custom_convert_map: dict + The convert map for plugin Returns ------- @@ -103,7 +106,7 @@ def from_torch( if via_relax: graph_model, params = torch.fx.symbolic_trace(model), None with torch.no_grad(): - relax_mod = from_fx(graph_model, input_info) + relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) else: datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] torch_datas = [torch.from_numpy(i) for i in datas] @@ -116,7 +119,9 @@ def from_torch( shape_list = list(zip(input_names, input_info)) else: shape_list = [("input" + str(idx), i_info) for idx, i_info in enumerate(input_info)] - relay_mod, params = tvm.relay.frontend.from_pytorch(scripted_model, shape_list) + relay_mod, params = tvm.relay.frontend.from_pytorch( + scripted_model, shape_list, custom_convert_map=custom_convert_map + ) relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config) if not as_msc: return relax_mod, params diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5e581e81f3ea..49e9fc4495f9 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1459,6 +1459,17 @@ def create_convert_map(self): "scaled_dot_product_attention": self._scaled_dot_product_attention, } + def update_convert_map(self, custom_convert_map: dict): + """Update self.convert_map with custom convert map + + Parameters + ---------- + custom_convert_map : Dictionary of str to Relax op + A custom op conversion map in the same format as self.convert_map + """ + + self.convert_map.update(custom_convert_map) + def from_fx( self, model, @@ -1466,10 +1477,16 @@ def from_fx( keep_params_as_input: bool, unwrap_unit_return_tuple: bool, no_bind_return_tuple: bool, + custom_convert_map: dict = None, ) -> tvm.IRModule: """Convert a PyTorch FX GraphModule to a Relax program.""" from torch import fx + if custom_convert_map: + custom_ops = set(custom_convert_map.keys()) + self.update_convert_map(custom_convert_map) + else: + custom_ops = set() self.named_modules = dict(model.named_modules()) graph: fx.Graph = model.graph @@ -1548,7 +1565,10 @@ def from_fx( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" - self.env[node] = self.convert_map[func_name](node) + if func_name in custom_ops: + self.env[node] = self.convert_map[func_name](node, self) + else: + self.env[node] = self.convert_map[func_name](node) elif node.op == "call_method": assert ( node.target in self.convert_map @@ -1572,6 +1592,7 @@ def from_fx( keep_params_as_input: bool = False, unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, + custom_convert_map: dict = None, ) -> tvm.IRModule: """Convert a PyTorch FX GraphModule to a Relax program @@ -1594,6 +1615,9 @@ def from_fx( A boolean flag indicating whether to bind the return tuple as a relax var. If the flag is true and the return value is a tuple, it will not bind it to a var. + custom_convert_map : Dictionary of str to Relax op + A custom op conversion map in the same format as TorchFXImporter.convert_map + Returns ------- output : tvm.IRModule @@ -1662,5 +1686,10 @@ def forward(self, input): check the placeholder rows in the beginning of the tabular. """ return TorchFXImporter().from_fx( - model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple + model, + input_info, + keep_params_as_input, + unwrap_unit_return_tuple, + no_bind_return_tuple, + custom_convert_map=custom_convert_map, ) diff --git a/tests/python/contrib/test_msc/test_plugin.py b/tests/python/contrib/test_msc/test_plugin.py index 277268f8aee8..e2d3b5fcd3d3 100644 --- a/tests/python/contrib/test_msc/test_plugin.py +++ b/tests/python/contrib/test_msc/test_plugin.py @@ -26,6 +26,7 @@ from tvm import relax from tvm.relax.transform import BindParams from tvm.script import relax as R +from tvm.contrib.msc.pipeline import MSCManager from tvm.contrib.msc.plugin import build_plugins from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils @@ -287,6 +288,39 @@ def _test_torch_plugin(manager): assert outputs.min() >= 0 and outputs.max() <= 0.5 +def _test_with_manager(plugins, compile_type, expected_info): + """Test the plugin with manager""" + + path = "test_plugin_" + compile_type + model = _get_torch_model(plugins[MSCFramework.TORCH]) + if torch.cuda.is_available(): + model = model.to(torch.device("cuda:0")) + config = { + "workspace": msc_utils.msc_dir(path), + "model_type": MSCFramework.TORCH, + "verbose": "critical", + "inputs": [["input_0", [1, 3, 224, 224], "float32"]], + "outputs": ["output"], + "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, + "prepare": {"profile": {"benchmark": {"repeat": 10}}}, + "baseline": { + "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}}, + }, + "compile": { + "run_type": compile_type, + "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}}, + }, + } + manager = MSCManager(model, config, plugins=plugins) + report = manager.run_pipe() + model_info = manager.runner.model_info + manager.destory() + assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) + assert msc_utils.dict_equal( + model_info, expected_info + ), "Model info {} mismatch with expected {}".format(model_info, expected_info) + + def test_plugin(): """Test the plugins""" @@ -302,6 +336,30 @@ def test_plugin(): _test_tvm_plugin(managers[MSCFramework.TVM], "cuda") _test_torch_plugin(managers[MSCFramework.TORCH]) + # test the plugin with manager + model_info = { + "inputs": [ + {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 4, "input": 1, "msc.conv2d_bias": 1, "MyRelu": 1, "nn.max_pool2d": 1}, + } + _test_with_manager(managers, MSCFramework.TORCH, model_info) + _test_with_manager(managers, MSCFramework.TVM, model_info) + if tvm.get_global_func("relax.ext.tensorrt", True) is not None: + byoc_info = { + "inputs": [ + {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": ""} + ], + "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, + } + _test_with_manager(managers, MSCFramework.TENSORRT, byoc_info) + plugin_root.destory()