diff --git a/jupyter_collaboration/handlers.py b/jupyter_collaboration/handlers.py index b0513156..a4d05e72 100644 --- a/jupyter_collaboration/handlers.py +++ b/jupyter_collaboration/handlers.py @@ -56,6 +56,12 @@ class YDocWebSocketHandler(WebSocketHandler, JupyterHandler): _message_queue: asyncio.Queue[Any] _background_tasks: set[asyncio.Task] + _room_locks: dict[str, asyncio.Lock] = {} + + def _room_lock(self, room_id: str) -> asyncio.Lock: + if room_id not in self._room_locks: + self._room_locks[room_id] = asyncio.Lock() + return self._room_locks[room_id] def create_task(self, aw): task = asyncio.create_task(aw) @@ -70,38 +76,38 @@ async def prepare(self): # Get room self._room_id: str = self.request.path.split("/")[-1] - if self._websocket_server.room_exists(self._room_id): - self.room: YRoom = await self._websocket_server.get_room(self._room_id) - - else: - if self._room_id.count(":") >= 2: - # DocumentRoom - file_format, file_type, file_id = decode_file_path(self._room_id) - if file_id in self._file_loaders: - self._emit( - LogLevel.WARNING, - None, - "There is another collaborative session accessing the same file.\nThe synchronization between rooms is not supported and you might lose some of your changes.", + async with self._room_lock(self._room_id): + if self._websocket_server.room_exists(self._room_id): + self.room: YRoom = await self._websocket_server.get_room(self._room_id) + else: + if self._room_id.count(":") >= 2: + # DocumentRoom + file_format, file_type, file_id = decode_file_path(self._room_id) + if file_id in self._file_loaders: + self._emit( + LogLevel.WARNING, + None, + "There is another collaborative session accessing the same file.\nThe synchronization between rooms is not supported and you might lose some of your changes.", + ) + + file = self._file_loaders[file_id] + updates_file_path = f".{file_type}:{file_id}.y" + ystore = self._ystore_class(path=updates_file_path, log=self.log) + self.room = DocumentRoom( + self._room_id, + file_format, + file_type, + file, + self.event_logger, + ystore, + self.log, + self._document_save_delay, ) - file = self._file_loaders[file_id] - updates_file_path = f".{file_type}:{file_id}.y" - ystore = self._ystore_class(path=updates_file_path, log=self.log) - self.room = DocumentRoom( - self._room_id, - file_format, - file_type, - file, - self.event_logger, - ystore, - self.log, - self._document_save_delay, - ) - - else: - # TransientRoom - # it is a transient document (e.g. awareness) - self.room = TransientRoom(self._room_id, self.log) + else: + # TransientRoom + # it is a transient document (e.g. awareness) + self.room = TransientRoom(self._room_id, self.log) await self._websocket_server.start_room(self.room) self._websocket_server.add_room(self._room_id, self.room) @@ -184,7 +190,8 @@ async def open(self, room_id): try: # Initialize the room - await self.room.initialize() + async with self._room_lock(self._room_id): + await self.room.initialize() self._emit_awareness_event(self.current_user.username, "join") except Exception as e: _, _, file_id = decode_file_path(self._room_id) @@ -323,29 +330,31 @@ async def _clean_room(self) -> None: contains a copy of the document. In addition, we remove the file if there is no rooms subscribed to it. """ - assert isinstance(self.room, DocumentRoom) - - if self._cleanup_delay is None: - return - - await asyncio.sleep(self._cleanup_delay) - - # Remove the room from the websocket server - self.log.info("Deleting Y document from memory: %s", self.room.room_id) - self._websocket_server.delete_room(room=self.room) - - # Clean room - del self.room - self.log.info("Room %s deleted", self._room_id) - self._emit(LogLevel.INFO, "clean", "Room deleted.") - - # Clean the file loader if there are not rooms using it - _, _, file_id = decode_file_path(self._room_id) - file = self._file_loaders[file_id] - if file.number_of_subscriptions == 0: - self.log.info("Deleting file %s", file.path) - await self._file_loaders.remove(file_id) - self._emit(LogLevel.INFO, "clean", "Loader deleted.") + async with self._room_lock(self._room_id): + assert isinstance(self.room, DocumentRoom) + + if self._cleanup_delay is None: + return + + await asyncio.sleep(self._cleanup_delay) + + # Remove the room from the websocket server + self.log.info("Deleting Y document from memory: %s", self.room.room_id) + self._websocket_server.delete_room(room=self.room) + + # Clean room + del self.room + self.log.info("Room %s deleted", self._room_id) + self._emit(LogLevel.INFO, "clean", "Room deleted.") + + # Clean the file loader if there are not rooms using it + _, _, file_id = decode_file_path(self._room_id) + file = self._file_loaders[file_id] + if file.number_of_subscriptions == 0: + self.log.info("Deleting file %s", file.path) + await self._file_loaders.remove(file_id) + self._emit(LogLevel.INFO, "clean", "Loader deleted.") + del self._room_locks[self._room_id] def check_origin(self, origin): """ diff --git a/jupyter_collaboration/rooms.py b/jupyter_collaboration/rooms.py index e4e83264..e88e6d39 100644 --- a/jupyter_collaboration/rooms.py +++ b/jupyter_collaboration/rooms.py @@ -44,7 +44,6 @@ def __init__( self._save_delay = save_delay self._update_lock = asyncio.Lock() - self._initialization_lock = asyncio.Lock() self._cleaner: asyncio.Task | None = None self._saving_document: asyncio.Task | None = None self._messages: dict[str, asyncio.Lock] = {} @@ -89,64 +88,63 @@ async def initialize(self) -> None: It is important to set the ready property in the parent class (`self.ready = True`), this setter will subscribe for updates on the shared document. """ - async with self._initialization_lock: - if self.ready: # type: ignore[has-type] - return + if self.ready: # type: ignore[has-type] + return - self.log.info("Initializing room %s", self._room_id) + self.log.info("Initializing room %s", self._room_id) - model = await self._file.load_content(self._file_format, self._file_type) + model = await self._file.load_content(self._file_format, self._file_type) - async with self._update_lock: - # try to apply Y updates from the YStore for this document - read_from_source = True - if self.ystore is not None: - try: - await self.ystore.apply_updates(self.ydoc) - self._emit( - LogLevel.INFO, - "load", - "Content loaded from the store {}".format( - self.ystore.__class__.__qualname__ - ), - ) - self.log.info( - "Content in room %s loaded from the ystore %s", - self._room_id, - self.ystore.__class__.__name__, - ) - read_from_source = False - except YDocNotFound: - # YDoc not found in the YStore, create the document from the source file (no change history) - pass - - if not read_from_source: - # if YStore updates and source file are out-of-sync, resync updates with source - if self._document.source != model["content"]: - # TODO: Delete document from the store. - self._emit( - LogLevel.INFO, "initialize", "The file is out-of-sync with the ystore." - ) - self.log.info( - "Content in file %s is out-of-sync with the ystore %s", - self._file.path, - self.ystore.__class__.__name__, - ) - read_from_source = True - - if read_from_source: - self._emit(LogLevel.INFO, "load", "Content loaded from disk.") + async with self._update_lock: + # try to apply Y updates from the YStore for this document + read_from_source = True + if self.ystore is not None: + try: + await self.ystore.apply_updates(self.ydoc) + self._emit( + LogLevel.INFO, + "load", + "Content loaded from the store {}".format( + self.ystore.__class__.__qualname__ + ), + ) + self.log.info( + "Content in room %s loaded from the ystore %s", + self._room_id, + self.ystore.__class__.__name__, + ) + read_from_source = False + except YDocNotFound: + # YDoc not found in the YStore, create the document from the source file (no change history) + pass + + if not read_from_source: + # if YStore updates and source file are out-of-sync, resync updates with source + if self._document.source != model["content"]: + # TODO: Delete document from the store. + self._emit( + LogLevel.INFO, "initialize", "The file is out-of-sync with the ystore." + ) self.log.info( - "Content in room %s loaded from file %s", self._room_id, self._file.path + "Content in file %s is out-of-sync with the ystore %s", + self._file.path, + self.ystore.__class__.__name__, ) - self._document.source = model["content"] + read_from_source = True + + if read_from_source: + self._emit(LogLevel.INFO, "load", "Content loaded from disk.") + self.log.info( + "Content in room %s loaded from file %s", self._room_id, self._file.path + ) + self._document.source = model["content"] - if self.ystore: - await self.ystore.encode_state_as_update(self.ydoc) + if self.ystore: + await self.ystore.encode_state_as_update(self.ydoc) - self._document.dirty = False - self.ready = True - self._emit(LogLevel.INFO, "initialize", "Room initialized") + self._document.dirty = False + self.ready = True + self._emit(LogLevel.INFO, "initialize", "Room initialized") def _emit(self, level: LogLevel, action: str | None = None, msg: str | None = None) -> None: data = {"level": level.value, "room": self._room_id, "path": self._file.path} diff --git a/tests/test_documents.py b/tests/test_documents.py index 35b1f213..20efb8dc 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -2,6 +2,7 @@ # Distributed under the terms of the Modified BSD License. import sys +from time import time if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -9,7 +10,7 @@ from importlib.metadata import entry_points import pytest -from anyio import sleep +from anyio import create_task_group, sleep from pycrdt_websocket import WebsocketProvider jupyter_ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyter_ydoc")} @@ -37,3 +38,24 @@ async def test_dirty( jupyter_ydoc.dirty = True await sleep(rtc_document_save_delay * 1.5) assert not jupyter_ydoc.dirty + + +async def test_room_concurrent_initialization( + rtc_create_file, + rtc_connect_doc_client, +): + file_format = "text" + file_type = "file" + file_path = "dummy.txt" + await rtc_create_file(file_path) + + async def connect(file_format, file_type, file_path): + async with await rtc_connect_doc_client(file_format, file_type, file_path) as ws: + pass + + t0 = time() + async with create_task_group() as tg: + tg.start_soon(connect, file_format, file_type, file_path) + tg.start_soon(connect, file_format, file_type, file_path) + t1 = time() + assert t1 - t0 < 0.5