Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement server-side ypywidgets rendering #364

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,6 @@ $RECYCLE.BIN/
.jupyter_ystore.db
.jupyter_ystore.db-journal
fps_cli_args.toml

# pixi environments
.pixi
1 change: 1 addition & 0 deletions jupyverse_api/jupyverse_api/jupyterlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ async def get_workspace(

class JupyterLabConfig(Config):
dev_mode: bool = False
server_side_execution: bool = False
2 changes: 1 addition & 1 deletion jupyverse_api/jupyverse_api/kernels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ class Session(BaseModel):

class Execution(BaseModel):
document_id: str
cell_idx: int
cell_id: str
6 changes: 5 additions & 1 deletion plugins/jupyterlab/fps_jupyterlab/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def get_lab(
self.get_index(
"default",
self.frontend_config.collaborative,
self.jupyterlab_config.server_side_execution,
self.jupyterlab_config.dev_mode,
self.frontend_config.base_url,
)
Expand All @@ -71,6 +72,7 @@ async def load_workspace(
self.get_index(
"default",
self.frontend_config.collaborative,
self.jupyterlab_config.server_side_execution,
self.jupyterlab_config.dev_mode,
self.frontend_config.base_url,
)
Expand Down Expand Up @@ -99,11 +101,12 @@ async def get_workspace(
return self.get_index(
name,
self.frontend_config.collaborative,
self.jupyterlab_config.server_side_execution,
self.jupyterlab_config.dev_mode,
self.frontend_config.base_url,
)

def get_index(self, workspace, collaborative, dev_mode, base_url="/"):
def get_index(self, workspace, collaborative, server_side_execution, dev_mode, base_url="/"):
for path in (self.static_lab_dir).glob("main.*.js"):
main_id = path.name.split(".")[1]
break
Expand All @@ -121,6 +124,7 @@ def get_index(self, workspace, collaborative, dev_mode, base_url="/"):
"baseUrl": base_url,
"cacheFiles": False,
"collaborative": collaborative,
"serverSideExecution": server_side_execution,
"devMode": dev_mode,
"disabledExtensions": self.disabled_extension,
"exposeAppInBrowser": False,
Expand Down
174 changes: 131 additions & 43 deletions plugins/kernels/fps_kernels/kernel_driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import uuid
from typing import Any, Dict, List, Optional, cast

from pycrdt import Array, Map

from jupyverse_api.yjs import Yjs

from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
from .connect import write_connection_file as _write_connection_file
from .kernelspec import find_kernelspec
Expand All @@ -23,10 +27,12 @@ def __init__(
connection_file: str = "",
write_connection_file: bool = True,
capture_kernel_output: bool = True,
yjs: Optional[Yjs] = None,
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
self.kernel_cwd = kernel_cwd
self.yjs = yjs
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
if write_connection_file:
Expand All @@ -37,11 +43,12 @@ def __init__(
self.key = cast(str, self.connection_cfg["key"])
self.session_id = uuid.uuid4().hex
self.msg_cnt = 0
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
self.channel_tasks: List[asyncio.Task] = []
self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {}
self.comm_messages: asyncio.Queue = asyncio.Queue()
self.tasks: List[asyncio.Task] = []

async def restart(self, startup_timeout: float = float("inf")) -> None:
for task in self.channel_tasks:
for task in self.tasks:
task.cancel()
msg = create_message("shutdown_request", content={"restart": True})
await send_message(msg, self.control_channel, self.key, change_date_to_str=True)
Expand All @@ -52,7 +59,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
break
await self._wait_for_ready(startup_timeout)
self.channel_tasks = []
self.tasks = []
self.listen_channels()

async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
Expand All @@ -69,6 +76,7 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
self.connect_channels()
await self._wait_for_ready(startup_timeout)
self.listen_channels()
self.tasks.append(asyncio.create_task(self._handle_comms()))

def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
connection_cfg = connection_cfg or self.connection_cfg
Expand All @@ -77,40 +85,43 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
self.iopub_channel = connect_channel("iopub", connection_cfg)

def listen_channels(self):
self.channel_tasks.append(asyncio.create_task(self.listen_iopub()))
self.channel_tasks.append(asyncio.create_task(self.listen_shell()))
self.tasks.append(asyncio.create_task(self.listen_iopub()))
self.tasks.append(asyncio.create_task(self.listen_shell()))

async def stop(self) -> None:
self.kernel_process.kill()
await self.kernel_process.wait()
os.remove(self.connection_file_path)
for task in self.channel_tasks:
for task in self.tasks:
task.cancel()

async def listen_iopub(self):
while True:
msg = await receive_message(self.iopub_channel, change_str_to_date=True)
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)
parent_id = msg["parent_header"].get("msg_id")
if msg["msg_type"] in ("comm_open", "comm_msg"):
self.comm_messages.put_nowait(msg)
elif parent_id in self.execute_requests.keys():
self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg)

async def listen_shell(self):
while True:
msg = await receive_message(self.shell_channel, change_str_to_date=True)
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
self.execute_requests[msg_id]["shell_msg"].put_nowait(msg)

async def execute(
self,
cell: Dict[str, Any],
ycell: Map,
timeout: float = float("inf"),
msg_id: str = "",
wait_for_executed: bool = True,
) -> None:
if cell["cell_type"] != "code":
if ycell["cell_type"] != "code":
return
content = {"code": cell["source"], "silent": False}
ycell["execution_state"] = "busy"
content = {"code": str(ycell["source"]), "silent": False}
msg = create_message(
"execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
)
Expand All @@ -120,40 +131,68 @@ async def execute(
msg_id = msg["header"]["msg_id"]
self.msg_cnt += 1
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Queue(),
"shell_msg": asyncio.Queue(),
}
if wait_for_executed:
deadline = time.time() + timeout
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Future(),
"shell_msg": asyncio.Future(),
}
while True:
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"],
msg = await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"].get(),
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["iopub_msg"].result()
self._handle_outputs(cell["outputs"], msg)
await self._handle_outputs(ycell["outputs"], msg)
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
(msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle")
):
break
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"],
msg = await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"].get(),
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["shell_msg"].result()
cell["execution_count"] = msg["content"]["execution_count"]
with ycell.doc.transaction():
ycell["execution_count"] = msg["content"]["execution_count"]
ycell["execution_state"] = "idle"
del self.execute_requests[msg_id]
else:
self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell)))

async def _handle_iopub(self, msg_id: str, ycell: Map) -> None:
while True:
msg = await self.execute_requests[msg_id]["iopub_msg"].get()
await self._handle_outputs(ycell["outputs"], msg)
if (
(msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle")
):
msg = await self.execute_requests[msg_id]["shell_msg"].get()
with ycell.doc.transaction():
ycell["execution_count"] = msg["content"]["execution_count"]
ycell["execution_state"] = "idle"

async def _handle_comms(self) -> None:
if self.yjs is None:
return

while True:
msg = await self.comm_messages.get()
msg_type = msg["header"]["msg_type"]
if msg_type == "comm_open":
comm_id = msg["content"]["comm_id"]
comm = Comm(comm_id, self.shell_channel, self.session_id, self.key)
self.yjs.widgets.comm_open(msg, comm) # type: ignore
elif msg_type == "comm_msg":
self.yjs.widgets.comm_msg(msg) # type: ignore

async def _wait_for_ready(self, timeout):
deadline = time.time() + timeout
Expand All @@ -178,22 +217,51 @@ async def _wait_for_ready(self, timeout):
break
new_timeout = deadline_to_timeout(deadline)

def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]):
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "stream":
if (not outputs) or (outputs[-1]["name"] != content["name"]):
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
outputs[-1]["text"].append(content["text"])
with outputs.doc.transaction():
# TODO: uncomment when changes are made in jupyter-ydoc
if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore
outputs.append(
#Map(
# {
# "name": content["name"],
# "output_type": msg_type,
# "text": Array([content["text"]]),
# }
#)
{
"name": content["name"],
"output_type": msg_type,
"text": [content["text"]],
}
)
else:
#outputs[-1]["text"].append(content["text"]) # type: ignore
last_output = outputs[-1]
last_output["text"].append(content["text"]) # type: ignore
outputs[-1] = last_output
elif msg_type in ("display_data", "execute_result"):
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
if "application/vnd.jupyter.ywidget-view+json" in content["data"]:
# this is a collaborative widget
model_id = content["data"]["application/vnd.jupyter.ywidget-view+json"]["model_id"]
if self.yjs is not None:
if model_id in self.yjs.widgets.widgets: # type: ignore
doc = self.yjs.widgets.widgets[model_id]["model"].ydoc # type: ignore
path = f"ywidget:{doc.guid}"
await self.yjs.room_manager.websocket_server.get_room(path, ydoc=doc) # type: ignore
outputs.append(doc)
else:
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
elif msg_type == "error":
outputs.append(
{
Expand All @@ -203,5 +271,25 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
"traceback": content["traceback"],
}
)
else:
return


class Comm:
def __init__(self, comm_id: str, shell_channel, session_id: str, key: str):
self.comm_id = comm_id
self.shell_channel = shell_channel
self.session_id = session_id
self.key = key
self.msg_cnt = 0

def send(self, buffers):
msg = create_message(
"comm_msg",
content={"comm_id": self.comm_id},
session_id=self.session_id,
msg_id=self.msg_cnt,
buffers=buffers,
)
self.msg_cnt += 1
asyncio.create_task(
send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
)
3 changes: 2 additions & 1 deletion plugins/kernels/fps_kernels/kernel_driver/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def create_message(
content: Dict = {},
session_id: str = "",
msg_id: str = "",
buffers: List = [],
) -> Dict[str, Any]:
header = create_message_header(msg_type, session_id, msg_id)
msg = {
Expand All @@ -65,7 +66,7 @@ def create_message(
"parent_header": {},
"content": content,
"metadata": {},
"buffers": [],
"buffers": buffers,
}
return msg

Expand Down
12 changes: 8 additions & 4 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,25 @@ async def execute_cell(
execution = Execution(**r)
if kernel_id in kernels:
ynotebook = self.yjs.get_document(execution.document_id)
cell = ynotebook.get_cell(execution.cell_idx)
cell["outputs"] = []
ycells = [ycell for ycell in ynotebook.ycells if ycell["id"] == execution.cell_id]
if not ycells:
return # FIXME

ycell = ycells[0]
del ycell["outputs"][:]

kernel = kernels[kernel_id]
if not kernel["driver"]:
kernel["driver"] = driver = KernelDriver(
kernelspec_path=Path(find_kernelspec(kernel["name"])).as_posix(),
write_connection_file=False,
connection_file=kernel["server"].connection_file_path,
yjs=self.yjs,
)
await driver.connect()
driver = kernel["driver"]

await driver.execute(cell)
ynotebook.set_cell(execution.cell_idx, cell)
await driver.execute(ycell, wait_for_executed=False)

async def get_kernel(
self,
Expand Down
Loading
Loading