-
Notifications
You must be signed in to change notification settings - Fork 398
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: disconnect RTC websockets when disconnected from main WS (#3803)
Hopefully fix for #3554
- Loading branch information
Showing
3 changed files
with
226 additions
and
30 deletions.
There are no files selected for viewing
131 changes: 131 additions & 0 deletions
131
frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* Copyright 2024 Marimo. All rights reserved. */ | ||
import { | ||
describe, | ||
it, | ||
expect, | ||
vi, | ||
beforeEach, | ||
afterEach, | ||
type Mock, | ||
} from "vitest"; | ||
import { CellProviderManager } from "../cell-manager"; | ||
import { WebsocketProvider } from "y-websocket"; | ||
import * as Y from "yjs"; | ||
import { store } from "@/core/state/jotai"; | ||
import { WebSocketState } from "@/core/websocket/types"; | ||
import { getSessionId } from "@/core/kernel/session"; | ||
import type { CellId } from "@/core/cells/ids"; | ||
import { connectionAtom } from "@/core/network/connection"; | ||
|
||
// Mock dependencies | ||
vi.mock("y-websocket", () => ({ | ||
WebsocketProvider: vi.fn(), | ||
})); | ||
|
||
vi.mock("@/core/kernel/session", () => ({ | ||
getSessionId: vi.fn(), | ||
})); | ||
|
||
const CELL_ID = "cell1" as CellId; | ||
|
||
describe("CellProviderManager", () => { | ||
let manager: CellProviderManager; | ||
const mockProvider = { | ||
doc: { | ||
getText: vi.fn(), | ||
}, | ||
destroy: vi.fn(), | ||
}; | ||
const mockYText = {}; | ||
|
||
beforeEach(() => { | ||
vi.clearAllMocks(); | ||
manager = CellProviderManager.getInstance(); | ||
(WebsocketProvider as Mock).mockImplementation(() => mockProvider); | ||
mockProvider.doc.getText.mockReturnValue(mockYText); | ||
store.set(connectionAtom, { state: WebSocketState.OPEN }); | ||
vi.spyOn(store, "get"); | ||
(getSessionId as Mock).mockReturnValue("test-session"); | ||
}); | ||
|
||
afterEach(() => { | ||
manager.disconnectAll(); | ||
}); | ||
|
||
it("should be a singleton", () => { | ||
const instance1 = CellProviderManager.getInstance(); | ||
const instance2 = CellProviderManager.getInstance(); | ||
expect(instance1).toBe(instance2); | ||
}); | ||
|
||
it("should create new provider if none exists", async () => { | ||
const { provider, ytext } = manager.getOrCreateProvider( | ||
CELL_ID, | ||
"initial code", | ||
); | ||
|
||
expect(WebsocketProvider).toHaveBeenCalledWith( | ||
"ws", | ||
CELL_ID, | ||
expect.any(Y.Doc), | ||
{ | ||
params: { | ||
session_id: "test-session", | ||
}, | ||
}, | ||
); | ||
expect(provider).toBe(mockProvider); | ||
expect(ytext.toJSON()).toBe("initial code"); | ||
}); | ||
|
||
it("should return existing provider if one exists", async () => { | ||
manager.getOrCreateProvider(CELL_ID, "initial code"); | ||
const { provider: provider2 } = manager.getOrCreateProvider( | ||
CELL_ID, | ||
"different code", | ||
); | ||
|
||
expect(WebsocketProvider).toHaveBeenCalledTimes(1); | ||
expect(provider2).toBe(mockProvider); | ||
}); | ||
|
||
it("should include file path in params if present", async () => { | ||
const originalLocation = window.location; | ||
// @ts-expect-error ehhh typescript | ||
// biome-ignore lint/performance/noDelete: ehh | ||
delete window.location; | ||
window.location = { | ||
...originalLocation, | ||
search: "?file=/path/to/file.py", | ||
}; | ||
|
||
manager.getOrCreateProvider(CELL_ID, "initial code"); | ||
|
||
expect(WebsocketProvider).toHaveBeenCalledWith( | ||
"ws", | ||
"cell1", | ||
expect.any(Y.Doc), | ||
{ | ||
params: { | ||
session_id: "test-session", | ||
file: "/path/to/file.py", | ||
}, | ||
}, | ||
); | ||
}); | ||
|
||
it("should disconnect a specific provider", () => { | ||
manager.getOrCreateProvider(CELL_ID, "code"); | ||
manager.disconnect(CELL_ID); | ||
|
||
expect(mockProvider.destroy).toHaveBeenCalled(); | ||
}); | ||
|
||
it("should disconnect all providers", async () => { | ||
manager.getOrCreateProvider(CELL_ID, "code"); | ||
manager.getOrCreateProvider("cell2" as CellId, "code"); | ||
manager.disconnectAll(); | ||
|
||
expect(mockProvider.destroy).toHaveBeenCalledTimes(2); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* Copyright 2024 Marimo. All rights reserved. */ | ||
import { WebsocketProvider } from "y-websocket"; | ||
import * as Y from "yjs"; | ||
import type { CellId } from "@/core/cells/ids"; | ||
import { KnownQueryParams } from "@/core/constants"; | ||
import { getSessionId } from "@/core/kernel/session"; | ||
import { connectionAtom } from "@/core/network/connection"; | ||
import { WebSocketState } from "@/core/websocket/types"; | ||
import { store } from "@/core/state/jotai"; | ||
|
||
const DOC_KEY = "code"; | ||
|
||
export class CellProviderManager { | ||
private providers = new Map<CellId, WebsocketProvider>(); | ||
private static instance: CellProviderManager; | ||
|
||
private constructor() { | ||
this.listenForConnectionChanges(); | ||
} | ||
|
||
static getInstance(): CellProviderManager { | ||
if (!CellProviderManager.instance) { | ||
CellProviderManager.instance = new CellProviderManager(); | ||
} | ||
return CellProviderManager.instance; | ||
} | ||
|
||
getOrCreateProvider( | ||
cellId: CellId, | ||
initialCode: string, | ||
): { provider: WebsocketProvider; ytext: Y.Text } { | ||
const existingProvider = this.providers.get(cellId); | ||
if (existingProvider) { | ||
return { | ||
provider: existingProvider, | ||
ytext: existingProvider.doc.getText(DOC_KEY), | ||
}; | ||
} | ||
|
||
// Wait for connection to be established | ||
// while (store.get(connectionAtom).state !== WebSocketState.OPEN) { | ||
// await new Promise((resolve) => setTimeout(resolve, 100)); | ||
// } | ||
|
||
const ydoc = new Y.Doc(); | ||
const ytext = ydoc.getText(DOC_KEY); | ||
if (initialCode && ytext.length === 0) { | ||
ytext.insert(0, initialCode); | ||
} | ||
|
||
const params: Record<string, string> = { | ||
session_id: getSessionId(), | ||
}; | ||
const searchParams = new URLSearchParams(window.location.search); | ||
const filePath = searchParams.get(KnownQueryParams.filePath); | ||
if (filePath) { | ||
params.file = filePath; | ||
} | ||
|
||
const provider = new WebsocketProvider("ws", cellId, ydoc, { params }); | ||
this.providers.set(cellId, provider); | ||
|
||
return { provider, ytext }; | ||
} | ||
|
||
listenForConnectionChanges(): void { | ||
const handleDisconnect = () => { | ||
const value = store.get(connectionAtom); | ||
const shouldDisconnect = | ||
value.state === WebSocketState.CLOSED || | ||
value.state === WebSocketState.CLOSING; | ||
if (shouldDisconnect) { | ||
this.disconnectAll(); | ||
} | ||
}; | ||
store.sub(connectionAtom, handleDisconnect); | ||
} | ||
|
||
disconnect(cellId: CellId): void { | ||
const provider = this.providers.get(cellId); | ||
if (provider) { | ||
provider.destroy(); | ||
this.providers.delete(cellId); | ||
} | ||
} | ||
|
||
disconnectAll(): void { | ||
for (const [cellId] of this.providers) { | ||
this.disconnect(cellId); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters