From 96d70dfd8628237f1e4cb08290adeee7fb1c2cf8 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 6 Jul 2022 22:52:15 +0200 Subject: [PATCH] Add POST execute cell --- .github/workflows/main.yml | 1 + .../fps_kernels/kernel_driver/__init__.py | 0 .../fps_kernels/kernel_driver/connect.py | 102 ++++++++ .../fps_kernels/kernel_driver/driver.py | 220 ++++++++++++++++++ .../fps_kernels/kernel_driver/kernelspec.py | 68 ++++++ .../fps_kernels/kernel_driver/message.py | 103 ++++++++ .../fps_kernels/kernel_driver/paths.py | 114 +++++++++ plugins/kernels/fps_kernels/models.py | 5 + plugins/kernels/fps_kernels/routes.py | 36 ++- plugins/kernels/setup.cfg | 2 + setup.cfg | 1 + tests/conftest.py | 10 +- tests/data/notebook0.ipynb | 49 ++++ tests/test_contents.py | 2 + tests/test_server.py | 67 ++++++ 15 files changed, 777 insertions(+), 3 deletions(-) create mode 100644 plugins/kernels/fps_kernels/kernel_driver/__init__.py create mode 100644 plugins/kernels/fps_kernels/kernel_driver/connect.py create mode 100644 plugins/kernels/fps_kernels/kernel_driver/driver.py create mode 100644 plugins/kernels/fps_kernels/kernel_driver/kernelspec.py create mode 100644 plugins/kernels/fps_kernels/kernel_driver/message.py create mode 100644 plugins/kernels/fps_kernels/kernel_driver/paths.py create mode 100644 tests/data/notebook0.ipynb diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a6e3eea3..3690da24 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -46,6 +46,7 @@ jobs: pip install ./plugins/lab pip install ./plugins/jupyterlab pip install "jupyter_ydoc >=0.1.16,<0.2.0" # FIXME: remove with next JupyterLab release + pip install "y-py >=0.5.4" pip install mypy pytest pytest-asyncio requests ipykernel diff --git a/plugins/kernels/fps_kernels/kernel_driver/__init__.py b/plugins/kernels/fps_kernels/kernel_driver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/kernels/fps_kernels/kernel_driver/connect.py b/plugins/kernels/fps_kernels/kernel_driver/connect.py new file mode 100644 index 00000000..b137bbff --- /dev/null +++ b/plugins/kernels/fps_kernels/kernel_driver/connect.py @@ -0,0 +1,102 @@ +import asyncio +import json +import os +import socket +import tempfile +import uuid +from typing import Dict, Tuple, Union + +import zmq +import zmq.asyncio +from zmq.asyncio import Socket + +channel_socket_types = { + "hb": zmq.REQ, + "shell": zmq.DEALER, + "iopub": zmq.SUB, + "stdin": zmq.DEALER, + "control": zmq.DEALER, +} + +context = zmq.asyncio.Context() + +cfg_t = Dict[str, Union[str, int]] + + +def get_port(ip: str) -> int: + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8) + sock.bind((ip, 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def write_connection_file( + fname: str = "", + ip: str = "", + transport: str = "tcp", + signature_scheme: str = "hmac-sha256", + kernel_name: str = "", +) -> Tuple[str, cfg_t]: + ip = ip or "127.0.0.1" + + if not fname: + fd, fname = tempfile.mkstemp(suffix=".json") + os.close(fd) + f = open(fname, "wt") + + channels = ["shell", "iopub", "stdin", "control", "hb"] + + cfg: cfg_t = {f"{c}_port": get_port(ip) for c in channels} + + cfg["ip"] = ip + cfg["key"] = uuid.uuid4().hex + cfg["transport"] = transport + cfg["signature_scheme"] = signature_scheme + cfg["kernel_name"] = kernel_name + + f.write(json.dumps(cfg, indent=2)) + f.close() + + return fname, cfg + + +def read_connection_file(fname: str = "") -> cfg_t: + with open(fname, "rt") as f: + cfg: cfg_t = json.load(f) + + return cfg + + +async def launch_kernel( + kernelspec_path: str, connection_file_path: str, capture_output: bool +) -> asyncio.subprocess.Process: + with open(kernelspec_path) as f: + kernelspec = json.load(f) + cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]] + if capture_output: + p = await asyncio.create_subprocess_exec( + *cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT + ) + else: + p = await asyncio.create_subprocess_exec(*cmd) + return p + + +def create_socket(channel: str, cfg: cfg_t) -> Socket: + ip = cfg["ip"] + port = cfg[f"{channel}_port"] + url = f"tcp://{ip}:{port}" + socket_type = channel_socket_types[channel] + sock = context.socket(socket_type) + sock.linger = 1000 # set linger to 1s to prevent hangs at exit + sock.connect(url) + return sock + + +def connect_channel(channel_name: str, cfg: cfg_t) -> Socket: + sock = create_socket(channel_name, cfg) + if channel_name == "iopub": + sock.setsockopt(zmq.SUBSCRIBE, b"") + return sock diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py new file mode 100644 index 00000000..47422c08 --- /dev/null +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -0,0 +1,220 @@ +import asyncio +import os +import time +import uuid +from typing import Any, Dict, List, Optional, Tuple, cast + +from zmq.asyncio import Socket + +from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file +from .connect import write_connection_file as _write_connection_file +from .kernelspec import find_kernelspec +from .message import create_message, deserialize, serialize + +DELIM = b"" + + +def deadline_to_timeout(deadline: float) -> float: + return max(0, deadline - time.time()) + + +def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]: + idx = msg_list.index(DELIM) + return msg_list[:idx], msg_list[idx + 1 :] # noqa + + +def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None: + to_send = serialize(msg, key) + sock.send_multipart(to_send, copy=True) + + +async def receive_message(sock: Socket, timeout: float = float("inf")) -> Optional[Dict[str, Any]]: + timeout *= 1000 # in ms + ready = await sock.poll(timeout) + if ready: + msg_list = await sock.recv_multipart() + idents, msg_list = feed_identities(msg_list) + return deserialize(msg_list) + return None + + +class KernelDriver: + def __init__( + self, + kernel_name: str = "", + kernelspec_path: str = "", + connection_file: str = "", + write_connection_file: bool = True, + capture_kernel_output: bool = True, + ) -> None: + self.capture_kernel_output = capture_kernel_output + self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name) + if not self.kernelspec_path: + raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") + if write_connection_file: + self.connection_file_path, self.connection_cfg = _write_connection_file(connection_file) + else: + self.connection_file_path = connection_file + self.connection_cfg = read_connection_file(connection_file) + self.key = cast(str, self.connection_cfg["key"]) + self.session_id = uuid.uuid4().hex + self.msg_cnt = 0 + self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {} + self.channel_tasks: List[asyncio.Task] = [] + + async def restart(self, startup_timeout: float = float("inf")) -> None: + for task in self.channel_tasks: + task.cancel() + msg = create_message("shutdown_request", content={"restart": True}) + send_message(msg, self.control_channel, self.key) + while True: + msg = cast(Dict[str, Any], await receive_message(self.control_channel)) + if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: + break + await self._wait_for_ready(startup_timeout) + self.channel_tasks = [] + self.listen_channels() + + async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None: + self.kernel_process = await launch_kernel( + self.kernelspec_path, self.connection_file_path, self.capture_kernel_output + ) + if connect: + await self.connect(startup_timeout) + + async def connect(self, startup_timeout: float = float("inf")) -> None: + self.connect_channels() + await self._wait_for_ready(startup_timeout) + self.listen_channels() + + def connect_channels(self, connection_cfg: cfg_t = None): + connection_cfg = connection_cfg or self.connection_cfg + self.shell_channel = connect_channel("shell", connection_cfg) + self.control_channel = connect_channel("control", connection_cfg) + self.iopub_channel = connect_channel("iopub", connection_cfg) + + def listen_channels(self): + self.channel_tasks.append(asyncio.create_task(self.listen_iopub())) + self.channel_tasks.append(asyncio.create_task(self.listen_shell())) + + async def stop(self) -> None: + self.kernel_process.kill() + await self.kernel_process.wait() + os.remove(self.connection_file_path) + for task in self.channel_tasks: + task.cancel() + + async def listen_iopub(self): + while True: + msg = await receive_message(self.iopub_channel) # type: ignore + msg_id = msg["parent_header"].get("msg_id") + if msg_id in self.execute_requests.keys(): + self.execute_requests[msg_id]["iopub_msg"].set_result(msg) + + async def listen_shell(self): + while True: + msg = await receive_message(self.shell_channel) # type: ignore + msg_id = msg["parent_header"].get("msg_id") + if msg_id in self.execute_requests.keys(): + self.execute_requests[msg_id]["shell_msg"].set_result(msg) + + async def execute( + self, + cell: Dict[str, Any], + timeout: float = float("inf"), + msg_id: str = "", + wait_for_executed: bool = True, + ) -> None: + if cell["cell_type"] != "code": + return + content = {"code": cell["source"], "silent": False} + msg = create_message( + "execute_request", content, session_id=self.session_id, msg_cnt=self.msg_cnt + ) + if msg_id: + msg["header"]["msg_id"] = msg_id + else: + msg_id = msg["header"]["msg_id"] + self.msg_cnt += 1 + send_message(msg, self.shell_channel, self.key) + if wait_for_executed: + deadline = time.time() + timeout + self.execute_requests[msg_id] = { + "iopub_msg": asyncio.Future(), + "shell_msg": asyncio.Future(), + } + while True: + try: + await asyncio.wait_for( + self.execute_requests[msg_id]["iopub_msg"], + deadline_to_timeout(deadline), + ) + except asyncio.TimeoutError: + error_message = f"Kernel didn't respond in {timeout} seconds" + raise RuntimeError(error_message) + msg = self.execute_requests[msg_id]["iopub_msg"].result() + self._handle_outputs(cell["outputs"], msg) + if ( + msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle" + ): + break + self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future() + try: + await asyncio.wait_for( + self.execute_requests[msg_id]["shell_msg"], + deadline_to_timeout(deadline), + ) + except asyncio.TimeoutError: + error_message = f"Kernel didn't respond in {timeout} seconds" + raise RuntimeError(error_message) + msg = self.execute_requests[msg_id]["shell_msg"].result() + cell["execution_count"] = msg["content"]["execution_count"] + del self.execute_requests[msg_id] + + async def _wait_for_ready(self, timeout): + deadline = time.time() + timeout + new_timeout = timeout + while True: + msg = create_message( + "kernel_info_request", session_id=self.session_id, msg_cnt=self.msg_cnt + ) + self.msg_cnt += 1 + send_message(msg, self.shell_channel, self.key) + msg = await receive_message(self.shell_channel, new_timeout) + if msg is None: + error_message = f"Kernel didn't respond in {timeout} seconds" + raise RuntimeError(error_message) + if msg["msg_type"] == "kernel_info_reply": + msg = await receive_message(self.iopub_channel, 0.2) + if msg is not None: + break + new_timeout = deadline_to_timeout(deadline) + + def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]): + msg_type = msg["header"]["msg_type"] + content = msg["content"] + if msg_type == "stream": + if (not outputs) or (outputs[-1]["name"] != content["name"]): + outputs.append({"name": content["name"], "output_type": msg_type, "text": []}) + outputs[-1]["text"].append(content["text"]) + elif msg_type in ("display_data", "execute_result"): + outputs.append( + { + "data": {"text/plain": [content["data"].get("text/plain", "")]}, + "execution_count": content["execution_count"], + "metadata": {}, + "output_type": msg_type, + } + ) + elif msg_type == "error": + outputs.append( + { + "ename": content["ename"], + "evalue": content["evalue"], + "output_type": "error", + "traceback": content["traceback"], + } + ) + else: + return diff --git a/plugins/kernels/fps_kernels/kernel_driver/kernelspec.py b/plugins/kernels/fps_kernels/kernel_driver/kernelspec.py new file mode 100644 index 00000000..a97ef7f7 --- /dev/null +++ b/plugins/kernels/fps_kernels/kernel_driver/kernelspec.py @@ -0,0 +1,68 @@ +import os +import sys + +from .paths import jupyter_data_dir + +if os.name == "nt": + programdata = os.environ.get("PROGRAMDATA", None) + if programdata: + SYSTEM_JUPYTER_PATH = [os.path.join(programdata, "jupyter")] + else: # PROGRAMDATA is not defined by default on XP + SYSTEM_JUPYTER_PATH = [os.path.join(sys.prefix, "share", "jupyter")] +else: + SYSTEM_JUPYTER_PATH = [ + "/usr/local/share/jupyter", + "/usr/share/jupyter", + ] + +ENV_JUPYTER_PATH = [os.path.join(sys.prefix, "share", "jupyter")] + + +def jupyter_path(*subdirs): + paths = [] + # highest priority is env + if os.environ.get("JUPYTER_PATH"): + paths.extend(p.rstrip(os.sep) for p in os.environ["JUPYTER_PATH"].split(os.pathsep)) + # then user dir + paths.append(jupyter_data_dir()) + # then sys.prefix + for p in ENV_JUPYTER_PATH: + if p not in SYSTEM_JUPYTER_PATH: + paths.append(p) + # finally, system + paths.extend(SYSTEM_JUPYTER_PATH) + + # add subdir, if requested + if subdirs: + paths = [os.path.join(p, *subdirs) for p in paths] + return paths + + +def kernelspec_dirs(): + return jupyter_path("kernels") + + +def _is_kernel_dir(path): + return os.path.isdir(path) and os.path.isfile(os.path.join(path, "kernel.json")) + + +def _list_kernels_in(kernel_dir): + if kernel_dir is None or not os.path.isdir(kernel_dir): + return {} + kernels = {} + for f in os.listdir(kernel_dir): + path = os.path.join(kernel_dir, f) + if _is_kernel_dir(path): + key = f.lower() + kernels[key] = path + return kernels + + +def find_kernelspec(kernel_name): + d = {} + for kernel_dir in kernelspec_dirs(): + kernels = _list_kernels_in(kernel_dir) + for kname, spec in kernels.items(): + if kname not in d: + d[kname] = os.path.join(spec, "kernel.json") + return d.get(kernel_name, "") diff --git a/plugins/kernels/fps_kernels/kernel_driver/message.py b/plugins/kernels/fps_kernels/kernel_driver/message.py new file mode 100644 index 00000000..e6231dfe --- /dev/null +++ b/plugins/kernels/fps_kernels/kernel_driver/message.py @@ -0,0 +1,103 @@ +import hashlib +import hmac +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, cast + +from dateutil.parser import parse as dateutil_parse # type: ignore +from zmq.utils import jsonapi + +protocol_version_info = (5, 3) +protocol_version = "%i.%i" % protocol_version_info + +DELIM = b"" + + +def str_to_date(obj: Dict[str, Any]) -> Dict[str, Any]: + if "date" in obj: + obj["date"] = dateutil_parse(obj["date"]) + return obj + + +def date_to_str(obj: Dict[str, Any]): + if "date" in obj and type(obj["date"]) is not str: + obj["date"] = obj["date"].isoformat().replace("+00:00", "Z") + return obj + + +def utcnow() -> datetime: + return datetime.utcnow().replace(tzinfo=timezone.utc) + + +def create_message_header(msg_type: str, session_id: str, msg_cnt: int) -> Dict[str, Any]: + if not session_id: + session_id = msg_id = uuid.uuid4().hex + else: + msg_id = f"{session_id}_{msg_cnt}" + header = { + "date": utcnow(), + "msg_id": msg_id, + "msg_type": msg_type, + "session": session_id, + "username": "david", + "version": protocol_version, + } + return header + + +def create_message( + msg_type: str, + content: Dict = {}, + session_id: str = "", + msg_cnt: int = 0, +) -> Dict[str, Any]: + header = create_message_header(msg_type, session_id, msg_cnt) + msg = { + "header": header, + "msg_id": header["msg_id"], + "msg_type": header["msg_type"], + "parent_header": {}, + "content": content, + "metadata": {}, + } + return msg + + +def pack(obj: Dict[str, Any]) -> bytes: + return jsonapi.dumps(obj) + + +def unpack(s: bytes) -> Dict[str, Any]: + return cast(Dict[str, Any], jsonapi.loads(s)) + + +def sign(msg_list: List[bytes], key: str) -> bytes: + auth = hmac.new(key.encode("ascii"), digestmod=hashlib.sha256) + h = auth.copy() + for m in msg_list: + h.update(m) + return h.hexdigest().encode() + + +def serialize(msg: Dict[str, Any], key: str) -> List[bytes]: + message = [ + pack(date_to_str(msg["header"])), + pack(date_to_str(msg["parent_header"])), + pack(date_to_str(msg["metadata"])), + pack(date_to_str(msg.get("content", {}))), + ] + to_send = [DELIM, sign(message, key)] + message + return to_send + + +def deserialize(msg_list: List[bytes]) -> Dict[str, Any]: + message: Dict[str, Any] = {} + header = unpack(msg_list[1]) + message["header"] = str_to_date(header) + message["msg_id"] = header["msg_id"] + message["msg_type"] = header["msg_type"] + message["parent_header"] = str_to_date(unpack(msg_list[2])) + message["metadata"] = unpack(msg_list[3]) + message["content"] = unpack(msg_list[4]) + message["buffers"] = [memoryview(b) for b in msg_list[5:]] + return message diff --git a/plugins/kernels/fps_kernels/kernel_driver/paths.py b/plugins/kernels/fps_kernels/kernel_driver/paths.py new file mode 100644 index 00000000..deae1832 --- /dev/null +++ b/plugins/kernels/fps_kernels/kernel_driver/paths.py @@ -0,0 +1,114 @@ +import glob +import os +import sys +import tempfile +import uuid +from typing import Dict, List + + +def _expand_path(s): + if os.name == "nt": + i = str(uuid.uuid4()) + s = s.replace("$\\", i) + s = os.path.expandvars(os.path.expanduser(s)) + if os.name == "nt": + s = s.replace(i, "$\\") + return s + + +def _filefind(filename, path_dirs=()): + filename = filename.strip('"').strip("'") + if os.path.isabs(filename) and os.path.isfile(filename): + return filename + + path_dirs = path_dirs or ("",) + + for path in path_dirs: + if path == ".": + path = os.getcwd() + testname = _expand_path(os.path.join(path, filename)) + if os.path.isfile(testname): + return os.path.abspath(testname) + + raise IOError(f"File {filename} does not exist in any of the search paths: {path_dirs}") + + +def get_home_dir(): + home = os.path.expanduser("~") + home = os.path.realpath(home) + return home + + +_dtemps: Dict = {} + + +def _mkdtemp_once(name): + if name in _dtemps: + return _dtemps[name] + d = _dtemps[name] = tempfile.mkdtemp(prefix=name + "-") + return d + + +def jupyter_config_dir(): + if os.environ.get("JUPYTER_NO_CONFIG"): + return _mkdtemp_once("jupyter-clean-cfg") + if "JUPYTER_CONFIG_DIR" in os.environ: + return os.environ.env["JUPYTER_CONFIG_DIR"] + home = get_home_dir() + return os.path.join(home, ".jupyter") + + +def jupyter_data_dir(): + if "JUPYTER_DATA_DIR" in os.environ: + return os.environ["JUPYTER_DATA_DIR"] + + home = get_home_dir() + + if sys.platform == "darwin": + return os.path.join(home, "Library", "Jupyter") + elif os.name == "nt": + appdata = os.environ.get("APPDATA", None) + if appdata: + return os.path.join(appdata, "jupyter") + else: + return os.path.join(jupyter_config_dir(), "data") + else: + xdg = os.environ.get("XDG_DATA_HOME", None) + if not xdg: + xdg = os.path.join(home, ".local", "share") + return os.path.join(xdg, "jupyter") + + +def jupyter_runtime_dir(): + if "JUPYTER_RUNTIME_DIR" in os.environ: + return os.environ("JUPYTER_RUNTIME_DIR") + return os.path.join(jupyter_data_dir(), "runtime") + + +def find_connection_file( + filename: str = "kernel-*.json", + paths: List[str] = [], +) -> str: + if not paths: + paths = [".", jupyter_runtime_dir()] + + path = _filefind(filename, paths) + if path: + return path + + if "*" in filename: + pat = filename + else: + pat = f"*{filename}*" + + matches = [] + for p in paths: + matches.extend(glob.glob(os.path.join(p, pat))) + + matches = [os.path.abspath(m) for m in matches] + if not matches: + raise IOError(f"Could not find {filename} in {paths}") + elif len(matches) == 1: + return matches[0] + else: + return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1] diff --git a/plugins/kernels/fps_kernels/models.py b/plugins/kernels/fps_kernels/models.py index ecbb21c0..55cd3f4c 100644 --- a/plugins/kernels/fps_kernels/models.py +++ b/plugins/kernels/fps_kernels/models.py @@ -32,3 +32,8 @@ class Session(BaseModel): type: str kernel: Kernel notebook: Notebook + + +class Execution(BaseModel): + document_id: str + cell_idx: int diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index 1c43ca77..58c1f952 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -16,14 +16,16 @@ from fps_auth.config import get_auth_config # type: ignore from fps_auth.models import UserRead # type: ignore from fps_lab.config import get_lab_config # type: ignore +from fps_yjs.routes import YDocWebSocketHandler # type: ignore from starlette.requests import Request # type: ignore +from .kernel_driver.driver import KernelDriver # type: ignore from .kernel_server.server import ( # type: ignore AcceptedWebSocket, KernelServer, kernels, ) -from .models import Session +from .models import Execution, Session router = APIRouter() @@ -132,7 +134,7 @@ async def create_session( ), ) kernel_id = str(uuid.uuid4()) - kernels[kernel_id] = {"name": kernel_name, "server": kernel_server} + kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None} await kernel_server.start() session_id = str(uuid.uuid4()) session = { @@ -171,6 +173,36 @@ async def restart_kernel( return result +@router.post("/api/kernels/{kernel_id}/execute") +async def execute_cell( + request: Request, + kernel_id, + user: UserRead = Depends(current_user), +): + r = await request.json() + execution = Execution(**r) + if kernel_id in kernels: + ynotebook = YDocWebSocketHandler.websocket_server.get_room(execution.document_id).document + cell = ynotebook.get_cell(execution.cell_idx) + cell["outputs"] = [] + + kernel = kernels[kernel_id] + kernelspec_path = str( + prefix_dir / "share" / "jupyter" / "kernels" / kernel["name"] / "kernel.json" + ) + if not kernel["driver"]: + kernel["driver"] = driver = KernelDriver( + kernelspec_path=kernelspec_path, + write_connection_file=False, + connection_file=kernel["server"].connection_file_path, + ) + await driver.connect() + driver = kernel["driver"] + + await driver.execute(cell) + ynotebook.set_cell(execution.cell_idx, cell) + + @router.get("/api/kernels/{kernel_id}") async def get_kernel( kernel_id, diff --git a/plugins/kernels/setup.cfg b/plugins/kernels/setup.cfg index 14694899..1e6183a2 100644 --- a/plugins/kernels/setup.cfg +++ b/plugins/kernels/setup.cfg @@ -23,8 +23,10 @@ install_requires = fps >=0.0.8 fps-auth fps-lab + fps-yjs pyzmq websockets + python-dateutil [options.entry_points] fps_router = diff --git a/setup.cfg b/setup.cfg index cd736168..76e99bd3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ test = pytest pytest-asyncio requests + websockets ipykernel [options.entry_points] diff --git a/tests/conftest.py b/tests/conftest.py index 35078a83..cf61f0e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ +import os import socket import subprocess import time +from pathlib import Path from uuid import uuid4 import pytest @@ -11,6 +13,11 @@ ) +@pytest.fixture() +def cwd(): + return Path(__file__).parent.parent + + @pytest.fixture() def authenticated_client(client): # create a new user @@ -66,7 +73,8 @@ def get_open_port(): @pytest.fixture() -def start_jupyverse(auth_mode, clear_users, capfd): +def start_jupyverse(auth_mode, clear_users, cwd, capfd): + os.chdir(cwd) port = get_open_port() command_list = [ "jupyverse", diff --git a/tests/data/notebook0.ipynb b/tests/data/notebook0.ipynb new file mode 100644 index 00000000..6285ff33 --- /dev/null +++ b/tests/data/notebook0.ipynb @@ -0,0 +1,49 @@ +{ + "cells": [ + { + "source": "1 + 2", + "outputs": [], + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "id": "a7243792-6f06-4462-a6b5-7e9ec604348e" + }, + { + "source": "print(\"Hello World!\")", + "outputs": [], + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "id": "dc428ca1-d2f1-4dc1-8811-6be77103d683" + }, + { + "source": "3 + 4", + "outputs": [], + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "id": "a7243792-6f06-4462-a6b5-7e9ec604348f" + } + ], + "metadata": { + "kernelspec": { + "language": "python", + "display_name": "Python 3 (ipykernel)", + "name": "python3" + }, + "language_info": { + "pygments_lexer": "ipython3", + "codemirror_mode": { + "version": 3, + "name": "ipython" + }, + "nbconvert_exporter": "python", + "mimetype": "text/x-python", + "file_extension": ".py", + "name": "python", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_contents.py b/tests/test_contents.py index 3f540dae..26f0ae81 100644 --- a/tests/test_contents.py +++ b/tests/test_contents.py @@ -7,6 +7,7 @@ @pytest.mark.parametrize("auth_mode", ("noauth",)) def test_tree(client, tmp_path): + prev_dir = os.getcwd() os.chdir(tmp_path) dname = Path(".") expected = [] @@ -58,3 +59,4 @@ def test_tree(client, tmp_path): sort_content_by_name(actual) sort_content_by_name(expected) assert actual == expected + os.chdir(prev_dir) diff --git a/tests/test_server.py b/tests/test_server.py index 881cc116..03e277bb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,11 @@ +import asyncio import json import pytest import requests +import y_py as Y +from websockets import connect +from ypy_websocket import WebsocketProvider prev_theme = {} test_theme = {"raw": '{// jupyverse test\n"theme": "JupyterLab Dark"}'} @@ -38,3 +42,66 @@ def test_settings_persistence_get(start_jupyverse): data=json.dumps(prev_theme), ) assert response.status_code == 204 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_mode", ("noauth",)) +@pytest.mark.parametrize("clear_users", (False,)) +async def test_rest_api(start_jupyverse): + url = start_jupyverse + ws_url = url.replace("http", "ws", 1) + # create a session to launch a kernel + response = requests.post( + f"{url}/api/sessions", + data=json.dumps( + { + "kernel": {"name": "python3"}, + "name": "notebook0.ipynb", + "path": "69e8a762-86c6-4102-a3da-a43d735fec2b", + "type": "notebook", + } + ), + ) + r = response.json() + kernel_id = r["kernel"]["id"] + document_id = "json:notebook:tests/data/notebook0.ipynb" + async with connect(f"{ws_url}/api/yjs/{document_id}") as websocket: + # connect to the shared notebook document + ydoc = Y.YDoc() + WebsocketProvider(ydoc, websocket) + # wait for file to be loaded and Y model to be created in server and client + await asyncio.sleep(0.1) + # execute notebook + for cell_idx in range(3): + response = requests.post( + f"{url}/api/kernels/{kernel_id}/execute", + data=json.dumps( + { + "document_id": document_id, + "cell_idx": cell_idx, + } + ), + ) + # wait for Y model to be updated + await asyncio.sleep(0.1) + # retrieve cells + cells = ydoc.get_array("cells").to_json() + assert cells[0]["outputs"] == [ + { + "data": {"text/plain": ["3"]}, + "execution_count": 1.0, + "metadata": {}, + "output_type": "execute_result", + } + ] + assert cells[1]["outputs"] == [ + {"name": "stdout", "output_type": "stream", "text": ["Hello World!\n"]} + ] + assert cells[2]["outputs"] == [ + { + "data": {"text/plain": ["7"]}, + "execution_count": 3.0, + "metadata": {}, + "output_type": "execute_result", + } + ]