Skip to content

Commit

Permalink
Merge pull request #2385 from cloudflare/jphillips/websocket-observab…
Browse files Browse the repository at this point in the history
…ility

Add WebSocketObserver
  • Loading branch information
jp4a50 authored Jul 17, 2024
2 parents 66c5c5c + 69927cb commit 0e28d29
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/workerd/api/http.c++
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ kj::Promise<DeferredProxy<void>> Response::send(
}

auto clientSocket = outer.acceptWebSocket(outHeaders);
auto wsPromise = ws->couple(kj::mv(clientSocket));
auto wsPromise = ws->couple(kj::mv(clientSocket), context.getMetrics());

KJ_IF_SOME(a, context.getActor()) {
KJ_IF_SOME(hib, a.getHibernationManager()) {
Expand Down
24 changes: 19 additions & 5 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ jsg::Ref<WebSocket> WebSocket::constructor(
return ws;
}

kj::Promise<DeferredProxy<void>> WebSocket::couple(kj::Own<kj::WebSocket> other) {
kj::Promise<DeferredProxy<void>> WebSocket::couple(kj::Own<kj::WebSocket> other, RequestObserver& request) {
auto& native = *farNative;
JSG_REQUIRE(!native.state.is<AwaitingConnection>(), TypeError,
"Can't return WebSocket in a Response if it was created with `new WebSocket()`");
Expand Down Expand Up @@ -341,10 +341,16 @@ kj::Promise<DeferredProxy<void>> WebSocket::couple(kj::Own<kj::WebSocket> other)
}
return false;
};
if (tryGetPeer() != kj::none) {
KJ_IF_SOME(p, tryGetPeer()) {
// We're terminating the WebSocket in this worker, so the upstream promise (which pumps
// messages from the client to this worker) counts as something the request is waiting for.
upstream = upstream.attach(context.registerPendingEvent());

// We can observe websocket traffic in both directions by attaching an observer to the peer
// websocket which terminates in the worker.
KJ_IF_SOME(observer, request.tryCreateWebSocketObserver()) {
p.observer = kj::mv(observer);
}
}

// We need to use `eagerlyEvaluate()` on both inputs to `joinPromises` to work around the awkward
Expand Down Expand Up @@ -745,7 +751,7 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet<Accepted>());
auto promise = kj::evalNow([&]() {
return accepted.canceler.wrap(pump(context, *outgoingMessages,
*accepted.ws, native, autoResponseStatus));
*accepted.ws, native, autoResponseStatus, observer));
});

// TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending
Expand Down Expand Up @@ -840,7 +846,7 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) {

kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse) {
AutoResponse& autoResponse, kj::Maybe<kj::Own<WebSocketObserver>>& observer) {
KJ_ASSERT(!native.isPumping);
native.isPumping = true;
autoResponse.isPumping = true;
Expand Down Expand Up @@ -899,6 +905,10 @@ kj::Promise<void> WebSocket::pump(
}
}

KJ_IF_SOME(o, observer) {
o->sentMessage(size);
}

KJ_IF_SOME(a, context.getActor()) {
a.getMetrics().sentWebSocketMessage(size);
}
Expand Down Expand Up @@ -941,9 +951,13 @@ kj::Promise<kj::Maybe<kj::Exception>> WebSocket::readLoop(
while (true) {
auto message = co_await ws.receive();

auto size = countBytesFromMessage(message);
KJ_IF_SOME(o, observer) {
o->receivedMessage(size);
}

context.getLimitEnforcer().topUpActor();
KJ_IF_SOME(a, context.getActor()) {
auto size = countBytesFromMessage(message);
a.getMetrics().receivedWebSocketMessage(size);
}

Expand Down
9 changes: 6 additions & 3 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,15 @@ class WebSocket: public EventTarget {
void initConnection(jsg::Lock& js, kj::Promise<PackedWebSocket>);

// Pumps messages from this WebSocket to `other`, and from `other` to this, making sure to
// register pending events as appropriate. Used to implement FetchEvent.respondWith().
// register pending events as appropriate. Used to connect a websocket to a client via an HTTP
// response.
//
// Only one of this or accept() is allowed to be invoked.
//
// As an exception to the usual KJ convention, it is not necessary for the JavaScript `WebSocket`
// object to be kept live while waiting for the promise returned by couple() to complete. Instead,
// the promise takes direct ownership of the underlying KJ-native WebSocket (as well as `other`).
kj::Promise<DeferredProxy<void>> couple(kj::Own<kj::WebSocket> other);
kj::Promise<DeferredProxy<void>> couple(kj::Own<kj::WebSocket> other, RequestObserver& request);

// Extract the kj::WebSocket from this api::WebSocket (if applicable). The kj::WebSocket will be
// owned elsewhere, but the api::WebSocket will retain a reference.
Expand Down Expand Up @@ -585,6 +586,8 @@ class WebSocket: public EventTarget {

AutoResponse autoResponseStatus;

kj::Maybe<kj::Own<WebSocketObserver>> observer;

// Contains a websocket and possibly some data from the WebSocketResponse headers.
struct PackedWebSocket {
kj::Own<kj::WebSocket> ws;
Expand Down Expand Up @@ -620,7 +623,7 @@ class WebSocket: public EventTarget {
// owned by the `IoContext` so it'll be canceled if the `IoContext` is destroyed.
static kj::Promise<void> pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse);
AutoResponse& autoResponse, kj::Maybe<kj::Own<WebSocketObserver>>& observer);

kj::Promise<kj::Maybe<kj::Exception>> readLoop(kj::Maybe<kj::Own<InputGate::CriticalSection>> cs);

Expand Down
16 changes: 16 additions & 0 deletions src/workerd/io/observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,27 @@ class WorkerInterface;
class LimitEnforcer;
class TimerChannel;

class WebSocketObserver: public kj::Refcounted {
public:
// Called when a worker sends a message on this WebSocket (includes close messages).
virtual void sentMessage(size_t bytes) { };
// Called when a worker receives a message on this WebSocket (includes close messages).
virtual void receivedMessage(size_t bytes) { };
};

// Observes a specific request to a specific worker. Also observes outgoing subrequests.
//
// Observing anything is optional. Default implementations of all methods observe nothing.
class RequestObserver: public kj::Refcounted {
public:
// This is called when the request is converted to a WebSocket connection terminating in a worker.
// An optional WebSocket observer may be returned to observe events on the worker's end of the
// WebSocket connection.
//
// This means that, when the returned observer observes a message being sent, the message is being
// sent from the worker to the client making the request.
virtual kj::Maybe<kj::Own<WebSocketObserver>> tryCreateWebSocketObserver() { return kj::none; };

// Invoked when the request is actually delivered.
//
// If, for some reason, this is not invoked before the object is destroyed, this indicate that
Expand Down

0 comments on commit 0e28d29

Please sign in to comment.