Skip to content

Commit

Permalink
Address Eric suggesions
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Apr 21, 2022
1 parent 2ce8445 commit 31a7e80
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 46 deletions.
74 changes: 53 additions & 21 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ class Session:
Remote configs for RPC tracker.
session_name : str
Hexagon RPC session name. Options are [hexagon-rpc, cpu-rpc].
`hexagon-rpc` is used with hexagon target and `cpu-rpc` used with
hexagon as a sub-target of LLVM.
Hexagon RPC session name.
remote_stack_size_bytes : int
The stack size of the remote device, to be passed to
Expand All @@ -67,10 +65,10 @@ def __init__(
self._rpc_receive_buffer_size_bytes: int = rpc_receive_buffer_size_bytes
self._remote_kw: dict = remote_kw
self._rpc = None
self.device = None
self._device = None

def __enter__(self):
if self.device:
if self._rpc:
# Already initialized
return self

Expand All @@ -88,15 +86,6 @@ def __enter__(self):
self._rpc_receive_buffer_size_bytes,
],
)
if self._session_name == "cpu-rpc":
self.device = self._rpc.cpu(0)
elif self._session_name == "hexagon-rpc":
self.device = self._rpc.hexagon(0)
else:
raise RuntimeError(
f"Incorrect session name: {self._session_name}.\n"
f"Options for session name are [hexagon-rpc, cpu-rpc]."
)
return self

except RuntimeError as exception:
Expand All @@ -105,6 +94,25 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_traceback):
pass

@property
def device(self):
"""Session device."""

if hasattr(self, "_device") and self._device is not None:
return self._device

if not hasattr(self, "_requires_cpu_device"):
assert (
False
), "Device type is not set. 'set_device_type' should be called before accessing device."

if self._requires_cpu_device:
self._device = self._rpc.cpu(0)
else:
self._device = self._rpc.hexagon(0)

return self._device

def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
"""Upload a local file to the remote workspace.
Expand Down Expand Up @@ -143,9 +151,7 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
TVM module object.
"""

assert (
self.device is not None
), "Hexagon session must be started using __enter__ prior to use"
assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use"

if isinstance(module, tvm.runtime.Module):
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -189,6 +195,7 @@ def get_graph_executor(
"""

graph_mod = self.load_module(module_name)
self.set_device_type(graph_mod)
return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)

def get_aot_executor(
Expand Down Expand Up @@ -216,6 +223,7 @@ def get_aot_executor(
"""

aot_mod = self.load_module(module_name)
self.set_device_type(aot_mod)
return tvm.runtime.executor.AotModule(aot_mod["default"](self.device))

def get_executor_from_factory(self, module: ExecutorFactoryModule):
Expand All @@ -236,6 +244,28 @@ def get_executor_from_factory(self, module: ExecutorFactoryModule):

raise TypeError(f"Unsupported executor type: {type(module)}")

def set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]):
"""Set session device type(hexagon, cpu) based on target in module.
Parameters
----------
module: TVMModule
TVM module object.
"""
# for cases when module is a single schedule without target attribute.
if not hasattr(module, "target"):
self._requires_cpu_device = False
else:
assert len(module.target.values()) == 1
for target in module.target.values():
target_type = str(target).split()[0]

if target_type == "llvm":
self._requires_cpu_device = True
else:
self._requires_cpu_device = False

def _graph_executor_from_factory(
self,
module: Union[str, pathlib.Path, GraphExecutorFactoryModule],
Expand Down Expand Up @@ -263,6 +293,7 @@ def _graph_executor_from_factory(

graph_json = module.get_graph_json()
graph_mod = self.load_module(module.get_lib())
self.set_device_type(module)

return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)

Expand Down Expand Up @@ -296,10 +327,11 @@ def _aot_executor_from_factory(
for target in module.target.values()
if "hexagon" in target.keys
)
assert len(module.target.values()) == 1

self.set_device_type(module)

for target in module.target.values():
target_kind = str(target).split()[0]
target_type = str(target).split()[0]

assert hexagon_arch, "No hexagon target architecture found"
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
Expand All @@ -310,13 +342,13 @@ def _aot_executor_from_factory(
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name

if target_kind == "hexagon":
if target_type == "hexagon":
module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
elif target_kind == "llvm":
elif target_type == "llvm":
module.export_library(
str(binary_path),
cc=hexagon.hexagon_clang_plus(),
Expand Down
12 changes: 1 addition & 11 deletions tests/python/contrib/test_hexagon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,11 @@ def terminate_rpc_servers():
)


@pytest.fixture()
@tvm.testing.fixture
def aot_target(aot_host_target):
if aot_host_target == "c":
yield tvm.target.hexagon("v68")
elif aot_host_target.startswith("llvm"):
yield aot_host_target
else:
assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]."


@pytest.fixture()
def rpc_session_name(aot_host_target):
if aot_host_target == "c":
yield "hexagon-rpc"
elif aot_host_target.startswith("llvm"):
yield "cpu-rpc"
else:
assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]."
28 changes: 14 additions & 14 deletions tests/python/contrib/test_hexagon/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_add(hexagon_session):
)

mod = hexagon_session.load_module(func)
hexagon_session.set_device_type(mod)

A_data = tvm.nd.array(np.array([2, 3], dtype=dtype), device=hexagon_session.device)
assert (A_data.numpy() == np.array([2, 3])).all()
Expand All @@ -68,6 +69,8 @@ def test_add_vtcm(hexagon_session):
)

mod = hexagon_session.load_module(func)
hexagon_session.set_device_type(mod)

A_data = tvm.nd.empty(A.shape, A.dtype, hexagon_session.device, "global.vtcm")
A_data.copyfrom(np.array([2, 3]))

Expand Down Expand Up @@ -101,6 +104,7 @@ def test_matmul(self, hexagon_session, M, N, K):
)

mod = hexagon_session.load_module(func)
hexagon_session.set_device_type(mod)

x = np.random.uniform(size=[i.value for i in X.shape]).astype(X.dtype)
y = np.random.uniform(size=[i.value for i in Y.shape]).astype(Y.dtype)
Expand Down Expand Up @@ -267,7 +271,7 @@ def _workaround_create_aot_shared():


@requires_hexagon_toolchain
def test_aot_executor(hexagon_launcher, aot_host_target, aot_target, rpc_session_name):
def test_aot_executor(hexagon_session, aot_host_target, aot_target):
dtype = "float32"
input_shape = (1, 128, 128, 3)
w_shape = (5, 5, 3, 8)
Expand Down Expand Up @@ -303,11 +307,10 @@ def test_aot_executor(hexagon_launcher, aot_host_target, aot_target, rpc_session
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
)

with hexagon_launcher.start_session(session_name=rpc_session_name) as hexagon_session:
aot_mod = hexagon_session.get_executor_from_factory(lowered)
aot_mod.set_input(**inputs)
aot_mod.run()
hexagon_output = aot_mod.get_output(0).numpy()
aot_mod = hexagon_session.get_executor_from_factory(lowered)
aot_mod.set_input(**inputs)
aot_mod.run()
hexagon_output = aot_mod.get_output(0).numpy()

target_llvm = tvm.target.Target("llvm")
with tvm.transform.PassContext(opt_level=3):
Expand All @@ -327,9 +330,7 @@ def test_aot_executor(hexagon_launcher, aot_host_target, aot_target, rpc_session


@requires_hexagon_toolchain
def test_aot_executor_multiple_conv2d(
hexagon_launcher, aot_host_target, aot_target, rpc_session_name
):
def test_aot_executor_multiple_conv2d(hexagon_session, aot_host_target, aot_target):
dtype = "float32"
input_shape = (1, 8, 8, 3)
w1_shape = (5, 5, 3, 1)
Expand Down Expand Up @@ -381,11 +382,10 @@ def test_aot_executor_multiple_conv2d(
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
)

with hexagon_launcher.start_session(session_name=rpc_session_name) as hexagon_session:
aot_mod = hexagon_session.get_executor_from_factory(lowered)
aot_mod.set_input(**inputs)
aot_mod.run()
hexagon_output = aot_mod.get_output(0).numpy()
aot_mod = hexagon_session.get_executor_from_factory(lowered)
aot_mod.set_input(**inputs)
aot_mod.run()
hexagon_output = aot_mod.get_output(0).numpy()

target_llvm = tvm.target.Target("llvm")
with tvm.transform.PassContext(opt_level=3):
Expand Down

0 comments on commit 31a7e80

Please sign in to comment.