From 50cf9c6703aa9dc01fb8d080250867113a7913ef Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Fri, 14 Feb 2025 13:21:24 -0500 Subject: [PATCH 1/2] fix: disconnect RTC websockets when disconnected from main WS --- .../rtc/__tests__/cell-manager.test.ts | 131 ++++++++++++++++++ .../src/core/codemirror/rtc/cell-manager.ts | 92 ++++++++++++ frontend/src/core/codemirror/rtc/extension.ts | 33 +---- 3 files changed, 226 insertions(+), 30 deletions(-) create mode 100644 frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts create mode 100644 frontend/src/core/codemirror/rtc/cell-manager.ts diff --git a/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts b/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts new file mode 100644 index 00000000000..922b1fa9872 --- /dev/null +++ b/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts @@ -0,0 +1,131 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from "vitest"; +import { CellProviderManager } from "../cell-manager"; +import { WebsocketProvider } from "y-websocket"; +import * as Y from "yjs"; +import { store } from "@/core/state/jotai"; +import { WebSocketState } from "@/core/websocket/types"; +import { getSessionId } from "@/core/kernel/session"; +import type { CellId } from "@/core/cells/ids"; +import { connectionAtom } from "@/core/network/connection"; + +// Mock dependencies +vi.mock("y-websocket", () => ({ + WebsocketProvider: vi.fn(), +})); + +vi.mock("@/core/kernel/session", () => ({ + getSessionId: vi.fn(), +})); + +const CELL_ID = "cell1" as CellId; + +describe("CellProviderManager", () => { + let manager: CellProviderManager; + const mockProvider = { + doc: { + getText: vi.fn(), + }, + destroy: vi.fn(), + }; + const mockYText = {}; + + beforeEach(() => { + vi.clearAllMocks(); + manager = CellProviderManager.getInstance(); + (WebsocketProvider as Mock).mockImplementation(() => mockProvider); + mockProvider.doc.getText.mockReturnValue(mockYText); + store.set(connectionAtom, { state: WebSocketState.OPEN }); + vi.spyOn(store, "get"); + (getSessionId as Mock).mockReturnValue("test-session"); + }); + + afterEach(() => { + manager.disconnectAll(); + }); + + it("should be a singleton", () => { + const instance1 = CellProviderManager.getInstance(); + const instance2 = CellProviderManager.getInstance(); + expect(instance1).toBe(instance2); + }); + + it("should create new provider if none exists", async () => { + const { provider, ytext } = await manager.getOrCreateProvider( + CELL_ID, + "initial code", + ); + + expect(WebsocketProvider).toHaveBeenCalledWith( + "ws", + CELL_ID, + expect.any(Y.Doc), + { + params: { + session_id: "test-session", + }, + }, + ); + expect(provider).toBe(mockProvider); + expect(ytext.toJSON()).toBe("initial code"); + }); + + it("should return existing provider if one exists", async () => { + await manager.getOrCreateProvider(CELL_ID, "initial code"); + const { provider: provider2 } = await manager.getOrCreateProvider( + CELL_ID, + "different code", + ); + + expect(WebsocketProvider).toHaveBeenCalledTimes(1); + expect(provider2).toBe(mockProvider); + }); + + it("should include file path in params if present", async () => { + const originalLocation = window.location; + // @ts-expect-error ehhh typescript + // biome-ignore lint/performance/noDelete: ehh + delete window.location; + window.location = { + ...originalLocation, + search: "?file=/path/to/file.py", + }; + + await manager.getOrCreateProvider(CELL_ID, "initial code"); + + expect(WebsocketProvider).toHaveBeenCalledWith( + "ws", + "cell1", + expect.any(Y.Doc), + { + params: { + session_id: "test-session", + file: "/path/to/file.py", + }, + }, + ); + }); + + it("should disconnect a specific provider", () => { + manager.getOrCreateProvider(CELL_ID, "code"); + manager.disconnect(CELL_ID); + + expect(mockProvider.destroy).toHaveBeenCalled(); + }); + + it("should disconnect all providers", async () => { + await manager.getOrCreateProvider(CELL_ID, "code"); + await manager.getOrCreateProvider("cell2" as CellId, "code"); + manager.disconnectAll(); + + expect(mockProvider.destroy).toHaveBeenCalledTimes(2); + }); +}); diff --git a/frontend/src/core/codemirror/rtc/cell-manager.ts b/frontend/src/core/codemirror/rtc/cell-manager.ts new file mode 100644 index 00000000000..1c87372da36 --- /dev/null +++ b/frontend/src/core/codemirror/rtc/cell-manager.ts @@ -0,0 +1,92 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { WebsocketProvider } from "y-websocket"; +import * as Y from "yjs"; +import type { CellId } from "@/core/cells/ids"; +import { KnownQueryParams } from "@/core/constants"; +import { getSessionId } from "@/core/kernel/session"; +import { connectionAtom } from "@/core/network/connection"; +import { WebSocketState } from "@/core/websocket/types"; +import { store } from "@/core/state/jotai"; + +const DOC_KEY = "code"; + +export class CellProviderManager { + private providers = new Map(); + private static instance: CellProviderManager; + + private constructor() { + this.listenForConnectionChanges(); + } + + static getInstance(): CellProviderManager { + if (!CellProviderManager.instance) { + CellProviderManager.instance = new CellProviderManager(); + } + return CellProviderManager.instance; + } + + getOrCreateProvider( + cellId: CellId, + initialCode: string, + ): { provider: WebsocketProvider; ytext: Y.Text } { + const existingProvider = this.providers.get(cellId); + if (existingProvider) { + return { + provider: existingProvider, + ytext: existingProvider.doc.getText(DOC_KEY), + }; + } + + // Wait for connection to be established + // while (store.get(connectionAtom).state !== WebSocketState.OPEN) { + // await new Promise((resolve) => setTimeout(resolve, 100)); + // } + + const ydoc = new Y.Doc(); + const ytext = ydoc.getText(DOC_KEY); + if (initialCode && ytext.length === 0) { + ytext.insert(0, initialCode); + } + + const params: Record = { + session_id: getSessionId(), + }; + const searchParams = new URLSearchParams(window.location.search); + const filePath = searchParams.get(KnownQueryParams.filePath); + if (filePath) { + params.file = filePath; + } + + const provider = new WebsocketProvider("ws", cellId, ydoc, { params }); + this.providers.set(cellId, provider); + + return { provider, ytext }; + } + + listenForConnectionChanges(): void { + const handleDisconnect = () => { + const value = store.get(connectionAtom); + const shouldDisconnect = + value.state === WebSocketState.CLOSED || + value.state === WebSocketState.CLOSING; + if (shouldDisconnect) { + this.disconnectAll(); + } + }; + store.sub(connectionAtom, handleDisconnect); + } + + disconnect(cellId: CellId): void { + const provider = this.providers.get(cellId); + if (provider) { + provider.destroy(); + this.providers.delete(cellId); + } + } + + disconnectAll(): void { + for (const [cellId] of this.providers) { + this.disconnect(cellId); + } + } +} diff --git a/frontend/src/core/codemirror/rtc/extension.ts b/frontend/src/core/codemirror/rtc/extension.ts index f9ced8d229b..b988cf32842 100644 --- a/frontend/src/core/codemirror/rtc/extension.ts +++ b/frontend/src/core/codemirror/rtc/extension.ts @@ -1,14 +1,9 @@ /* Copyright 2024 Marimo. All rights reserved. */ -import { KnownQueryParams } from "@/core/constants"; -import { getSessionId } from "@/core/kernel/session"; import { yCollab } from "y-codemirror.next"; -import * as Y from "yjs"; -import { WebsocketProvider } from "y-websocket"; import type { CellId } from "@/core/cells/ids"; import { isWasm } from "@/core/wasm/utils"; import type { Extension } from "@codemirror/state"; - -const cellProviders = new Map(); +import { CellProviderManager } from "./cell-manager"; export function realTimeCollaboration( cellId: CellId, @@ -21,30 +16,8 @@ export function realTimeCollaboration( }; } - let wsProvider = cellProviders.get(cellId); - let ytext: Y.Text; - - if (wsProvider) { - ytext = wsProvider.doc.getText("code"); - } else { - const ydoc = new Y.Doc(); - ytext = ydoc.getText("code"); - if (initialCode) { - ytext.insert(0, initialCode); - } - // Add file and session_id to the params - const params: Record = {}; - params.session_id = getSessionId(); - const searchParams = new URLSearchParams(window.location.search); - const filePath = searchParams.get(KnownQueryParams.filePath); - if (filePath) { - params.file = filePath; - } - wsProvider = new WebsocketProvider("ws", cellId, ydoc, { - params, - }); - cellProviders.set(cellId, wsProvider); - } + const manager = CellProviderManager.getInstance(); + const { ytext } = manager.getOrCreateProvider(cellId, initialCode); const extension = yCollab(ytext, null); From dd3e6e3bc2e256fb0f367eb588cd6ec794348119 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Fri, 14 Feb 2025 13:33:43 -0500 Subject: [PATCH 2/2] lint --- .../codemirror/rtc/__tests__/cell-manager.test.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts b/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts index 922b1fa9872..c5397355850 100644 --- a/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts +++ b/frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts @@ -59,7 +59,7 @@ describe("CellProviderManager", () => { }); it("should create new provider if none exists", async () => { - const { provider, ytext } = await manager.getOrCreateProvider( + const { provider, ytext } = manager.getOrCreateProvider( CELL_ID, "initial code", ); @@ -79,8 +79,8 @@ describe("CellProviderManager", () => { }); it("should return existing provider if one exists", async () => { - await manager.getOrCreateProvider(CELL_ID, "initial code"); - const { provider: provider2 } = await manager.getOrCreateProvider( + manager.getOrCreateProvider(CELL_ID, "initial code"); + const { provider: provider2 } = manager.getOrCreateProvider( CELL_ID, "different code", ); @@ -99,7 +99,7 @@ describe("CellProviderManager", () => { search: "?file=/path/to/file.py", }; - await manager.getOrCreateProvider(CELL_ID, "initial code"); + manager.getOrCreateProvider(CELL_ID, "initial code"); expect(WebsocketProvider).toHaveBeenCalledWith( "ws", @@ -122,8 +122,8 @@ describe("CellProviderManager", () => { }); it("should disconnect all providers", async () => { - await manager.getOrCreateProvider(CELL_ID, "code"); - await manager.getOrCreateProvider("cell2" as CellId, "code"); + manager.getOrCreateProvider(CELL_ID, "code"); + manager.getOrCreateProvider("cell2" as CellId, "code"); manager.disconnectAll(); expect(mockProvider.destroy).toHaveBeenCalledTimes(2);