diff --git a/src/adapters/action/dispatcher-ws.spec.ts b/src/adapters/action/dispatcher-ws.spec.ts index 2b0a545b..b0462971 100644 --- a/src/adapters/action/dispatcher-ws.spec.ts +++ b/src/adapters/action/dispatcher-ws.spec.ts @@ -29,7 +29,7 @@ describe("DispatcherWs", () => { }).toThrow(MastoUnexpectedError); }); - it("can be disposed", async () => { + it("can be disposed", () => { const connector = new WebSocketConnectorImpl({ constructorParameters: ["wss://example.com"], }); @@ -41,8 +41,6 @@ describe("DispatcherWs", () => { ); dispatcher[Symbol.dispose](); - await expect(() => connector.acquire()).rejects.toThrow( - MastoWebSocketError, - ); + expect(() => connector.acquire()).toThrow(MastoWebSocketError); }); }); diff --git a/src/adapters/action/dispatcher-ws.ts b/src/adapters/action/dispatcher-ws.ts index bd1cb47d..d394c688 100644 --- a/src/adapters/action/dispatcher-ws.ts +++ b/src/adapters/action/dispatcher-ws.ts @@ -24,7 +24,7 @@ export class WebSocketActionDispatcher dispatch(action: WebSocketAction): T { if (action.type === "close") { - this.connector.close(); + this.connector.kill(); return {} as T; } @@ -50,6 +50,6 @@ export class WebSocketActionDispatcher } [Symbol.dispose](): void { - this.connector.close(); + this.connector.kill(); } } diff --git a/src/adapters/ws/web-socket-connector.spec.ts b/src/adapters/ws/web-socket-connector.spec.ts index 7d261f40..49230c42 100644 --- a/src/adapters/ws/web-socket-connector.spec.ts +++ b/src/adapters/ws/web-socket-connector.spec.ts @@ -18,7 +18,7 @@ describe("WebSocketConnector", () => { expect(ws1).toBe(ws2); server.close(); - connector.close(); + connector.kill(); }); it("rejects if WebSocket closes", async () => { @@ -26,7 +26,7 @@ describe("WebSocketConnector", () => { constructorParameters: [`ws://localhost:0`], }); const promise = connector.acquire(); - connector.close(); + connector.kill(); await expect(promise).rejects.toBeInstanceOf(MastoWebSocketError); }); diff --git a/src/adapters/ws/web-socket-connector.ts b/src/adapters/ws/web-socket-connector.ts index 819ee851..74195040 100644 --- a/src/adapters/ws/web-socket-connector.ts +++ b/src/adapters/ws/web-socket-connector.ts @@ -18,12 +18,10 @@ interface WebSocketConnectorImplProps { export class WebSocketConnectorImpl implements WebSocketConnector { private ws?: WebSocket; + private killed = false; private queue: PromiseWithResolvers[] = []; private backoff: ExponentialBackoff; - private closed = false; - private initialized = false; - constructor( private readonly props: WebSocketConnectorImplProps, private readonly logger?: Logger, @@ -31,32 +29,31 @@ export class WebSocketConnectorImpl implements WebSocketConnector { this.backoff = new ExponentialBackoff({ maxAttempts: this.props.maxAttempts, }); + this.spawn(); } - async acquire(): Promise { - if (this.closed) { - throw new MastoWebSocketError("WebSocket closed"); + async *[Symbol.asyncIterator](): AsyncIterableIterator { + while (!this.killed) { + yield await this.acquire(); } + } - this.init(); + acquire(): Promise { + if (this.killed) { + throw new MastoWebSocketError("WebSocket closed"); + } if (this.ws != undefined) { - return this.ws; + return Promise.resolve(this.ws); } const promiseWithResolvers = createPromiseWithResolvers(); this.queue.push(promiseWithResolvers); - return await promiseWithResolvers.promise; - } - - async *[Symbol.asyncIterator](): AsyncIterableIterator { - while (!this.closed) { - yield await this.acquire(); - } + return promiseWithResolvers.promise; } - close(): void { - this.closed = true; + kill(): void { + this.killed = true; this.ws?.close(); this.backoff.clear(); @@ -67,14 +64,8 @@ export class WebSocketConnectorImpl implements WebSocketConnector { this.queue = []; } - private async init() { - if (this.initialized) { - return; - } - - this.initialized = true; - - while (!this.closed) { + private async spawn() { + while (!this.killed) { this.ws?.close(); try { @@ -114,6 +105,7 @@ export class WebSocketConnectorImpl implements WebSocketConnector { ), ); } + this.queue = []; } } diff --git a/src/adapters/ws/web-socket-subscription.spec.ts b/src/adapters/ws/web-socket-subscription.spec.ts index c68798ba..5978a6e6 100644 --- a/src/adapters/ws/web-socket-subscription.spec.ts +++ b/src/adapters/ws/web-socket-subscription.spec.ts @@ -8,24 +8,6 @@ import { WebSocketSubscription } from "./web-socket-subscription"; import { WebSocketSubscriptionCounterImpl } from "./web-socket-subscription-counter"; describe("WebSocketSubscription", () => { - it("doesn't do anything if no connection was established", async () => { - const logger = createLogger(); - - const subscription = new WebSocketSubscription( - new WebSocketConnectorImpl( - { constructorParameters: ["ws://localhost:0"] }, - logger, - ), - new WebSocketSubscriptionCounterImpl(), - new SerializerNativeImpl(), - "public", - logger, - ); - - const res = subscription.unsubscribe(); - expect(res).toBeUndefined(); - }); - it("implements async iterator", async () => { const logger = createLogger(); const port = await getPort(); @@ -43,12 +25,12 @@ describe("WebSocketSubscription", () => { }); }); - const connection = new WebSocketConnectorImpl( + const connector = new WebSocketConnectorImpl( { constructorParameters: [`ws://localhost:${port}`] }, logger, ); const subscription = new WebSocketSubscription( - connection, + connector, new WebSocketSubscriptionCounterImpl(), new SerializerNativeImpl(), "public", @@ -63,7 +45,7 @@ describe("WebSocketSubscription", () => { expect(value).toBe("123"); - connection.close(); + connector.kill(); server.close(); }); }); diff --git a/src/adapters/ws/web-socket-subscription.ts b/src/adapters/ws/web-socket-subscription.ts index 512dc167..779fe040 100644 --- a/src/adapters/ws/web-socket-subscription.ts +++ b/src/adapters/ws/web-socket-subscription.ts @@ -56,10 +56,6 @@ export class WebSocketSubscription implements mastodon.streaming.Subscription { } unsubscribe(): void { - if (this.connection == undefined) { - return; - } - this.counter.decrement(this.stream, this.params); if (this.counter.count(this.stream, this.params) <= 0) { @@ -69,10 +65,8 @@ export class WebSocketSubscription implements mastodon.streaming.Subscription { ...this.params, }); - this.connection.send(data); + this.connection?.send(data); } - - this.connection = undefined; } [Symbol.asyncIterator](): AsyncIterableIterator { diff --git a/src/interfaces/ws.ts b/src/interfaces/ws.ts index 712ae4de..d410f6b7 100644 --- a/src/interfaces/ws.ts +++ b/src/interfaces/ws.ts @@ -2,7 +2,7 @@ import { type WebSocket } from "isomorphic-ws"; export interface WebSocketConnector extends AsyncIterable { acquire(): Promise; - close(): void; + kill(): void; } export interface WebSocketSubscriptionCounter { diff --git a/src/mastodon/streaming/client.ts b/src/mastodon/streaming/client.ts index 9f44985e..da732256 100644 --- a/src/mastodon/streaming/client.ts +++ b/src/mastodon/streaming/client.ts @@ -10,7 +10,6 @@ export interface SubscribeHashtagParams { export interface Subscription extends AsyncIterable, Disposable { values(): AsyncIterableIterator; - unsubscribe(): void; } diff --git a/tests/streaming/connections.spec.ts b/tests/streaming/connections.spec.ts index 95ebd75f..5bde0419 100644 --- a/tests/streaming/connections.spec.ts +++ b/tests/streaming/connections.spec.ts @@ -1,17 +1,19 @@ import assert from "node:assert"; +import crypto from "node:crypto"; it("maintains connections for the event even if other handlers closed it", async () => { + const tag = `tag_${crypto.randomBytes(4).toString("hex")}`; await using alice = await sessions.acquire({ waitForWs: true }); - using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" }); - using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" }); + using subscription1 = alice.ws.hashtag.subscribe({ tag }); + using subscription2 = alice.ws.hashtag.subscribe({ tag }); const promise1 = subscription1.values().take(1).toArray(); const promise2 = subscription2.values().take(2).toArray(); // Dispatch event for subscription1 to establish connection const status1 = await alice.rest.v1.statuses.create({ - status: "#test", + status: `#${tag}`, visibility: "public", }); await promise1; @@ -19,7 +21,7 @@ it("maintains connections for the event even if other handlers closed it", async // subscription1 is now closed, so status2 will only be dispatched to subscription2 const status2 = await alice.rest.v1.statuses.create({ - status: "#test", + status: `#${tag}`, visibility: "public", }); @@ -37,16 +39,17 @@ it("maintains connections for the event even if other handlers closed it", async }); it("maintains connections for the event if unsubscribe called twice", async () => { + const tag = `tag_${crypto.randomBytes(4).toString("hex")}`; await using alice = await sessions.acquire({ waitForWs: true }); - using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" }); - using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" }); + using subscription1 = alice.ws.hashtag.subscribe({ tag }); + using subscription2 = alice.ws.hashtag.subscribe({ tag }); const promise1 = subscription1.values().take(1).toArray(); const promise2 = subscription2.values().take(2).toArray(); const status1 = await alice.rest.v1.statuses.create({ - status: "#test", + status: `#${tag}`, visibility: "public", }); await promise1; @@ -56,7 +59,7 @@ it("maintains connections for the event if unsubscribe called twice", async () = subscription1.unsubscribe(); const status2 = await alice.rest.v1.statuses.create({ - status: "#test", + status: `#${tag}`, visibility: "public", }); @@ -74,17 +77,18 @@ it("maintains connections for the event if unsubscribe called twice", async () = }); it("maintains connections for the event if another handler called unsubscribe before connection established", async () => { + const tag = `tag_${crypto.randomBytes(4).toString("hex")}`; await using alice = await sessions.acquire({ waitForWs: true }); - using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" }); - using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" }); + using subscription1 = alice.ws.hashtag.subscribe({ tag }); + using subscription2 = alice.ws.hashtag.subscribe({ tag }); subscription1.unsubscribe(); const promise2 = subscription2.values().take(1).toArray(); const status1 = await alice.rest.v1.statuses.create({ - status: "#test", + status: `#${tag}`, visibility: "public", });