Skip to content

Commit

Permalink
Fix concurrent room initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Mar 21, 2024
1 parent 44f6a7a commit d89ab87
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 107 deletions.
117 changes: 63 additions & 54 deletions jupyter_collaboration/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class YDocWebSocketHandler(WebSocketHandler, JupyterHandler):

_message_queue: asyncio.Queue[Any]
_background_tasks: set[asyncio.Task]
_room_locks: dict[str, asyncio.Lock] = {}

def _room_lock(self, room_id: str) -> asyncio.Lock:
if room_id not in self._room_locks:
self._room_locks[room_id] = asyncio.Lock()
return self._room_locks[room_id]

def create_task(self, aw):
task = asyncio.create_task(aw)
Expand All @@ -70,38 +76,38 @@ async def prepare(self):
# Get room
self._room_id: str = self.request.path.split("/")[-1]

if self._websocket_server.room_exists(self._room_id):
self.room: YRoom = await self._websocket_server.get_room(self._room_id)

else:
if self._room_id.count(":") >= 2:
# DocumentRoom
file_format, file_type, file_id = decode_file_path(self._room_id)
if file_id in self._file_loaders:
self._emit(
LogLevel.WARNING,
None,
"There is another collaborative session accessing the same file.\nThe synchronization between rooms is not supported and you might lose some of your changes.",
async with self._room_lock(self._room_id):
if self._websocket_server.room_exists(self._room_id):
self.room: YRoom = await self._websocket_server.get_room(self._room_id)
else:
if self._room_id.count(":") >= 2:
# DocumentRoom
file_format, file_type, file_id = decode_file_path(self._room_id)
if file_id in self._file_loaders:
self._emit(
LogLevel.WARNING,
None,
"There is another collaborative session accessing the same file.\nThe synchronization between rooms is not supported and you might lose some of your changes.",
)

file = self._file_loaders[file_id]
updates_file_path = f".{file_type}:{file_id}.y"
ystore = self._ystore_class(path=updates_file_path, log=self.log)
self.room = DocumentRoom(
self._room_id,
file_format,
file_type,
file,
self.event_logger,
ystore,
self.log,
self._document_save_delay,
)

file = self._file_loaders[file_id]
updates_file_path = f".{file_type}:{file_id}.y"
ystore = self._ystore_class(path=updates_file_path, log=self.log)
self.room = DocumentRoom(
self._room_id,
file_format,
file_type,
file,
self.event_logger,
ystore,
self.log,
self._document_save_delay,
)

else:
# TransientRoom
# it is a transient document (e.g. awareness)
self.room = TransientRoom(self._room_id, self.log)
else:
# TransientRoom
# it is a transient document (e.g. awareness)
self.room = TransientRoom(self._room_id, self.log)

await self._websocket_server.start_room(self.room)
self._websocket_server.add_room(self._room_id, self.room)
Expand Down Expand Up @@ -184,7 +190,8 @@ async def open(self, room_id):

try:
# Initialize the room
await self.room.initialize()
async with self._room_lock(self._room_id):
await self.room.initialize()
self._emit_awareness_event(self.current_user.username, "join")
except Exception as e:
_, _, file_id = decode_file_path(self._room_id)
Expand Down Expand Up @@ -323,29 +330,31 @@ async def _clean_room(self) -> None:
contains a copy of the document. In addition, we remove the file if there is no rooms
subscribed to it.
"""
assert isinstance(self.room, DocumentRoom)

if self._cleanup_delay is None:
return

await asyncio.sleep(self._cleanup_delay)

# Remove the room from the websocket server
self.log.info("Deleting Y document from memory: %s", self.room.room_id)
self._websocket_server.delete_room(room=self.room)

# Clean room
del self.room
self.log.info("Room %s deleted", self._room_id)
self._emit(LogLevel.INFO, "clean", "Room deleted.")

# Clean the file loader if there are not rooms using it
_, _, file_id = decode_file_path(self._room_id)
file = self._file_loaders[file_id]
if file.number_of_subscriptions == 0:
self.log.info("Deleting file %s", file.path)
await self._file_loaders.remove(file_id)
self._emit(LogLevel.INFO, "clean", "Loader deleted.")
async with self._room_lock(self._room_id):
assert isinstance(self.room, DocumentRoom)

if self._cleanup_delay is None:
return

await asyncio.sleep(self._cleanup_delay)

# Remove the room from the websocket server
self.log.info("Deleting Y document from memory: %s", self.room.room_id)
self._websocket_server.delete_room(room=self.room)

# Clean room
del self.room
self.log.info("Room %s deleted", self._room_id)
self._emit(LogLevel.INFO, "clean", "Room deleted.")

# Clean the file loader if there are not rooms using it
_, _, file_id = decode_file_path(self._room_id)
file = self._file_loaders[file_id]
if file.number_of_subscriptions == 0:
self.log.info("Deleting file %s", file.path)
await self._file_loaders.remove(file_id)
self._emit(LogLevel.INFO, "clean", "Loader deleted.")
del self._room_locks[self._room_id]

def check_origin(self, origin):
"""
Expand Down
102 changes: 50 additions & 52 deletions jupyter_collaboration/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
self._save_delay = save_delay

self._update_lock = asyncio.Lock()
self._initialization_lock = asyncio.Lock()
self._cleaner: asyncio.Task | None = None
self._saving_document: asyncio.Task | None = None
self._messages: dict[str, asyncio.Lock] = {}
Expand Down Expand Up @@ -89,64 +88,63 @@ async def initialize(self) -> None:
It is important to set the ready property in the parent class (`self.ready = True`),
this setter will subscribe for updates on the shared document.
"""
async with self._initialization_lock:
if self.ready: # type: ignore[has-type]
return
if self.ready: # type: ignore[has-type]
return

self.log.info("Initializing room %s", self._room_id)
self.log.info("Initializing room %s", self._room_id)

model = await self._file.load_content(self._file_format, self._file_type)
model = await self._file.load_content(self._file_format, self._file_type)

async with self._update_lock:
# try to apply Y updates from the YStore for this document
read_from_source = True
if self.ystore is not None:
try:
await self.ystore.apply_updates(self.ydoc)
self._emit(
LogLevel.INFO,
"load",
"Content loaded from the store {}".format(
self.ystore.__class__.__qualname__
),
)
self.log.info(
"Content in room %s loaded from the ystore %s",
self._room_id,
self.ystore.__class__.__name__,
)
read_from_source = False
except YDocNotFound:
# YDoc not found in the YStore, create the document from the source file (no change history)
pass

if not read_from_source:
# if YStore updates and source file are out-of-sync, resync updates with source
if self._document.source != model["content"]:
# TODO: Delete document from the store.
self._emit(
LogLevel.INFO, "initialize", "The file is out-of-sync with the ystore."
)
self.log.info(
"Content in file %s is out-of-sync with the ystore %s",
self._file.path,
self.ystore.__class__.__name__,
)
read_from_source = True

if read_from_source:
self._emit(LogLevel.INFO, "load", "Content loaded from disk.")
async with self._update_lock:
# try to apply Y updates from the YStore for this document
read_from_source = True
if self.ystore is not None:
try:
await self.ystore.apply_updates(self.ydoc)
self._emit(
LogLevel.INFO,
"load",
"Content loaded from the store {}".format(
self.ystore.__class__.__qualname__
),
)
self.log.info(
"Content in room %s loaded from the ystore %s",
self._room_id,
self.ystore.__class__.__name__,
)
read_from_source = False
except YDocNotFound:
# YDoc not found in the YStore, create the document from the source file (no change history)
pass

if not read_from_source:
# if YStore updates and source file are out-of-sync, resync updates with source
if self._document.source != model["content"]:
# TODO: Delete document from the store.
self._emit(
LogLevel.INFO, "initialize", "The file is out-of-sync with the ystore."
)
self.log.info(
"Content in room %s loaded from file %s", self._room_id, self._file.path
"Content in file %s is out-of-sync with the ystore %s",
self._file.path,
self.ystore.__class__.__name__,
)
self._document.source = model["content"]
read_from_source = True

if read_from_source:
self._emit(LogLevel.INFO, "load", "Content loaded from disk.")
self.log.info(
"Content in room %s loaded from file %s", self._room_id, self._file.path
)
self._document.source = model["content"]

if self.ystore:
await self.ystore.encode_state_as_update(self.ydoc)
if self.ystore:
await self.ystore.encode_state_as_update(self.ydoc)

self._document.dirty = False
self.ready = True
self._emit(LogLevel.INFO, "initialize", "Room initialized")
self._document.dirty = False
self.ready = True
self._emit(LogLevel.INFO, "initialize", "Room initialized")

def _emit(self, level: LogLevel, action: str | None = None, msg: str | None = None) -> None:
data = {"level": level.value, "room": self._room_id, "path": self._file.path}
Expand Down
24 changes: 23 additions & 1 deletion tests/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# Distributed under the terms of the Modified BSD License.

import sys
from time import time

if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points

import pytest
from anyio import sleep
from anyio import create_task_group, sleep
from pycrdt_websocket import WebsocketProvider

jupyter_ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyter_ydoc")}
Expand Down Expand Up @@ -37,3 +38,24 @@ async def test_dirty(
jupyter_ydoc.dirty = True
await sleep(rtc_document_save_delay * 1.5)
assert not jupyter_ydoc.dirty


async def test_room_concurrent_initialization(
rtc_create_file,
rtc_connect_doc_client,
):
file_format = "text"
file_type = "file"
file_path = "dummy.txt"
await rtc_create_file(file_path)

async def connect(file_format, file_type, file_path):
async with await rtc_connect_doc_client(file_format, file_type, file_path) as ws:
pass

t0 = time()
async with create_task_group() as tg:
tg.start_soon(connect, file_format, file_type, file_path)
tg.start_soon(connect, file_format, file_type, file_path)
t1 = time()
assert t1 - t0 < 0.5

0 comments on commit d89ab87

Please sign in to comment.