Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kernel WebSocket protocol #1110

Merged
merged 1 commit into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions jupyter_server/services/kernels/connection/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import weakref
from concurrent.futures import Future
from textwrap import dedent
from typing import Dict as Dict_t
from typing import MutableSet

from jupyter_client import protocol_version as client_protocol_version
Expand All @@ -21,6 +22,7 @@

from jupyter_server.transutils import _i18n

from ..websocket import KernelWebsocketHandler
from .abc import KernelWebsocketConnectionABC
from .base import (
BaseKernelWebsocketConnection,
Expand Down Expand Up @@ -103,7 +105,7 @@ def write_message(self):
# class-level registry of open sessions
# allows checking for conflict on session-id,
# which is used as a zmq identity and must be unique.
_open_sessions: dict = {}
_open_sessions: Dict_t[str, KernelWebsocketHandler] = {}
_open_sockets: MutableSet["ZMQChannelsWebsocketConnection"] = weakref.WeakSet()

_kernel_info_future: Future
Expand Down Expand Up @@ -391,7 +393,7 @@ def close(self):
def disconnect(self):
self.log.debug("Websocket closed %s", self.session_key)
# unregister myself as an open session (only if it's really me)
if self._open_sessions.get(self.session_key) is self:
if self._open_sessions.get(self.session_key) is self.websocket_handler:
self._open_sessions.pop(self.session_key)

if self.kernel_id in self.multi_kernel_manager:
Expand Down Expand Up @@ -536,16 +538,6 @@ def _reserialize_reply(self, msg_or_list, channel=None):
else:
return json.dumps(msg, default=json_default)

def select_subprotocol(self, subprotocols):
preferred_protocol = self.kernel_ws_protocol
if preferred_protocol is None:
preferred_protocol = "v1.kernel.websocket.jupyter.org"
elif preferred_protocol == "":
preferred_protocol = None
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
# None is the default, "legacy" protocol
return selected_subprotocol

def _on_zmq_reply(self, stream, msg_list):
# Sometimes this gets triggered when the on_close method is scheduled in the
# eventloop but hasn't been called.
Expand Down
12 changes: 11 additions & 1 deletion jupyter_server/services/kernels/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler):
"""The kernels websocket should connecte"""
"""The kernels websocket should connect"""

auth_resource = AUTH_RESOURCE

Expand Down Expand Up @@ -75,6 +75,16 @@ def on_close(self):
self.connection.disconnect()
self.connection = None

def select_subprotocol(self, subprotocols):
preferred_protocol = self.connection.kernel_ws_protocol
if preferred_protocol is None:
preferred_protocol = "v1.kernel.websocket.jupyter.org"
elif preferred_protocol == "":
preferred_protocol = None
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
# None is the default, "legacy" protocol
return selected_subprotocol


default_handlers = [
(r"/api/kernels/%s/channels" % _kernel_id_regex, KernelWebsocketHandler),
Expand Down
7 changes: 4 additions & 3 deletions tests/services/kernels/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from jupyter_client.jsonutil import json_clean, json_default
from jupyter_client.session import Session
from tornado.httpserver import HTTPRequest
from tornado.websocket import WebSocketHandler

from jupyter_server.serverapp import ServerApp
from jupyter_server.services.kernels.connection.channels import (
ZMQChannelsWebsocketConnection,
)
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler


async def test_websocket_connection(jp_serverapp):
Expand All @@ -19,10 +19,11 @@ async def test_websocket_connection(jp_serverapp):
kernel = app.kernel_manager.get_kernel(kernel_id)
request = HTTPRequest("foo", "GET")
request.connection = MagicMock()
handler = WebSocketHandler(app.web_app, request)
handler = KernelWebsocketHandler(app.web_app, request)
handler.ws_connection = MagicMock()
handler.ws_connection.is_closing = lambda: False
conn = ZMQChannelsWebsocketConnection(parent=kernel, websocket_handler=handler)
handler.connection = conn
await conn.prepare()
conn.connect()
await asyncio.wrap_future(conn.nudge())
Expand All @@ -37,7 +38,7 @@ async def test_websocket_connection(jp_serverapp):
conn.handle_incoming_message(data)
conn.handle_outgoing_message("iopub", session.serialize(msg))
assert (
conn.select_subprotocol(["v1.kernel.websocket.jupyter.org"])
conn.websocket_handler.select_subprotocol(["v1.kernel.websocket.jupyter.org"])
== "v1.kernel.websocket.jupyter.org"
)
conn.write_stderr("test", {})
Expand Down