Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESI][Runtime] Poll method and optional service thread polling #7460

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions integration_test/Dialect/ESI/runtime/loopback.mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@
print(f"result: {result}")
if platform != "trace":
assert result == [-21, -22]

acc = None

print("PASS")
22 changes: 19 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class Accelerator : public HWModule {

/// Abstract class representing a connection to an accelerator. Actual
/// connections (e.g. to a co-simulation or actual device) are implemented by
/// subclasses.
/// subclasses. No methods in here are thread safe.
class AcceleratorConnection {
public:
AcceleratorConnection(Context &ctxt);
virtual ~AcceleratorConnection() = default;
virtual ~AcceleratorConnection();
Context &getCtxt() const { return ctxt; }

/// Disconnect from the accelerator cleanly.
Expand All @@ -89,7 +89,12 @@ class AcceleratorConnection {
virtual std::map<std::string, ChannelPort &>
requestChannelsFor(AppIDPath, const BundleType *) = 0;

AcceleratorServiceThread *getServiceThread() { return serviceThread.get(); }
/// Return a pointer to the accelerator 'service' thread (or threads). If the
/// thread(s) are not running, they will be started when this method is
/// called. `std::thread` is used. If users don't want the runtime to spin up
/// threads, don't call this method. `AcceleratorServiceThread` is owned by
/// AcceleratorConnection and governed by the lifetime of the this object.
AcceleratorServiceThread *getServiceThread();

using Service = services::Service;
/// Get a typed reference to a particular service type. Caller does *not* take
Expand All @@ -109,6 +114,10 @@ class AcceleratorConnection {
ServiceImplDetails details = {},
HWClientDetails clients = {});

/// Assume ownership of an accelerator object. Ties the lifetime of the
/// accelerator to this connection. Returns a raw pointer to the object.
Accelerator *takeOwnership(std::unique_ptr<Accelerator> accel);

protected:
/// Called by `getServiceImpl` exclusively. It wraps the pointer returned by
/// this in a unique_ptr and caches it. Separate this from the
Expand All @@ -128,6 +137,10 @@ class AcceleratorConnection {
std::map<ServiceCacheKey, std::unique_ptr<Service>> serviceCache;

std::unique_ptr<AcceleratorServiceThread> serviceThread;

/// List of accelerator objects owned by this connection. These are destroyed
/// when the connection dies or is shutdown.
std::vector<std::unique_ptr<Accelerator>> ownedAccelerators;
};

namespace registry {
Expand Down Expand Up @@ -173,6 +186,9 @@ class AcceleratorServiceThread {
addListener(std::initializer_list<ReadChannelPort *> listenPorts,
std::function<void(ReadChannelPort *, MessageData)> callback);

/// Poll this module.
void addPoll(HWModule &module);

/// Instruct the service thread to stop running.
void stop();

Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Design.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class HWModule {
return portIndex;
}

/// Master poll method. Calls the `poll` method on all locally owned ports and
/// the master `poll` method on all of the children. Returns true if any of
/// the `poll` calls returns true.
bool poll();

protected:
const std::optional<ModuleInfo> info;
const std::vector<std::unique_ptr<Instance>> children;
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class Manifest {
// Modules which have designer specified metadata.
std::vector<ModuleInfo> getModuleInfos() const;

// Build a dynamic design hierarchy from the manifest.
std::unique_ptr<Accelerator>
buildAccelerator(AcceleratorConnection &acc) const;
// Build a dynamic design hierarchy from the manifest. The
// AcceleratorConnection owns the returned pointer so its lifetime is
// determined by the connection.
Accelerator *buildAccelerator(AcceleratorConnection &acc) const;

/// The Type Table is an ordered list of types. The offset can be used to
/// compactly and uniquely within a design. It does not include all of the
Expand Down
52 changes: 46 additions & 6 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,38 @@ namespace esi {
class ChannelPort {
public:
ChannelPort(const Type *type) : type(type) {}
virtual ~ChannelPort() { disconnect(); }
virtual ~ChannelPort() {}

/// Set up a connection to the accelerator. The buffer size is optional and
/// should be considered merely a hint. Individual implementations use it
/// however they like. The unit is number of messages of the port type.
virtual void connect(std::optional<unsigned> bufferSize = std::nullopt) {
connectImpl(bufferSize);
virtual void connect(std::optional<unsigned> bufferSize = std::nullopt) = 0;
virtual void disconnect() = 0;
virtual bool isConnected() const = 0;

/// Poll for incoming data. Returns true if data was read or written into a
/// buffer as a result of the poll. Calling the call back could (will) also
/// happen in that case. Some backends need this to be called periodically. In
/// the usual case, this will be called by a background thread, but the ESI
/// runtime does not want to assume that the host processes use standard
/// threads. If the user wants to provide their own threads, they need to call
/// this on each port occasionally. This is also called from the 'master' poll
/// method in the Accelerator class.
bool poll() {
if (isConnected())
return pollImpl();
return false;
}
virtual void disconnect() {}

const Type *getType() const { return type; }

private:
protected:
const Type *type;

/// Method called by poll() to actually poll the channel if the channel is
/// connected.
virtual bool pollImpl() { return false; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't a user get some kind of feedback if they called this on something which they expected to have a poll implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feedback of what sort? Whether or not one has to run this function is very much backend dependent so I don't want like an exception being thrown if a user calls it on a backend which doesn't need it.


/// Called by all connect methods to let backends initiate the underlying
/// connections.
virtual void connectImpl(std::optional<unsigned> bufferSize) {}
Expand All @@ -58,8 +75,19 @@ class WriteChannelPort : public ChannelPort {
public:
using ChannelPort::ChannelPort;

virtual void
connect(std::optional<unsigned> bufferSize = std::nullopt) override {
connectImpl(bufferSize);
connected = true;
}
virtual void disconnect() override { connected = false; }
virtual bool isConnected() const override { return connected; }

/// A very basic write API. Will likely change for performance reasons.
virtual void write(const MessageData &) = 0;

private:
volatile bool connected = false;
};

/// A ChannelPort which reads data from the accelerator. It has two modes:
Expand All @@ -72,6 +100,9 @@ class ReadChannelPort : public ChannelPort {
ReadChannelPort(const Type *type)
: ChannelPort(type), mode(Mode::Disconnected) {}
virtual void disconnect() override { mode = Mode::Disconnected; }
virtual bool isConnected() const override {
return mode != Mode::Disconnected;
}

//===--------------------------------------------------------------------===//
// Callback mode: To use a callback, connect with a callback function which
Expand Down Expand Up @@ -121,7 +152,7 @@ class ReadChannelPort : public ChannelPort {
protected:
/// Indicates the current mode of the channel.
enum Mode { Disconnected, Callback, Polling };
Mode mode;
volatile Mode mode;

/// Backends call this callback when new data is available.
std::function<bool(MessageData)> callback;
Expand Down Expand Up @@ -178,6 +209,15 @@ class BundlePort {
return const_cast<T *>(dynamic_cast<const T *>(this));
}

/// Calls `poll` on all channels in the bundle and returns true if any of them
/// returned true.
bool poll() {
bool result = false;
for (auto &channel : channels)
result |= channel.second.poll();
return result;
}

private:
AppID id;
std::map<std::string, ChannelPort &> channels;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TraceAccelerator : public esi::AcceleratorConnection {
/// is opened for writing. For 'Read' mode, this file is opened for reading.
TraceAccelerator(Context &, Mode mode, std::filesystem::path manifestJson,
std::filesystem::path traceFile);
~TraceAccelerator() override;

/// Parse the connection string and instantiate the accelerator. Format is:
/// "<mode>:<manifest path>[:<traceFile>]".
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class XrtAccelerator : public esi::AcceleratorConnection {
struct Impl;

XrtAccelerator(Context &, std::string xclbin, std::string kernelName);
~XrtAccelerator();
static std::unique_ptr<AcceleratorConnection>
connect(Context &, std::string connectionString);

Expand Down
48 changes: 43 additions & 5 deletions lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ using namespace esi::services;

namespace esi {
AcceleratorConnection::AcceleratorConnection(Context &ctxt)
: ctxt(ctxt), serviceThread(std::make_unique<AcceleratorServiceThread>()) {}
: ctxt(ctxt), serviceThread(nullptr) {}
AcceleratorConnection::~AcceleratorConnection() { disconnect(); }

AcceleratorServiceThread *AcceleratorConnection::getServiceThread() {
if (!serviceThread)
serviceThread = std::make_unique<AcceleratorServiceThread>();
return serviceThread.get();
}

services::Service *AcceleratorConnection::getService(Service::Type svcType,
AppIDPath id,
Expand All @@ -54,6 +61,13 @@ services::Service *AcceleratorConnection::getService(Service::Type svcType,
return cacheEntry.get();
}

Accelerator *
AcceleratorConnection::takeOwnership(std::unique_ptr<Accelerator> acc) {
Accelerator *ret = acc.get();
ownedAccelerators.push_back(std::move(acc));
return ret;
}

/// Get the path to the currently running executable.
static std::filesystem::path getExePath() {
#ifdef __linux__
Expand Down Expand Up @@ -224,18 +238,27 @@ struct AcceleratorServiceThread::Impl {
addListener(std::initializer_list<ReadChannelPort *> listenPorts,
std::function<void(ReadChannelPort *, MessageData)> callback);

void addTask(std::function<void(void)> task) {
std::lock_guard<std::mutex> g(m);
taskList.push_back(task);
}

private:
void loop();
volatile bool shutdown = false;
std::thread me;

// Protect the listeners std::map.
std::mutex listenerMutex;
// Protect the shared data structures.
std::mutex m;

// Map of read ports to callbacks.
std::map<ReadChannelPort *,
std::pair<std::function<void(ReadChannelPort *, MessageData)>,
std::future<MessageData>>>
listeners;

/// Tasks which should be called on every loop iteration.
std::vector<std::function<void(void)>> taskList;
};

void AcceleratorServiceThread::Impl::loop() {
Expand All @@ -245,6 +268,7 @@ void AcceleratorServiceThread::Impl::loop() {
std::function<void(ReadChannelPort *, MessageData)>,
MessageData>>
portUnlockWorkList;
std::vector<std::function<void(void)>> taskListCopy;
MessageData data;

while (!shutdown) {
Expand All @@ -256,7 +280,7 @@ void AcceleratorServiceThread::Impl::loop() {
// Check and gather data from all the read ports we are monitoring. Put the
// callbacks to be called later so we can release the lock.
{
std::lock_guard<std::mutex> g(listenerMutex);
std::lock_guard<std::mutex> g(m);
for (auto &[channel, cbfPair] : listeners) {
assert(channel && "Null channel in listener list");
std::future<MessageData> &f = cbfPair.second;
Expand All @@ -273,13 +297,22 @@ void AcceleratorServiceThread::Impl::loop() {

// Clear the worklist for the next iteration.
portUnlockWorkList.clear();

// Call any tasks that have been added. Copy it first so we can release the
// lock ASAP.
{
std::lock_guard<std::mutex> g(m);
taskListCopy = taskList;
}
for (auto &task : taskListCopy)
task();
}
}

void AcceleratorServiceThread::Impl::addListener(
std::initializer_list<ReadChannelPort *> listenPorts,
std::function<void(ReadChannelPort *, MessageData)> callback) {
std::lock_guard<std::mutex> g(listenerMutex);
std::lock_guard<std::mutex> g(m);
for (auto port : listenPorts) {
if (listeners.count(port))
throw std::runtime_error("Port already has a listener");
Expand Down Expand Up @@ -312,6 +345,11 @@ void AcceleratorServiceThread::addListener(
impl->addListener(listenPorts, callback);
}

void AcceleratorServiceThread::addPoll(HWModule &module) {
assert(impl && "Service thread not running");
impl->addTask([&module]() { module.poll(); });
}

void AcceleratorConnection::disconnect() {
if (serviceThread) {
serviceThread->stop();
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/lib/Design.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,13 @@ HWModule::HWModule(std::optional<ModuleInfo> info,
childIndex(buildIndex(this->children)), services(services),
ports(std::move(ports)), portIndex(buildIndex(this->ports)) {}

bool HWModule::poll() {
bool result = false;
for (auto &port : ports)
result |= port->poll();
for (auto &child : children)
result |= child->poll();
return result;
}

} // namespace esi
5 changes: 2 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,8 @@ std::vector<ModuleInfo> Manifest::getModuleInfos() const {
return ret;
}

std::unique_ptr<Accelerator>
Manifest::buildAccelerator(AcceleratorConnection &acc) const {
return impl->buildAccelerator(acc);
Accelerator *Manifest::buildAccelerator(AcceleratorConnection &acc) const {
return acc.takeOwnership(impl->buildAccelerator(acc));
}

const std::vector<const Type *> &Manifest::getTypeTable() const {
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void ReadChannelPort::connect(std::function<bool(MessageData)> callback,
throw std::runtime_error("Channel already connected");
mode = Mode::Callback;
this->callback = callback;
ChannelPort::connect(bufferSize);
connectImpl(bufferSize);
}

void ReadChannelPort::connect(std::optional<unsigned> bufferSize) {
Expand All @@ -71,7 +71,7 @@ void ReadChannelPort::connect(std::optional<unsigned> bufferSize) {
}
return true;
};
ChannelPort::connect(bufferSize);
connectImpl(bufferSize);
}

std::future<MessageData> ReadChannelPort::readAsync() {
Expand Down
Loading
Loading