Skip to content

Commit

Permalink
Merge pull request #2863 from cloudflare/dlapid/prewarm_promise
Browse files Browse the repository at this point in the history
Change WorkerInterface so prewarm returns a promise.
  • Loading branch information
danlapid committed Oct 11, 2024
1 parent 5e13cf1 commit 1dae2d0
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 72 deletions.
8 changes: 2 additions & 6 deletions src/workerd/api/hibernatable-web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ kj::Promise<WorkerInterface::CustomEvent::Result> HibernatableWebSocketCustomEve
kj::Promise<WorkerInterface::CustomEvent::Result> HibernatableWebSocketCustomEventImpl::sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) {
auto req = dispatcher.castAs<rpc::HibernatableWebSocketEventDispatcher>()
.hibernatableWebSocketEventRequest();
Expand Down Expand Up @@ -194,15 +193,12 @@ HibernatableWebSocketEvent::ItemsForRelease::ItemsForRelease(
tags(kj::mv(tags)) {}

HibernatableWebSocketCustomEventImpl::HibernatableWebSocketCustomEventImpl(uint16_t typeId,
kj::TaskSet& waitUntilTasks,
kj::Own<HibernationReader> params,
kj::Maybe<Worker::Actor::HibernationManager&> manager)
: typeId(typeId),
params(kj::mv(params)) {}
HibernatableWebSocketCustomEventImpl::HibernatableWebSocketCustomEventImpl(uint16_t typeId,
kj::TaskSet& waitUntilTasks,
HibernatableSocketParams params,
Worker::Actor::HibernationManager& manager)
HibernatableWebSocketCustomEventImpl::HibernatableWebSocketCustomEventImpl(
uint16_t typeId, HibernatableSocketParams params, Worker::Actor::HibernationManager& manager)
: typeId(typeId),
params(kj::mv(params)),
manager(manager) {}
Expand Down
8 changes: 2 additions & 6 deletions src/workerd/api/hibernatable-web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,17 @@ class HibernatableWebSocketCustomEventImpl final: public WorkerInterface::Custom
public kj::Refcounted {
public:
HibernatableWebSocketCustomEventImpl(uint16_t typeId,
kj::TaskSet& waitUntilTasks,
kj::Own<HibernationReader> params,
kj::Maybe<Worker::Actor::HibernationManager&> manager = kj::none);
HibernatableWebSocketCustomEventImpl(uint16_t typeId,
kj::TaskSet& waitUntilTasks,
HibernatableSocketParams params,
Worker::Actor::HibernationManager& manager);
HibernatableWebSocketCustomEventImpl(
uint16_t typeId, HibernatableSocketParams params, Worker::Actor::HibernationManager& manager);

kj::Promise<Result> run(kj::Own<IoContext_IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) override;

kj::Promise<Result> sendRpc(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) override;

uint16_t getType() override {
Expand Down
1 change: 0 additions & 1 deletion src/workerd/api/queue.c++
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,6 @@ kj::Promise<WorkerInterface::CustomEvent::Result> QueueCustomEventImpl::run(
kj::Promise<WorkerInterface::CustomEvent::Result> QueueCustomEventImpl::sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) {
auto req = dispatcher.castAs<rpc::EventDispatcher>().queueRequest();
KJ_SWITCH_ONEOF(params) {
Expand Down
1 change: 0 additions & 1 deletion src/workerd/api/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ class QueueCustomEventImpl final: public WorkerInterface::CustomEvent, public kj

kj::Promise<Result> sendRpc(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) override;

static const uint16_t EVENT_TYPE = 5;
Expand Down
1 change: 0 additions & 1 deletion src/workerd/api/trace.c++
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,6 @@ auto TraceCustomEventImpl::run(kj::Own<IoContext::IncomingRequest> incomingReque

auto TraceCustomEventImpl::sendRpc(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
workerd::rpc::EventDispatcher::Client dispatcher) -> kj::Promise<Result> {
auto req = dispatcher.sendTracesRequest();
auto out = req.initTraces(traces.size());
Expand Down
4 changes: 1 addition & 3 deletions src/workerd/api/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,7 @@ class UnsafeTraceMetrics final: public jsg::Object {

class TraceCustomEventImpl final: public WorkerInterface::CustomEvent {
public:
TraceCustomEventImpl(
uint16_t typeId, kj::TaskSet& waitUntilTasks, kj::Array<kj::Own<Trace>> traces)
TraceCustomEventImpl(uint16_t typeId, kj::Array<kj::Own<Trace>> traces)
: typeId(typeId),
traces(kj::mv(traces)) {}

Expand All @@ -615,7 +614,6 @@ class TraceCustomEventImpl final: public WorkerInterface::CustomEvent {

kj::Promise<Result> sendRpc(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) override;

uint16_t getType() override {
Expand Down
1 change: 0 additions & 1 deletion src/workerd/api/worker-rpc.c++
Original file line number Diff line number Diff line change
Expand Up @@ -1745,7 +1745,6 @@ kj::Promise<WorkerInterface::CustomEvent::Result> JsRpcSessionCustomEventImpl::r
kj::Promise<WorkerInterface::CustomEvent::Result> JsRpcSessionCustomEventImpl::sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) {
// We arrange to revoke all capabilities in this session as soon as `sendRpc()` completes or is
// canceled. Normally, the server side doesn't return if any capabilities still exist, so this
Expand Down
1 change: 0 additions & 1 deletion src/workerd/api/worker-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ class JsRpcSessionCustomEventImpl final: public WorkerInterface::CustomEvent {

kj::Promise<Result> sendRpc(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) override;

uint16_t getType() override {
Expand Down
4 changes: 2 additions & 2 deletions src/workerd/io/hibernation-manager.c++
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ kj::Promise<void> HibernationManagerImpl::handleSocketTermination(
auto workerInterface = loopback->getWorker(IoChannelFactory::SubrequestMetadata{});
event = workerInterface
->customEvent(kj::heap<api::HibernatableWebSocketCustomEventImpl>(
hibernationEventType, readLoopTasks, kj::mv(KJ_REQUIRE_NONNULL(params)), *this))
hibernationEventType, kj::mv(KJ_REQUIRE_NONNULL(params)), *this))
.ignoreResult()
.attach(kj::mv(workerInterface));
}
Expand Down Expand Up @@ -366,7 +366,7 @@ kj::Promise<void> HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) {
// Dispatch the event.
auto workerInterface = loopback->getWorker(IoChannelFactory::SubrequestMetadata{});
co_await workerInterface->customEvent(kj::heap<api::HibernatableWebSocketCustomEventImpl>(
hibernationEventType, readLoopTasks, kj::mv(params), *this));
hibernationEventType, kj::mv(params), *this));
if (isClose) {
co_return;
}
Expand Down
5 changes: 3 additions & 2 deletions src/workerd/io/worker-entrypoint.c++
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public:
kj::AsyncIoStream& connection,
ConnectResponse& response,
kj::HttpConnectSettings settings) override;
void prewarm(kj::StringPtr url) override;
kj::Promise<void> prewarm(kj::StringPtr url) override;
kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override;
kj::Promise<AlarmResult> runAlarm(kj::Date scheduledTime, uint32_t retryCount) override;
kj::Promise<bool> test() override;
Expand Down Expand Up @@ -447,7 +447,7 @@ kj::Promise<void> WorkerEntrypoint::connect(kj::StringPtr host,
JSG_FAIL_REQUIRE(TypeError, "Incoming CONNECT on a worker not supported");
}

void WorkerEntrypoint::prewarm(kj::StringPtr url) {
kj::Promise<void> WorkerEntrypoint::prewarm(kj::StringPtr url) {
// Nothing to do, the worker is already loaded.
TRACE_EVENT("workerd", "WorkerEntrypoint::prewarm()", "url", url.cStr());
auto incomingRequest =
Expand All @@ -459,6 +459,7 @@ void WorkerEntrypoint::prewarm(kj::StringPtr url) {
// TODO(someday): Ideally, middleware workers would forward prewarm() to the next stage. At
// present we don't have a good way to decide what stage that is, especially given that we'll
// be switching to `next` being a binding in the future.
return kj::READY_NOW;
}

kj::Promise<WorkerInterface::ScheduledResult> WorkerEntrypoint::runScheduled(
Expand Down
44 changes: 16 additions & 28 deletions src/workerd/io/worker-interface.c++
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ namespace {
// interface the promise resolved to.
class PromisedWorkerInterface final: public kj::Refcounted, public WorkerInterface {
public:
PromisedWorkerInterface(
kj::TaskSet& waitUntilTasks, kj::Promise<kj::Own<WorkerInterface>> promise)
: waitUntilTasks(waitUntilTasks),
promise(promise.then([this](kj::Own<WorkerInterface> result) { worker = kj::mv(result); })
PromisedWorkerInterface(kj::Promise<kj::Own<WorkerInterface>> promise)
: promise(promise.then([this](kj::Own<WorkerInterface> result) { worker = kj::mv(result); })
.fork()) {}

kj::Promise<void> request(kj::HttpMethod method,
Expand Down Expand Up @@ -49,18 +47,12 @@ public:
}
}

void prewarm(kj::StringPtr url) override {
kj::Promise<void> prewarm(kj::StringPtr url) override {
KJ_IF_SOME(w, worker) {
w.get()->prewarm(url);
co_return co_await w.get()->prewarm(url);
} else {
static auto constexpr handlePrewarm =
[](kj::Promise<void> promise, kj::String url,
kj::Own<PromisedWorkerInterface> self) -> kj::Promise<void> {
co_await promise;
KJ_ASSERT_NONNULL(self->worker)->prewarm(url);
};

waitUntilTasks.add(handlePrewarm(promise.addBranch(), kj::str(url), kj::addRef(*this)));
co_await promise;
co_return co_await KJ_ASSERT_NONNULL(worker)->prewarm(url);
}
}

Expand Down Expand Up @@ -92,15 +84,13 @@ public:
}

private:
kj::TaskSet& waitUntilTasks;
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<WorkerInterface>> worker;
};
} // namespace

kj::Own<WorkerInterface> newPromisedWorkerInterface(
kj::TaskSet& waitUntilTasks, kj::Promise<kj::Own<WorkerInterface>> promise) {
return kj::refcounted<PromisedWorkerInterface>(waitUntilTasks, kj::mv(promise));
kj::Own<WorkerInterface> newPromisedWorkerInterface(kj::Promise<kj::Own<WorkerInterface>> promise) {
return kj::refcounted<PromisedWorkerInterface>(kj::mv(promise));
}

kj::Own<kj::HttpClient> asHttpClient(kj::Own<WorkerInterface> workerInterface) {
Expand Down Expand Up @@ -238,7 +228,7 @@ public:
kj::AsyncIoStream& connection,
ConnectResponse& response,
kj::HttpConnectSettings settings) override;
void prewarm(kj::StringPtr url) override;
kj::Promise<void> prewarm(kj::StringPtr url) override;
kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override;
kj::Promise<AlarmResult> runAlarm(kj::Date scheduledTime, uint32_t retryCount) override;
kj::Promise<CustomEvent::Result> customEvent(kj::Own<CustomEvent> event) override;
Expand Down Expand Up @@ -273,8 +263,8 @@ RevocableWebSocketWorkerInterface::RevocableWebSocketWorkerInterface(
: worker(worker),
revokeProm(revokeProm.fork()) {}

void RevocableWebSocketWorkerInterface::prewarm(kj::StringPtr url) {
worker.prewarm(url);
kj::Promise<void> RevocableWebSocketWorkerInterface::prewarm(kj::StringPtr url) {
return worker.prewarm(url);
}

kj::Promise<WorkerInterface::ScheduledResult> RevocableWebSocketWorkerInterface::runScheduled(
Expand Down Expand Up @@ -324,8 +314,9 @@ public:
kj::throwFatalException(kj::mv(exception));
}

void prewarm(kj::StringPtr url) override {
kj::Promise<void> prewarm(kj::StringPtr url) override {
// ignore
return kj::READY_NOW;
}

kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override {
Expand Down Expand Up @@ -354,11 +345,9 @@ kj::Own<WorkerInterface> WorkerInterface::fromException(kj::Exception&& e) {

RpcWorkerInterface::RpcWorkerInterface(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher)
: httpOverCapnpFactory(httpOverCapnpFactory),
byteStreamFactory(byteStreamFactory),
waitUntilTasks(waitUntilTasks),
dispatcher(kj::mv(dispatcher)) {}

kj::Promise<void> RpcWorkerInterface::request(kj::HttpMethod method,
Expand All @@ -381,10 +370,10 @@ kj::Promise<void> RpcWorkerInterface::connect(kj::StringPtr host,
return promise.attach(kj::mv(inner));
}

void RpcWorkerInterface::prewarm(kj::StringPtr url) {
kj::Promise<void> RpcWorkerInterface::prewarm(kj::StringPtr url) {
auto req = dispatcher.prewarmRequest(capnp::MessageSize{url.size() / sizeof(capnp::word) + 4, 0});
req.setUrl(url);
waitUntilTasks.add(req.send().ignoreResult());
return req.send().ignoreResult();
}

kj::Promise<WorkerInterface::ScheduledResult> RpcWorkerInterface::runScheduled(
Expand Down Expand Up @@ -414,8 +403,7 @@ kj::Promise<WorkerInterface::AlarmResult> RpcWorkerInterface::runAlarm(

kj::Promise<WorkerInterface::CustomEvent::Result> RpcWorkerInterface::customEvent(
kj::Own<CustomEvent> event) {
return event->sendRpc(httpOverCapnpFactory, byteStreamFactory, waitUntilTasks, dispatcher)
.attach(kj::mv(event));
return event->sendRpc(httpOverCapnpFactory, byteStreamFactory, dispatcher).attach(kj::mv(event));
}

// ======================================================================================
Expand Down
12 changes: 3 additions & 9 deletions src/workerd/io/worker-interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class WorkerInterface: public kj::HttpService {
// to be invoked.
//
// If prewarm() has to do anything asynchronous, it should use "waitUntil" tasks.
virtual void prewarm(kj::StringPtr url) = 0;
virtual kj::Promise<void> prewarm(kj::StringPtr url) = 0;

struct ScheduledResult {
bool retry = true;
Expand Down Expand Up @@ -116,7 +116,6 @@ class WorkerInterface: public kj::HttpService {
// Forward the event over RPC.
virtual kj::Promise<Result> sendRpc(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher) = 0;

// Get the type for this event for logging / metrics purposes. This is intended for use by the
Expand Down Expand Up @@ -147,10 +146,7 @@ class WorkerInterface: public kj::HttpService {

// Given a Promise for a WorkerInterface, return a WorkerInterface whose methods will first wait
// for the promise, then invoke the destination object.
kj::Own<WorkerInterface> newPromisedWorkerInterface(
kj::TaskSet& waitUntilTasks, kj::Promise<kj::Own<WorkerInterface>> promise);
// TODO(cleanup): `waitUntilTasks` is only needed to handle `prewarm` since they
// don't return promises. We should maybe change them to return promises?
kj::Own<WorkerInterface> newPromisedWorkerInterface(kj::Promise<kj::Own<WorkerInterface>> promise);

// Adapts WorkerInterface to HttpClient, including taking ownership.
//
Expand All @@ -169,7 +165,6 @@ class RpcWorkerInterface: public WorkerInterface {
public:
RpcWorkerInterface(capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
capnp::ByteStreamFactory& byteStreamFactory,
kj::TaskSet& waitUntilTasks,
rpc::EventDispatcher::Client dispatcher);

kj::Promise<void> request(kj::HttpMethod method,
Expand All @@ -184,15 +179,14 @@ class RpcWorkerInterface: public WorkerInterface {
ConnectResponse& tunnel,
kj::HttpConnectSettings settings) override;

void prewarm(kj::StringPtr url) override;
kj::Promise<void> prewarm(kj::StringPtr url) override;
kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override;
kj::Promise<AlarmResult> runAlarm(kj::Date scheduledTime, uint32_t retryCount) override;
kj::Promise<CustomEvent::Result> customEvent(kj::Own<CustomEvent> event) override;

private:
capnp::HttpOverCapnpFactory& httpOverCapnpFactory;
capnp::ByteStreamFactory& byteStreamFactory;
kj::TaskSet& waitUntilTasks;
rpc::EventDispatcher::Client dispatcher;
};

Expand Down
6 changes: 3 additions & 3 deletions src/workerd/io/worker.c++
Original file line number Diff line number Diff line change
Expand Up @@ -3968,7 +3968,7 @@ public:
kj::AsyncIoStream& connection,
kj::HttpService::ConnectResponse& tunnel,
kj::HttpConnectSettings settings) override;
void prewarm(kj::StringPtr url) override;
kj::Promise<void> prewarm(kj::StringPtr url) override;
kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override;
kj::Promise<AlarmResult> runAlarm(kj::Date scheduledTime, uint32_t retryCount) override;
kj::Promise<CustomEvent::Result> customEvent(kj::Own<CustomEvent> event) override;
Expand Down Expand Up @@ -4189,8 +4189,8 @@ kj::Promise<void> Worker::Isolate::SubrequestClient::connect(kj::StringPtr host,
}

// TODO(someday): Log other kinds of subrequests?
void Worker::Isolate::SubrequestClient::prewarm(kj::StringPtr url) {
inner->prewarm(url);
kj::Promise<void> Worker::Isolate::SubrequestClient::prewarm(kj::StringPtr url) {
return inner->prewarm(url);
}
kj::Promise<WorkerInterface::ScheduledResult> Worker::Isolate::SubrequestClient::runScheduled(
kj::Date scheduledTime, kj::StringPtr cron) {
Expand Down
Loading

0 comments on commit 1dae2d0

Please sign in to comment.