Skip to content

Commit

Permalink
fix: disconnect RTC websockets when disconnected from main WS (#3803)
Browse files Browse the repository at this point in the history
Hopefully fix for #3554
  • Loading branch information
mscolnick authored Feb 14, 2025
1 parent 83c21f6 commit d039aa7
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 30 deletions.
131 changes: 131 additions & 0 deletions frontend/src/core/codemirror/rtc/__tests__/cell-manager.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/* Copyright 2024 Marimo. All rights reserved. */
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from "vitest";
import { CellProviderManager } from "../cell-manager";
import { WebsocketProvider } from "y-websocket";
import * as Y from "yjs";
import { store } from "@/core/state/jotai";
import { WebSocketState } from "@/core/websocket/types";
import { getSessionId } from "@/core/kernel/session";
import type { CellId } from "@/core/cells/ids";
import { connectionAtom } from "@/core/network/connection";

// Mock dependencies
vi.mock("y-websocket", () => ({
WebsocketProvider: vi.fn(),
}));

vi.mock("@/core/kernel/session", () => ({
getSessionId: vi.fn(),
}));

const CELL_ID = "cell1" as CellId;

describe("CellProviderManager", () => {
let manager: CellProviderManager;
const mockProvider = {
doc: {
getText: vi.fn(),
},
destroy: vi.fn(),
};
const mockYText = {};

beforeEach(() => {
vi.clearAllMocks();
manager = CellProviderManager.getInstance();
(WebsocketProvider as Mock).mockImplementation(() => mockProvider);
mockProvider.doc.getText.mockReturnValue(mockYText);
store.set(connectionAtom, { state: WebSocketState.OPEN });
vi.spyOn(store, "get");
(getSessionId as Mock).mockReturnValue("test-session");
});

afterEach(() => {
manager.disconnectAll();
});

it("should be a singleton", () => {
const instance1 = CellProviderManager.getInstance();
const instance2 = CellProviderManager.getInstance();
expect(instance1).toBe(instance2);
});

it("should create new provider if none exists", async () => {
const { provider, ytext } = manager.getOrCreateProvider(
CELL_ID,
"initial code",
);

expect(WebsocketProvider).toHaveBeenCalledWith(
"ws",
CELL_ID,
expect.any(Y.Doc),
{
params: {
session_id: "test-session",
},
},
);
expect(provider).toBe(mockProvider);
expect(ytext.toJSON()).toBe("initial code");
});

it("should return existing provider if one exists", async () => {
manager.getOrCreateProvider(CELL_ID, "initial code");
const { provider: provider2 } = manager.getOrCreateProvider(
CELL_ID,
"different code",
);

expect(WebsocketProvider).toHaveBeenCalledTimes(1);
expect(provider2).toBe(mockProvider);
});

it("should include file path in params if present", async () => {
const originalLocation = window.location;
// @ts-expect-error ehhh typescript
// biome-ignore lint/performance/noDelete: ehh
delete window.location;
window.location = {
...originalLocation,
search: "?file=/path/to/file.py",
};

manager.getOrCreateProvider(CELL_ID, "initial code");

expect(WebsocketProvider).toHaveBeenCalledWith(
"ws",
"cell1",
expect.any(Y.Doc),
{
params: {
session_id: "test-session",
file: "/path/to/file.py",
},
},
);
});

it("should disconnect a specific provider", () => {
manager.getOrCreateProvider(CELL_ID, "code");
manager.disconnect(CELL_ID);

expect(mockProvider.destroy).toHaveBeenCalled();
});

it("should disconnect all providers", async () => {
manager.getOrCreateProvider(CELL_ID, "code");
manager.getOrCreateProvider("cell2" as CellId, "code");
manager.disconnectAll();

expect(mockProvider.destroy).toHaveBeenCalledTimes(2);
});
});
92 changes: 92 additions & 0 deletions frontend/src/core/codemirror/rtc/cell-manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* Copyright 2024 Marimo. All rights reserved. */
import { WebsocketProvider } from "y-websocket";
import * as Y from "yjs";
import type { CellId } from "@/core/cells/ids";
import { KnownQueryParams } from "@/core/constants";
import { getSessionId } from "@/core/kernel/session";
import { connectionAtom } from "@/core/network/connection";
import { WebSocketState } from "@/core/websocket/types";
import { store } from "@/core/state/jotai";

const DOC_KEY = "code";

export class CellProviderManager {
private providers = new Map<CellId, WebsocketProvider>();
private static instance: CellProviderManager;

private constructor() {
this.listenForConnectionChanges();
}

static getInstance(): CellProviderManager {
if (!CellProviderManager.instance) {
CellProviderManager.instance = new CellProviderManager();
}
return CellProviderManager.instance;
}

getOrCreateProvider(
cellId: CellId,
initialCode: string,
): { provider: WebsocketProvider; ytext: Y.Text } {
const existingProvider = this.providers.get(cellId);
if (existingProvider) {
return {
provider: existingProvider,
ytext: existingProvider.doc.getText(DOC_KEY),
};
}

// Wait for connection to be established
// while (store.get(connectionAtom).state !== WebSocketState.OPEN) {
// await new Promise((resolve) => setTimeout(resolve, 100));
// }

const ydoc = new Y.Doc();
const ytext = ydoc.getText(DOC_KEY);
if (initialCode && ytext.length === 0) {
ytext.insert(0, initialCode);
}

const params: Record<string, string> = {
session_id: getSessionId(),
};
const searchParams = new URLSearchParams(window.location.search);
const filePath = searchParams.get(KnownQueryParams.filePath);
if (filePath) {
params.file = filePath;
}

const provider = new WebsocketProvider("ws", cellId, ydoc, { params });
this.providers.set(cellId, provider);

return { provider, ytext };
}

listenForConnectionChanges(): void {
const handleDisconnect = () => {
const value = store.get(connectionAtom);
const shouldDisconnect =
value.state === WebSocketState.CLOSED ||
value.state === WebSocketState.CLOSING;
if (shouldDisconnect) {
this.disconnectAll();
}
};
store.sub(connectionAtom, handleDisconnect);
}

disconnect(cellId: CellId): void {
const provider = this.providers.get(cellId);
if (provider) {
provider.destroy();
this.providers.delete(cellId);
}
}

disconnectAll(): void {
for (const [cellId] of this.providers) {
this.disconnect(cellId);
}
}
}
33 changes: 3 additions & 30 deletions frontend/src/core/codemirror/rtc/extension.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
/* Copyright 2024 Marimo. All rights reserved. */
import { KnownQueryParams } from "@/core/constants";
import { getSessionId } from "@/core/kernel/session";
import { yCollab } from "y-codemirror.next";
import * as Y from "yjs";
import { WebsocketProvider } from "y-websocket";
import type { CellId } from "@/core/cells/ids";
import { isWasm } from "@/core/wasm/utils";
import type { Extension } from "@codemirror/state";

const cellProviders = new Map<CellId, WebsocketProvider>();
import { CellProviderManager } from "./cell-manager";

export function realTimeCollaboration(
cellId: CellId,
Expand All @@ -21,30 +16,8 @@ export function realTimeCollaboration(
};
}

let wsProvider = cellProviders.get(cellId);
let ytext: Y.Text;

if (wsProvider) {
ytext = wsProvider.doc.getText("code");
} else {
const ydoc = new Y.Doc();
ytext = ydoc.getText("code");
if (initialCode) {
ytext.insert(0, initialCode);
}
// Add file and session_id to the params
const params: Record<string, string> = {};
params.session_id = getSessionId();
const searchParams = new URLSearchParams(window.location.search);
const filePath = searchParams.get(KnownQueryParams.filePath);
if (filePath) {
params.file = filePath;
}
wsProvider = new WebsocketProvider("ws", cellId, ydoc, {
params,
});
cellProviders.set(cellId, wsProvider);
}
const manager = CellProviderManager.getInstance();
const { ytext } = manager.getOrCreateProvider(cellId, initialCode);

const extension = yCollab(ytext, null);

Expand Down

0 comments on commit d039aa7

Please sign in to comment.