Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix websocket auto-response message races #1309

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/workerd/api/actor-state.c++
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ void DurableObjectState::setWebSocketAutoResponse(
// If there's no request/response pair, we unset any current set auto response configuration.
KJ_IF_SOME(manager, a.getHibernationManager()) {
// If there's no hibernation manager created yet, there's nothing to do here.
manager.unsetWebSocketAutoResponse();
manager.setWebSocketAutoResponse(kj::none, kj::none);
}
return;
}
Expand All @@ -902,15 +902,15 @@ void DurableObjectState::setWebSocketAutoResponse(
// If there's no hibernation manager created yet, we should create one and
// set its auto response.
}
KJ_REQUIRE_NONNULL(a.getHibernationManager()).setWebSocketAutoResponse(kj::mv(reqResp));
KJ_REQUIRE_NONNULL(a.getHibernationManager()).setWebSocketAutoResponse(
reqResp->getRequest(), reqResp->getResponse());
}

kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> DurableObjectState::getWebSocketAutoResponse() {
auto& a = KJ_REQUIRE_NONNULL(IoContext::current().getActor());
KJ_IF_SOME(manager, a.getHibernationManager()) {
// If there's no hibernation manager created yet, there's nothing to do here.
auto r = manager.getWebSocketAutoResponse();
return r;
return manager.getWebSocketAutoResponse();
}
return kj::none;
}
Expand Down
69 changes: 64 additions & 5 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ IoOwn<WebSocket::Native> WebSocket::initNative(
// We might have called `close()` when this WebSocket was previously active.
// If so, we want to prevent any future calls to `send()`.
nativeObj->closedOutgoing = closedOutgoingConn;
autoResponseStatus.isClosed = nativeObj->closedOutgoing;
return ioContext.addObject(kj::mv(nativeObj));
}

Expand Down Expand Up @@ -531,7 +532,11 @@ void WebSocket::send(jsg::Lock& js, kj::OneOf<kj::Array<byte>, kj::String> messa

KJ_UNREACHABLE;
}();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg)});

auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() -
autoResponseStatus.queuedAutoResponses;
autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg), pendingAutoResponses});

ensurePumping(js);
}
Expand Down Expand Up @@ -578,13 +583,22 @@ void WebSocket::close(
"If you specify a WebSocket close reason, you must also specify a code.");
}

// pendingAutoResponses stores the number of queuedAutoResponses that will be pumped before sending
// the current GatedMessage, guaranteeing order.
// queuedAutoResponses stores the total number of auto-response messages that are already in accounted
// for in previous GatedMessages. This is useful to easily calculate the number of pendingAutoResponses
// for each new GateMessage.
auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() -
autoResponseStatus.queuedAutoResponses;
autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size();

outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::WebSocket::Close {
// Code 1005 actually translates to sending a close message with no body on the wire.
static_cast<uint16_t>(code.orDefault(1005)),
kj::mv(reason).orDefault(nullptr),
},
}, pendingAutoResponses
});

native.closedOutgoing = true;
Expand Down Expand Up @@ -670,8 +684,10 @@ void WebSocket::serializeAttachment(jsg::Lock& js, jsg::JsValue attachment) {
serializedAttachment = kj::mv(released.data);
}

void WebSocket::setAutoResponseTimestamp(kj::Maybe<kj::Date> time) {
void WebSocket::setAutoResponseStatus(kj::Maybe<kj::Date> time,
kj::Promise<void> autoResponsePromise) {
autoResponseTimestamp = time;
autoResponseStatus.ongoingAutoResponse = kj::mv(autoResponsePromise);
}


Expand All @@ -689,7 +705,8 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
auto& context = IoContext::current();
auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet<Accepted>());
auto promise = kj::evalNow([&]() {
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws, native));
return accepted.canceler.wrap(pump(context, *outgoingMessages,
*accepted.ws, native, autoResponseStatus));
});

// TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending
Expand Down Expand Up @@ -729,6 +746,17 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
}
}

kj::Promise<void> WebSocket::sendAutoResponse(kj::String message, kj::WebSocket& ws) {
if (autoResponseStatus.isPumping) {
autoResponseStatus.pendingAutoResponseDeque.push_back(kj::mv(message));
} else if (!autoResponseStatus.isClosed){
auto p = ws.send(message).fork();
autoResponseStatus.ongoingAutoResponse = p.addBranch();
co_await p;
autoResponseStatus.ongoingAutoResponse = kj::READY_NOW;
}
}

namespace {

size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
Expand Down Expand Up @@ -757,9 +785,11 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
} // namespace

kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native) {
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse) {
KJ_ASSERT(!native.isPumping);
native.isPumping = true;
autoResponse.isPumping = true;
KJ_DEFER({
// We use a KJ_DEFER to set native.isPumping = false to ensure that it happens -- we had a bug
// in the past where this was handled by the caller of WebSocket::pump() and it allowed for
Expand All @@ -769,8 +799,19 @@ kj::Promise<void> WebSocket::pump(
// Either we were already through all our outgoing messages or we experienced failure/
// cancellation and cannot send these anyway.
outgoingMessages.clear();

autoResponse.isPumping = false;

if (autoResponse.pendingAutoResponseDeque.size() > 0) {
autoResponse.pendingAutoResponseDeque.clear();
}
});

// If we have a ongoingAutoResponse, we must co_await it here because there's a ws.send()
// in progress. Otherwise there can occur ws.send() race problems.
co_await autoResponse.ongoingAutoResponse;
autoResponse.ongoingAutoResponse = kj::READY_NOW;

while (outgoingMessages.size() > 0) {
GatedMessage gatedMessage = outgoingMessages.release(*outgoingMessages.ordered().begin());
KJ_IF_SOME(promise, gatedMessage.outputLock) {
Expand All @@ -779,6 +820,15 @@ kj::Promise<void> WebSocket::pump(

auto size = countBytesFromMessage(gatedMessage.message);

while (gatedMessage.pendingAutoResponses > 0) {
KJ_ASSERT(autoResponse.pendingAutoResponseDeque.size() >= gatedMessage.pendingAutoResponses);
auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front());
autoResponse.pendingAutoResponseDeque.pop_front();
gatedMessage.pendingAutoResponses--;
autoResponse.queuedAutoResponses--;
co_await ws.send(message);
}

KJ_SWITCH_ONEOF(gatedMessage.message) {
KJ_CASE_ONEOF(text, kj::String) {
co_await ws.send(text);
Expand All @@ -790,6 +840,7 @@ kj::Promise<void> WebSocket::pump(
}
KJ_CASE_ONEOF(close, kj::WebSocket::Close) {
co_await ws.close(close.code, close.reason);
autoResponse.isClosed = true;
break;
}
}
Expand All @@ -798,6 +849,14 @@ kj::Promise<void> WebSocket::pump(
a.getMetrics().sentWebSocketMessage(size);
}
}

// If there are any auto-responses left to process, we should do it now.
// We should also check if the last sent message was a close. Shouldn't happen.
while (autoResponse.pendingAutoResponseDeque.size() > 0 && !autoResponse.isClosed) {
auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front());
autoResponse.pendingAutoResponseDeque.pop_front();
co_await ws.send(message);
}
}

void WebSocket::tryReleaseNative(jsg::Lock& js) {
Expand Down
22 changes: 20 additions & 2 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "basics.h"
#include <workerd/io/io-context.h>
#include <workerd/jsg/ser.h>
#include <stdlib.h>

namespace workerd {
class ActorObserver;
Expand Down Expand Up @@ -395,12 +396,15 @@ class WebSocket: public EventTarget {

// Used to get/store the last auto request/response timestamp for this WebSocket.
// These methods are c++ only and are not exposed to our js interface.
void setAutoResponseTimestamp(kj::Maybe<kj::Date> time);
// Also used to track hibernatable websockets auto-response sends.
void setAutoResponseStatus(kj::Maybe<kj::Date> time, kj::Promise<void> autoResponsePromise);

// Used to get/store the last auto request/response timestamp for this WebSocket.
// These methods are c++ only and are not exposed to our js interface.
kj::Maybe<kj::Date> getAutoResponseTimestamp();

kj::Promise<void> sendAutoResponse(kj::String message, kj::WebSocket& ws);

int getReadyState();

bool isAccepted();
Expand Down Expand Up @@ -642,12 +646,25 @@ class WebSocket: public EventTarget {
struct GatedMessage {
kj::Maybe<kj::Promise<void>> outputLock; // must wait for this before actually sending
kj::WebSocket::Message message;
size_t pendingAutoResponses = 0;
};
using OutgoingMessagesMap = kj::Table<GatedMessage, kj::InsertionOrderIndex>;
// Queue of messages to be sent. This is wraped in a IoOwn so that the pump loop can safely
// access the map without locking the isolate.
IoOwn<OutgoingMessagesMap> outgoingMessages;

// Keep track of current hibernatable websockets auto-response status to avoid racing
// between regular websocket messages, and auto-responses.
struct AutoResponse {
kj::Promise<void> ongoingAutoResponse = kj::READY_NOW;
std::deque<kj::String> pendingAutoResponseDeque;
size_t queuedAutoResponses = 0;
bool isPumping = false;
bool isClosed = false;
};

AutoResponse autoResponseStatus;

Locality locality;

// Contains a websocket and possibly some data from the WebSocketResponse headers.
Expand Down Expand Up @@ -677,7 +694,8 @@ class WebSocket: public EventTarget {
// objects so are safe to access from the thread without the isolate lock. The whole task is
// owned by the `IoContext` so it'll be canceled if the `IoContext` is destroyed.
static kj::Promise<void> pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native);
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse);

kj::Promise<kj::Maybe<kj::Exception>> readLoop();

Expand Down
63 changes: 45 additions & 18 deletions src/workerd/io/hibernation-manager.c++
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,26 @@ kj::Vector<jsg::Ref<api::WebSocket>> HibernationManagerImpl::getWebSockets(
}

void HibernationManagerImpl::setWebSocketAutoResponse(
jsg::Ref<api::WebSocketRequestResponsePair> reqResp) {
autoResponsePair = kj::mv(reqResp);
}

void HibernationManagerImpl::unsetWebSocketAutoResponse() {
autoResponsePair = kj::none;
kj::Maybe<kj::StringPtr> request, kj::Maybe<kj::StringPtr> response) {
jqmmes marked this conversation as resolved.
Show resolved Hide resolved
KJ_IF_SOME(req, request) {
// If we have a request, we must also have a response. If response is kj::none, we'll throw.
autoResponsePair->request = kj::str(req);
autoResponsePair->response = kj::str(KJ_REQUIRE_NONNULL(response));
return;
}
// If we don't have a request, we must unset both request and response.
autoResponsePair->request = kj::none;
autoResponsePair->response = kj::none;
}

kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> HibernationManagerImpl::getWebSocketAutoResponse() {
KJ_IF_SOME(ar, autoResponsePair) {
return ar.addRef();
} else {
return kj::none;
KJ_IF_SOME(req, autoResponsePair->request) {
// When getting the currently set auto-response pair, if we have a request we must have a response
// set. If not, we'll throw.
return api::WebSocketRequestResponsePair::constructor(kj::str(req),
kj::str(KJ_REQUIRE_NONNULL(autoResponsePair->response)));
}
return kj::none;
}

void HibernationManagerImpl::setTimerChannel(TimerChannel& timerChannel) {
Expand Down Expand Up @@ -188,10 +194,12 @@ kj::Promise<void> HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) {

auto skip = false;

KJ_IF_SOME (reqResp, autoResponsePair) {
// If we have a request != kj::none, we can compare it the received message. This also implies
// that we have a response set in autoResponsePair.
KJ_IF_SOME (req, autoResponsePair->request) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
if (text == (reqResp)->getRequest()) {
if (text == req) {
// If the received message matches the one set for auto-response, we must
// short-circuit readLoop, store the current timestamp and and automatically respond
// with the expected response.
Expand All @@ -202,15 +210,34 @@ kj::Promise<void> HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) {
// We'll store the current timestamp in the HibernatableWebSocket to assure it gets
// stored even if the WebSocket is currently hibernating. In that scenario, the timestamp
// value will be loaded into the WebSocket during unhibernation.
KJ_IF_SOME(active, hib.activeOrPackage.tryGet<jsg::Ref<api::WebSocket>>()) {
// If the actor is not hibernated/If the WebSocket is active, we need to update
// autoResponseTimestamp on the active websocket.
(active)->setAutoResponseTimestamp(hib.autoResponseTimestamp);
KJ_SWITCH_ONEOF(hib.activeOrPackage){
KJ_CASE_ONEOF(apiWs, jsg::Ref<api::WebSocket>) {
// If the actor is not hibernated/If the WebSocket is active, we need to update
// autoResponseTimestamp on the active websocket.
apiWs->setAutoResponseStatus(hib.autoResponseTimestamp, kj::READY_NOW);
// Since we had a request set, we must have and response that's sent back using the
// same websocket here. The sending of response is managed in web-socket to avoid
// possible racing problems with regular websocket messages.
co_await apiWs->sendAutoResponse(
kj::str(KJ_REQUIRE_NONNULL(autoResponsePair->response).asArray()), ws);
}
KJ_CASE_ONEOF(package, api::WebSocket::HibernationPackage) {
if (!package.closedOutgoingConnection) {
// We need to store the autoResponsePromise because we may instantiate an api::websocket
// If we do that, we have to provide it with the promise to avoid races. This can
// happen if we have a websocket hibernating, that unhibernates and sends a
// message while ws.send() for auto-response is also sending.
auto p = ws.send(
KJ_REQUIRE_NONNULL(autoResponsePair->response).asArray()).fork();
hib.autoResponsePromise = p.addBranch();
co_await p;
hib.autoResponsePromise = kj::READY_NOW;
}
}
}
co_await ws.send((reqResp)->getResponse().asArray());
skip = true;
// If we've sent an auto response message, we should not unhibernate or deliver the
// received message to the actor
skip = true;
}
}
KJ_CASE_ONEOF_DEFAULT {}
Expand Down
Loading
Loading