diff --git a/src/workerd/api/hibernatable-web-socket.c++ b/src/workerd/api/hibernatable-web-socket.c++ index 3275af50d54..5c2eec8dff1 100644 --- a/src/workerd/api/hibernatable-web-socket.c++ +++ b/src/workerd/api/hibernatable-web-socket.c++ @@ -58,7 +58,8 @@ jsg::Ref HibernatableWebSocketEvent::claimWebSocket(jsg::Lock& lock, kj::Promise HibernatableWebSocketCustomEventImpl::run( kj::Own incomingRequest, - kj::Maybe entrypointName) { + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) { // Mark the request as delivered because we're about to run some JS. auto& context = incomingRequest->getContext(); incomingRequest->delivered(); diff --git a/src/workerd/api/hibernatable-web-socket.h b/src/workerd/api/hibernatable-web-socket.h index f709ef8482b..87d17dbb8ba 100644 --- a/src/workerd/api/hibernatable-web-socket.h +++ b/src/workerd/api/hibernatable-web-socket.h @@ -71,7 +71,8 @@ class HibernatableWebSocketCustomEventImpl final: public WorkerInterface::Custom kj::Promise run( kj::Own incomingRequest, - kj::Maybe entrypointName) override; + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) override; kj::Promise sendRpc( capnp::HttpOverCapnpFactory& httpOverCapnpFactory, diff --git a/src/workerd/api/queue.c++ b/src/workerd/api/queue.c++ index a1da1c670b4..040af12cfc9 100644 --- a/src/workerd/api/queue.c++ +++ b/src/workerd/api/queue.c++ @@ -528,7 +528,8 @@ jsg::Ref startQueueEvent( kj::Promise QueueCustomEventImpl::run( kj::Own incomingRequest, - kj::Maybe entrypointName) { + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) { incomingRequest->delivered(); auto& context = incomingRequest->getContext(); diff --git a/src/workerd/api/queue.h b/src/workerd/api/queue.h index 5eabce960aa..44cf5e7157e 100644 --- a/src/workerd/api/queue.h +++ b/src/workerd/api/queue.h @@ -324,7 +324,8 @@ class QueueCustomEventImpl final: public WorkerInterface::CustomEvent, public kj kj::Promise run( kj::Own incomingRequest, - kj::Maybe entrypointName) override; + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) override; kj::Promise sendRpc( capnp::HttpOverCapnpFactory& httpOverCapnpFactory, diff --git a/src/workerd/api/streams/readable.c++ b/src/workerd/api/streams/readable.c++ index 9ec71a46e65..21ff1ddbc22 100644 --- a/src/workerd/api/streams/readable.c++ +++ b/src/workerd/api/streams/readable.c++ @@ -611,6 +611,57 @@ private: kj::Maybe expectedLength; }; +// Wrapper around ReadableStreamSource that prevents deferred proxying. We need this for RPC +// streams because although they are "system streams", they become disconnected when the IoContext +// is destroyed, due to the JsRpcCustomEventImpl being canceled. +// +// TODO(someday): Devise a better way for RPC streams to extend the lifetime of the RPC session +// beyond the destruction of the IoContext, if it is being used for deferred proxying. +class NoDeferredProxyReadableStream: public ReadableStreamSource { +public: + NoDeferredProxyReadableStream(kj::Own inner, IoContext& ioctx) + : inner(kj::mv(inner)), ioctx(ioctx) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + + kj::Promise> pumpTo(WritableStreamSink& output, bool end) override { + // Move the deferred proxy part of the task over to the non-deferred part. To do this, + // we use `ioctx.waitForDeferredProxy()`, which returns a single promise covering both parts + // (and, importantly, registering pending events where needed). Then, we add a noop deferred + // proxy to the end of that. + return addNoopDeferredProxy(ioctx.waitForDeferredProxy(inner->pumpTo(output, end))); + } + + StreamEncoding getPreferredEncoding() override { + return inner->getPreferredEncoding(); + } + + kj::Maybe tryGetLength(StreamEncoding encoding) override { + return inner->tryGetLength(encoding); + } + + void cancel(kj::Exception reason) override { + return inner->cancel(kj::mv(reason)); + } + + kj::Maybe tryTee(uint64_t limit) override { + return inner->tryTee(limit).map([&](Tee tee) { + return Tee { + .branches = { + kj::heap(kj::mv(tee.branches[0]), ioctx), + kj::heap(kj::mv(tee.branches[1]), ioctx), + } + }; + }); + } + +private: + kj::Own inner; + IoContext& ioctx; +}; + } // namespace void ReadableStream::serialize(jsg::Lock& js, jsg::Serializer& serializer) { @@ -693,7 +744,9 @@ jsg::Ref ReadableStream::deserialize( externalHandler->setLastStream(ioctx.getByteStreamFactory().kjToCapnp(kj::mv(out))); - return jsg::alloc(ioctx, newSystemStream(kj::mv(in), encoding, ioctx)); + return jsg::alloc(ioctx, + kj::heap( + newSystemStream(kj::mv(in), encoding, ioctx), ioctx)); } kj::StringPtr ReaderImpl::jsgGetMemoryName() const { return "ReaderImpl"_kjc; } diff --git a/src/workerd/api/tests/js-rpc-test.js b/src/workerd/api/tests/js-rpc-test.js index dde0e8d1867..110fe610b9c 100644 --- a/src/workerd/api/tests/js-rpc-test.js +++ b/src/workerd/api/tests/js-rpc-test.js @@ -74,11 +74,18 @@ export let nonClass = { async fetch(req, env, ctx) { // This is used in the stream test to fetch some gziped data. - return new Response("this text was gzipped", { - headers: { - "Content-Encoding": "gzip" - } - }); + if (req.url.endsWith("/gzip")) { + return new Response("this text was gzipped", { + headers: { + "Content-Encoding": "gzip" + } + }); + } else if (req.url.endsWith("/stream-from-rpc")) { + let stream = await env.MyService.returnReadableStream(); + return new Response(stream); + } else { + throw new Error("unknown route"); + } } } @@ -86,6 +93,9 @@ export let nonClass = { // to fail). let globalRpcPromise; +// Promise initialized by testWaitUntil() and then resolved shortly later, in a waitUntil task. +let globalWaitUntilPromise; + export class MyService extends WorkerEntrypoint { constructor(ctx, env) { super(ctx, env); @@ -347,6 +357,19 @@ export class MyService extends WorkerEntrypoint { cf: {foo: 123, bar: "def"}, }); } + + testWaitUntil() { + // Initialize globalWaitUntilPromise to a promise that will be resolved in a waitUntil task + // later on. We'll perform a cross-context wait to verify that the waitUntil task actually + // completes and resolves the promise. + let resolve; + globalWaitUntilPromise = new Promise(r => { resolve = r; }); + + this.ctx.waitUntil((async () => { + await scheduler.wait(100); + resolve(); + })()); + } } export class MyActor extends DurableObject { @@ -789,11 +812,8 @@ export let disposal = { // If we abort the server's I/O context, though, then the counter is disposed. await assert.rejects(obj.abort(), { - // TODO(someday): This ought to propagate the abort exception, but that requires a bunch - // more work... - name: "Error", - message: "The destination execution context for this RPC was canceled while the " + - "call was still running." + name: "RangeError", + message: "foo bar abort reason" }); await counter.onDisposed(); @@ -852,6 +872,16 @@ export let crossContextSharingDoesntWork = { }, } +export let waitUntilWorks = { + async test(controller, env, ctx) { + globalWaitUntilPromise = null; + await env.MyService.testWaitUntil(); + + assert.strictEqual(globalWaitUntilPromise instanceof Promise, true); + await globalWaitUntilPromise; + } +} + function stripDispose(obj) { assert.deepEqual(!!obj[Symbol.dispose], true); delete obj[Symbol.dispose]; @@ -1062,7 +1092,7 @@ export let streams = { // Send an encoded ReadableStream { - let gzippedResp = await env.self.fetch("http://foo"); + let gzippedResp = await env.self.fetch("http://foo/gzip"); let text = await env.MyService.readFromStream(gzippedResp.body); @@ -1087,6 +1117,13 @@ export let streams = { assert.strictEqual(await readPromise, "foo, bar, baz!"); } + + // Perform an HTTP request whose response uses a ReadableStream obtained over RPC. + { + let resp = await env.self.fetch("http://foo/stream-from-rpc"); + + assert.strictEqual(await resp.text(), "foo, bar, baz!"); + } } } diff --git a/src/workerd/api/trace.c++ b/src/workerd/api/trace.c++ index 46628c0f31d..23eda5a2b78 100644 --- a/src/workerd/api/trace.c++ +++ b/src/workerd/api/trace.c++ @@ -610,7 +610,8 @@ kj::Promise sendTracesToExportedHandler( } // namespace auto TraceCustomEventImpl::run( - kj::Own incomingRequest, kj::Maybe entrypointNamePtr) + kj::Own incomingRequest, kj::Maybe entrypointNamePtr, + kj::TaskSet& waitUntilTasks) -> kj::Promise { // Don't bother to wait around for the handler to run, just hand it off to the waitUntil tasks. waitUntilTasks.add( diff --git a/src/workerd/api/trace.h b/src/workerd/api/trace.h index 7c99e752f97..64f88c0d880 100644 --- a/src/workerd/api/trace.h +++ b/src/workerd/api/trace.h @@ -590,7 +590,8 @@ class TraceCustomEventImpl final: public WorkerInterface::CustomEvent { kj::Promise run( kj::Own incomingRequest, - kj::Maybe entrypointName) override; + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) override; kj::Promise sendRpc( capnp::HttpOverCapnpFactory& httpOverCapnpFactory, diff --git a/src/workerd/api/worker-rpc.c++ b/src/workerd/api/worker-rpc.c++ index e9b183d3496..4b591cbd0df 100644 --- a/src/workerd/api/worker-rpc.c++ +++ b/src/workerd/api/worker-rpc.c++ @@ -1673,7 +1673,8 @@ private: kj::Promise JsRpcSessionCustomEventImpl::run( kj::Own incomingRequest, - kj::Maybe entrypointName) { + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) { IoContext& ioctx = incomingRequest->getContext(); incomingRequest->delivered(); @@ -1685,11 +1686,16 @@ kj::Promise JsRpcSessionCustomEventImpl::r mapAddRef(incomingRequest->getWorkerTracer())), kj::refcounted(kj::mv(doneFulfiller)))); + KJ_DEFER({ + // waitUntil() should allow extending execution on the server side even when the client + // disconnects. + waitUntilTasks.add(incomingRequest->drain().attach(kj::mv(incomingRequest))); + }); + // `donePromise` resolves once there are no longer any capabilities pointing between the client // and server as part of this session. - co_await donePromise - .then([&ir = *incomingRequest]() { return ir.drain(); }) - .exclusiveJoin(ioctx.onAbort()); + co_await donePromise.exclusiveJoin(ioctx.onAbort()); + co_return WorkerInterface::CustomEvent::Result { .outcome = EventOutcome::OK }; diff --git a/src/workerd/api/worker-rpc.h b/src/workerd/api/worker-rpc.h index f14a9773801..548620647e5 100644 --- a/src/workerd/api/worker-rpc.h +++ b/src/workerd/api/worker-rpc.h @@ -395,7 +395,8 @@ class JsRpcSessionCustomEventImpl final: public WorkerInterface::CustomEvent { kj::Promise run( kj::Own incomingRequest, - kj::Maybe entrypointName) override; + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) override; kj::Promise sendRpc( capnp::HttpOverCapnpFactory& httpOverCapnpFactory, diff --git a/src/workerd/io/worker-entrypoint.c++ b/src/workerd/io/worker-entrypoint.c++ index 2e589ef6803..028fc0f242c 100644 --- a/src/workerd/io/worker-entrypoint.c++ +++ b/src/workerd/io/worker-entrypoint.c++ @@ -650,7 +650,8 @@ kj::Promise this->incomingRequest = kj::none; auto& context = incomingRequest->getContext(); - auto promise = event->run(kj::mv(incomingRequest), entrypointName).attach(kj::mv(event)); + auto promise = event->run(kj::mv(incomingRequest), entrypointName, waitUntilTasks) + .attach(kj::mv(event)); // TODO(cleanup): In theory `context` may have been destroyed by now if `event->run()` dropped // the `incomingRequest` synchronously. No current implementation does that, and diff --git a/src/workerd/io/worker-interface.h b/src/workerd/io/worker-interface.h index 6aa041c7814..03e10c5055c 100644 --- a/src/workerd/io/worker-interface.h +++ b/src/workerd/io/worker-interface.h @@ -106,7 +106,8 @@ class WorkerInterface: public kj::HttpService { // for this event. virtual kj::Promise run( kj::Own incomingRequest, - kj::Maybe entrypointName) = 0; + kj::Maybe entrypointName, + kj::TaskSet& waitUntilTasks) = 0; // Forward the event over RPC. virtual kj::Promise sendRpc(