Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] AoT with LLVM Codegen on Hexagon #11065

Merged
merged 6 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,14 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
assert self._workspace
self._copy_to_remote(local_path, os.path.join(str(self._workspace), remote_filename))

def start_session(self) -> Session:
def start_session(self, session_name: str = "hexagon-rpc") -> Session:
"""Connect to the RPC server.

Parameters
----------
session_name : str
RPC session name.

Returns
-------
Session :
Expand All @@ -197,7 +202,7 @@ def start_session(self) -> Session:
"timeout": 0,
"key": self._device_key,
}
return Session(self, hexagon_remote_kw)
return Session(self, hexagon_remote_kw, session_name=session_name)

def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
"""Load TVM module.
Expand Down
83 changes: 68 additions & 15 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ def __init__(
rpc_receive_buffer_size_bytes: int = 2 * 1024 * 1024,
):
self._launcher = launcher
self._session_name = session_name
self._remote_stack_size_bytes = remote_stack_size_bytes
self._rpc_receive_buffer_size_bytes = rpc_receive_buffer_size_bytes
self._remote_kw = remote_kw
self._session_name: str = session_name
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
self._remote_stack_size_bytes: int = remote_stack_size_bytes
self._rpc_receive_buffer_size_bytes: int = rpc_receive_buffer_size_bytes
self._remote_kw: dict = remote_kw
self._rpc = None
self.device = None
self._requires_cpu_device = False
self._device = None

def __enter__(self):
if self.device:
if self._rpc:
# Already initialized
return self

Expand All @@ -86,7 +87,6 @@ def __enter__(self):
self._rpc_receive_buffer_size_bytes,
],
)
self.device = self._rpc.hexagon(0)
return self

except RuntimeError as exception:
Expand All @@ -95,6 +95,20 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_traceback):
pass

@property
def device(self):
"""Session device."""

if self._device is not None:
return self._device

if self._requires_cpu_device:
self._device = self._rpc.cpu(0)
else:
self._device = self._rpc.hexagon(0)

return self._device

def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
"""Upload a local file to the remote workspace.

Expand Down Expand Up @@ -133,9 +147,7 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
TVM module object.
"""

assert (
self.device is not None
), "Hexagon session must be started using __enter__ prior to use"
assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use"

if isinstance(module, tvm.runtime.Module):
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -179,6 +191,7 @@ def get_graph_executor(
"""

graph_mod = self.load_module(module_name)
self._set_device_type(graph_mod)
return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)

def get_aot_executor(
Expand Down Expand Up @@ -206,6 +219,7 @@ def get_aot_executor(
"""

aot_mod = self.load_module(module_name)
self._set_device_type(aot_mod)
return tvm.runtime.executor.AotModule(aot_mod["default"](self.device))

def get_executor_from_factory(self, module: ExecutorFactoryModule):
Expand All @@ -226,6 +240,28 @@ def get_executor_from_factory(self, module: ExecutorFactoryModule):

raise TypeError(f"Unsupported executor type: {type(module)}")

def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]):
"""Set session device type(hexagon, cpu) based on target in module.

Parameters
----------

module: TVMModule
TVM module object.
"""
# for cases when module is a single schedule without target attribute.
if not hasattr(module, "target"):
self._requires_cpu_device = False
else:
assert len(module.target.values()) == 1
for target in module.target.values():
target_type = str(target).split()[0]

if target_type == "llvm":
self._requires_cpu_device = True
else:
self._requires_cpu_device = False

def _graph_executor_from_factory(
self,
module: Union[str, pathlib.Path, GraphExecutorFactoryModule],
Expand Down Expand Up @@ -286,6 +322,12 @@ def _aot_executor_from_factory(
for target in module.target.values()
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
if "hexagon" in target.keys
)
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved

self._set_device_type(module)

for target in module.target.values():
target_type = str(target).split()[0]

assert hexagon_arch, "No hexagon target architecture found"
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
hexagon_arch = hexagon_arch.pop()
Expand All @@ -295,11 +337,22 @@ def _aot_executor_from_factory(
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name

module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
if target_type == "hexagon":
module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
elif target_type == "llvm":
module.export_library(
str(binary_path),
cc=hexagon.hexagon_clang_plus(),
)
else:
raise ValueError(
f"Incorrect Target kind.\n"
f"Target kind should be from these options: [hexagon, llvm]."
)

self.upload(binary_path, binary_name)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def alloc_buffer(
"""
special_stmt - Reads/Writes
"""

@overload
def reads(read_regions: List[BufferSlice]) -> None: ...
@overload
Expand Down Expand Up @@ -337,6 +338,7 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr:
"""
Scope handler - Loops
"""

@overload
def serial(
begin: Union[PrimExpr, int],
Expand Down
8 changes: 4 additions & 4 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,12 +1234,12 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
Target target_host;
for (const auto& it : tmp) {
auto dev_type = it.first.as<tir::IntImmNode>();
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
// TODO(tvm-team): AoT only works with kDLCPU device type. We can remove kDLHexagon
// here once we refactored kDLHexagon to kDLCPU.
if (!target_host.defined() && ((it.second->kind->device_type == kDLCPU) ||
(it.second->kind->device_type == kDLHexagon))) {
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
target_host = it.second;
}
if (!target_host.defined() && it.second->kind->device_type == kDLHexagon) {
target_host = *(new Target("c"));
}
ICHECK(dev_type);
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
}
Expand Down
12 changes: 10 additions & 2 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh

void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t alignment,
DLDataType type_hint) {
// Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
Expand All @@ -94,6 +95,7 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t align
}

void HexagonDeviceAPIv2::FreeDataSpace(Device dev, void* ptr) {
// Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
Expand All @@ -107,12 +109,18 @@ struct HexagonWorkspacePool : public WorkspacePool {
};

void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
// Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
return dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size);
}

void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
// Added kDLCPU since we use hexagon as a sub-target of LLVM which by default maps to kDLCPU;
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
CHECK(hexagon_buffer_map_.count(data) != 0)
<< "Attempt made to free unknown or already freed workspace allocation";
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, data);
Expand Down
28 changes: 18 additions & 10 deletions src/runtime/hexagon/rpc/hexagon/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class HexagonIOHandler {
void MessageStart(size_t message_size_bytes) {}

ssize_t PosixWrite(const uint8_t* buf, size_t write_len_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes
<< ")";
LOG(INFO) << "HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes << ")";
int32_t written_size = write_buffer_.sputn(reinterpret_cast<const char*>(buf), write_len_bytes);
if (written_size != write_len_bytes) {
LOG(ERROR) << "written_size(" << written_size << ") != write_len_bytes(" << write_len_bytes
Expand All @@ -72,10 +71,10 @@ class HexagonIOHandler {
return (ssize_t)written_size;
}

void MessageDone() { LOG(INFO) << "INFO: Message Done."; }
void MessageDone() { LOG(INFO) << "Message Done."; }

ssize_t PosixRead(uint8_t* buf, size_t read_len_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
LOG(INFO) << "HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
<< "), read_buffer_index_(" << read_buffer_index_ << ")";

uint32_t bytes_to_read = 0;
Expand All @@ -99,7 +98,7 @@ class HexagonIOHandler {
* \return The status
*/
AEEResult SetReadBuffer(const uint8_t* data, size_t data_size_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
LOG(INFO) << "HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
<< "), read_buffer_index_(" << read_buffer_index_ << "), read_buffer_size_bytes_("
<< read_buffer_size_bytes_ << ")";
if (data_size_bytes > read_buffer_size_bytes_) {
Expand All @@ -121,7 +120,7 @@ class HexagonIOHandler {
* \return The size of data that is read in bytes.
*/
int64_t ReadFromWriteBuffer(uint8_t* buf, size_t read_size_bytes) {
LOG(INFO) << "INFO: HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
LOG(INFO) << "HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
<< read_size_bytes;
int64_t size = (int64_t)write_buffer_.sgetn(reinterpret_cast<char*>(buf), read_size_bytes);
write_buffer_available_length_ -= size;
Expand All @@ -133,7 +132,7 @@ class HexagonIOHandler {
return size;
}

void Close() { LOG(INFO) << "INFO: HexagonIOHandler Close called"; }
void Close() { LOG(INFO) << "HexagonIOHandler Close called"; }

void Exit(int code) { exit(code); }

Expand All @@ -156,13 +155,20 @@ class HexagonRPCServer {
* \param data The data pointer
* \param data_size_bytes The data size in bytes.
*
* \return The size of data written to IOHandler.
* \return The size of data written to IOHandler if no error.
* Otherwise, returns -1;
*/
int64_t Write(const uint8_t* data, size_t data_size_bytes) {
if (io_.SetReadBuffer(data, data_size_bytes) != AEE_SUCCESS) {
AEEResult rc = io_.SetReadBuffer(data, data_size_bytes);
if (rc != AEE_SUCCESS) {
LOG(ERROR) << "ERROR: SetReadBuffer failed: " << rc;
return -1;
}

if (!rpc_server_.ProcessOnePacket()) {
LOG(ERROR) << "ERROR: ProcessOnePacket failed";
return -1;
}
rpc_server_.ProcessOnePacket();
return (int64_t)data_size_bytes;
}

Expand Down Expand Up @@ -211,6 +217,8 @@ const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) {
void reset_device_api() {
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2");
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api);
// Registering device_api.cpu as device_api.hexagon.v2 since we use hexagon as sub-target of LLVM.
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api);
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
}

int __QAIC_HEADER(hexagon_rpc_open)(const char* uri, remote_handle64* handle) {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/hexagon/rpc/simulator/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ int main() {
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
ICHECK(api_v2 != nullptr);
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);

tvm::runtime::hexagon::SimulatorRPCServer server;

Expand Down
34 changes: 33 additions & 1 deletion src/runtime/hexagon/rpc/simulator/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ class SimulatorRPCChannel final : public RPCChannel {
std::string runmain; // Path to run_main_on_hexagon.
};

struct Message_ {
Message msg;
std::string str() const;
};

Message SendMsg(Message msg);
Message SendMsg(uint32_t code, uint32_t len, uint32_t va);
void ReadFromProcess(void* host_dst, HEX_VA_t src, size_t len);
Expand Down Expand Up @@ -461,6 +466,27 @@ std::string SimulatorRPCChannel::Cpu_::str() const {
return default_cpu_;
}

std::string SimulatorRPCChannel::Message_::str() const {
switch (msg.code) {
case Message::kNone:
return "kNone";
case Message::kAck:
return "kAck";
case Message::kTerminate:
return "kTerminate";
case Message::kReceiveStart:
return "kReceiveStart";
case Message::kReceiveEnd:
return "kReceiveEnd";
case Message::kSendStart:
return "kSendStart";
case Message::kSendEnd:
return "kSendEnd";
default:
break;
}
}

SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std::string& cpu)
: root(sdk_root) {
// For v69 chips, still look for v68 in the directory names.
Expand Down Expand Up @@ -524,6 +550,7 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) {
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
ICHECK(api_v2 != nullptr);
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);

const char* sdk_root_env = std::getenv("HEXAGON_SDK_ROOT");
ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT";
Expand Down Expand Up @@ -651,9 +678,14 @@ Message SimulatorRPCChannel::SendMsg(Message msg) {
HEX_4u_t result;

core = sim_->Run(&result);
ICHECK_EQ(core, HEX_CORE_BREAKPOINT);
Core_ core_ = {core};
ICHECK_EQ(core, HEX_CORE_BREAKPOINT)
<< "Expecting HEX_CORE_BREAKPOINT, received: " << core_.str();
};

Message_ msg_ = {msg};
LOG(INFO) << "Sending message: " << msg_.str();

WriteToProcess(message_buffer_v_, &msg, sizeof msg);
run();

Expand Down
Loading