diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 4697b5dc7fe09..7aedd8027e03a 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -44,9 +44,7 @@ class Session: Remote configs for RPC tracker. session_name : str - Hexagon RPC session name. Options are [hexagon-rpc, cpu-rpc]. - `hexagon-rpc` is used with hexagon target and `cpu-rpc` used with - hexagon as a sub-target of LLVM. + Hexagon RPC session name. remote_stack_size_bytes : int The stack size of the remote device, to be passed to @@ -67,10 +65,10 @@ def __init__( self._rpc_receive_buffer_size_bytes: int = rpc_receive_buffer_size_bytes self._remote_kw: dict = remote_kw self._rpc = None - self.device = None + self._device = None def __enter__(self): - if self.device: + if self._rpc: # Already initialized return self @@ -88,15 +86,6 @@ def __enter__(self): self._rpc_receive_buffer_size_bytes, ], ) - if self._session_name == "cpu-rpc": - self.device = self._rpc.cpu(0) - elif self._session_name == "hexagon-rpc": - self.device = self._rpc.hexagon(0) - else: - raise RuntimeError( - f"Incorrect session name: {self._session_name}.\n" - f"Options for session name are [hexagon-rpc, cpu-rpc]." - ) return self except RuntimeError as exception: @@ -105,6 +94,25 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): pass + @property + def device(self): + """Session device.""" + + if hasattr(self, "_device") and self._device is not None: + return self._device + + if not hasattr(self, "_requires_cpu_device"): + assert ( + False + ), "Device type is not set. 'set_device_type' should be called before accessing device." + + if self._requires_cpu_device: + self._device = self._rpc.cpu(0) + else: + self._device = self._rpc.hexagon(0) + + return self._device + def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): """Upload a local file to the remote workspace. @@ -143,9 +151,7 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): TVM module object. """ - assert ( - self.device is not None - ), "Hexagon session must be started using __enter__ prior to use" + assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" if isinstance(module, tvm.runtime.Module): with tempfile.TemporaryDirectory() as temp_dir: @@ -189,6 +195,7 @@ def get_graph_executor( """ graph_mod = self.load_module(module_name) + self.set_device_type(graph_mod) return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device) def get_aot_executor( @@ -216,6 +223,7 @@ def get_aot_executor( """ aot_mod = self.load_module(module_name) + self.set_device_type(aot_mod) return tvm.runtime.executor.AotModule(aot_mod["default"](self.device)) def get_executor_from_factory(self, module: ExecutorFactoryModule): @@ -236,6 +244,28 @@ def get_executor_from_factory(self, module: ExecutorFactoryModule): raise TypeError(f"Unsupported executor type: {type(module)}") + def set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]): + """Set session device type(hexagon, cpu) based on target in module. + + Parameters + ---------- + + module: TVMModule + TVM module object. + """ + # for cases when module is a single schedule without target attribute. + if not hasattr(module, "target"): + self._requires_cpu_device = False + else: + assert len(module.target.values()) == 1 + for target in module.target.values(): + target_type = str(target).split()[0] + + if target_type == "llvm": + self._requires_cpu_device = True + else: + self._requires_cpu_device = False + def _graph_executor_from_factory( self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule], @@ -263,6 +293,7 @@ def _graph_executor_from_factory( graph_json = module.get_graph_json() graph_mod = self.load_module(module.get_lib()) + self.set_device_type(module) return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device) @@ -296,10 +327,11 @@ def _aot_executor_from_factory( for target in module.target.values() if "hexagon" in target.keys ) - assert len(module.target.values()) == 1 + + self.set_device_type(module) for target in module.target.values(): - target_kind = str(target).split()[0] + target_type = str(target).split()[0] assert hexagon_arch, "No hexagon target architecture found" assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}" @@ -310,13 +342,13 @@ def _aot_executor_from_factory( binary_name = "test_binary.so" binary_path = temp_dir / binary_name - if target_kind == "hexagon": + if target_type == "hexagon": module.export_library( str(binary_path), fcompile=hexagon.create_aot_shared, hexagon_arch=hexagon_arch, ) - elif target_kind == "llvm": + elif target_type == "llvm": module.export_library( str(binary_path), cc=hexagon.hexagon_clang_plus(), diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 1a917724b5214..9700087d475a2 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -218,13 +218,3 @@ def aot_target(aot_host_target): yield aot_host_target else: assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]." - - -@pytest.fixture() -def rpc_session_name(aot_host_target): - if aot_host_target == "c": - yield "hexagon-rpc" - elif aot_host_target.startswith("llvm"): - yield "cpu-rpc" - else: - assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]." diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index a6624969e7e9d..c12f3f936fbf8 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -43,6 +43,7 @@ def test_add(hexagon_session): ) mod = hexagon_session.load_module(func) + hexagon_session.set_device_type(mod) A_data = tvm.nd.array(np.array([2, 3], dtype=dtype), device=hexagon_session.device) assert (A_data.numpy() == np.array([2, 3])).all() @@ -68,6 +69,8 @@ def test_add_vtcm(hexagon_session): ) mod = hexagon_session.load_module(func) + hexagon_session.set_device_type(mod) + A_data = tvm.nd.empty(A.shape, A.dtype, hexagon_session.device, "global.vtcm") A_data.copyfrom(np.array([2, 3])) @@ -101,6 +104,7 @@ def test_matmul(self, hexagon_session, M, N, K): ) mod = hexagon_session.load_module(func) + hexagon_session.set_device_type(mod) x = np.random.uniform(size=[i.value for i in X.shape]).astype(X.dtype) y = np.random.uniform(size=[i.value for i in Y.shape]).astype(Y.dtype) @@ -267,7 +271,7 @@ def _workaround_create_aot_shared(): @requires_hexagon_toolchain -def test_aot_executor(hexagon_launcher, aot_host_target, aot_target, rpc_session_name): +def test_aot_executor(hexagon_session, aot_host_target, aot_target): dtype = "float32" input_shape = (1, 128, 128, 3) w_shape = (5, 5, 3, 8) @@ -303,11 +307,10 @@ def test_aot_executor(hexagon_launcher, aot_host_target, aot_target, rpc_session executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) - with hexagon_launcher.start_session(session_name=rpc_session_name) as hexagon_session: - aot_mod = hexagon_session.get_executor_from_factory(lowered) - aot_mod.set_input(**inputs) - aot_mod.run() - hexagon_output = aot_mod.get_output(0).numpy() + aot_mod = hexagon_session.get_executor_from_factory(lowered) + aot_mod.set_input(**inputs) + aot_mod.run() + hexagon_output = aot_mod.get_output(0).numpy() target_llvm = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): @@ -327,9 +330,7 @@ def test_aot_executor(hexagon_launcher, aot_host_target, aot_target, rpc_session @requires_hexagon_toolchain -def test_aot_executor_multiple_conv2d( - hexagon_launcher, aot_host_target, aot_target, rpc_session_name -): +def test_aot_executor_multiple_conv2d(hexagon_session, aot_host_target, aot_target): dtype = "float32" input_shape = (1, 8, 8, 3) w1_shape = (5, 5, 3, 1) @@ -381,11 +382,10 @@ def test_aot_executor_multiple_conv2d( executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) - with hexagon_launcher.start_session(session_name=rpc_session_name) as hexagon_session: - aot_mod = hexagon_session.get_executor_from_factory(lowered) - aot_mod.set_input(**inputs) - aot_mod.run() - hexagon_output = aot_mod.get_output(0).numpy() + aot_mod = hexagon_session.get_executor_from_factory(lowered) + aot_mod.set_input(**inputs) + aot_mod.run() + hexagon_output = aot_mod.get_output(0).numpy() target_llvm = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3):