Skip to content

Commit

Permalink
Fix websocket auto-response message races
Browse files Browse the repository at this point in the history
This PR fixes racing issues between regular websocket messages and auto-responses.
  • Loading branch information
jqmmes committed Oct 16, 2023
1 parent 262307f commit accae32
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 27 deletions.
7 changes: 4 additions & 3 deletions src/workerd/api/actor-state.c++
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ void DurableObjectState::setWebSocketAutoResponse(
// If there's no request/response pair, we unset any current set auto response configuration.
KJ_IF_SOME(manager, a.getHibernationManager()) {
// If there's no hibernation manager created yet, there's nothing to do here.
manager.unsetWebSocketAutoResponse();
manager.setWebSocketAutoResponse(kj::none, kj::none);
}
return;
}
Expand All @@ -902,15 +902,16 @@ void DurableObjectState::setWebSocketAutoResponse(
// If there's no hibernation manager created yet, we should create one and
// set its auto response.
}
KJ_REQUIRE_NONNULL(a.getHibernationManager()).setWebSocketAutoResponse(kj::mv(reqResp));
KJ_REQUIRE_NONNULL(a.getHibernationManager()).setWebSocketAutoResponse(
reqResp->getRequest(), reqResp->getResponse());
}

kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> DurableObjectState::getWebSocketAutoResponse() {
auto& a = KJ_REQUIRE_NONNULL(IoContext::current().getActor());
KJ_IF_SOME(manager, a.getHibernationManager()) {
// If there's no hibernation manager created yet, there's nothing to do here.
auto r = manager.getWebSocketAutoResponse();
return r;
return kj::mv(r);
}
return kj::none;
}
Expand Down
67 changes: 62 additions & 5 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ IoOwn<WebSocket::Native> WebSocket::initNative(
// We might have called `close()` when this WebSocket was previously active.
// If so, we want to prevent any future calls to `send()`.
nativeObj->closedOutgoing = closedOutgoingConn;
autoResponseStatus.isClosed = nativeObj->closedOutgoing;
return ioContext.addObject(kj::mv(nativeObj));
}

Expand Down Expand Up @@ -531,7 +532,10 @@ void WebSocket::send(jsg::Lock& js, kj::OneOf<kj::Array<byte>, kj::String> messa

KJ_UNREACHABLE;
}();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg)});

auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() - autoResponseStatus.queuedAutoResponses;
autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg), pendingAutoResponses});

ensurePumping(js);
}
Expand Down Expand Up @@ -578,13 +582,22 @@ void WebSocket::close(
"If you specify a WebSocket close reason, you must also specify a code.");
}

// pendingAutoResponses stores the number of queuedAutoResponses that will be pumped before sending
// the current GatedMessage, guaranteeing order.
// queuedAutoResponses stores the total number of auto-response messages that are already in accounted
// for in previous GatedMessages. This is useful to easily calculate the number of pendingAutoResponses
// for each new GateMessage.
auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() -
autoResponseStatus.queuedAutoResponses;
autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size();

outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::WebSocket::Close {
// Code 1005 actually translates to sending a close message with no body on the wire.
static_cast<uint16_t>(code.orDefault(1005)),
kj::mv(reason).orDefault(nullptr),
},
}, pendingAutoResponses
});

native.closedOutgoing = true;
Expand Down Expand Up @@ -670,8 +683,10 @@ void WebSocket::serializeAttachment(jsg::Lock& js, jsg::JsValue attachment) {
serializedAttachment = kj::mv(released.data);
}

void WebSocket::setAutoResponseTimestamp(kj::Maybe<kj::Date> time) {
void WebSocket::setAutoResponseStatus(kj::Maybe<kj::Date> time,
kj::Promise<void> autoResponsePromise) {
autoResponseTimestamp = time;
autoResponseStatus.ongoingAutoResponse = kj::mv(autoResponsePromise);
}


Expand All @@ -689,7 +704,8 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
auto& context = IoContext::current();
auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet<Accepted>());
auto promise = kj::evalNow([&]() {
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws, native));
return accepted.canceler.wrap(pump(context, *outgoingMessages,
*accepted.ws, native, autoResponseStatus));
});

// TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending
Expand Down Expand Up @@ -729,6 +745,17 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
}
}

kj::Promise<void> WebSocket::sendAutoResponse(kj::String message, kj::WebSocket& ws) {
if (autoResponseStatus.isPumping) {
autoResponseStatus.pendingAutoResponseDeque.push_back(kj::mv(message));
} else if (!autoResponseStatus.isClosed){
auto p = ws.send(message).fork();
autoResponseStatus.ongoingAutoResponse = p.addBranch();
co_await p;
autoResponseStatus.ongoingAutoResponse = kj::READY_NOW;
}
}

namespace {

size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
Expand Down Expand Up @@ -757,7 +784,8 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
} // namespace

kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native) {
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse) {
KJ_ASSERT(!native.isPumping);
native.isPumping = true;
KJ_DEFER({
Expand All @@ -769,8 +797,19 @@ kj::Promise<void> WebSocket::pump(
// Either we were already through all our outgoing messages or we experienced failure/
// cancellation and cannot send these anyway.
outgoingMessages.clear();

autoResponse.isPumping = false;

if (autoResponse.pendingAutoResponseDeque.size() > 0) {
autoResponse.pendingAutoResponseDeque.clear();
}
});

// If we have a ongoingAutoResponse, we must co_await it here because there's a ws.send()
// in progress. Otherwise there can occur ws.send() race problems.
co_await autoResponse.ongoingAutoResponse;
autoResponse.ongoingAutoResponse = kj::READY_NOW;

while (outgoingMessages.size() > 0) {
GatedMessage gatedMessage = outgoingMessages.release(*outgoingMessages.ordered().begin());
KJ_IF_SOME(promise, gatedMessage.outputLock) {
Expand All @@ -779,6 +818,15 @@ kj::Promise<void> WebSocket::pump(

auto size = countBytesFromMessage(gatedMessage.message);

while (gatedMessage.pendingAutoResponses > 0) {
KJ_ASSERT(autoResponse.pendingAutoResponseDeque.size() >= gatedMessage.pendingAutoResponses);
auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front());
autoResponse.pendingAutoResponseDeque.pop_front();
gatedMessage.pendingAutoResponses--;
autoResponse.queuedAutoResponses--;
co_await ws.send(message);
}

KJ_SWITCH_ONEOF(gatedMessage.message) {
KJ_CASE_ONEOF(text, kj::String) {
co_await ws.send(text);
Expand All @@ -790,6 +838,7 @@ kj::Promise<void> WebSocket::pump(
}
KJ_CASE_ONEOF(close, kj::WebSocket::Close) {
co_await ws.close(close.code, close.reason);
autoResponse.isClosed = true;
break;
}
}
Expand All @@ -798,6 +847,14 @@ kj::Promise<void> WebSocket::pump(
a.getMetrics().sentWebSocketMessage(size);
}
}

// If there are any auto-responses left to process, we should do it now.
// We should also check if the last sent message was a close. Shouldn't happen.
while (autoResponse.pendingAutoResponseDeque.size() > 0 && !autoResponse.isClosed) {
auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front());
autoResponse.pendingAutoResponseDeque.pop_front();
co_await ws.send(message);
}
}

void WebSocket::tryReleaseNative(jsg::Lock& js) {
Expand Down
22 changes: 20 additions & 2 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "basics.h"
#include <workerd/io/io-context.h>
#include <workerd/jsg/ser.h>
#include <stdlib.h>

namespace workerd {
class ActorObserver;
Expand Down Expand Up @@ -395,12 +396,15 @@ class WebSocket: public EventTarget {

// Used to get/store the last auto request/response timestamp for this WebSocket.
// These methods are c++ only and are not exposed to our js interface.
void setAutoResponseTimestamp(kj::Maybe<kj::Date> time);
// Also used to track hibernatable websockets auto-response sends.
void setAutoResponseStatus(kj::Maybe<kj::Date> time, kj::Promise<void> autoResponsePromise);

// Used to get/store the last auto request/response timestamp for this WebSocket.
// These methods are c++ only and are not exposed to our js interface.
kj::Maybe<kj::Date> getAutoResponseTimestamp();

kj::Promise<void> sendAutoResponse(kj::String message, kj::WebSocket& ws);

int getReadyState();

bool isAccepted();
Expand Down Expand Up @@ -642,12 +646,25 @@ class WebSocket: public EventTarget {
struct GatedMessage {
kj::Maybe<kj::Promise<void>> outputLock; // must wait for this before actually sending
kj::WebSocket::Message message;
size_t pendingAutoResponses = 0;
};
using OutgoingMessagesMap = kj::Table<GatedMessage, kj::InsertionOrderIndex>;
// Queue of messages to be sent. This is wraped in a IoOwn so that the pump loop can safely
// access the map without locking the isolate.
IoOwn<OutgoingMessagesMap> outgoingMessages;

// Keep track of current hibernatable websockets auto-response status to avoid racing
// between regular websocket messages, and auto-responses.
struct AutoResponse {
kj::Promise<void> ongoingAutoResponse = kj::READY_NOW;
std::deque<kj::String> pendingAutoResponseDeque;
size_t queuedAutoResponses = 0;
bool isPumping = false;
bool isClosed = false;
};

AutoResponse autoResponseStatus;

Locality locality;

// Contains a websocket and possibly some data from the WebSocketResponse headers.
Expand Down Expand Up @@ -677,7 +694,8 @@ class WebSocket: public EventTarget {
// objects so are safe to access from the thread without the isolate lock. The whole task is
// owned by the `IoContext` so it'll be canceled if the `IoContext` is destroyed.
static kj::Promise<void> pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native);
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse);

kj::Promise<kj::Maybe<kj::Exception>> readLoop();

Expand Down
41 changes: 32 additions & 9 deletions src/workerd/io/hibernation-manager.c++
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,20 @@ kj::Vector<jsg::Ref<api::WebSocket>> HibernationManagerImpl::getWebSockets(
}

void HibernationManagerImpl::setWebSocketAutoResponse(
jsg::Ref<api::WebSocketRequestResponsePair> reqResp) {
autoResponsePair = kj::mv(reqResp);
kj::Maybe<kj::StringPtr> request, kj::Maybe<kj::StringPtr> response) {
KJ_IF_SOME(req, request) {
KJ_IF_SOME(resp, response){
auto autoRR = kj::heap<AutoRequestResponsePair>();
autoRR->response = kj::mv(req);
autoRR->request = kj::mv(resp);
autoResponsePair = kj::mv(autoRR);
return;
}
}
}

void HibernationManagerImpl::unsetWebSocketAutoResponse() {
autoResponsePair = kj::none;
return api::WebSocketRequestResponsePair::constructor(kj::str(ar.request), kj::str(ar.response));
}

kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> HibernationManagerImpl::getWebSocketAutoResponse() {
Expand Down Expand Up @@ -191,7 +199,7 @@ kj::Promise<void> HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) {
KJ_IF_SOME (reqResp, autoResponsePair) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
if (text == (reqResp)->getRequest()) {
if (text == reqResp.request) {
// If the received message matches the one set for auto-response, we must
// short-circuit readLoop, store the current timestamp and and automatically respond
// with the expected response.
Expand All @@ -202,15 +210,30 @@ kj::Promise<void> HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) {
// We'll store the current timestamp in the HibernatableWebSocket to assure it gets
// stored even if the WebSocket is currently hibernating. In that scenario, the timestamp
// value will be loaded into the WebSocket during unhibernation.
KJ_IF_SOME(active, hib.activeOrPackage.tryGet<jsg::Ref<api::WebSocket>>()) {
// If the actor is not hibernated/If the WebSocket is active, we need to update
// autoResponseTimestamp on the active websocket.
(active)->setAutoResponseTimestamp(hib.autoResponseTimestamp);
KJ_SWITCH_ONEOF(hib.activeOrPackage){
KJ_CASE_ONEOF(apiWs, jsg::Ref<api::WebSocket>) {
// If the actor is not hibernated/If the WebSocket is active, we need to update
// autoResponseTimestamp on the active websocket.
apiWs->setAutoResponseStatus(hib.autoResponseTimestamp, kj::READY_NOW);
co_await apiWs->sendAutoResponse(kj::str(reqResp.response.asArray()), ws);
}
KJ_CASE_ONEOF(package, api::WebSocket::HibernationPackage) {
if (!package.closedOutgoingConnection) {
// We need to store the autoResponsePromise because we may instantiate an api::websocket
// If we do that, we have to provide it with the promise to avoid races. This can
// happen if we have a websocket hibernating, that unhibernates and sends a
// message while ws.send() for auto-response is also sending.
auto p = ws.send(reqResp.response.asArray()).fork();
hib.autoResponsePromise = p.addBranch();
co_await p;
hib.autoResponsePromise = kj::READY_NOW;
}
}
}
co_await ws.send((reqResp)->getResponse().asArray());
skip = true;
// If we've sent an auto response message, we should not unhibernate or deliver the
// received message to the actor
skip = true;
}
}
KJ_CASE_ONEOF_DEFAULT {}
Expand Down
23 changes: 17 additions & 6 deletions src/workerd/io/hibernation-manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
// This converts our activeOrPackage from an api::WebSocket to a HibernationPackage.
void hibernateWebSockets(Worker::Lock& lock) override;

void setWebSocketAutoResponse(jsg::Ref<api::WebSocketRequestResponsePair> reqResp) override;
void unsetWebSocketAutoResponse() override;
void setWebSocketAutoResponse(kj::Maybe<kj::StringPtr> request,
kj::Maybe<kj::StringPtr> response) override;
kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> getWebSocketAutoResponse() override;
void setTimerChannel(TimerChannel& timerChannel) override;

Expand Down Expand Up @@ -108,11 +108,13 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
// to the api::WebSocket.
jsg::Ref<api::WebSocket> getActiveOrUnhibernate(jsg::Lock& js) {
KJ_IF_SOME(package, activeOrPackage.tryGet<api::WebSocket::HibernationPackage>()) {
// Now that we unhibernated the WebSocket, we can set the last received autoResponse timestamp
// that was stored in the corresponding HibernatableWebSocket. We also move autoResponsePromise
// from the hibernation manager to api::websocket to prevent possible ws.send races.
activeOrPackage.init<jsg::Ref<api::WebSocket>>(
api::WebSocket::hibernatableFromNative(js, *KJ_REQUIRE_NONNULL(ws), kj::mv(package))
)->setAutoResponseTimestamp(autoResponseTimestamp);
// Now that we unhibernated the WebSocket, we can set the last received autoResponse timestamp
// that was stored in the corresponding HibernatableWebSocket.
)->setAutoResponseStatus(autoResponseTimestamp, kj::mv(autoResponsePromise));
autoResponsePromise = kj::READY_NOW;
}
return activeOrPackage.get<jsg::Ref<api::WebSocket>>().addRef();
}
Expand Down Expand Up @@ -151,6 +153,10 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
// Stores the last received autoResponseRequest timestamp.
kj::Maybe<kj::Date> autoResponseTimestamp;

// Keeps track of the currently ongoing websocket auto-response send promise. This promise may
// be moved to api::websocket if an hibernating websocket unhibernates.
kj::Promise<void> autoResponsePromise = kj::READY_NOW;

friend HibernationManagerImpl;
};

Expand Down Expand Up @@ -181,6 +187,11 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
TagCollection(TagCollection&& other) = default;
};

struct AutoRequestResponsePair {
kj::StringPtr request;
kj::StringPtr response;
};

// A hashmap of tags to HibernatableWebSockets associated with the tag.
// We use a kj::List so we can quickly remove websockets that have disconnected.
// Also note that we box the keys and values such that in the event of a hashmap resizing we don't
Expand Down Expand Up @@ -216,7 +227,7 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
};
DisconnectHandler onDisconnect;
kj::TaskSet readLoopTasks;
kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> autoResponsePair;
kj::Maybe<AutoRequestResponsePair&> autoResponsePair;
kj::Maybe<TimerChannel&> timer;
};
}; // namespace workerd
4 changes: 2 additions & 2 deletions src/workerd/io/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,8 @@ class Worker::Actor final: public kj::Refcounted {
jsg::Lock& js,
kj::Maybe<kj::StringPtr> tag) = 0;
virtual void hibernateWebSockets(Worker::Lock& lock) = 0;
virtual void setWebSocketAutoResponse(jsg::Ref<api::WebSocketRequestResponsePair> reqResp) = 0;
virtual void unsetWebSocketAutoResponse() = 0;
virtual void setWebSocketAutoResponse(kj::Maybe<kj::StringPtr> request,
kj::Maybe<kj::StringPtr> response) = 0;
virtual kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> getWebSocketAutoResponse() = 0;
virtual void setTimerChannel(TimerChannel& timerChannel) = 0;
};
Expand Down

0 comments on commit accae32

Please sign in to comment.