Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add POST execute cell #191

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
pip install ./plugins/lab
pip install ./plugins/jupyterlab
pip install "jupyter_ydoc >=0.1.16,<0.2.0" # FIXME: remove with next JupyterLab release
pip install "y-py >=0.5.4"

pip install mypy pytest pytest-asyncio requests ipykernel

Expand Down
Empty file.
102 changes: 102 additions & 0 deletions plugins/kernels/fps_kernels/kernel_driver/connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import asyncio
import json
import os
import socket
import tempfile
import uuid
from typing import Dict, Tuple, Union

import zmq
import zmq.asyncio
from zmq.asyncio import Socket

channel_socket_types = {
"hb": zmq.REQ,
"shell": zmq.DEALER,
"iopub": zmq.SUB,
"stdin": zmq.DEALER,
"control": zmq.DEALER,
}

context = zmq.asyncio.Context()

cfg_t = Dict[str, Union[str, int]]


def get_port(ip: str) -> int:
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
sock.bind((ip, 0))
port = sock.getsockname()[1]
sock.close()
return port


def write_connection_file(
fname: str = "",
ip: str = "",
transport: str = "tcp",
signature_scheme: str = "hmac-sha256",
kernel_name: str = "",
) -> Tuple[str, cfg_t]:
ip = ip or "127.0.0.1"

if not fname:
fd, fname = tempfile.mkstemp(suffix=".json")
os.close(fd)
f = open(fname, "wt")

channels = ["shell", "iopub", "stdin", "control", "hb"]

cfg: cfg_t = {f"{c}_port": get_port(ip) for c in channels}

cfg["ip"] = ip
cfg["key"] = uuid.uuid4().hex
cfg["transport"] = transport
cfg["signature_scheme"] = signature_scheme
cfg["kernel_name"] = kernel_name

f.write(json.dumps(cfg, indent=2))
f.close()

return fname, cfg


def read_connection_file(fname: str = "") -> cfg_t:
with open(fname, "rt") as f:
cfg: cfg_t = json.load(f)

return cfg


async def launch_kernel(
kernelspec_path: str, connection_file_path: str, capture_output: bool
) -> asyncio.subprocess.Process:
with open(kernelspec_path) as f:
kernelspec = json.load(f)
cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]]
if capture_output:
p = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT
)
else:
p = await asyncio.create_subprocess_exec(*cmd)
return p


def create_socket(channel: str, cfg: cfg_t) -> Socket:
ip = cfg["ip"]
port = cfg[f"{channel}_port"]
url = f"tcp://{ip}:{port}"
socket_type = channel_socket_types[channel]
sock = context.socket(socket_type)
sock.linger = 1000 # set linger to 1s to prevent hangs at exit
sock.connect(url)
return sock


def connect_channel(channel_name: str, cfg: cfg_t) -> Socket:
sock = create_socket(channel_name, cfg)
if channel_name == "iopub":
sock.setsockopt(zmq.SUBSCRIBE, b"")
return sock
220 changes: 220 additions & 0 deletions plugins/kernels/fps_kernels/kernel_driver/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import asyncio
import os
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast

from zmq.asyncio import Socket

from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
from .connect import write_connection_file as _write_connection_file
from .kernelspec import find_kernelspec
from .message import create_message, deserialize, serialize

DELIM = b"<IDS|MSG>"


def deadline_to_timeout(deadline: float) -> float:
return max(0, deadline - time.time())


def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]:
idx = msg_list.index(DELIM)
return msg_list[:idx], msg_list[idx + 1 :] # noqa


def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None:
to_send = serialize(msg, key)
sock.send_multipart(to_send, copy=True)


async def receive_message(sock: Socket, timeout: float = float("inf")) -> Optional[Dict[str, Any]]:
timeout *= 1000 # in ms
ready = await sock.poll(timeout)
if ready:
msg_list = await sock.recv_multipart()
idents, msg_list = feed_identities(msg_list)
return deserialize(msg_list)
return None


class KernelDriver:
def __init__(
self,
kernel_name: str = "",
kernelspec_path: str = "",
connection_file: str = "",
write_connection_file: bool = True,
capture_kernel_output: bool = True,
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
if write_connection_file:
self.connection_file_path, self.connection_cfg = _write_connection_file(connection_file)
else:
self.connection_file_path = connection_file
self.connection_cfg = read_connection_file(connection_file)
self.key = cast(str, self.connection_cfg["key"])
self.session_id = uuid.uuid4().hex
self.msg_cnt = 0
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
self.channel_tasks: List[asyncio.Task] = []

async def restart(self, startup_timeout: float = float("inf")) -> None:
for task in self.channel_tasks:
task.cancel()
msg = create_message("shutdown_request", content={"restart": True})
send_message(msg, self.control_channel, self.key)
while True:
msg = cast(Dict[str, Any], await receive_message(self.control_channel))
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
break
await self._wait_for_ready(startup_timeout)
self.channel_tasks = []
self.listen_channels()

async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
self.kernel_process = await launch_kernel(
self.kernelspec_path, self.connection_file_path, self.capture_kernel_output
)
if connect:
await self.connect(startup_timeout)

async def connect(self, startup_timeout: float = float("inf")) -> None:
self.connect_channels()
await self._wait_for_ready(startup_timeout)
self.listen_channels()

def connect_channels(self, connection_cfg: cfg_t = None):
connection_cfg = connection_cfg or self.connection_cfg
self.shell_channel = connect_channel("shell", connection_cfg)
self.control_channel = connect_channel("control", connection_cfg)
self.iopub_channel = connect_channel("iopub", connection_cfg)

def listen_channels(self):
self.channel_tasks.append(asyncio.create_task(self.listen_iopub()))
self.channel_tasks.append(asyncio.create_task(self.listen_shell()))

async def stop(self) -> None:
self.kernel_process.kill()
await self.kernel_process.wait()
os.remove(self.connection_file_path)
for task in self.channel_tasks:
task.cancel()

async def listen_iopub(self):
while True:
msg = await receive_message(self.iopub_channel) # type: ignore
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)

async def listen_shell(self):
while True:
msg = await receive_message(self.shell_channel) # type: ignore
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["shell_msg"].set_result(msg)

async def execute(
self,
cell: Dict[str, Any],
timeout: float = float("inf"),
msg_id: str = "",
wait_for_executed: bool = True,
) -> None:
if cell["cell_type"] != "code":
return
content = {"code": cell["source"], "silent": False}
msg = create_message(
"execute_request", content, session_id=self.session_id, msg_cnt=self.msg_cnt
)
if msg_id:
msg["header"]["msg_id"] = msg_id
else:
msg_id = msg["header"]["msg_id"]
self.msg_cnt += 1
send_message(msg, self.shell_channel, self.key)
if wait_for_executed:
deadline = time.time() + timeout
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Future(),
"shell_msg": asyncio.Future(),
}
while True:
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"],
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["iopub_msg"].result()
self._handle_outputs(cell["outputs"], msg)
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
):
break
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"],
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["shell_msg"].result()
cell["execution_count"] = msg["content"]["execution_count"]
del self.execute_requests[msg_id]

async def _wait_for_ready(self, timeout):
deadline = time.time() + timeout
new_timeout = timeout
while True:
msg = create_message(
"kernel_info_request", session_id=self.session_id, msg_cnt=self.msg_cnt
)
self.msg_cnt += 1
send_message(msg, self.shell_channel, self.key)
msg = await receive_message(self.shell_channel, new_timeout)
if msg is None:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
if msg["msg_type"] == "kernel_info_reply":
msg = await receive_message(self.iopub_channel, 0.2)
if msg is not None:
break
new_timeout = deadline_to_timeout(deadline)

def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "stream":
if (not outputs) or (outputs[-1]["name"] != content["name"]):
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
outputs[-1]["text"].append(content["text"])
elif msg_type in ("display_data", "execute_result"):
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
elif msg_type == "error":
outputs.append(
{
"ename": content["ename"],
"evalue": content["evalue"],
"output_type": "error",
"traceback": content["traceback"],
}
)
else:
return
Loading