Skip to content

Commit

Permalink
Use weak refs between WebSockets in a WebSocketPair (cloudflare#2161)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasnell authored May 24, 2024
1 parent cab848c commit 1d89f3b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
30 changes: 18 additions & 12 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ WebSocket::WebSocket(jsg::Lock& js,
IoContext& ioContext,
kj::WebSocket& ws,
HibernationPackage package)
: url(kj::mv(package.url)),
: weakRef(kj::refcounted<WeakRef<WebSocket>>(kj::Badge<WebSocket> {}, *this)),
url(kj::mv(package.url)),
protocol(kj::mv(package.protocol)),
extensions(kj::mv(package.extensions)),
serializedAttachment(kj::mv(package.serializedAttachment)),
Expand All @@ -71,7 +72,8 @@ jsg::Ref<WebSocket> WebSocket::hibernatableFromNative(
}

WebSocket::WebSocket(kj::Own<kj::WebSocket> native, Locality locality)
: url(kj::none),
: weakRef(kj::refcounted<WeakRef<WebSocket>>(kj::Badge<WebSocket> {}, *this)),
url(kj::none),
farNative(nullptr),
outgoingMessages(IoContext::current().addObject(kj::heap<OutgoingMessagesMap>())),
locality(locality) {
Expand All @@ -81,7 +83,8 @@ WebSocket::WebSocket(kj::Own<kj::WebSocket> native, Locality locality)
}

WebSocket::WebSocket(kj::String url, Locality locality)
: url(kj::mv(url)),
: weakRef(kj::refcounted<WeakRef<WebSocket>>(kj::Badge<WebSocket> {}, *this)),
url(kj::mv(url)),
farNative(nullptr),
outgoingMessages(IoContext::current().addObject(kj::heap<OutgoingMessagesMap>())),
locality(locality) {
Expand Down Expand Up @@ -968,8 +971,8 @@ jsg::Ref<WebSocketPair> WebSocketPair::constructor() {
auto first = pair->getFirst();
auto second = pair->getSecond();

first->setMaybePair(second.addRef());
second->setMaybePair(first.addRef());
first->setMaybePair(second->addWeakRef());
second->setMaybePair(first->addWeakRef());
return kj::mv(pair);
}

Expand Down Expand Up @@ -1015,8 +1018,8 @@ void WebSocket::assertNoError(jsg::Lock& js) {
}
}

void WebSocket::setMaybePair(jsg::Ref<WebSocket> other) {
maybePair = other.addRef();
void WebSocket::setMaybePair(kj::Own<WeakRef<WebSocket>> other) {
maybePair = kj::mv(other);
}

kj::Own<kj::WebSocket> WebSocket::acceptAsHibernatable(kj::Array<kj::StringPtr> tags) {
Expand Down Expand Up @@ -1068,15 +1071,19 @@ bool WebSocket::awaitingHibernatableRelease() {
}

void WebSocket::setRemoteOnPair() {
JSG_REQUIRE_NONNULL(maybePair, Error,
"this WebSocket is not one end of a WebSocketPair")->locality = REMOTE;
auto& ref = JSG_REQUIRE_NONNULL(maybePair, Error,
"this WebSocket is not one end of a WebSocketPair");
ref->runIfAlive([](WebSocket& ref) { ref.locality = REMOTE; });
}

bool WebSocket::pairIsAwaitingCoupling() {
bool answer = false;
KJ_IF_SOME(pair, maybePair) {
return pair->farNative->state.is<AwaitingAcceptanceOrCoupling>();
pair->runIfAlive([&answer](WebSocket& pair) {
answer = pair.farNative->state.is<AwaitingAcceptanceOrCoupling>();
});
}
return false;
return answer;
}

WebSocket::HibernationPackage WebSocket::buildPackageForHibernation() {
Expand Down Expand Up @@ -1199,7 +1206,6 @@ void WebSocket::visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
tracker.trackField("error", error);
tracker.trackFieldWithSize("IoOwn<OutgoingMessagesMap>", sizeof(IoOwn<OutgoingMessagesMap>));
tracker.trackField("autoResponseStatus", autoResponseStatus);
tracker.trackField("maybePair", maybePair);
}

} // namespace workerd::api
22 changes: 20 additions & 2 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <kj/compat/http.h>
#include "basics.h"
#include <workerd/io/io-context.h>
#include <workerd/util/weak-refs.h>
#include <stdlib.h>

namespace workerd {
Expand Down Expand Up @@ -230,6 +231,10 @@ class WebSocket: public EventTarget {
bool closedOutgoingConnection = false;
};

~WebSocket() noexcept(false) {
weakRef->invalidate();
}

// This WebSocket constructor is only used when WebSockets wake up from hibernation.
// It will immediately set the `state` to `Accepted`, but it limits the behavior by specifying it
// as `Hibernatable` -- thereby making most api::WebSocket methods inaccessible.
Expand Down Expand Up @@ -409,7 +414,12 @@ class WebSocket: public EventTarget {

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const;

kj::Own<WeakRef<WebSocket>> addWeakRef() {
return weakRef->addRef();
}

private:
kj::Own<WeakRef<WebSocket>> weakRef;
kj::Maybe<kj::String> url;
kj::Maybe<kj::String> protocol = kj::String();
kj::Maybe<kj::String> extensions = kj::String();
Expand Down Expand Up @@ -599,9 +609,17 @@ class WebSocket: public EventTarget {
};

// So that each end of a WebSocketPair can keep track of its pair.
kj::Maybe<jsg::Ref<WebSocket>> maybePair;
// We use a weak ref to track the pair to avoid having a strong ref cycle
// between the two WebSocket instances that would cause them to leak. This
// can mean, however, that it's possible for one side of the pair to be garbage
// collected while the other still exists. This should be fairly unusual tho.
kj::Maybe<kj::Own<WeakRef<WebSocket>>> maybePair;

void visitForGc(jsg::GcVisitor& visitor) {
visitor.visit(error);
}

void setMaybePair(jsg::Ref<WebSocket> other);
void setMaybePair(kj::Own<WeakRef<WebSocket>> other);

friend jsg::Ref<WebSocketPair> WebSocketPair::constructor();

Expand Down

0 comments on commit 1d89f3b

Please sign in to comment.