Skip to content

Commit

Permalink
Support websocket subprotocols
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 18, 2022
1 parent c3238ac commit 4b398ce
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 17 deletions.
105 changes: 105 additions & 0 deletions jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,64 @@
from .handlers import JupyterHandler


def serialize_binary_message(msg):
"""serialize a message as a binary blob
Header:
4 bytes: number of msg parts (nbufs) as 32b int
4 * nbufs bytes: offset for each buffer as integer as 32b int
Offsets are from the start of the buffer, including the header.
Returns
-------
The message serialized to bytes.
"""
# don't modify msg or buffer list in-place
msg = msg.copy()
buffers = list(msg.pop("buffers"))
if sys.version_info < (3, 4):
buffers = [x.tobytes() for x in buffers]
bmsg = json.dumps(msg, default=json_default).encode("utf8")
buffers.insert(0, bmsg)
nbufs = len(buffers)
offsets = [4 * (nbufs + 1)]
for buf in buffers[:-1]:
offsets.append(offsets[-1] + len(buf))
offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
buffers.insert(0, offsets_buf)
return b"".join(buffers)


def deserialize_binary_message(bmsg):
"""deserialize a message from a binary blog
Header:
4 bytes: number of msg parts (nbufs) as 32b int
4 * nbufs bytes: offset for each buffer as integer as 32b int
Offsets are from the start of the buffer, including the header.
Returns
-------
message dictionary
"""
nbufs = struct.unpack("!i", bmsg[:4])[0]
offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
offsets.append(None)
bufs = []
for start, stop in zip(offsets[:-1], offsets[1:]):
bufs.append(bmsg[start:stop])
msg = json.loads(bufs[0].decode("utf8"))
msg["header"] = extract_dates(msg["header"])
msg["parent_header"] = extract_dates(msg["parent_header"])
msg["buffers"] = bufs[1:]
return msg


# ping interval for keeping websockets alive (30 seconds)
WS_PING_INTERVAL = 30000

Expand Down Expand Up @@ -155,6 +213,37 @@ def send_error(self, *args, **kwargs):
# we can close the connection more gracefully.
self.stream.close()

def _reserialize_reply(self, msg_or_list, channel=None):
"""Reserialize a reply message using JSON.
msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
If it is the zmq list, it will be deserialized with self.session.
This takes the msg list from the ZMQ socket and serializes the result for the websocket.
This method should be used by self._on_zmq_reply to build messages that can
be sent back to the browser.
"""
if isinstance(msg_or_list, dict):
# already unpacked
msg = msg_or_list
else:
idents, msg_list = self.session.feed_identities(msg_or_list)
msg = self.session.deserialize(msg_list)
if channel:
msg["channel"] = channel
if msg["buffers"]:
buf = serialize_binary_message(msg)
return buf
else:
smsg = json.dumps(msg, default=json_default)
return cast_unicode(smsg)

def select_subprotocol(self, subprotocols):
selected_subprotocol = "0.0.1" if "0.0.1" in subprotocols else None
# None is the default, "legacy" protocol
return selected_subprotocol

def _on_zmq_reply(self, stream, msg_list):
# Sometimes this gets triggered when the on_close method is scheduled in the
# eventloop but hasn't been called.
Expand All @@ -163,6 +252,22 @@ def _on_zmq_reply(self, stream, msg_list):
self.close()
return
channel = getattr(stream, "channel", None)
try:
msg = self._reserialize_reply(msg_list, channel=channel)
except Exception:
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
else:
self.write_message(msg, binary=isinstance(msg, bytes))

def _on_zmq_reply_0_0_1(self, stream, msg_list):
# Sometimes this gets triggered when the on_close method is scheduled in the
# eventloop but hasn't been called.
if self.ws_connection is None or stream.closed():
self.log.warning("zmq message arrived on closed channel")
self.close()
return

channel = getattr(stream, "channel", None)
offsets = []
curr_sum = 0
for msg in msg_list:
Expand Down
Loading

0 comments on commit 4b398ce

Please sign in to comment.