diff --git a/jupyter_server/services/kernels/connection/channels.py b/jupyter_server/services/kernels/connection/channels.py index 13beab3cc3..9bc06236e3 100644 --- a/jupyter_server/services/kernels/connection/channels.py +++ b/jupyter_server/services/kernels/connection/channels.py @@ -4,7 +4,7 @@ import weakref from concurrent.futures import Future from textwrap import dedent -from typing import MutableSet +from typing import Dict as Dict_t, MutableSet from jupyter_client import protocol_version as client_protocol_version from tornado import gen, web @@ -21,6 +21,7 @@ from jupyter_server.transutils import _i18n +from ..websocket import KernelWebsocketHandler from .abc import KernelWebsocketConnectionABC from .base import ( BaseKernelWebsocketConnection, @@ -103,7 +104,7 @@ def write_message(self): # class-level registry of open sessions # allows checking for conflict on session-id, # which is used as a zmq identity and must be unique. - _open_sessions: dict = {} + _open_sessions: Dict_t[str, KernelWebsocketHandler] = {} _open_sockets: MutableSet["ZMQChannelsWebsocketConnection"] = weakref.WeakSet() _kernel_info_future: Future @@ -391,7 +392,7 @@ def close(self): def disconnect(self): self.log.debug("Websocket closed %s", self.session_key) # unregister myself as an open session (only if it's really me) - if self._open_sessions.get(self.session_key) is self: + if self._open_sessions.get(self.session_key) is self.websocket_handler: self._open_sessions.pop(self.session_key) if self.kernel_id in self.multi_kernel_manager: @@ -536,16 +537,6 @@ def _reserialize_reply(self, msg_or_list, channel=None): else: return json.dumps(msg, default=json_default) - def select_subprotocol(self, subprotocols): - preferred_protocol = self.kernel_ws_protocol - if preferred_protocol is None: - preferred_protocol = "v1.kernel.websocket.jupyter.org" - elif preferred_protocol == "": - preferred_protocol = None - selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None - # None is the default, "legacy" protocol - return selected_subprotocol - def _on_zmq_reply(self, stream, msg_list): # Sometimes this gets triggered when the on_close method is scheduled in the # eventloop but hasn't been called. diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 2806053a98..be3e021548 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -14,7 +14,7 @@ class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): - """The kernels websocket should connecte""" + """The kernels websocket should connect""" auth_resource = AUTH_RESOURCE @@ -75,6 +75,16 @@ def on_close(self): self.connection.disconnect() self.connection = None + def select_subprotocol(self, subprotocols): + preferred_protocol = self.connection.kernel_ws_protocol + if preferred_protocol is None: + preferred_protocol = "v1.kernel.websocket.jupyter.org" + elif preferred_protocol == "": + preferred_protocol = None + selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None + # None is the default, "legacy" protocol + return selected_subprotocol + default_handlers = [ (r"/api/kernels/%s/channels" % _kernel_id_regex, KernelWebsocketHandler),