Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protocol alignment #657

Merged
merged 10 commits into from
Feb 4, 2022
56 changes: 51 additions & 5 deletions jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,38 @@ def deserialize_binary_message(bmsg):
return msg


def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
if pack:
msg_list = [
pack(msg_or_list["header"]),
pack(msg_or_list["parent_header"]),
pack(msg_or_list["metadata"]),
pack(msg_or_list["content"]),
]
else:
msg_list = msg_or_list
channel = channel.encode("utf-8")
offsets = []
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
offsets.append(len(channel) + offsets[-1])
for msg in msg_list:
offsets.append(len(msg) + offsets[-1])
offset_number = len(offsets).to_bytes(8, byteorder="little")
offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list)
return bin_msg


def deserialize_msg_from_ws_v1(ws_msg):
offset_number = int.from_bytes(ws_msg[:8], "little")
offsets = [
int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
]
channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
return channel, msg_list


# ping interval for keeping websockets alive (30 seconds)
WS_PING_INTERVAL = 30000

Expand Down Expand Up @@ -239,6 +271,16 @@ def _reserialize_reply(self, msg_or_list, channel=None):
smsg = json.dumps(msg, default=json_default)
return cast_unicode(smsg)

def select_subprotocol(self, subprotocols):
preferred_protocol = self.settings.get("kernel_ws_protocol")
if preferred_protocol is None:
preferred_protocol = "v1.kernel.websocket.jupyter.org"
elif preferred_protocol == "":
preferred_protocol = None
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
# None is the default, "legacy" protocol
return selected_subprotocol

def _on_zmq_reply(self, stream, msg_list):
# Sometimes this gets triggered when the on_close method is scheduled in the
# eventloop but hasn't been called.
Expand All @@ -247,12 +289,16 @@ def _on_zmq_reply(self, stream, msg_list):
self.close()
return
channel = getattr(stream, "channel", None)
try:
msg = self._reserialize_reply(msg_list, channel=channel)
except Exception:
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(msg_list, channel)
self.write_message(bin_msg, binary=True)
else:
self.write_message(msg, binary=isinstance(msg, bytes))
try:
msg = self._reserialize_reply(msg_list, channel=channel)
except Exception:
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
else:
self.write_message(msg, binary=isinstance(msg, bytes))


class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):
Expand Down
26 changes: 26 additions & 0 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ def init_settings(
"no_cache_paths": [url_path_join(base_url, "static", "custom")],
},
version_hash=version_hash,
# kernel message protocol over websoclet
kernel_ws_protocol=jupyter_app.kernel_ws_protocol,
# rate limits
limit_rate=jupyter_app.limit_rate,
iopub_msg_rate_limit=jupyter_app.iopub_msg_rate_limit,
iopub_data_rate_limit=jupyter_app.iopub_data_rate_limit,
rate_limit_window=jupyter_app.rate_limit_window,
Expand Down Expand Up @@ -1612,6 +1615,29 @@ def _update_server_extensions(self, change):
help=_i18n("Reraise exceptions encountered loading server extensions?"),
)

kernel_ws_protocol = Unicode(
None,
allow_none=True,
config=True,
help=_i18n(
"Preferred kernel message protocol over websocket to use (default: None). "
"If an empty string is passed, select the legacy protocol. If None, "
"the selected protocol will depend on what the front-end supports "
"(usually the most recent protocol supported by the back-end and the "
"front-end)."
),
)

limit_rate = Bool(
True,
config=True,
help=_i18n(
"Whether to limit the rate of IOPub messages (default: True). "
"If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window "
"to tune the rate."
),
)

iopub_msg_rate_limit = Float(
1000,
config=True,
Expand Down
176 changes: 126 additions & 50 deletions jupyter_server/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

from ...base.handlers import APIHandler
from ...base.zmqhandlers import AuthenticatedZMQStreamHandler
from ...base.zmqhandlers import deserialize_binary_message
from ...base.zmqhandlers import (
deserialize_binary_message,
serialize_msg_to_ws_v1,
deserialize_msg_from_ws_v1,
)
from jupyter_server.utils import ensure_async
from jupyter_server.utils import url_escape
from jupyter_server.utils import url_path_join
Expand Down Expand Up @@ -105,6 +109,10 @@ def kernel_info_timeout(self):
km_default = self.kernel_manager.kernel_info_timeout
return self.settings.get("kernel_info_timeout", km_default)

@property
def limit_rate(self):
return self.settings.get("limit_rate", True)

@property
def iopub_msg_rate_limit(self):
return self.settings.get("iopub_msg_rate_limit", 0)
Expand Down Expand Up @@ -449,64 +457,112 @@ def subscribe(value):

return connected

def on_message(self, msg):
def on_message(self, ws_msg):
if not self.channels:
# already closed, ignore the message
self.log.debug("Received message on closed websocket %r", msg)
self.log.debug("Received message on closed websocket %r", ws_msg)
return
if isinstance(msg, bytes):
msg = deserialize_binary_message(msg)

if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
channel, msg_list = deserialize_msg_from_ws_v1(ws_msg)
msg = {
"header": None,
}
else:
msg = json.loads(msg)
channel = msg.pop("channel", None)
if isinstance(ws_msg, bytes):
msg = deserialize_binary_message(ws_msg)
else:
msg = json.loads(ws_msg)
msg_list = []
channel = msg.pop("channel", None)

if channel is None:
self.log.warning("No channel specified, assuming shell: %s", msg)
channel = "shell"
if channel not in self.channels:
self.log.warning("No such channel: %r", channel)
return
am = self.kernel_manager.allowed_message_types
mt = msg["header"]["msg_type"]
if am and mt not in am:
self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt)
else:
ignore_msg = False
if am:
msg["header"] = self.get_part("header", msg["header"], msg_list)
if msg["header"]["msg_type"] not in am:
self.log.warning(
'Received message of type "%s", which is not allowed. Ignoring.'
% msg["header"]["msg_type"]
)
ignore_msg = True
if not ignore_msg:
stream = self.channels[channel]
self.session.send(stream, msg)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
self.session.send_raw(stream, msg_list)
else:
self.session.send(stream, msg)

def get_part(self, field, value, msg_list):
if value is None:
field2idx = {
"header": 0,
"parent_header": 1,
"content": 3,
}
value = self.session.unpack(msg_list[field2idx[field]])
return value

def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)

parent = msg["parent_header"]

def write_stderr(error_message):
self.log.warning(error_message)
msg = self.session.msg(
"stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent
)
msg["channel"] = "iopub"
self.write_message(json.dumps(msg, default=json_default))
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
msg = {"header": None, "parent_header": None, "content": None}
else:
msg = self.session.deserialize(fed_msg_list)

channel = getattr(stream, "channel", None)
msg_type = msg["header"]["msg_type"]
parts = fed_msg_list[1:]

if channel == "iopub" and msg_type == "error":
self._on_error(msg)
self._on_error(channel, msg, parts)

if (
channel == "iopub"
and msg_type == "status"
and msg["content"].get("execution_state") == "idle"
):
# reset rate limit counter on status=idle,
# to avoid 'Run All' hitting limits prematurely.
self._iopub_window_byte_queue = []
self._iopub_window_msg_count = 0
self._iopub_window_byte_count = 0
self._iopub_msgs_exceeded = False
self._iopub_data_exceeded = False
if self._limit_rate(channel, msg, parts):
return

if channel == "iopub" and msg_type not in {"status", "comm_open", "execute_input"}:
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, parts)
else:
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)

def write_stderr(self, error_message, parent_header):
self.log.warning(error_message)
err_msg = self.session.msg(
"stream",
content={"text": error_message + "\n", "name": "stderr"},
parent=parent_header,
)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(err_msg, "iopub", self.session.pack)
self.write_message(bin_msg, binary=True)
else:
err_msg["channel"] = "iopub"
self.write_message(json.dumps(err_msg, default=json_default))

def _limit_rate(self, channel, msg, msg_list):
if not (self.limit_rate and channel == "iopub"):
return False

msg["header"] = self.get_part("header", msg["header"], msg_list)

msg_type = msg["header"]["msg_type"]
if msg_type == "status":
msg["content"] = self.get_part("content", msg["content"], msg_list)
if msg["content"].get("execution_state") == "idle":
# reset rate limit counter on status=idle,
# to avoid 'Run All' hitting limits prematurely.
self._iopub_window_byte_queue = []
self._iopub_window_msg_count = 0
self._iopub_window_byte_count = 0
self._iopub_msgs_exceeded = False
self._iopub_data_exceeded = False

if msg_type not in {"status", "comm_open", "execute_input"}:

# Remove the counts queued for removal.
now = IOLoop.current().time()
Expand Down Expand Up @@ -542,7 +598,10 @@ def write_stderr(error_message):
if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit:
if not self._iopub_msgs_exceeded:
self._iopub_msgs_exceeded = True
write_stderr(
msg["parent_header"] = self.get_part(
"parent_header", msg["parent_header"], msg_list
)
self.write_stderr(
dedent(
"""\
IOPub message rate exceeded.
Expand All @@ -557,7 +616,8 @@ def write_stderr(error_message):
""".format(
self.iopub_msg_rate_limit, self.rate_limit_window
)
)
),
msg["parent_header"],
)
else:
# resume once we've got some headroom below the limit
Expand All @@ -570,7 +630,10 @@ def write_stderr(error_message):
if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit:
if not self._iopub_data_exceeded:
self._iopub_data_exceeded = True
write_stderr(
msg["parent_header"] = self.get_part(
"parent_header", msg["parent_header"], msg_list
)
self.write_stderr(
dedent(
"""\
IOPub data rate exceeded.
Expand All @@ -585,7 +648,8 @@ def write_stderr(error_message):
""".format(
self.iopub_data_rate_limit, self.rate_limit_window
)
)
),
msg["parent_header"],
)
else:
# resume once we've got some headroom below the limit
Expand All @@ -600,8 +664,9 @@ def write_stderr(error_message):
self._iopub_window_msg_count -= 1
self._iopub_window_byte_count -= byte_count
self._iopub_window_byte_queue.pop(-1)
return
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)
return True

return False

def close(self):
super(ZMQChannelsHandler, self).close()
Expand Down Expand Up @@ -651,8 +716,12 @@ def _send_status_message(self, status):
# that all messages from the stopped kernel have been delivered
iopub.flush()
msg = self.session.msg("status", {"execution_state": status})
msg["channel"] = "iopub"
self.write_message(json.dumps(msg, default=json_default))
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(msg, "iopub", self.session.pack)
self.write_message(bin_msg, binary=True)
else:
msg["channel"] = "iopub"
self.write_message(json.dumps(msg, default=json_default))

def on_kernel_restarted(self):
self.log.warning("kernel %s restarted", self.kernel_id)
Expand All @@ -662,12 +731,19 @@ def on_restart_failed(self):
self.log.error("kernel %s restarted failed!", self.kernel_id)
self._send_status_message("dead")

def _on_error(self, msg):
def _on_error(self, channel, msg, msg_list):
if self.kernel_manager.allow_tracebacks:
return
msg["content"]["ename"] = "ExecutionError"
msg["content"]["evalue"] = "Execution error"
msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message]

if channel == "iopub":
msg["header"] = self.get_part("header", msg["header"], msg_list)
if msg["header"]["msg_type"] == "error":
msg["content"] = self.get_part("content", msg["content"], msg_list)
msg["content"]["ename"] = "ExecutionError"
msg["content"]["evalue"] = "Execution error"
msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message]
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
msg_list[3] = self.session.pack(msg["content"])


# -----------------------------------------------------------------------------
Expand Down