Skip to content

Commit

Permalink
[ESI] Add optional non-blocking write API to WriteChannelPort (#7555)
Browse files Browse the repository at this point in the history
[ESI] Add optional non-blocking write API to `WriteChannelPort`
  • Loading branch information
mortbopet authored Aug 30, 2024
1 parent b937bcf commit e594996
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 14 deletions.
21 changes: 20 additions & 1 deletion frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import esiaccel as esi

import sys
import time

platform = sys.argv[1]
acc = esi.AcceleratorConnection(platform, sys.argv[2])
Expand Down Expand Up @@ -87,12 +88,30 @@ def read_offset_check(i: int, add_amt: int):
################################################################################

data = 10234
# Blocking write interface
send.write(data)
got_data = False
resp = recv.read()

print(f"data: {data}")
print(f"resp: {resp}")
assert resp == data + add_amt

# Non-blocking write interface
data = 10235
nb_wr_start = time.time()

# Timeout of 5 seconds
nb_timeout = nb_wr_start + 5
write_succeeded = False
while time.time() < nb_timeout:
write_succeeded = send.try_write(data)
if write_succeeded:
break

assert (write_succeeded, "Non-blocking write failed")
resp = recv.read()
print(f"data: {data}")
print(f"resp: {resp}")
assert resp == data + add_amt

print("PASS")
8 changes: 7 additions & 1 deletion lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,15 @@ class WriteChannelPort : public ChannelPort {
virtual void disconnect() override { connected = false; }
virtual bool isConnected() const override { return connected; }

/// A very basic write API. Will likely change for performance reasons.
/// A very basic blocking write API. Will likely change for performance
/// reasons.
virtual void write(const MessageData &) = 0;

/// A basic non-blocking write API. Returns true if the data was written.
/// It is invalid for backends to always return false (i.e. backends must
/// eventually ensure that writes may succeed).
virtual bool tryWrite(const MessageData &data) = 0;

private:
volatile bool connected = false;
};
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ class WriteCosimChannelPort : public WriteChannelPort {
". Details: " + sendStatus.error_details());
}

bool tryWrite(const MessageData &data) override {
write(data);
return true;
}

protected:
ChannelServer::Stub *rpcClient;
/// The channel description as provided by the server.
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/RpcServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class RpcServerWritePort : public WriteChannelPort {
public:
RpcServerWritePort(Type *type) : WriteChannelPort(type) {}
void write(const MessageData &data) override { writeQueue.push(data); }
bool tryWrite(const MessageData &data) override {
writeQueue.push(data);
return true;
}

utils::TSQueue<MessageData> writeQueue;
};
Expand Down
14 changes: 10 additions & 4 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct esi::backends::trace::TraceAccelerator::Impl {
void adoptChannelPort(ChannelPort *port) { channels.emplace_back(port); }

void write(const AppIDPath &id, const std::string &portName, const void *data,
size_t size);
size_t size, const std::string &prefix = "");
std::ostream &write(std::string service) {
assert(traceWrite && "traceWrite is null");
*traceWrite << "[" << service << "] ";
Expand All @@ -93,14 +93,15 @@ struct esi::backends::trace::TraceAccelerator::Impl {

void TraceAccelerator::Impl::write(const AppIDPath &id,
const std::string &portName,
const void *data, size_t size) {
const void *data, size_t size,
const std::string &prefix) {
if (!isWriteable())
return;
std::string b64data;
utils::encodeBase64(data, size, b64data);

*traceWrite << "write " << id << '.' << portName << ": " << b64data
<< std::endl;
*traceWrite << prefix << (prefix.empty() ? "w" : "W") << "rite " << id << '.'
<< portName << ": " << b64data << std::endl;
}

std::unique_ptr<AcceleratorConnection>
Expand Down Expand Up @@ -192,6 +193,11 @@ class WriteTraceChannelPort : public WriteChannelPort {
impl.write(id, portName, data.getBytes(), data.getSize());
}

bool tryWrite(const MessageData &data) override {
impl.write(id, portName, data.getBytes(), data.getSize(), "try");
return true;
}

protected:
TraceAccelerator::Impl &impl;
AppIDPath id;
Expand Down
11 changes: 9 additions & 2 deletions lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,18 @@ PYBIND11_MODULE(esiCppAccel, m) {
py::return_value_policy::reference);

py::class_<WriteChannelPort, ChannelPort>(m, "WriteChannelPort")
.def("write", [](WriteChannelPort &p, py::bytearray &data) {
.def("write",
[](WriteChannelPort &p, py::bytearray &data) {
py::buffer_info info(py::buffer(data).request());
std::vector<uint8_t> dataVec((uint8_t *)info.ptr,
(uint8_t *)info.ptr + info.size);
p.write(dataVec);
})
.def("tryWrite", [](WriteChannelPort &p, py::bytearray &data) {
py::buffer_info info(py::buffer(data).request());
std::vector<uint8_t> dataVec((uint8_t *)info.ptr,
(uint8_t *)info.ptr + info.size);
p.write(dataVec);
return p.tryWrite(dataVec);
});
py::class_<ReadChannelPort, ChannelPort>(m, "ReadChannelPort")
.def(
Expand Down
19 changes: 13 additions & 6 deletions lib/Dialect/ESI/runtime/python/esiaccel/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,26 @@ def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
super().__init__(owner, cpp_port)
self.cpp_port: cpp.WriteChannelPort = cpp_port

def write(self, msg=None) -> bool:
"""Write a typed message to the channel. Attempts to serialize 'msg' to what
the accelerator expects, but will fail if the object is not convertible to
the port type."""

def __serialize_msg(self, msg=None) -> bytearray:
valid, reason = self.type.is_valid(msg)
if not valid:
raise ValueError(
f"'{msg}' cannot be converted to '{self.type}': {reason}")
msg_bytes: bytearray = self.type.serialize(msg)
self.cpp_port.write(msg_bytes)
return msg_bytes

def write(self, msg=None) -> bool:
"""Write a typed message to the channel. Attempts to serialize 'msg' to what
the accelerator expects, but will fail if the object is not convertible to
the port type."""
self.cpp_port.write(self.__serialize_msg(msg))
return True

def try_write(self, msg=None) -> bool:
"""Like 'write', but uses the non-blocking tryWrite method of the underlying
port. Returns True if the write was successful, False otherwise."""
return self.cpp_port.tryWrite(self.__serialize_msg(msg))


class ReadPort(Port):
"""A unidirectional communication channel from the accelerator to the host."""
Expand Down

0 comments on commit e594996

Please sign in to comment.