Skip to content

Commit

Permalink
AOT with LLVM Codegen on Hexagon
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Apr 19, 2022
1 parent a945586 commit 3cd73c4
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 41 deletions.
4 changes: 2 additions & 2 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
assert self._workspace
self._copy_to_remote(local_path, os.path.join(str(self._workspace), remote_filename))

def start_session(self) -> Session:
def start_session(self, name="hexagon-rpc") -> Session:
"""Connect to the RPC server.
Returns
Expand All @@ -197,7 +197,7 @@ def start_session(self) -> Session:
"timeout": 0,
"key": self._device_key,
}
return Session(self, hexagon_remote_kw)
return Session(self, hexagon_remote_kw, session_name=name)

def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
"""Load TVM module.
Expand Down
38 changes: 28 additions & 10 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def __init__(
rpc_receive_buffer_size_bytes: int = 2 * 1024 * 1024,
):
self._launcher = launcher
self._session_name = session_name
self._remote_stack_size_bytes = remote_stack_size_bytes
self._rpc_receive_buffer_size_bytes = rpc_receive_buffer_size_bytes
self._remote_kw = remote_kw
self._session_name: str = session_name
self._remote_stack_size_bytes: int = remote_stack_size_bytes
self._rpc_receive_buffer_size_bytes: int = rpc_receive_buffer_size_bytes
self._remote_kw: dict = remote_kw
self._rpc = None
self.device = None

Expand All @@ -86,7 +86,12 @@ def __enter__(self):
self._rpc_receive_buffer_size_bytes,
],
)
self.device = self._rpc.hexagon(0)
if self._session_name == "cpu-rpc":
self.device = self._rpc.cpu(0)
elif self._session_name == "hexagon-rpc":
self.device = self._rpc.hexagon(0)
else:
raise RuntimeError(f"Incorrect session name: {self._session_name}")
return self

except RuntimeError as exception:
Expand Down Expand Up @@ -286,6 +291,11 @@ def _aot_executor_from_factory(
for target in module.target.values()
if "hexagon" in target.keys
)
assert len(module.target.values()) == 1

for target in module.target.values():
target_kind = str(target).split()[0]

assert hexagon_arch, "No hexagon target architecture found"
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
hexagon_arch = hexagon_arch.pop()
Expand All @@ -295,11 +305,19 @@ def _aot_executor_from_factory(
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name

module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
if target_kind == "hexagon":
module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
elif target_kind == "llvm":
module.export_library(
str(binary_path),
cc=hexagon.hexagon_clang_plus(),
)
else:
raise ValueError("Incorrect Target kind.")

self.upload(binary_path, binary_name)

Expand Down
6 changes: 2 additions & 4 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,12 +1234,10 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
Target target_host;
for (const auto& it : tmp) {
auto dev_type = it.first.as<tir::IntImmNode>();
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
if (!target_host.defined() && ((it.second->kind->device_type == kDLCPU) ||
(it.second->kind->device_type == kDLHexagon))) {
target_host = it.second;
}
if (!target_host.defined() && it.second->kind->device_type == kDLHexagon) {
target_host = *(new Target("c"));
}
ICHECK(dev_type);
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
}
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,16 @@ struct HexagonWorkspacePool : public WorkspacePool {
};

void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
return dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size);
}

void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
CHECK(hexagon_buffer_map_.count(data) != 0)
<< "Attempt made to free unknown or already freed workspace allocation";
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, data);
Expand Down
26 changes: 16 additions & 10 deletions src/runtime/hexagon/rpc/hexagon/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ class HexagonIOHandler {
read_buffer_size_bytes_{read_buffer_size_bytes},
write_buffer_available_length_{0} {}

void MessageStart(size_t message_size_bytes) {}
void MessageStart(size_t message_size_bytes) { LOG(INFO) << "MessageStart called."; }

ssize_t PosixWrite(const uint8_t* buf, size_t write_len_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes
<< ")";
LOG(INFO) << "HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes << ")";
int32_t written_size = write_buffer_.sputn(reinterpret_cast<const char*>(buf), write_len_bytes);
if (written_size != write_len_bytes) {
LOG(ERROR) << "written_size(" << written_size << ") != write_len_bytes(" << write_len_bytes
Expand All @@ -72,10 +71,10 @@ class HexagonIOHandler {
return (ssize_t)written_size;
}

void MessageDone() { LOG(INFO) << "INFO: Message Done."; }
void MessageDone() { LOG(INFO) << "Message Done."; }

ssize_t PosixRead(uint8_t* buf, size_t read_len_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
LOG(INFO) << "HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
<< "), read_buffer_index_(" << read_buffer_index_ << ")";

uint32_t bytes_to_read = 0;
Expand All @@ -99,7 +98,7 @@ class HexagonIOHandler {
* \return The status
*/
AEEResult SetReadBuffer(const uint8_t* data, size_t data_size_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
LOG(INFO) << "HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
<< "), read_buffer_index_(" << read_buffer_index_ << "), read_buffer_size_bytes_("
<< read_buffer_size_bytes_ << ")";
if (data_size_bytes > read_buffer_size_bytes_) {
Expand All @@ -121,7 +120,7 @@ class HexagonIOHandler {
* \return The size of data that is read in bytes.
*/
int64_t ReadFromWriteBuffer(uint8_t* buf, size_t read_size_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
LOG(INFO) << "HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
<< read_size_bytes;
int64_t size = (int64_t)write_buffer_.sgetn(reinterpret_cast<char*>(buf), read_size_bytes);
write_buffer_available_length_ -= size;
Expand All @@ -133,7 +132,7 @@ class HexagonIOHandler {
return size;
}

void Close() { LOG(INFO) << "INFO: HexagonIOHandler Close called"; }
void Close() { LOG(INFO) << "HexagonIOHandler Close called"; }

void Exit(int code) { exit(code); }

Expand All @@ -156,13 +155,19 @@ class HexagonRPCServer {
* \param data The data pointer
* \param data_size_bytes The data size in bytes.
*
* \return The size of data written to IOHandler.
* \return The size of data written to IOHandler if no error.
* Otherwise, returns -1;
*/
int64_t Write(const uint8_t* data, size_t data_size_bytes) {
if (io_.SetReadBuffer(data, data_size_bytes) != AEE_SUCCESS) {
LOG(ERROR) << "ERROR: SetReadBuffer failed";
return -1;
}

if (!rpc_server_.ProcessOnePacket()) {
LOG(ERROR) << "ERROR: ProcessOnePacket failed";
return -1;
}
rpc_server_.ProcessOnePacket();
return (int64_t)data_size_bytes;
}

Expand Down Expand Up @@ -211,6 +216,7 @@ const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) {
void reset_device_api() {
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2");
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api);
}

int __QAIC_HEADER(hexagon_rpc_open)(const char* uri, remote_handle64* handle) {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/hexagon/rpc/simulator/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ int main() {
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
ICHECK(api_v2 != nullptr);
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);

tvm::runtime::hexagon::SimulatorRPCServer server;

Expand Down
34 changes: 33 additions & 1 deletion src/runtime/hexagon/rpc/simulator/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ class SimulatorRPCChannel final : public RPCChannel {
std::string runmain; // Path to run_main_on_hexagon.
};

struct Message_ {
Message msg;
std::string str() const;
};

Message SendMsg(Message msg);
Message SendMsg(uint32_t code, uint32_t len, uint32_t va);
void ReadFromProcess(void* host_dst, HEX_VA_t src, size_t len);
Expand Down Expand Up @@ -461,6 +466,27 @@ std::string SimulatorRPCChannel::Cpu_::str() const {
return default_cpu_;
}

std::string SimulatorRPCChannel::Message_::str() const {
switch (msg.code) {
case Message::kNone:
return "kNone";
case Message::kAck:
return "kAck";
case Message::kTerminate:
return "kTerminate";
case Message::kReceiveStart:
return "kReceiveStart";
case Message::kReceiveEnd:
return "kReceiveEnd";
case Message::kSendStart:
return "kSendStart";
case Message::kSendEnd:
return "kSendEnd";
default:
break;
}
}

SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std::string& cpu)
: root(sdk_root) {
// For v69 chips, still look for v68 in the directory names.
Expand Down Expand Up @@ -524,6 +550,7 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) {
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
ICHECK(api_v2 != nullptr);
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);

const char* sdk_root_env = std::getenv("HEXAGON_SDK_ROOT");
ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT";
Expand Down Expand Up @@ -651,9 +678,14 @@ Message SimulatorRPCChannel::SendMsg(Message msg) {
HEX_4u_t result;

core = sim_->Run(&result);
ICHECK_EQ(core, HEX_CORE_BREAKPOINT);
Core_ core_ = {core};
ICHECK_EQ(core, HEX_CORE_BREAKPOINT)
<< "Expecting HEX_CORE_BREAKPOINT, received: " << core_.str();
};

Message_ msg_ = {msg};
LOG(INFO) << "Sending message: " << msg_.str();

WriteToProcess(message_buffer_v_, &msg, sizeof msg);
run();

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
loaders += name.substr(loadkey.size());
}
}
LOG(FATAL) << "Binary was created using " << type_key
<< " but a loader of that name is not registered. Available loaders are " << loaders
LOG(FATAL) << "Binary was created using {" << type_key
<< "} but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}

Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,12 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {

TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenHexagon();
*rv = static_cast<void*>(cg);
});

} // namespace codegen
} // namespace tvm

Expand Down
Loading

0 comments on commit 3cd73c4

Please sign in to comment.