From c0f3f3ca3f976ed4c5ae3570aaaa2ed81591d1fa Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 2 Dec 2022 16:36:38 +0100 Subject: [PATCH] Fix kernel WebSocket protocol --- .../services/kernels/connection/channels.py | 16 ++++------------ jupyter_server/services/kernels/websocket.py | 12 +++++++++++- tests/services/kernels/test_connection.py | 7 ++++--- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/jupyter_server/services/kernels/connection/channels.py b/jupyter_server/services/kernels/connection/channels.py index 13beab3cc3..b5a2d4ce60 100644 --- a/jupyter_server/services/kernels/connection/channels.py +++ b/jupyter_server/services/kernels/connection/channels.py @@ -4,6 +4,7 @@ import weakref from concurrent.futures import Future from textwrap import dedent +from typing import Dict as Dict_t from typing import MutableSet from jupyter_client import protocol_version as client_protocol_version @@ -21,6 +22,7 @@ from jupyter_server.transutils import _i18n +from ..websocket import KernelWebsocketHandler from .abc import KernelWebsocketConnectionABC from .base import ( BaseKernelWebsocketConnection, @@ -103,7 +105,7 @@ def write_message(self): # class-level registry of open sessions # allows checking for conflict on session-id, # which is used as a zmq identity and must be unique. - _open_sessions: dict = {} + _open_sessions: Dict_t[str, KernelWebsocketHandler] = {} _open_sockets: MutableSet["ZMQChannelsWebsocketConnection"] = weakref.WeakSet() _kernel_info_future: Future @@ -391,7 +393,7 @@ def close(self): def disconnect(self): self.log.debug("Websocket closed %s", self.session_key) # unregister myself as an open session (only if it's really me) - if self._open_sessions.get(self.session_key) is self: + if self._open_sessions.get(self.session_key) is self.websocket_handler: self._open_sessions.pop(self.session_key) if self.kernel_id in self.multi_kernel_manager: @@ -536,16 +538,6 @@ def _reserialize_reply(self, msg_or_list, channel=None): else: return json.dumps(msg, default=json_default) - def select_subprotocol(self, subprotocols): - preferred_protocol = self.kernel_ws_protocol - if preferred_protocol is None: - preferred_protocol = "v1.kernel.websocket.jupyter.org" - elif preferred_protocol == "": - preferred_protocol = None - selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None - # None is the default, "legacy" protocol - return selected_subprotocol - def _on_zmq_reply(self, stream, msg_list): # Sometimes this gets triggered when the on_close method is scheduled in the # eventloop but hasn't been called. diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 2806053a98..be3e021548 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -14,7 +14,7 @@ class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): - """The kernels websocket should connecte""" + """The kernels websocket should connect""" auth_resource = AUTH_RESOURCE @@ -75,6 +75,16 @@ def on_close(self): self.connection.disconnect() self.connection = None + def select_subprotocol(self, subprotocols): + preferred_protocol = self.connection.kernel_ws_protocol + if preferred_protocol is None: + preferred_protocol = "v1.kernel.websocket.jupyter.org" + elif preferred_protocol == "": + preferred_protocol = None + selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None + # None is the default, "legacy" protocol + return selected_subprotocol + default_handlers = [ (r"/api/kernels/%s/channels" % _kernel_id_regex, KernelWebsocketHandler), diff --git a/tests/services/kernels/test_connection.py b/tests/services/kernels/test_connection.py index 1022eb3cd4..7fb7a1eee4 100644 --- a/tests/services/kernels/test_connection.py +++ b/tests/services/kernels/test_connection.py @@ -5,12 +5,12 @@ from jupyter_client.jsonutil import json_clean, json_default from jupyter_client.session import Session from tornado.httpserver import HTTPRequest -from tornado.websocket import WebSocketHandler from jupyter_server.serverapp import ServerApp from jupyter_server.services.kernels.connection.channels import ( ZMQChannelsWebsocketConnection, ) +from jupyter_server.services.kernels.websocket import KernelWebsocketHandler async def test_websocket_connection(jp_serverapp): @@ -19,10 +19,11 @@ async def test_websocket_connection(jp_serverapp): kernel = app.kernel_manager.get_kernel(kernel_id) request = HTTPRequest("foo", "GET") request.connection = MagicMock() - handler = WebSocketHandler(app.web_app, request) + handler = KernelWebsocketHandler(app.web_app, request) handler.ws_connection = MagicMock() handler.ws_connection.is_closing = lambda: False conn = ZMQChannelsWebsocketConnection(parent=kernel, websocket_handler=handler) + handler.connection = conn await conn.prepare() conn.connect() await asyncio.wrap_future(conn.nudge()) @@ -37,7 +38,7 @@ async def test_websocket_connection(jp_serverapp): conn.handle_incoming_message(data) conn.handle_outgoing_message("iopub", session.serialize(msg)) assert ( - conn.select_subprotocol(["v1.kernel.websocket.jupyter.org"]) + conn.websocket_handler.select_subprotocol(["v1.kernel.websocket.jupyter.org"]) == "v1.kernel.websocket.jupyter.org" ) conn.write_stderr("test", {})