Skip to content

Commit

Permalink
Protocol alignment (#657)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Feb 4, 2022
1 parent 3a5f4b1 commit cbc54fa
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 56 deletions.
56 changes: 51 additions & 5 deletions jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,38 @@ def deserialize_binary_message(bmsg):
return msg


def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
if pack:
msg_list = [
pack(msg_or_list["header"]),
pack(msg_or_list["parent_header"]),
pack(msg_or_list["metadata"]),
pack(msg_or_list["content"]),
]
else:
msg_list = msg_or_list
channel = channel.encode("utf-8")
offsets = []
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
offsets.append(len(channel) + offsets[-1])
for msg in msg_list:
offsets.append(len(msg) + offsets[-1])
offset_number = len(offsets).to_bytes(8, byteorder="little")
offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list)
return bin_msg


def deserialize_msg_from_ws_v1(ws_msg):
offset_number = int.from_bytes(ws_msg[:8], "little")
offsets = [
int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
]
channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
return channel, msg_list


# ping interval for keeping websockets alive (30 seconds)
WS_PING_INTERVAL = 30000

Expand Down Expand Up @@ -239,6 +271,16 @@ def _reserialize_reply(self, msg_or_list, channel=None):
smsg = json.dumps(msg, default=json_default)
return cast_unicode(smsg)

def select_subprotocol(self, subprotocols):
preferred_protocol = self.settings.get("kernel_ws_protocol")
if preferred_protocol is None:
preferred_protocol = "v1.kernel.websocket.jupyter.org"
elif preferred_protocol == "":
preferred_protocol = None
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
# None is the default, "legacy" protocol
return selected_subprotocol

def _on_zmq_reply(self, stream, msg_list):
# Sometimes this gets triggered when the on_close method is scheduled in the
# eventloop but hasn't been called.
Expand All @@ -247,12 +289,16 @@ def _on_zmq_reply(self, stream, msg_list):
self.close()
return
channel = getattr(stream, "channel", None)
try:
msg = self._reserialize_reply(msg_list, channel=channel)
except Exception:
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(msg_list, channel)
self.write_message(bin_msg, binary=True)
else:
self.write_message(msg, binary=isinstance(msg, bytes))
try:
msg = self._reserialize_reply(msg_list, channel=channel)
except Exception:
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
else:
self.write_message(msg, binary=isinstance(msg, bytes))


class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):
Expand Down
26 changes: 26 additions & 0 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ def init_settings(
"no_cache_paths": [url_path_join(base_url, "static", "custom")],
},
version_hash=version_hash,
# kernel message protocol over websoclet
kernel_ws_protocol=jupyter_app.kernel_ws_protocol,
# rate limits
limit_rate=jupyter_app.limit_rate,
iopub_msg_rate_limit=jupyter_app.iopub_msg_rate_limit,
iopub_data_rate_limit=jupyter_app.iopub_data_rate_limit,
rate_limit_window=jupyter_app.rate_limit_window,
Expand Down Expand Up @@ -1612,6 +1615,29 @@ def _update_server_extensions(self, change):
help=_i18n("Reraise exceptions encountered loading server extensions?"),
)

kernel_ws_protocol = Unicode(
None,
allow_none=True,
config=True,
help=_i18n(
"Preferred kernel message protocol over websocket to use (default: None). "
"If an empty string is passed, select the legacy protocol. If None, "
"the selected protocol will depend on what the front-end supports "
"(usually the most recent protocol supported by the back-end and the "
"front-end)."
),
)

limit_rate = Bool(
True,
config=True,
help=_i18n(
"Whether to limit the rate of IOPub messages (default: True). "
"If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window "
"to tune the rate."
),
)

iopub_msg_rate_limit = Float(
1000,
config=True,
Expand Down
178 changes: 127 additions & 51 deletions jupyter_server/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

from ...base.handlers import APIHandler
from ...base.zmqhandlers import AuthenticatedZMQStreamHandler
from ...base.zmqhandlers import deserialize_binary_message
from ...base.zmqhandlers import (
deserialize_binary_message,
serialize_msg_to_ws_v1,
deserialize_msg_from_ws_v1,
)
from jupyter_server.utils import ensure_async
from jupyter_server.utils import url_escape
from jupyter_server.utils import url_path_join
Expand Down Expand Up @@ -105,6 +109,10 @@ def kernel_info_timeout(self):
km_default = self.kernel_manager.kernel_info_timeout
return self.settings.get("kernel_info_timeout", km_default)

@property
def limit_rate(self):
return self.settings.get("limit_rate", True)

@property
def iopub_msg_rate_limit(self):
return self.settings.get("iopub_msg_rate_limit", 0)
Expand Down Expand Up @@ -452,64 +460,112 @@ def subscribe(value):

return connected

def on_message(self, msg):
def on_message(self, ws_msg):
if not self.channels:
# already closed, ignore the message
self.log.debug("Received message on closed websocket %r", msg)
self.log.debug("Received message on closed websocket %r", ws_msg)
return
if isinstance(msg, bytes):
msg = deserialize_binary_message(msg)

if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
channel, msg_list = deserialize_msg_from_ws_v1(ws_msg)
msg = {
"header": None,
}
else:
msg = json.loads(msg)
channel = msg.pop("channel", None)
if isinstance(ws_msg, bytes):
msg = deserialize_binary_message(ws_msg)
else:
msg = json.loads(ws_msg)
msg_list = []
channel = msg.pop("channel", None)

if channel is None:
self.log.warning("No channel specified, assuming shell: %s", msg)
channel = "shell"
if channel not in self.channels:
self.log.warning("No such channel: %r", channel)
return
am = self.kernel_manager.allowed_message_types
mt = msg["header"]["msg_type"]
if am and mt not in am:
self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt)
else:
ignore_msg = False
if am:
msg["header"] = self.get_part("header", msg["header"], msg_list)
if msg["header"]["msg_type"] not in am:
self.log.warning(
'Received message of type "%s", which is not allowed. Ignoring.'
% msg["header"]["msg_type"]
)
ignore_msg = True
if not ignore_msg:
stream = self.channels[channel]
self.session.send(stream, msg)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
self.session.send_raw(stream, msg_list)
else:
self.session.send(stream, msg)

def get_part(self, field, value, msg_list):
if value is None:
field2idx = {
"header": 0,
"parent_header": 1,
"content": 3,
}
value = self.session.unpack(msg_list[field2idx[field]])
return value

def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)

parent = msg["parent_header"]

def write_stderr(error_message):
self.log.warning(error_message)
msg = self.session.msg(
"stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent
)
msg["channel"] = "iopub"
self.write_message(json.dumps(msg, default=json_default))
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
msg = {"header": None, "parent_header": None, "content": None}
else:
msg = self.session.deserialize(fed_msg_list)

channel = getattr(stream, "channel", None)
msg_type = msg["header"]["msg_type"]
parts = fed_msg_list[1:]

if channel == "iopub" and msg_type == "error":
self._on_error(msg)
self._on_error(channel, msg, parts)

if (
channel == "iopub"
and msg_type == "status"
and msg["content"].get("execution_state") == "idle"
):
# reset rate limit counter on status=idle,
# to avoid 'Run All' hitting limits prematurely.
self._iopub_window_byte_queue = []
self._iopub_window_msg_count = 0
self._iopub_window_byte_count = 0
self._iopub_msgs_exceeded = False
self._iopub_data_exceeded = False

if channel == "iopub" and msg_type not in {"status", "comm_open", "execute_input"}:
if self._limit_rate(channel, msg, parts):
return

if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, parts)
else:
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)

def write_stderr(self, error_message, parent_header):
self.log.warning(error_message)
err_msg = self.session.msg(
"stream",
content={"text": error_message + "\n", "name": "stderr"},
parent=parent_header,
)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(err_msg, "iopub", self.session.pack)
self.write_message(bin_msg, binary=True)
else:
err_msg["channel"] = "iopub"
self.write_message(json.dumps(err_msg, default=json_default))

def _limit_rate(self, channel, msg, msg_list):
if not (self.limit_rate and channel == "iopub"):
return False

msg["header"] = self.get_part("header", msg["header"], msg_list)

msg_type = msg["header"]["msg_type"]
if msg_type == "status":
msg["content"] = self.get_part("content", msg["content"], msg_list)
if msg["content"].get("execution_state") == "idle":
# reset rate limit counter on status=idle,
# to avoid 'Run All' hitting limits prematurely.
self._iopub_window_byte_queue = []
self._iopub_window_msg_count = 0
self._iopub_window_byte_count = 0
self._iopub_msgs_exceeded = False
self._iopub_data_exceeded = False

if msg_type not in {"status", "comm_open", "execute_input"}:

# Remove the counts queued for removal.
now = IOLoop.current().time()
Expand Down Expand Up @@ -545,7 +601,10 @@ def write_stderr(error_message):
if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit:
if not self._iopub_msgs_exceeded:
self._iopub_msgs_exceeded = True
write_stderr(
msg["parent_header"] = self.get_part(
"parent_header", msg["parent_header"], msg_list
)
self.write_stderr(
dedent(
"""\
IOPub message rate exceeded.
Expand All @@ -560,7 +619,8 @@ def write_stderr(error_message):
""".format(
self.iopub_msg_rate_limit, self.rate_limit_window
)
)
),
msg["parent_header"],
)
else:
# resume once we've got some headroom below the limit
Expand All @@ -573,7 +633,10 @@ def write_stderr(error_message):
if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit:
if not self._iopub_data_exceeded:
self._iopub_data_exceeded = True
write_stderr(
msg["parent_header"] = self.get_part(
"parent_header", msg["parent_header"], msg_list
)
self.write_stderr(
dedent(
"""\
IOPub data rate exceeded.
Expand All @@ -588,7 +651,8 @@ def write_stderr(error_message):
""".format(
self.iopub_data_rate_limit, self.rate_limit_window
)
)
),
msg["parent_header"],
)
else:
# resume once we've got some headroom below the limit
Expand All @@ -603,8 +667,9 @@ def write_stderr(error_message):
self._iopub_window_msg_count -= 1
self._iopub_window_byte_count -= byte_count
self._iopub_window_byte_queue.pop(-1)
return
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)
return True

return False

def close(self):
super(ZMQChannelsHandler, self).close()
Expand Down Expand Up @@ -654,8 +719,12 @@ def _send_status_message(self, status):
# that all messages from the stopped kernel have been delivered
iopub.flush()
msg = self.session.msg("status", {"execution_state": status})
msg["channel"] = "iopub"
self.write_message(json.dumps(msg, default=json_default))
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(msg, "iopub", self.session.pack)
self.write_message(bin_msg, binary=True)
else:
msg["channel"] = "iopub"
self.write_message(json.dumps(msg, default=json_default))

def on_kernel_restarted(self):
self.log.warning("kernel %s restarted", self.kernel_id)
Expand All @@ -665,12 +734,19 @@ def on_restart_failed(self):
self.log.error("kernel %s restarted failed!", self.kernel_id)
self._send_status_message("dead")

def _on_error(self, msg):
def _on_error(self, channel, msg, msg_list):
if self.kernel_manager.allow_tracebacks:
return
msg["content"]["ename"] = "ExecutionError"
msg["content"]["evalue"] = "Execution error"
msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message]

if channel == "iopub":
msg["header"] = self.get_part("header", msg["header"], msg_list)
if msg["header"]["msg_type"] == "error":
msg["content"] = self.get_part("content", msg["content"], msg_list)
msg["content"]["ename"] = "ExecutionError"
msg["content"]["evalue"] = "Execution error"
msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message]
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
msg_list[3] = self.session.pack(msg["content"])


# -----------------------------------------------------------------------------
Expand Down

0 comments on commit cbc54fa

Please sign in to comment.