diff --git a/src/workerd/api/actor-state.c++ b/src/workerd/api/actor-state.c++ index cb8a66ad99d..0d9454bf283 100644 --- a/src/workerd/api/actor-state.c++ +++ b/src/workerd/api/actor-state.c++ @@ -880,7 +880,7 @@ void DurableObjectState::setWebSocketAutoResponse( // If there's no request/response pair, we unset any current set auto response configuration. KJ_IF_SOME(manager, a.getHibernationManager()) { // If there's no hibernation manager created yet, there's nothing to do here. - manager.unsetWebSocketAutoResponse(); + manager.setWebSocketAutoResponse(kj::none, kj::none); } return; } @@ -902,7 +902,8 @@ void DurableObjectState::setWebSocketAutoResponse( // If there's no hibernation manager created yet, we should create one and // set its auto response. } - KJ_REQUIRE_NONNULL(a.getHibernationManager()).setWebSocketAutoResponse(kj::mv(reqResp)); + KJ_REQUIRE_NONNULL(a.getHibernationManager()).setWebSocketAutoResponse( + reqResp->getRequest(), reqResp->getResponse()); } kj::Maybe> DurableObjectState::getWebSocketAutoResponse() { @@ -910,7 +911,7 @@ kj::Maybe> DurableObjectState::getWe KJ_IF_SOME(manager, a.getHibernationManager()) { // If there's no hibernation manager created yet, there's nothing to do here. auto r = manager.getWebSocketAutoResponse(); - return r; + return kj::mv(r); } return kj::none; } diff --git a/src/workerd/api/web-socket.c++ b/src/workerd/api/web-socket.c++ index e45af2a8c6a..c9fda211714 100644 --- a/src/workerd/api/web-socket.c++ +++ b/src/workerd/api/web-socket.c++ @@ -34,6 +34,7 @@ IoOwn WebSocket::initNative( // We might have called `close()` when this WebSocket was previously active. // If so, we want to prevent any future calls to `send()`. nativeObj->closedOutgoing = closedOutgoingConn; + autoResponseStatus.isClosed = nativeObj->closedOutgoing; return ioContext.addObject(kj::mv(nativeObj)); } @@ -531,7 +532,10 @@ void WebSocket::send(jsg::Lock& js, kj::OneOf, kj::String> messa KJ_UNREACHABLE; }(); - outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg)}); + + auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() - autoResponseStatus.queuedAutoResponses; + autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size(); + outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg), pendingAutoResponses}); ensurePumping(js); } @@ -578,13 +582,22 @@ void WebSocket::close( "If you specify a WebSocket close reason, you must also specify a code."); } + // pendingAutoResponses stores the number of queuedAutoResponses that will be pumped before sending + // the current GatedMessage, guaranteeing order. + // queuedAutoResponses stores the total number of auto-response messages that are already in accounted + // for in previous GatedMessages. This is useful to easily calculate the number of pendingAutoResponses + // for each new GateMessage. + auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() - + autoResponseStatus.queuedAutoResponses; + autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size(); + outgoingMessages->insert(GatedMessage{ IoContext::current().waitForOutputLocksIfNecessary(), kj::WebSocket::Close { // Code 1005 actually translates to sending a close message with no body on the wire. static_cast(code.orDefault(1005)), kj::mv(reason).orDefault(nullptr), - }, + }, pendingAutoResponses }); native.closedOutgoing = true; @@ -670,8 +683,10 @@ void WebSocket::serializeAttachment(jsg::Lock& js, jsg::JsValue attachment) { serializedAttachment = kj::mv(released.data); } -void WebSocket::setAutoResponseTimestamp(kj::Maybe time) { +void WebSocket::setAutoResponseStatus(kj::Maybe time, + kj::Promise autoResponsePromise) { autoResponseTimestamp = time; + autoResponseStatus.ongoingAutoResponse = kj::mv(autoResponsePromise); } @@ -689,7 +704,8 @@ void WebSocket::ensurePumping(jsg::Lock& js) { auto& context = IoContext::current(); auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet()); auto promise = kj::evalNow([&]() { - return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws, native)); + return accepted.canceler.wrap(pump(context, *outgoingMessages, + *accepted.ws, native, autoResponseStatus)); }); // TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending @@ -729,6 +745,17 @@ void WebSocket::ensurePumping(jsg::Lock& js) { } } +kj::Promise WebSocket::sendAutoResponse(kj::String message, kj::WebSocket& ws) { + if (autoResponseStatus.isPumping) { + autoResponseStatus.pendingAutoResponseDeque.push_back(kj::mv(message)); + } else if (!autoResponseStatus.isClosed){ + auto p = ws.send(message).fork(); + autoResponseStatus.ongoingAutoResponse = p.addBranch(); + co_await p; + autoResponseStatus.ongoingAutoResponse = kj::READY_NOW; + } +} + namespace { size_t countBytesFromMessage(const kj::WebSocket::Message& message) { @@ -757,7 +784,8 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) { } // namespace kj::Promise WebSocket::pump( - IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native) { + IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native, + AutoResponse& autoResponse) { KJ_ASSERT(!native.isPumping); native.isPumping = true; KJ_DEFER({ @@ -769,8 +797,19 @@ kj::Promise WebSocket::pump( // Either we were already through all our outgoing messages or we experienced failure/ // cancellation and cannot send these anyway. outgoingMessages.clear(); + + autoResponse.isPumping = false; + + if (autoResponse.pendingAutoResponseDeque.size() > 0) { + autoResponse.pendingAutoResponseDeque.clear(); + } }); + // If we have a ongoingAutoResponse, we must co_await it here because there's a ws.send() + // in progress. Otherwise there can occur ws.send() race problems. + co_await autoResponse.ongoingAutoResponse; + autoResponse.ongoingAutoResponse = kj::READY_NOW; + while (outgoingMessages.size() > 0) { GatedMessage gatedMessage = outgoingMessages.release(*outgoingMessages.ordered().begin()); KJ_IF_SOME(promise, gatedMessage.outputLock) { @@ -779,6 +818,15 @@ kj::Promise WebSocket::pump( auto size = countBytesFromMessage(gatedMessage.message); + while (gatedMessage.pendingAutoResponses > 0) { + KJ_ASSERT(autoResponse.pendingAutoResponseDeque.size() >= gatedMessage.pendingAutoResponses); + auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front()); + autoResponse.pendingAutoResponseDeque.pop_front(); + gatedMessage.pendingAutoResponses--; + autoResponse.queuedAutoResponses--; + co_await ws.send(message); + } + KJ_SWITCH_ONEOF(gatedMessage.message) { KJ_CASE_ONEOF(text, kj::String) { co_await ws.send(text); @@ -790,6 +838,7 @@ kj::Promise WebSocket::pump( } KJ_CASE_ONEOF(close, kj::WebSocket::Close) { co_await ws.close(close.code, close.reason); + autoResponse.isClosed = true; break; } } @@ -798,6 +847,14 @@ kj::Promise WebSocket::pump( a.getMetrics().sentWebSocketMessage(size); } } + + // If there are any auto-responses left to process, we should do it now. + // We should also check if the last sent message was a close. Shouldn't happen. + while (autoResponse.pendingAutoResponseDeque.size() > 0 && !autoResponse.isClosed) { + auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front()); + autoResponse.pendingAutoResponseDeque.pop_front(); + co_await ws.send(message); + } } void WebSocket::tryReleaseNative(jsg::Lock& js) { diff --git a/src/workerd/api/web-socket.h b/src/workerd/api/web-socket.h index 8e228d55985..2db82642e6e 100644 --- a/src/workerd/api/web-socket.h +++ b/src/workerd/api/web-socket.h @@ -9,6 +9,7 @@ #include "basics.h" #include #include +#include namespace workerd { class ActorObserver; @@ -395,12 +396,15 @@ class WebSocket: public EventTarget { // Used to get/store the last auto request/response timestamp for this WebSocket. // These methods are c++ only and are not exposed to our js interface. - void setAutoResponseTimestamp(kj::Maybe time); + // Also used to track hibernatable websockets auto-response sends. + void setAutoResponseStatus(kj::Maybe time, kj::Promise autoResponsePromise); // Used to get/store the last auto request/response timestamp for this WebSocket. // These methods are c++ only and are not exposed to our js interface. kj::Maybe getAutoResponseTimestamp(); + kj::Promise sendAutoResponse(kj::String message, kj::WebSocket& ws); + int getReadyState(); bool isAccepted(); @@ -642,12 +646,25 @@ class WebSocket: public EventTarget { struct GatedMessage { kj::Maybe> outputLock; // must wait for this before actually sending kj::WebSocket::Message message; + size_t pendingAutoResponses = 0; }; using OutgoingMessagesMap = kj::Table; // Queue of messages to be sent. This is wraped in a IoOwn so that the pump loop can safely // access the map without locking the isolate. IoOwn outgoingMessages; + // Keep track of current hibernatable websockets auto-response status to avoid racing + // between regular websocket messages, and auto-responses. + struct AutoResponse { + kj::Promise ongoingAutoResponse = kj::READY_NOW; + std::deque pendingAutoResponseDeque; + size_t queuedAutoResponses = 0; + bool isPumping = false; + bool isClosed = false; + }; + + AutoResponse autoResponseStatus; + Locality locality; // Contains a websocket and possibly some data from the WebSocketResponse headers. @@ -677,7 +694,8 @@ class WebSocket: public EventTarget { // objects so are safe to access from the thread without the isolate lock. The whole task is // owned by the `IoContext` so it'll be canceled if the `IoContext` is destroyed. static kj::Promise pump( - IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native); + IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native, + AutoResponse& autoResponse); kj::Promise> readLoop(); diff --git a/src/workerd/io/hibernation-manager.c++ b/src/workerd/io/hibernation-manager.c++ index e2662a0d1bc..9adc2f4d31a 100644 --- a/src/workerd/io/hibernation-manager.c++ +++ b/src/workerd/io/hibernation-manager.c++ @@ -99,12 +99,20 @@ kj::Vector> HibernationManagerImpl::getWebSockets( } void HibernationManagerImpl::setWebSocketAutoResponse( - jsg::Ref reqResp) { - autoResponsePair = kj::mv(reqResp); + kj::Maybe request, kj::Maybe response) { + KJ_IF_SOME(req, request) { + KJ_IF_SOME(resp, response){ + auto autoRR = kj::heap(); + autoRR->response = kj::mv(req); + autoRR->request = kj::mv(resp); + autoResponsePair = kj::mv(autoRR); + return; + } + } } void HibernationManagerImpl::unsetWebSocketAutoResponse() { - autoResponsePair = kj::none; + return api::WebSocketRequestResponsePair::constructor(kj::str(ar.request), kj::str(ar.response)); } kj::Maybe> HibernationManagerImpl::getWebSocketAutoResponse() { @@ -191,7 +199,7 @@ kj::Promise HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) { KJ_IF_SOME (reqResp, autoResponsePair) { KJ_SWITCH_ONEOF(message) { KJ_CASE_ONEOF(text, kj::String) { - if (text == (reqResp)->getRequest()) { + if (text == reqResp.request) { // If the received message matches the one set for auto-response, we must // short-circuit readLoop, store the current timestamp and and automatically respond // with the expected response. @@ -202,15 +210,30 @@ kj::Promise HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) { // We'll store the current timestamp in the HibernatableWebSocket to assure it gets // stored even if the WebSocket is currently hibernating. In that scenario, the timestamp // value will be loaded into the WebSocket during unhibernation. - KJ_IF_SOME(active, hib.activeOrPackage.tryGet>()) { - // If the actor is not hibernated/If the WebSocket is active, we need to update - // autoResponseTimestamp on the active websocket. - (active)->setAutoResponseTimestamp(hib.autoResponseTimestamp); + KJ_SWITCH_ONEOF(hib.activeOrPackage){ + KJ_CASE_ONEOF(apiWs, jsg::Ref) { + // If the actor is not hibernated/If the WebSocket is active, we need to update + // autoResponseTimestamp on the active websocket. + apiWs->setAutoResponseStatus(hib.autoResponseTimestamp, kj::READY_NOW); + co_await apiWs->sendAutoResponse(kj::str(reqResp.response.asArray()), ws); + } + KJ_CASE_ONEOF(package, api::WebSocket::HibernationPackage) { + if (!package.closedOutgoingConnection) { + // We need to store the autoResponsePromise because we may instantiate an api::websocket + // If we do that, we have to provide it with the promise to avoid races. This can + // happen if we have a websocket hibernating, that unhibernates and sends a + // message while ws.send() for auto-response is also sending. + auto p = ws.send(reqResp.response.asArray()).fork(); + hib.autoResponsePromise = p.addBranch(); + co_await p; + hib.autoResponsePromise = kj::READY_NOW; + } + } } co_await ws.send((reqResp)->getResponse().asArray()); - skip = true; // If we've sent an auto response message, we should not unhibernate or deliver the // received message to the actor + skip = true; } } KJ_CASE_ONEOF_DEFAULT {} diff --git a/src/workerd/io/hibernation-manager.h b/src/workerd/io/hibernation-manager.h index 6d7ae60d2e0..2964bcf5fed 100644 --- a/src/workerd/io/hibernation-manager.h +++ b/src/workerd/io/hibernation-manager.h @@ -40,8 +40,8 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager { // This converts our activeOrPackage from an api::WebSocket to a HibernationPackage. void hibernateWebSockets(Worker::Lock& lock) override; - void setWebSocketAutoResponse(jsg::Ref reqResp) override; - void unsetWebSocketAutoResponse() override; + void setWebSocketAutoResponse(kj::Maybe request, + kj::Maybe response) override; kj::Maybe> getWebSocketAutoResponse() override; void setTimerChannel(TimerChannel& timerChannel) override; @@ -108,11 +108,13 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager { // to the api::WebSocket. jsg::Ref getActiveOrUnhibernate(jsg::Lock& js) { KJ_IF_SOME(package, activeOrPackage.tryGet()) { + // Now that we unhibernated the WebSocket, we can set the last received autoResponse timestamp + // that was stored in the corresponding HibernatableWebSocket. We also move autoResponsePromise + // from the hibernation manager to api::websocket to prevent possible ws.send races. activeOrPackage.init>( api::WebSocket::hibernatableFromNative(js, *KJ_REQUIRE_NONNULL(ws), kj::mv(package)) - )->setAutoResponseTimestamp(autoResponseTimestamp); - // Now that we unhibernated the WebSocket, we can set the last received autoResponse timestamp - // that was stored in the corresponding HibernatableWebSocket. + )->setAutoResponseStatus(autoResponseTimestamp, kj::mv(autoResponsePromise)); + autoResponsePromise = kj::READY_NOW; } return activeOrPackage.get>().addRef(); } @@ -151,6 +153,10 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager { // Stores the last received autoResponseRequest timestamp. kj::Maybe autoResponseTimestamp; + // Keeps track of the currently ongoing websocket auto-response send promise. This promise may + // be moved to api::websocket if an hibernating websocket unhibernates. + kj::Promise autoResponsePromise = kj::READY_NOW; + friend HibernationManagerImpl; }; @@ -181,6 +187,11 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager { TagCollection(TagCollection&& other) = default; }; + struct AutoRequestResponsePair { + kj::StringPtr request; + kj::StringPtr response; + }; + // A hashmap of tags to HibernatableWebSockets associated with the tag. // We use a kj::List so we can quickly remove websockets that have disconnected. // Also note that we box the keys and values such that in the event of a hashmap resizing we don't @@ -216,7 +227,7 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager { }; DisconnectHandler onDisconnect; kj::TaskSet readLoopTasks; - kj::Maybe> autoResponsePair; + kj::Maybe autoResponsePair; kj::Maybe timer; }; }; // namespace workerd diff --git a/src/workerd/io/worker.h b/src/workerd/io/worker.h index a355cab7ba6..ace5c03f230 100644 --- a/src/workerd/io/worker.h +++ b/src/workerd/io/worker.h @@ -697,8 +697,8 @@ class Worker::Actor final: public kj::Refcounted { jsg::Lock& js, kj::Maybe tag) = 0; virtual void hibernateWebSockets(Worker::Lock& lock) = 0; - virtual void setWebSocketAutoResponse(jsg::Ref reqResp) = 0; - virtual void unsetWebSocketAutoResponse() = 0; + virtual void setWebSocketAutoResponse(kj::Maybe request, + kj::Maybe response) = 0; virtual kj::Maybe> getWebSocketAutoResponse() = 0; virtual void setTimerChannel(TimerChannel& timerChannel) = 0; };