Skip to content

Commit

Permalink
[Relay][Runtime] Add set_input/output_zero_copy in python (apache#1…
Browse files Browse the repository at this point in the history
…3623)

* add set_output and test for set_output_zero_copy in python

* clean up

* clean up test

* test finished

* remove set output

* remove setoutput from header

* use zero copy for params

* fix typo

* address comments

* address comments

* add second test for set_input params

* add requires_torch

* add requires torch

* remove pytest

* add error handling for c graph executor

* better handling
  • Loading branch information
Yuanjing Shi authored and Mikael Sevenier committed Dec 29, 2022
1 parent 8eb6ab3 commit 999690b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 16 deletions.
58 changes: 57 additions & 1 deletion python/tvm/contrib/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ class GraphModule(object):
def __init__(self, module):
self.module = module
self._set_input = module["set_input"]

# TODO(shingjan): The graph_executor in C doesn't have
# set_input/output_zero_copy implemented.
try:
self._set_input_zero_copy = module["set_input_zero_copy"]
except AttributeError:
self._set_input_zero_copy = lambda *_: (_ for _ in ()).throw(
Exception("set_input_zero_copy is not implemented for C graph executor")
)
try:
self._set_output_zero_copy = module["set_output_zero_copy"]
except AttributeError:
self._set_output_zero_copy = lambda *_: (_ for _ in ()).throw(
Exception("set_output_zero_copy is not implemented for C graph executor")
)
self._run = module["run"]
self._get_output = module["get_output"]
self._get_input = module["get_input"]
Expand All @@ -172,7 +187,7 @@ def set_input(self, key=None, value=None, **params):
The input key
value : the input value.
The input key
The input value
params : dict of str to NDArray
Additional arguments
Expand All @@ -195,6 +210,47 @@ def set_input(self, key=None, value=None, **params):
if val:
self._get_input(k).copyfrom(params[k])

def set_input_zero_copy(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs with zero memory copy
Parameters
----------
key : int or str
The input key
value : the input value in DLPack
The input value
params : dict of str to NDArray
Additional arguments
"""
if key is not None:
self._set_input_zero_copy(key, value)

if params:
keys = list(params.keys())

for k in keys:
# TODO(zhiics) Skip the weights for submodule in a better way.
# We should use ConstLoaderModule for initialization and remove
# params from set_input
val = self._get_input(k)
if val:
self._set_input_zero_copy(k, params[k])

def set_output_zero_copy(self, key, value):
"""Set outputs to the module with zero memory copy
Parameters
----------
key : int or str
The output key
value : the output value in DLPack
The output value
"""
self._set_output_zero_copy(key, value)

def run(self, **input_dict):
"""Run forward execution of the graph
Expand Down
20 changes: 5 additions & 15 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,6 @@ def main(a: T.handle, b: T.handle) -> None: # type: ignore
# pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument


def _has_torch():
import importlib.util # pylint: disable=unused-import,import-outside-toplevel

spec = importlib.util.find_spec("torch")
return spec is not None


requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed")


def test_meta_schedule_dynamic_loop_extent():
a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
Expand All @@ -72,7 +62,7 @@ def test_meta_schedule_dynamic_loop_extent():
assert not extracted_tasks


@requires_torch
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_resnet():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params)
Expand Down Expand Up @@ -108,7 +98,7 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
@tvm.testing.requires_package("torch")
def test_task_extraction_winograd_tensorcore():
mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224, 224])
seq = tvm.transform.Sequential(
Expand All @@ -126,7 +116,7 @@ def test_task_extraction_winograd_tensorcore():
assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4


@requires_torch
@tvm.testing.requires_package("torch")
def test_task_extraction_anchor_block():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.relay_integration.extract_tasks(
Expand Down Expand Up @@ -161,7 +151,7 @@ def test_task_extraction_anchor_block():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_bert_base():
pytest.importorskip(
"transformers", reason="transformers package is required to import bert_base"
Expand Down Expand Up @@ -259,7 +249,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
assert expected_shape == shape, t.task_name


@requires_torch
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
@register_func("relay.backend.tir_converter.remove_purely_spatial", override=True)
def filter_func(args, _) -> bool:
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_runtime_module_based_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,44 @@ def test_num_threads():
assert reported == hardware_threads or reported == hardware_threads // 2


@tvm.testing.requires_llvm
@tvm.testing.requires_package("torch")
def test_graph_module_zero_copy():
mod = tvm.IRModule()
params = {}
dev = tvm.cpu()
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
mod["main"] = relay.Function([x, y], z)

# need torch to do the from_dlpack trick
import torch

compiled_graph_lib = relay.build(mod, target="llvm", params=params)
gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev))
x_data = torch.rand((1, 10))
y_data = torch.rand((1, 10))
z_data = torch.rand((1, 10))
z_torch = x_data + y_data

# zero copy run
assert not np.allclose(z_data.numpy(), z_torch.numpy())
gm.set_input_zero_copy("x", tvm.nd.from_dlpack(x_data))
gm.set_input_zero_copy("y", tvm.nd.from_dlpack(y_data))
gm.set_output_zero_copy(0, tvm.nd.from_dlpack(z_data))
gm.run()

tvm.testing.assert_allclose(z_data.numpy(), z_torch.numpy())

# zero input copy with params
gm = graph_executor.GraphModule(compiled_graph_lib["default"](dev))
gm.set_input_zero_copy(x=tvm.nd.from_dlpack(x_data), y=tvm.nd.from_dlpack(y_data))
gm.run()

tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy())


if __name__ == "__main__":
test_legacy_compatibility()
test_cpu()
Expand All @@ -699,3 +737,4 @@ def test_num_threads():
test_cpu_get_graph_json()
test_cpu_get_graph_params_run()
test_cpu_get_graph_params_compare()
test_graph_module_zero_copy()

0 comments on commit 999690b

Please sign in to comment.