Skip to content

Commit

Permalink
AOT with LLVM Codegen on Hexagon
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Apr 19, 2022
1 parent a945586 commit 6ebbb79
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 36 deletions.
4 changes: 2 additions & 2 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
assert self._workspace
self._copy_to_remote(local_path, os.path.join(str(self._workspace), remote_filename))

def start_session(self) -> Session:
def start_session(self, name="hexagon-rpc") -> Session:
"""Connect to the RPC server.
Returns
Expand All @@ -197,7 +197,7 @@ def start_session(self) -> Session:
"timeout": 0,
"key": self._device_key,
}
return Session(self, hexagon_remote_kw)
return Session(self, hexagon_remote_kw, session_name=name)

def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
"""Load TVM module.
Expand Down
30 changes: 24 additions & 6 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ def __enter__(self):
self._rpc_receive_buffer_size_bytes,
],
)
self.device = self._rpc.hexagon(0)
if self._session_name == "cpu-rpc":
self.device = self._rpc.cpu(0)
elif self._session_name == "hexagon-rpc":
self.device = self._rpc.hexagon(0)
else:
raise RuntimeError("Incorrect session name: %s", self._session_name)
return self

except RuntimeError as exception:
Expand Down Expand Up @@ -286,6 +291,11 @@ def _aot_executor_from_factory(
for target in module.target.values()
if "hexagon" in target.keys
)
assert len(module.target.values()) == 1

for target in module.target.values():
target_kind = str(target).split()[0]

assert hexagon_arch, "No hexagon target architecture found"
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
hexagon_arch = hexagon_arch.pop()
Expand All @@ -295,11 +305,19 @@ def _aot_executor_from_factory(
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name

module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
if target_kind == "hexagon":
module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
elif target_kind == "llvm":
module.export_library(
str(binary_path),
cc=hexagon.hexagon_clang_plus(),
)
else:
raise ValueError("Incorrect Target kind.")

self.upload(binary_path, binary_name)

Expand Down
6 changes: 2 additions & 4 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,12 +1234,10 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
Target target_host;
for (const auto& it : tmp) {
auto dev_type = it.first.as<tir::IntImmNode>();
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
if (!target_host.defined() && ((it.second->kind->device_type == kDLCPU) ||
(it.second->kind->device_type == kDLHexagon))) {
target_host = it.second;
}
if (!target_host.defined() && it.second->kind->device_type == kDLHexagon) {
target_host = *(new Target("c"));
}
ICHECK(dev_type);
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
}
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,16 @@ struct HexagonWorkspacePool : public WorkspacePool {
};

void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
return dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size);
}

void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
CHECK(hexagon_buffer_map_.count(data) != 0)
<< "Attempt made to free unknown or already freed workspace allocation";
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, data);
Expand Down
25 changes: 16 additions & 9 deletions src/runtime/hexagon/rpc/hexagon/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ class HexagonIOHandler {
read_buffer_size_bytes_{read_buffer_size_bytes},
write_buffer_available_length_{0} {}

void MessageStart(size_t message_size_bytes) {}
void MessageStart(size_t message_size_bytes) { LOG(INFO) << "MessageStart called."; }

ssize_t PosixWrite(const uint8_t* buf, size_t write_len_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes
LOG(INFO) << "HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes
<< ")";
int32_t written_size = write_buffer_.sputn(reinterpret_cast<const char*>(buf), write_len_bytes);
if (written_size != write_len_bytes) {
Expand All @@ -72,10 +72,10 @@ class HexagonIOHandler {
return (ssize_t)written_size;
}

void MessageDone() { LOG(INFO) << "INFO: Message Done."; }
void MessageDone() { LOG(INFO) << "Message Done."; }

ssize_t PosixRead(uint8_t* buf, size_t read_len_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
LOG(INFO) << "HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
<< "), read_buffer_index_(" << read_buffer_index_ << ")";

uint32_t bytes_to_read = 0;
Expand All @@ -99,7 +99,7 @@ class HexagonIOHandler {
* \return The status
*/
AEEResult SetReadBuffer(const uint8_t* data, size_t data_size_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
LOG(INFO) << "HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
<< "), read_buffer_index_(" << read_buffer_index_ << "), read_buffer_size_bytes_("
<< read_buffer_size_bytes_ << ")";
if (data_size_bytes > read_buffer_size_bytes_) {
Expand All @@ -121,7 +121,7 @@ class HexagonIOHandler {
* \return The size of data that is read in bytes.
*/
int64_t ReadFromWriteBuffer(uint8_t* buf, size_t read_size_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
LOG(INFO) << "HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
<< read_size_bytes;
int64_t size = (int64_t)write_buffer_.sgetn(reinterpret_cast<char*>(buf), read_size_bytes);
write_buffer_available_length_ -= size;
Expand All @@ -133,7 +133,7 @@ class HexagonIOHandler {
return size;
}

void Close() { LOG(INFO) << "INFO: HexagonIOHandler Close called"; }
void Close() { LOG(INFO) << "HexagonIOHandler Close called"; }

void Exit(int code) { exit(code); }

Expand All @@ -156,13 +156,19 @@ class HexagonRPCServer {
* \param data The data pointer
* \param data_size_bytes The data size in bytes.
*
* \return The size of data written to IOHandler.
* \return The size of data written to IOHandler if no error.
* Otherwise, returns -1;
*/
int64_t Write(const uint8_t* data, size_t data_size_bytes) {
if (io_.SetReadBuffer(data, data_size_bytes) != AEE_SUCCESS) {
LOG(ERROR) << "ERROR: SetReadBuffer failed";
return -1;
}

if (!rpc_server_.ProcessOnePacket()) {
LOG(ERROR) << "ERROR: ProcessOnePacket failed";
return -1;
}
rpc_server_.ProcessOnePacket();
return (int64_t)data_size_bytes;
}

Expand Down Expand Up @@ -211,6 +217,7 @@ const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) {
void reset_device_api() {
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2");
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api);
}

int __QAIC_HEADER(hexagon_rpc_open)(const char* uri, remote_handle64* handle) {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/hexagon/rpc/simulator/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ int main() {
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
ICHECK(api_v2 != nullptr);
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);

tvm::runtime::hexagon::SimulatorRPCServer server;

Expand Down
34 changes: 33 additions & 1 deletion src/runtime/hexagon/rpc/simulator/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ class SimulatorRPCChannel final : public RPCChannel {
std::string runmain; // Path to run_main_on_hexagon.
};

struct Message_ {
Message msg;
std::string str() const;
};

Message SendMsg(Message msg);
Message SendMsg(uint32_t code, uint32_t len, uint32_t va);
void ReadFromProcess(void* host_dst, HEX_VA_t src, size_t len);
Expand Down Expand Up @@ -461,6 +466,27 @@ std::string SimulatorRPCChannel::Cpu_::str() const {
return default_cpu_;
}

std::string SimulatorRPCChannel::Message_::str() const {
switch (msg.code) {
case Message::kNone:
return "kNone";
case Message::kAck:
return "kAck";
case Message::kTerminate:
return "kTerminate";
case Message::kReceiveStart:
return "kReceiveStart";
case Message::kReceiveEnd:
return "kReceiveEnd";
case Message::kSendStart:
return "kSendStart";
case Message::kSendEnd:
return "kSendEnd";
default:
break;
}
}

SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std::string& cpu)
: root(sdk_root) {
// For v69 chips, still look for v68 in the directory names.
Expand Down Expand Up @@ -524,6 +550,7 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) {
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
ICHECK(api_v2 != nullptr);
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);

const char* sdk_root_env = std::getenv("HEXAGON_SDK_ROOT");
ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT";
Expand Down Expand Up @@ -651,9 +678,14 @@ Message SimulatorRPCChannel::SendMsg(Message msg) {
HEX_4u_t result;

core = sim_->Run(&result);
ICHECK_EQ(core, HEX_CORE_BREAKPOINT);
Core_ core_ = {core};
ICHECK_EQ(core, HEX_CORE_BREAKPOINT)
<< "Expecting HEX_CORE_BREAKPOINT, received: " << core_.str();
};

Message_ msg_ = {msg};
LOG(INFO) << "Sending message: " << msg_.str();

WriteToProcess(message_buffer_v_, &msg, sizeof msg);
run();

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
loaders += name.substr(loadkey.size());
}
}
LOG(FATAL) << "Binary was created using " << type_key
<< " but a loader of that name is not registered. Available loaders are " << loaders
LOG(FATAL) << "Binary was created using {" << type_key
<< "} but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}

Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,12 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {

TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenHexagon();
*rv = static_cast<void*>(cg);
});

} // namespace codegen
} // namespace tvm

Expand Down
35 changes: 25 additions & 10 deletions tests/python/contrib/test_hexagon/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@
# under the License.

import os
import pathlib
import sys
import pytest
import numpy as np
import logging

import tvm.testing
from tvm import te
from tvm import relay
from tvm.relay.backend import Executor, Runtime
from tvm.contrib import utils, ndk
from tvm.contrib.hexagon.build import HexagonLauncher
import tvm.contrib.hexagon as hexagon

from .conftest import requires_hexagon_toolchain

aot_target_kind = tvm.testing.parameter(
"c",
"llvm -keys=hexagon -link-params=0 -mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv68 -mtriple=hexagon",
)


@requires_hexagon_toolchain
def test_add(hexagon_session):
Expand Down Expand Up @@ -270,8 +271,18 @@ def _workaround_create_aot_shared():
)


def get_target_and_session(target_kind: str):
if target_kind == "c":
target_hexagon = tvm.target.hexagon("v68")
session_key = "hexagon-rpc"
elif target_kind.startswith("llvm"):
target_hexagon = target_kind
session_key = "cpu-rpc"
return target_hexagon, session_key


@requires_hexagon_toolchain
def test_aot_executor(hexagon_session):
def test_aot_executor(hexagon_launcher, aot_target_kind):
dtype = "float32"
input_shape = (1, 128, 128, 3)
w_shape = (5, 5, 3, 8)
Expand All @@ -290,7 +301,7 @@ def test_aot_executor(hexagon_session):
relay_mod = tvm.IRModule.from_expr(f)
relay_mod = relay.transform.InferType()(relay_mod)

target_hexagon = tvm.target.hexagon("v68")
target_hexagon, session_key = get_target_and_session(aot_target_kind)

weight_data = np.random.rand(w_shape[0], w_shape[1], w_shape[2], w_shape[3]).astype(dtype=dtype)
input_data = np.random.rand(
Expand All @@ -304,11 +315,13 @@ def test_aot_executor(hexagon_session):
lowered = tvm.relay.build(
relay_mod,
params=params,
target=tvm.target.Target(target_hexagon, host="c"),
target=tvm.target.Target(target_hexagon, host=aot_target_kind),
runtime=Runtime("cpp"),
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
)

hexagon_session = hexagon_launcher.start_session(name=session_key)
hexagon_session.__enter__()
aot_mod = hexagon_session.get_executor_from_factory(lowered)
aot_mod.set_input(**inputs)
aot_mod.run()
Expand All @@ -332,7 +345,7 @@ def test_aot_executor(hexagon_session):


@requires_hexagon_toolchain
def test_aot_executor_multiple_conv2d(hexagon_session):
def test_aot_executor_multiple_conv2d(hexagon_launcher, aot_target_kind):
dtype = "float32"
input_shape = (1, 8, 8, 3)
w1_shape = (5, 5, 3, 1)
Expand Down Expand Up @@ -362,7 +375,7 @@ def test_aot_executor_multiple_conv2d(hexagon_session):
relay_mod = tvm.IRModule.from_expr(f)
relay_mod = relay.transform.InferType()(relay_mod)

target_hexagon = tvm.target.hexagon("v68")
target_hexagon, session_key = get_target_and_session(aot_target_kind)

weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], w1_shape[3]).astype(
dtype=dtype
Expand All @@ -381,11 +394,13 @@ def test_aot_executor_multiple_conv2d(hexagon_session):
lowered = tvm.relay.build(
relay_mod,
params=params,
target=tvm.target.Target(target_hexagon, host="c"),
target=tvm.target.Target(target_hexagon, host=aot_target_kind),
runtime=Runtime("cpp"),
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
)

hexagon_session = hexagon_launcher.start_session(name=session_key)
hexagon_session.__enter__()
aot_mod = hexagon_session.get_executor_from_factory(lowered)
aot_mod.set_input(**inputs)
aot_mod.run()
Expand Down

0 comments on commit 6ebbb79

Please sign in to comment.