Skip to content

Commit

Permalink
Merge pull request #1901 from cloudflare/kenton/jsrpc-response-stream
Browse files Browse the repository at this point in the history
JSRPC: Fix waitUntil() and response streams
  • Loading branch information
kentonv authored Mar 27, 2024
2 parents 7e6f215 + 9e9dea7 commit c42aec2
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 25 deletions.
3 changes: 2 additions & 1 deletion src/workerd/api/hibernatable-web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ jsg::Ref<WebSocket> HibernatableWebSocketEvent::claimWebSocket(jsg::Lock& lock,

kj::Promise<WorkerInterface::CustomEvent::Result> HibernatableWebSocketCustomEventImpl::run(
kj::Own<IoContext_IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) {
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) {
// Mark the request as delivered because we're about to run some JS.
auto& context = incomingRequest->getContext();
incomingRequest->delivered();
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/hibernatable-web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class HibernatableWebSocketCustomEventImpl final: public WorkerInterface::Custom

kj::Promise<Result> run(
kj::Own<IoContext_IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) override;
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) override;

kj::Promise<Result> sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/queue.c++
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ jsg::Ref<QueueEvent> startQueueEvent(

kj::Promise<WorkerInterface::CustomEvent::Result> QueueCustomEventImpl::run(
kj::Own<IoContext_IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) {
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) {
incomingRequest->delivered();
auto& context = incomingRequest->getContext();

Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ class QueueCustomEventImpl final: public WorkerInterface::CustomEvent, public kj

kj::Promise<Result> run(
kj::Own<IoContext_IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) override;
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) override;

kj::Promise<Result> sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
Expand Down
55 changes: 54 additions & 1 deletion src/workerd/api/streams/readable.c++
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,57 @@ private:
kj::Maybe<uint64_t> expectedLength;
};

// Wrapper around ReadableStreamSource that prevents deferred proxying. We need this for RPC
// streams because although they are "system streams", they become disconnected when the IoContext
// is destroyed, due to the JsRpcCustomEventImpl being canceled.
//
// TODO(someday): Devise a better way for RPC streams to extend the lifetime of the RPC session
// beyond the destruction of the IoContext, if it is being used for deferred proxying.
class NoDeferredProxyReadableStream: public ReadableStreamSource {
public:
NoDeferredProxyReadableStream(kj::Own<ReadableStreamSource> inner, IoContext& ioctx)
: inner(kj::mv(inner)), ioctx(ioctx) {}

kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return inner->tryRead(buffer, minBytes, maxBytes);
}

kj::Promise<DeferredProxy<void>> pumpTo(WritableStreamSink& output, bool end) override {
// Move the deferred proxy part of the task over to the non-deferred part. To do this,
// we use `ioctx.waitForDeferredProxy()`, which returns a single promise covering both parts
// (and, importantly, registering pending events where needed). Then, we add a noop deferred
// proxy to the end of that.
return addNoopDeferredProxy(ioctx.waitForDeferredProxy(inner->pumpTo(output, end)));
}

StreamEncoding getPreferredEncoding() override {
return inner->getPreferredEncoding();
}

kj::Maybe<uint64_t> tryGetLength(StreamEncoding encoding) override {
return inner->tryGetLength(encoding);
}

void cancel(kj::Exception reason) override {
return inner->cancel(kj::mv(reason));
}

kj::Maybe<Tee> tryTee(uint64_t limit) override {
return inner->tryTee(limit).map([&](Tee tee) {
return Tee {
.branches = {
kj::heap<NoDeferredProxyReadableStream>(kj::mv(tee.branches[0]), ioctx),
kj::heap<NoDeferredProxyReadableStream>(kj::mv(tee.branches[1]), ioctx),
}
};
});
}

private:
kj::Own<ReadableStreamSource> inner;
IoContext& ioctx;
};

} // namespace

void ReadableStream::serialize(jsg::Lock& js, jsg::Serializer& serializer) {
Expand Down Expand Up @@ -693,7 +744,9 @@ jsg::Ref<ReadableStream> ReadableStream::deserialize(

externalHandler->setLastStream(ioctx.getByteStreamFactory().kjToCapnp(kj::mv(out)));

return jsg::alloc<ReadableStream>(ioctx, newSystemStream(kj::mv(in), encoding, ioctx));
return jsg::alloc<ReadableStream>(ioctx,
kj::heap<NoDeferredProxyReadableStream>(
newSystemStream(kj::mv(in), encoding, ioctx), ioctx));
}

kj::StringPtr ReaderImpl::jsgGetMemoryName() const { return "ReaderImpl"_kjc; }
Expand Down
59 changes: 48 additions & 11 deletions src/workerd/api/tests/js-rpc-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,28 @@ export let nonClass = {

async fetch(req, env, ctx) {
// This is used in the stream test to fetch some gziped data.
return new Response("this text was gzipped", {
headers: {
"Content-Encoding": "gzip"
}
});
if (req.url.endsWith("/gzip")) {
return new Response("this text was gzipped", {
headers: {
"Content-Encoding": "gzip"
}
});
} else if (req.url.endsWith("/stream-from-rpc")) {
let stream = await env.MyService.returnReadableStream();
return new Response(stream);
} else {
throw new Error("unknown route");
}
}
}

// Globals used to test passing RPC promises or properties across I/O contexts (which is expected
// to fail).
let globalRpcPromise;

// Promise initialized by testWaitUntil() and then resolved shortly later, in a waitUntil task.
let globalWaitUntilPromise;

export class MyService extends WorkerEntrypoint {
constructor(ctx, env) {
super(ctx, env);
Expand Down Expand Up @@ -347,6 +357,19 @@ export class MyService extends WorkerEntrypoint {
cf: {foo: 123, bar: "def"},
});
}

testWaitUntil() {
// Initialize globalWaitUntilPromise to a promise that will be resolved in a waitUntil task
// later on. We'll perform a cross-context wait to verify that the waitUntil task actually
// completes and resolves the promise.
let resolve;
globalWaitUntilPromise = new Promise(r => { resolve = r; });

this.ctx.waitUntil((async () => {
await scheduler.wait(100);
resolve();
})());
}
}

export class MyActor extends DurableObject {
Expand Down Expand Up @@ -789,11 +812,8 @@ export let disposal = {

// If we abort the server's I/O context, though, then the counter is disposed.
await assert.rejects(obj.abort(), {
// TODO(someday): This ought to propagate the abort exception, but that requires a bunch
// more work...
name: "Error",
message: "The destination execution context for this RPC was canceled while the " +
"call was still running."
name: "RangeError",
message: "foo bar abort reason"
});

await counter.onDisposed();
Expand Down Expand Up @@ -852,6 +872,16 @@ export let crossContextSharingDoesntWork = {
},
}

export let waitUntilWorks = {
async test(controller, env, ctx) {
globalWaitUntilPromise = null;
await env.MyService.testWaitUntil();

assert.strictEqual(globalWaitUntilPromise instanceof Promise, true);
await globalWaitUntilPromise;
}
}

function stripDispose(obj) {
assert.deepEqual(!!obj[Symbol.dispose], true);
delete obj[Symbol.dispose];
Expand Down Expand Up @@ -1062,7 +1092,7 @@ export let streams = {

// Send an encoded ReadableStream
{
let gzippedResp = await env.self.fetch("http://foo");
let gzippedResp = await env.self.fetch("http://foo/gzip");

let text = await env.MyService.readFromStream(gzippedResp.body);

Expand All @@ -1087,6 +1117,13 @@ export let streams = {

assert.strictEqual(await readPromise, "foo, bar, baz!");
}

// Perform an HTTP request whose response uses a ReadableStream obtained over RPC.
{
let resp = await env.self.fetch("http://foo/stream-from-rpc");

assert.strictEqual(await resp.text(), "foo, bar, baz!");
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/trace.c++
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ kj::Promise<void> sendTracesToExportedHandler(
} // namespace

auto TraceCustomEventImpl::run(
kj::Own<IoContext::IncomingRequest> incomingRequest, kj::Maybe<kj::StringPtr> entrypointNamePtr)
kj::Own<IoContext::IncomingRequest> incomingRequest, kj::Maybe<kj::StringPtr> entrypointNamePtr,
kj::TaskSet& waitUntilTasks)
-> kj::Promise<Result> {
// Don't bother to wait around for the handler to run, just hand it off to the waitUntil tasks.
waitUntilTasks.add(
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ class TraceCustomEventImpl final: public WorkerInterface::CustomEvent {

kj::Promise<Result> run(
kj::Own<IoContext::IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) override;
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) override;

kj::Promise<Result> sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
Expand Down
14 changes: 10 additions & 4 deletions src/workerd/api/worker-rpc.c++
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,8 @@ private:

kj::Promise<WorkerInterface::CustomEvent::Result> JsRpcSessionCustomEventImpl::run(
kj::Own<IoContext::IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) {
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) {
IoContext& ioctx = incomingRequest->getContext();

incomingRequest->delivered();
Expand All @@ -1685,11 +1686,16 @@ kj::Promise<WorkerInterface::CustomEvent::Result> JsRpcSessionCustomEventImpl::r
mapAddRef(incomingRequest->getWorkerTracer())),
kj::refcounted<ServerTopLevelMembrane>(kj::mv(doneFulfiller))));

KJ_DEFER({
// waitUntil() should allow extending execution on the server side even when the client
// disconnects.
waitUntilTasks.add(incomingRequest->drain().attach(kj::mv(incomingRequest)));
});

// `donePromise` resolves once there are no longer any capabilities pointing between the client
// and server as part of this session.
co_await donePromise
.then([&ir = *incomingRequest]() { return ir.drain(); })
.exclusiveJoin(ioctx.onAbort());
co_await donePromise.exclusiveJoin(ioctx.onAbort());

co_return WorkerInterface::CustomEvent::Result {
.outcome = EventOutcome::OK
};
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/worker-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ class JsRpcSessionCustomEventImpl final: public WorkerInterface::CustomEvent {

kj::Promise<Result> run(
kj::Own<IoContext::IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) override;
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) override;

kj::Promise<Result> sendRpc(
capnp::HttpOverCapnpFactory& httpOverCapnpFactory,
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/io/worker-entrypoint.c++
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,8 @@ kj::Promise<WorkerInterface::CustomEvent::Result>
this->incomingRequest = kj::none;

auto& context = incomingRequest->getContext();
auto promise = event->run(kj::mv(incomingRequest), entrypointName).attach(kj::mv(event));
auto promise = event->run(kj::mv(incomingRequest), entrypointName, waitUntilTasks)
.attach(kj::mv(event));

// TODO(cleanup): In theory `context` may have been destroyed by now if `event->run()` dropped
// the `incomingRequest` synchronously. No current implementation does that, and
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/io/worker-interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class WorkerInterface: public kj::HttpService {
// for this event.
virtual kj::Promise<Result> run(
kj::Own<IoContext_IncomingRequest> incomingRequest,
kj::Maybe<kj::StringPtr> entrypointName) = 0;
kj::Maybe<kj::StringPtr> entrypointName,
kj::TaskSet& waitUntilTasks) = 0;

// Forward the event over RPC.
virtual kj::Promise<Result> sendRpc(
Expand Down

0 comments on commit c42aec2

Please sign in to comment.