From 494a30e087b26fd7fa5ed0905838cf6c42e38f8c Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Thu, 25 Oct 2018 17:17:18 +0100 Subject: [PATCH] Initial work towards using jupyter_kernel_mgmt in jupyter server Cherry-pick Notebook PR https://github.com/jupyter/notebook/pull/4837 --- jupyter_server/base/handlers.py | 10 +- jupyter_server/base/zmqhandlers.py | 174 +------ jupyter_server/kernelspecs/handlers.py | 20 +- jupyter_server/serverapp.py | 55 ++- jupyter_server/services/kernels/handlers.py | 279 +++++------ .../services/kernels/kernelmanager.py | 435 +++++++++++------- .../services/kernels/ws_serialize.py | 119 +++++ .../services/kernelspecs/handlers.py | 41 +- jupyter_server/services/sessions/handlers.py | 2 +- .../services/sessions/sessionmanager.py | 9 +- setup.py | 2 + tests/services/kernels/test_api.py | 10 +- tests/services/kernelspecs/test_api.py | 28 +- tests/services/sessions/test_api.py | 8 +- tests/test_serialize.py | 44 +- 15 files changed, 614 insertions(+), 622 deletions(-) create mode 100644 jupyter_server/services/kernels/ws_serialize.py diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 1d9bbb3cd5..a49271ed17 100755 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -245,7 +245,11 @@ def contents_js_source(self): #--------------------------------------------------------------- # Manager objects #--------------------------------------------------------------- - + + @property + def kernel_finder(self): + return self.settings['kernel_finder'] + @property def kernel_manager(self): return self.settings['kernel_manager'] @@ -261,10 +265,6 @@ def session_manager(self): @property def terminal_manager(self): return self.settings['terminal_manager'] - - @property - def kernel_spec_manager(self): - return self.settings['kernel_spec_manager'] @property def config_manager(self): diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index a7dde2cbe5..44d37ae5ac 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -4,81 +4,11 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import json -import struct -import sys -import tornado - from urllib.parse import urlparse -from tornado import gen, ioloop, web -from tornado.websocket import WebSocketHandler - -from jupyter_client.session import Session -from jupyter_client.jsonutil import date_default, extract_dates -from ipython_genutils.py3compat import cast_unicode - -from .handlers import JupyterHandler -from jupyter_server.utils import maybe_future - - -def serialize_binary_message(msg): - """serialize a message as a binary blob - - Header: - - 4 bytes: number of msg parts (nbufs) as 32b int - 4 * nbufs bytes: offset for each buffer as integer as 32b int - - Offsets are from the start of the buffer, including the header. - - Returns - ------- - - The message serialized to bytes. - - """ - # don't modify msg or buffer list in-place - msg = msg.copy() - buffers = list(msg.pop('buffers')) - if sys.version_info < (3, 4): - buffers = [x.tobytes() for x in buffers] - bmsg = json.dumps(msg, default=date_default).encode('utf8') - buffers.insert(0, bmsg) - nbufs = len(buffers) - offsets = [4 * (nbufs + 1)] - for buf in buffers[:-1]: - offsets.append(offsets[-1] + len(buf)) - offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets) - buffers.insert(0, offsets_buf) - return b''.join(buffers) - -def deserialize_binary_message(bmsg): - """deserialize a message from a binary blog - - Header: - - 4 bytes: number of msg parts (nbufs) as 32b int - 4 * nbufs bytes: offset for each buffer as integer as 32b int - - Offsets are from the start of the buffer, including the header. - - Returns - ------- - - message dictionary - """ - nbufs = struct.unpack('!i', bmsg[:4])[0] - offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)])) - offsets.append(None) - bufs = [] - for start, stop in zip(offsets[:-1], offsets[1:]): - bufs.append(bmsg[start:stop]) - msg = json.loads(bufs[0].decode('utf8')) - msg['header'] = extract_dates(msg['header']) - msg['parent_header'] = extract_dates(msg['parent_header']) - msg['buffers'] = bufs[1:] - return msg +from tornado import ioloop +from tornado.iostream import StreamClosedError +from tornado.websocket import WebSocketHandler, WebSocketClosedError # ping interval for keeping websockets alive (30 seconds) WS_PING_INTERVAL = 30000 @@ -188,101 +118,3 @@ def send_ping(self): def on_pong(self, data): self.last_pong = ioloop.IOLoop.current().time() - -class ZMQStreamHandler(WebSocketMixin, WebSocketHandler): - - if tornado.version_info < (4,1): - """Backport send_error from tornado 4.1 to 4.0""" - def send_error(self, *args, **kwargs): - if self.stream is None: - super(WebSocketHandler, self).send_error(*args, **kwargs) - else: - # If we get an uncaught exception during the handshake, - # we have no choice but to abruptly close the connection. - # TODO: for uncaught exceptions after the handshake, - # we can close the connection more gracefully. - self.stream.close() - - - def _reserialize_reply(self, msg_or_list, channel=None): - """Reserialize a reply message using JSON. - - msg_or_list can be an already-deserialized msg dict or the zmq buffer list. - If it is the zmq list, it will be deserialized with self.session. - - This takes the msg list from the ZMQ socket and serializes the result for the websocket. - This method should be used by self._on_zmq_reply to build messages that can - be sent back to the browser. - - """ - if isinstance(msg_or_list, dict): - # already unpacked - msg = msg_or_list - else: - idents, msg_list = self.session.feed_identities(msg_or_list) - msg = self.session.deserialize(msg_list) - if channel: - msg['channel'] = channel - if msg['buffers']: - buf = serialize_binary_message(msg) - return buf - else: - smsg = json.dumps(msg, default=date_default) - return cast_unicode(smsg) - - def _on_zmq_reply(self, stream, msg_list): - # Sometimes this gets triggered when the on_close method is scheduled in the - # eventloop but hasn't been called. - if self.ws_connection is None or stream.closed(): - self.log.warning("zmq message arrived on closed channel") - self.close() - return - channel = getattr(stream, 'channel', None) - try: - msg = self._reserialize_reply(msg_list, channel=channel) - except Exception: - self.log.critical("Malformed message: %r" % msg_list, exc_info=True) - else: - self.write_message(msg, binary=isinstance(msg, bytes)) - - -class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler): - - def set_default_headers(self): - """Undo the set_default_headers in JupyterHandler - - which doesn't make sense for websockets - """ - pass - - def pre_get(self): - """Run before finishing the GET request - - Extend this method to add logic that should fire before - the websocket finishes completing. - """ - # authenticate the request before opening the websocket - if self.get_current_user() is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) - - if self.get_argument('session_id', False): - self.session.session = cast_unicode(self.get_argument('session_id')) - else: - self.log.warning("No session ID specified") - - @gen.coroutine - def get(self, *args, **kwargs): - # pre_get can be a coroutine in subclasses - # assign and yield in two step to avoid tornado 3 issues - res = self.pre_get() - yield maybe_future(res) - res = super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs) - yield maybe_future(res) - - def initialize(self): - self.log.debug("Initializing websocket connection %s", self.request.path) - self.session = Session(config=self.config) - - def get_compression_options(self): - return self.settings.get('websocket_compression_options', None) diff --git a/jupyter_server/kernelspecs/handlers.py b/jupyter_server/kernelspecs/handlers.py index 228694b8a5..d1297597ad 100644 --- a/jupyter_server/kernelspecs/handlers.py +++ b/jupyter_server/kernelspecs/handlers.py @@ -11,19 +11,23 @@ def initialize(self): @web.authenticated def get(self, kernel_name, path, include_body=True): - ksm = self.kernel_spec_manager - try: - self.root = ksm.get_kernel_spec(kernel_name).resource_dir - except KeyError: - raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) - self.log.debug("Serving kernel resource from: %s", self.root) - return web.StaticFileHandler.get(self, path, include_body=include_body) + kf = self.kernel_finder + # TODO: Do we actually want all kernel type names to be case-insensitive? + kernel_name = kernel_name.lower() + for name, info in kf.find_kernels(): + if name == kernel_name: + self.root = info['resource_dir'] + self.log.debug("Serving kernel resource from: %s", self.root) + return web.StaticFileHandler.get(self, path, + include_body=include_body) + + raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) @web.authenticated def head(self, kernel_name, path): return self.get(kernel_name, path, include_body=False) + default_handlers = [ (r"/kernelspecs/%s/(?P.*)" % kernel_name_regex, KernelSpecResourceHandler), ] - diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 65d3562071..b55b0bcea9 100755 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -88,8 +88,8 @@ ) from jupyter_core.paths import jupyter_config_path from jupyter_client import KernelManager -from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME from jupyter_client.session import Session +from jupyter_kernel_mgmt.discovery import KernelFinder from nbformat.sign import NotebookNotary from traitlets import ( Any, Dict, Unicode, Integer, List, Bool, Bytes, Instance, @@ -159,13 +159,13 @@ def load_handlers(name): class ServerWebApplication(web.Application): def __init__(self, jupyter_app, default_services, kernel_manager, contents_manager, - session_manager, kernel_spec_manager, + session_manager, kernel_finder, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options): settings = self.init_settings( jupyter_app, kernel_manager, contents_manager, - session_manager, kernel_spec_manager, config_manager, + session_manager, kernel_finder, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options) handlers = self.init_handlers(default_services, settings) @@ -173,7 +173,7 @@ def __init__(self, jupyter_app, default_services, kernel_manager, contents_manag super(ServerWebApplication, self).__init__(handlers, **settings) def init_settings(self, jupyter_app, kernel_manager, contents_manager, - session_manager, kernel_spec_manager, + session_manager, kernel_finder, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options=None): @@ -248,10 +248,10 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, local_hostnames=jupyter_app.local_hostnames, # managers + kernel_finder=kernel_finder, kernel_manager=kernel_manager, contents_manager=contents_manager, session_manager=session_manager, - kernel_spec_manager=kernel_spec_manager, config_manager=config_manager, # handlers @@ -555,7 +555,7 @@ class ServerApp(JupyterApp): flags = flags classes = [ - KernelManager, Session, MappingKernelManager, KernelSpecManager, + KernelManager, Session, MappingKernelManager, ContentsManager, FileContentsManager, NotebookNotary, GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient, ] @@ -1033,6 +1033,12 @@ def template_file_path(self): (shutdown the Jupyter server).""" ) + kernel_providers = List(config=True, + help=_('A list of kernel provider instances. ' + 'If not specified, all installed kernel providers are found ' + 'using entry points.') + ) + contents_manager_class = Type( default_value=LargeFileManager, klass=ContentsManager, @@ -1058,20 +1064,6 @@ def template_file_path(self): help=_('The config manager class to use') ) - kernel_spec_manager = Instance(KernelSpecManager, allow_none=True) - - kernel_spec_manager_class = Type( - default_value=KernelSpecManager, - config=True, - help=""" - The kernel spec manager class to use. Should be a subclass - of `jupyter_client.kernelspec.KernelSpecManager`. - - The Api of KernelSpecManager is provisional and might change - without warning between this version of Jupyter and the next stable one. - """ - ) - login_handler_class = Type( default_value=LoginHandler, klass=web.RequestHandler, @@ -1104,7 +1096,7 @@ def _default_info_file(self): def _default_browser_open_file(self): basename = "jpserver-%s-open.html" % os.getpid() return os.path.join(self.runtime_dir, basename) - + pylab = Unicode('disabled', config=True, help=_(""" DISABLED: use %pylab or %matplotlib in the notebook to enable matplotlib. @@ -1237,16 +1229,23 @@ def init_configurables(self): if self.gateway_config.gateway_enabled: self.kernel_manager_class = 'jupyter_server.gateway.managers.GatewayKernelManager' self.session_manager_class = 'jupyter_server.gateway.managers.GatewaySessionManager' - self.kernel_spec_manager_class = 'jupyter_server.gateway.managers.GatewayKernelSpecManager' +# FIXME - no more kernel-spec-manager! +# self.kernel_spec_manager_class = 'jupyter_server.gateway.managers.GatewayKernelSpecManager' +# +# self.kernel_spec_manager = self.kernel_spec_manager_class( +# parent=self, +# ) + + if self.kernel_providers: + self.kernel_finder = KernelFinder(self.kernel_providers) + else: + self.kernel_finder = KernelFinder.from_entrypoints() - self.kernel_spec_manager = self.kernel_spec_manager_class( - parent=self, - ) self.kernel_manager = self.kernel_manager_class( parent=self, log=self.log, connection_dir=self.runtime_dir, - kernel_spec_manager=self.kernel_spec_manager, + kernel_finder=self.kernel_finder, ) self.contents_manager = self.contents_manager_class( parent=self, @@ -1301,7 +1300,7 @@ def init_webapp(self): self.web_app = ServerWebApplication( self, self.default_services, self.kernel_manager, self.contents_manager, - self.session_manager, self.kernel_spec_manager, + self.session_manager, self.kernel_finder, self.config_manager, self.extra_services, self.log, self.base_url, self.default_url, self.tornado_settings, self.jinja_environment_options, @@ -1490,7 +1489,7 @@ def init_server_extensions(self): Import the module, then call the load_jupyter_server_extension function, if one exists. - + The extension API is experimental, and may change in future releases. """ # Initialize extensions diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index 358798408c..c065ff8f8e 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -10,18 +10,20 @@ import logging from textwrap import dedent +import tornado from tornado import gen, web from tornado.concurrent import Future from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler -from jupyter_client import protocol_version as client_protocol_version from jupyter_client.jsonutil import date_default +from jupyter_protocol.messages import Message from ipython_genutils.py3compat import cast_unicode from jupyter_server.utils import url_path_join, url_escape, maybe_future -from ...base.handlers import APIHandler -from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message - +from ...base.handlers import APIHandler, JupyterHandler +from ...base.zmqhandlers import WebSocketMixin +from .ws_serialize import serialize_message, deserialize_message class MainKernelHandler(APIHandler): @@ -58,6 +60,7 @@ class KernelHandler(APIHandler): @web.authenticated def get(self, kernel_id): km = self.kernel_manager + km._check_kernel_id(kernel_id) model = km.kernel_model(kernel_id) self.finish(json.dumps(model, default=date_default)) @@ -77,12 +80,17 @@ class KernelActionHandler(APIHandler): def post(self, kernel_id, action): km = self.kernel_manager if action == 'interrupt': - km.interrupt_kernel(kernel_id) + kernel = km.get_kernel(kernel_id) + # Don't interrupt a kernel while it's still starting + yield kernel.client_ready() + kernel.interrupt() self.set_status(204) if action == 'restart': try: yield maybe_future(km.restart_kernel(kernel_id)) + except web.HTTPError: + raise except Exception as e: self.log.error("Exception restarting kernel", exc_info=True) self.set_status(500) @@ -92,7 +100,7 @@ def post(self, kernel_id, action): self.finish() -class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): +class ZMQChannelsHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): '''There is one ZMQChannelsHandler per running kernel and it oversees all the sessions. ''' @@ -119,85 +127,29 @@ def iopub_data_rate_limit(self): def rate_limit_window(self): return self.settings.get('rate_limit_window', 1.0) + @property + def kernel_client(self): + return self.kernel_manager.get_kernel(self.kernel_id).client + def __repr__(self): return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized')) - def create_stream(self): - km = self.kernel_manager - identity = self.session.bsession - for channel in ('shell', 'control', 'iopub', 'stdin'): - meth = getattr(km, 'connect_' + channel) - self.channels[channel] = stream = meth(self.kernel_id, identity=identity) - stream.channel = channel - - def request_kernel_info(self): - """send a request for kernel_info""" - km = self.kernel_manager - kernel = km.get_kernel(self.kernel_id) - try: - # check for previous request - future = kernel._kernel_info_future - except AttributeError: - self.log.debug("Requesting kernel info from %s", self.kernel_id) - # Create a kernel_info channel to query the kernel protocol version. - # This channel will be closed after the kernel_info reply is received. - if self.kernel_info_channel is None: - self.kernel_info_channel = km.connect_shell(self.kernel_id) - self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) - self.session.send(self.kernel_info_channel, "kernel_info_request") - # store the future on the kernel, so only one request is sent - kernel._kernel_info_future = self._kernel_info_future - else: - if not future.done(): - self.log.debug("Waiting for pending kernel_info request") - future.add_done_callback(lambda f: self._finish_kernel_info(f.result())) - return self._kernel_info_future - - def _handle_kernel_info_reply(self, msg): - """process the kernel_info_reply + def set_default_headers(self): + """Undo the set_default_headers in IPythonHandler - enabling msg spec adaptation, if necessary + which doesn't make sense for websockets """ - idents,msg = self.session.feed_identities(msg) - try: - msg = self.session.deserialize(msg) - except: - self.log.error("Bad kernel_info reply", exc_info=True) - self._kernel_info_future.set_result({}) - return - else: - info = msg['content'] - self.log.debug("Received kernel info: %s", info) - if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info: - self.log.error("Kernel info request failed, assuming current %s", info) - info = {} - self._finish_kernel_info(info) - - # close the kernel_info channel, we don't need it anymore - if self.kernel_info_channel: - self.kernel_info_channel.close() - self.kernel_info_channel = None - - def _finish_kernel_info(self, info): - """Finish handling kernel_info reply - - Set up protocol adaptation, if needed, - and signal that connection can continue. - """ - protocol_version = info.get('protocol_version', client_protocol_version) - if protocol_version != client_protocol_version: - self.session.adapt_version = int(protocol_version.split('.')[0]) - self.log.info("Adapting from protocol version {protocol_version} (kernel {kernel_id}) to {client_protocol_version} (client).".format(protocol_version=protocol_version, kernel_id=self.kernel_id, client_protocol_version=client_protocol_version)) - if not self._kernel_info_future.done(): - self._kernel_info_future.set_result(info) + pass + + def get_compression_options(self): + return self.settings.get('websocket_compression_options', None) + + channels = {'shell', 'control', 'iopub', 'stdin'} def initialize(self): super(ZMQChannelsHandler, self).initialize() self.zmq_stream = None - self.channels = {} self.kernel_id = None - self.kernel_info_channel = None - self._kernel_info_future = Future() self._close_future = Future() self.session_key = '' @@ -211,33 +163,27 @@ def initialize(self): # by a delta amount at some point in the future. self._iopub_window_byte_queue = [] + session_id = None + @gen.coroutine def pre_get(self): - # authenticate first - super(ZMQChannelsHandler, self).pre_get() + # authenticate the request before opening the websocket + if self.get_current_user() is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + if self.get_argument('session_id', False): + self.session_id = cast_unicode(self.get_argument('session_id')) + else: + self.log.warning("No session ID specified") + # check session collision: yield self._register_session() - # then request kernel info, waiting up to a certain time before giving up. - # We don't want to wait forever, because browsers don't take it well when - # servers never respond to websocket connection requests. - kernel = self.kernel_manager.get_kernel(self.kernel_id) - self.session.key = kernel.session.key - future = self.request_kernel_info() - - def give_up(): - """Don't wait forever for the kernel to reply""" - if future.done(): - return - self.log.warning("Timeout waiting for kernel_info reply from %s", self.kernel_id) - future.set_result({}) - loop = IOLoop.current() - loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up) - # actually wait for it - yield future @gen.coroutine def get(self, kernel_id): self.kernel_id = cast_unicode(kernel_id, 'ascii') + yield self.pre_get() yield super(ZMQChannelsHandler, self).get(kernel_id=kernel_id) @gen.coroutine @@ -248,57 +194,45 @@ def _register_session(self): This is likely due to a client reconnecting from a lost network connection, where the socket on our side has not been cleaned up yet. """ - self.session_key = '%s:%s' % (self.kernel_id, self.session.session) + self.session_key = '%s:%s' % (self.kernel_id, self.session_id) stale_handler = self._open_sessions.get(self.session_key) if stale_handler: self.log.warning("Replacing stale connection: %s", self.session_key) yield stale_handler.close() self._open_sessions[self.session_key] = self + @gen.coroutine def open(self, kernel_id): super(ZMQChannelsHandler, self).open() km = self.kernel_manager + km._check_kernel_id(kernel_id) km.notify_connect(kernel_id) + kernel = km.get_kernel(kernel_id) + yield from kernel.client_ready() # on new connections, flush the message buffer - buffer_info = km.get_buffer(kernel_id, self.session_key) - if buffer_info and buffer_info['session_key'] == self.session_key: + buffer_key, replay_buffer = kernel.get_buffer() + if buffer_key == self.session_key: self.log.info("Restoring connection for %s", self.session_key) - self.channels = buffer_info['channels'] - replay_buffer = buffer_info['buffer'] if replay_buffer: self.log.info("Replaying %s buffered messages", len(replay_buffer)) - for channel, msg_list in replay_buffer: - stream = self.channels[channel] - self._on_zmq_reply(stream, msg_list) - else: - try: - self.create_stream() - except web.HTTPError as e: - self.log.error("Error opening stream: %s", e) - # WebSockets don't response to traditional error codes so we - # close the connection. - for channel, stream in self.channels.items(): - if not stream.closed(): - stream.close() - self.close() - return + for msg, channel in replay_buffer: + self._on_zmq_msg(msg, channel) - km.add_restart_callback(self.kernel_id, self.on_kernel_restarted) - km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead') + kernel.restarter.add_callback(self.on_kernel_died, 'died') + kernel.restarter.add_callback(self.on_kernel_restarted, 'restarted') + kernel.restarter.add_callback(self.on_restart_failed, 'failed') - for channel, stream in self.channels.items(): - stream.on_recv_stream(self._on_zmq_reply) + kernel.msg_handlers.append(self._on_zmq_msg) def on_message(self, msg): - if not self.channels: + """Received websocket message; forward to kernel""" + if self._close_future.done(): # already closed, ignore the message self.log.debug("Received message on closed websocket %r", msg) return - if isinstance(msg, bytes): - msg = deserialize_binary_message(msg) - else: - msg = json.loads(msg) + + msg = deserialize_message(msg) channel = msg.pop('channel', None) if channel is None: self.log.warning("No channel specified, assuming shell: %s", msg) @@ -311,25 +245,25 @@ def on_message(self, msg): if am and mt not in am: self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt) else: - stream = self.channels[channel] - self.session.send(stream, msg) + self.kernel_client.messaging.send(channel, Message(**msg)) + + def _on_zmq_msg(self, msg: Message, channel): + """Received message from kernel; forward over websocket""" + if self.ws_connection is None: + return - def _on_zmq_reply(self, stream, msg_list): - idents, fed_msg_list = self.session.feed_identities(msg_list) - msg = self.session.deserialize(fed_msg_list) - parent = msg['parent_header'] def write_stderr(error_message): self.log.warning(error_message) - msg = self.session.msg("stream", + stream_msg = Message.from_type("stream", content={"text": error_message + '\n', "name": "stderr"}, - parent=parent - ) - msg['channel'] = 'iopub' - self.write_message(json.dumps(msg, default=date_default)) - channel = getattr(stream, 'channel', None) - msg_type = msg['header']['msg_type'] - - if channel == 'iopub' and msg_type == 'status' and msg['content'].get('execution_state') == 'idle': + ).make_dict() + stream_msg['parent_header'] = msg.parent_header + stream_msg['channel'] = 'iopub' + self.write_message(json.dumps(stream_msg, default=date_default)) + + msg_type = msg.header['msg_type'] + + if channel == 'iopub' and msg_type == 'status' and msg.content.get('execution_state') == 'idle': # reset rate limit counter on status=idle, # to avoid 'Run All' hitting limits prematurely. self._iopub_window_byte_queue = [] @@ -356,7 +290,7 @@ def write_stderr(error_message): # Increment the bytes and message count self._iopub_window_msg_count += 1 if msg_type == 'stream': - byte_count = sum([len(x) for x in msg_list]) + byte_count = len(msg.content['text'].encode('utf-8')) else: byte_count = 0 self._iopub_window_byte_count += byte_count @@ -421,7 +355,14 @@ def write_stderr(error_message): self._iopub_window_byte_count -= byte_count self._iopub_window_byte_queue.pop(-1) return - super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) + + try: + ws_msg = serialize_message(msg, channel=channel) + except Exception: + self.log.critical("Malformed message: %r" % msg, + exc_info=True) + else: + self.write_message(ws_msg, binary=isinstance(ws_msg, bytes)) def close(self): super(ZMQChannelsHandler, self).close() @@ -436,50 +377,44 @@ def on_close(self): km = self.kernel_manager if self.kernel_id in km: km.notify_disconnect(self.kernel_id) - km.remove_restart_callback( - self.kernel_id, self.on_kernel_restarted, - ) - km.remove_restart_callback( - self.kernel_id, self.on_restart_failed, 'dead', - ) + kernel = km.get_kernel(self.kernel_id) + try: + kernel.msg_handlers.remove(self._on_zmq_msg) + except ValueError: + self.log.debug("Message handler not connected") + + kernel.restarter.remove_callback(self.on_kernel_died, 'died') + kernel.restarter.remove_callback(self.on_restart_failed, 'failed') + kernel.restarter.remove_callback(self.on_kernel_restarted, 'restarted') # start buffering instead of closing if this was the last connection - if km._kernel_connections[self.kernel_id] == 0: + if kernel.n_connections == 0: km.start_buffering(self.kernel_id, self.session_key, self.channels) - self._close_future.set_result(None) - return - - # This method can be called twice, once by self.kernel_died and once - # from the WebSocket close event. If the WebSocket connection is - # closed before the ZMQ streams are setup, they could be None. - for channel, stream in self.channels.items(): - if stream is not None and not stream.closed(): - stream.on_recv(None) - stream.close() - self.channels = {} self._close_future.set_result(None) def _send_status_message(self, status): - iopub = self.channels.get('iopub', None) - if iopub and not iopub.closed(): - # flush IOPub before sending a restarting/dead status message - # ensures proper ordering on the IOPub channel - # that all messages from the stopped kernel have been delivered - iopub.flush() - msg = self.session.msg("status", + msg = Message.from_type("status", {'execution_state': status} ) - msg['channel'] = 'iopub' - self.write_message(json.dumps(msg, default=date_default)) + ws_msg = serialize_message(msg, channel='iopub') + return self.write_message(ws_msg, binary=isinstance(ws_msg, bytes)) + + def on_kernel_died(self, _data): + logging.warning("kernel %s died, noticed by auto restarter", self.kernel_id) + return self._send_status_message('restarting') - def on_kernel_restarted(self): - logging.warn("kernel %s restarted", self.kernel_id) - self._send_status_message('restarting') + @gen.coroutine + def on_kernel_restarted(self, _data): + kernel = self.kernel_manager.get_kernel(self.kernel_id) + # Send the status message once the client is connected + yield kernel.client_ready() + logging.warning("kernel %s restarted", self.kernel_id) + return self._send_status_message('starting') - def on_restart_failed(self): + def on_restart_failed(self, _data): logging.error("kernel %s restarted failed!", self.kernel_id) - self._send_status_message('dead') + return self._send_status_message('dead') #----------------------------------------------------------------------------- diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index 250472f2cc..4238ca4557 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -9,18 +9,19 @@ from collections import defaultdict from datetime import datetime, timedelta -from functools import partial import os +import uuid from tornado import gen, web -from tornado.concurrent import Future from tornado.ioloop import IOLoop, PeriodicCallback +from tornado.locks import Event -from jupyter_client.session import Session -from jupyter_client.multikernelmanager import MultiKernelManager +from jupyter_kernel_mgmt.client import IOLoopKernelClient +from jupyter_kernel_mgmt.restarter import TornadoKernelRestarter from traitlets import (Any, Bool, Dict, List, Unicode, TraitError, Integer, Float, Instance, default, validate ) +from traitlets.config.configurable import LoggingConfigurable from jupyter_server.utils import maybe_future, to_os_path, exists from jupyter_server._tz import utcnow, isoformat @@ -29,18 +30,148 @@ from jupyter_server.prometheus.metrics import KERNEL_CURRENTLY_RUNNING_TOTAL -class MappingKernelManager(MultiKernelManager): - """A KernelManager that handles - - File mapping - - HTTP error handling - - Kernel message filtering +class KernelInterface(LoggingConfigurable): + """A wrapper around one kernel, including manager, client and restarter. + + A KernelInterface instance persists across kernel restarts, whereas + manager and client objects are recreated. """ + def __init__(self, kernel_type, kernel_finder): + super(KernelInterface, self).__init__() + self.kernel_type = kernel_type + self.kernel_finder = kernel_finder + + self.connection_info, self.manager = kernel_finder.launch(kernel_type) + self.n_connections = 0 + self.execution_state = 'starting' + self.last_activity = utcnow() + + self.restarter = TornadoKernelRestarter(self.manager, kernel_type, + kernel_finder=self.kernel_finder) + self.restarter.add_callback(self._handle_kernel_died, 'died') + self.restarter.add_callback(self._handle_kernel_restarted, 'restarted') + self.restarter.start() + + self.buffer_for_key = None + # TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple + self.buffer = [] + + # Message handlers stored here don't have to be re-added if the kernel + # is restarted. + self.msg_handlers = [] + # A future that resolves when the client is connected + self.client_connected = self._connect_client() + self._client_connected_evt = Event() + + client = None + + @gen.coroutine + def _connect_client(self): + """Connect a client and wait for it to be ready.""" + self.client = IOLoopKernelClient(self.connection_info, self.manager) + yield self.client.wait_for_ready() + self.client.add_handler(self._msg_received, {'shell', 'iopub', 'stdin'}) + self._client_connected_evt.set() + + def _close_client(self): + if self.client is not None: + self._client_connected_evt.clear() + self.client_connected.cancel() + self.client.close() + self.client = None + + def client_ready(self): + """Return a future which resolves when the client is ready""" + if self.client is None: + return self._client_connected_evt.wait() + else: + return self.client_connected + + def _msg_received(self, msg, channel): + loop = IOLoop.current() + for handler in self.msg_handlers: + loop.add_callback(handler, msg, channel) + + @gen.coroutine + def shutdown(self, now=False): + self.restarter.stop() + + if now or (self.client is None): + self.manager.kill() + else: + yield self.client_connected + yield self.client.shutdown_or_terminate() + + self._close_client() + self.manager.cleanup() + + def interrupt(self): + self.manager.interrupt() + + @gen.coroutine + def _handle_kernel_died(self, data): + """Called when the auto-restarter notices the kernel has died""" + self._close_client() + + @gen.coroutine + def _handle_kernel_restarted(self, data): + """Called when the kernel has been restarted""" + self.manager = data['manager'] + self.connection_info = data['connection_info'] + self.client_connected = self._connect_client() + yield self.client_connected + + @gen.coroutine + def restart(self): + yield self.shutdown() + # The restart will trigger _handle_kernel_restarted() to connect a + # new client. + self.restarter.do_restart() + # Resume monitoring the kernel for auto-restart + self.restarter.start() + yield self._client_connected_evt.wait() + + def start_buffering(self, session_key): + # record the session key because only one session can buffer + self.buffer_for_key = session_key + + # forward any future messages to the internal buffer + self.client.add_handler(self._buffer_msg, {'shell', 'iopub', 'stdin'}) + + def _buffer_msg(self, msg, channel): + self.log.debug("Buffering msg on %s", channel) + self.buffer.append((msg, channel)) + + def get_buffer(self): + """Get the buffer for a given kernel, and stop buffering new messages + """ + buffer, key = self.buffer, self.buffer_for_key + self.buffer = [] + self.stop_buffering() + return buffer, key + + def stop_buffering(self): + """Stop buffering kernel messages + """ + self.client.remove_handler(self._buffer_msg) + + if self.buffer: + self.log.info("Discarding %s buffered messages for %s", + len(self.buffer), self.buffer_for_key) + self.buffer = [] + self.buffer_for_key = None + + +class MappingKernelManager(LoggingConfigurable): + """A KernelManager that handles notebook mapping and HTTP error handling""" @default('kernel_manager_class') def _default_kernel_manager_class(self): return "jupyter_client.ioloop.IOLoopKernelManager" - kernel_argv = List(Unicode()) + default_kernel_name = Unicode('pyimport/kernel', config=True, + help="The name of the default kernel to start" + ) root_dir = Unicode(config=True) @@ -120,16 +251,24 @@ def _default_kernel_buffers(self): last_kernel_activity = Instance(datetime, help="The last activity on any kernel, including shutting down a kernel") - def __init__(self, **kwargs): - super(MappingKernelManager, self).__init__(**kwargs) - self.last_kernel_activity = utcnow() - allowed_message_types = List(trait=Unicode(), config=True, help="""White list of allowed kernel message types. When the list is empty, all message types are allowed. """ ) + def __init__(self, kernel_finder, **kwargs): + super(MappingKernelManager, self).__init__(**kwargs) + self.last_kernel_activity = utcnow() + self._kernels = {} + self._kernels_starting = {} + self._restarters = {} + self.kernel_finder = kernel_finder + self.initialize_culler() + + def get_kernel(self, kernel_id): + return self._kernels[kernel_id] + #------------------------------------------------------------------------- # Methods for managing kernels and sessions #------------------------------------------------------------------------- @@ -137,7 +276,13 @@ def __init__(self, **kwargs): def _handle_kernel_died(self, kernel_id): """notice that a kernel died""" self.log.warning("Kernel %s died, removing from map.", kernel_id) - self.remove_kernel(kernel_id) + kernel = self._kernels.pop(kernel_id) + kernel.client.close() + kernel.manager.cleanup() + + KERNEL_CURRENTLY_RUNNING_TOTAL.labels( + type=kernel.kernel_type + ).inc() def cwd_for_path(self, path): """Turn API path into absolute OS path.""" @@ -149,7 +294,7 @@ def cwd_for_path(self, path): return os_path @gen.coroutine - def start_kernel(self, kernel_id=None, path=None, **kwargs): + def start_kernel(self, kernel_id=None, path=None, kernel_name=None, **kwargs): """Start a kernel for a session and return its kernel_id. Parameters @@ -166,38 +311,52 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs): an existing kernel is returned, but it may be checked in the future. """ if kernel_id is None: - if path is not None: - kwargs['cwd'] = self.cwd_for_path(path) - kernel_id = yield maybe_future( - super(MappingKernelManager, self).start_kernel(**kwargs) - ) - self._kernel_connections[kernel_id] = 0 - self.start_watching_activity(kernel_id) - self.log.info("Kernel started: %s" % kernel_id) - self.log.debug("Kernel args: %r" % kwargs) - # register callback for failed auto-restart - self.add_restart_callback(kernel_id, - lambda : self._handle_kernel_died(kernel_id), - 'dead', - ) - - # Increase the metric of number of kernels running - # for the relevant kernel type by 1 - KERNEL_CURRENTLY_RUNNING_TOTAL.labels( - type=self._kernels[kernel_id].kernel_name - ).inc() - + kernel_id = self.start_launching_kernel(path, kernel_name, **kwargs) + yield self.get_kernel(kernel_id).client_ready() else: self._check_kernel_id(kernel_id) self.log.info("Using existing kernel: %s" % kernel_id) - # Initialize culling if not already - if not self._initialized_culler: - self.initialize_culler() - # py2-compat raise gen.Return(kernel_id) + def start_launching_kernel(self, path=None, kernel_name=None, **kwargs): + """Launch a new kernel, return its kernel ID + + This is a synchronous method which starts the process of launching a + kernel. Retrieve the KernelInterface object and call ``.client_ready()`` + to get a future for the rest of the startup & connection. + """ + if path is not None: + kwargs['cwd'] = self.cwd_for_path(path) + + if kernel_name is None: + kernel_name = 'pyimport/kernel' + elif '/' not in kernel_name: + kernel_name = 'spec/' + kernel_name + + kernel = KernelInterface(kernel_name, self.kernel_finder) + kernel_id = kernel.manager.kernel_id + if kernel_id is None: # if provider didn't set a kernel_id, let's associate one to this kernel + kernel_id = str(uuid.uuid4()) + self._kernels[kernel_id] = kernel + + self.start_watching_activity(kernel_id) + self.log.info("Kernel started: %s" % kernel_id) + + kernel.restarter.add_callback( + lambda data: self._handle_kernel_died(kernel_id), + 'failed' + ) + + # Increase the metric of number of kernels running + # for the relevant kernel type by 1 + KERNEL_CURRENTLY_RUNNING_TOTAL.labels( + type=self._kernels[kernel_id].kernel_type + ).inc() + + return kernel_id + def start_buffering(self, kernel_id, session_key, channels): """Start buffering messages for a kernel @@ -220,142 +379,70 @@ def start_buffering(self, kernel_id, session_key, channels): self.log.info("Starting buffering for %s", session_key) self._check_kernel_id(kernel_id) + kernel = self._kernels[kernel_id] # clear previous buffering state - self.stop_buffering(kernel_id) - buffer_info = self._kernel_buffers[kernel_id] - # record the session key because only one session can buffer - buffer_info['session_key'] = session_key - # TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple - buffer_info['buffer'] = [] - buffer_info['channels'] = channels - - # forward any future messages to the internal buffer - def buffer_msg(channel, msg_parts): - self.log.debug("Buffering msg on %s:%s", kernel_id, channel) - buffer_info['buffer'].append((channel, msg_parts)) - - for channel, stream in channels.items(): - stream.on_recv(partial(buffer_msg, channel)) - - def get_buffer(self, kernel_id, session_key): - """Get the buffer for a given kernel - - Parameters - ---------- - kernel_id : str - The id of the kernel to stop buffering. - session_key: str, optional - The session_key, if any, that should get the buffer. - If the session_key matches the current buffered session_key, - the buffer will be returned. - """ - self.log.debug("Getting buffer for %s", kernel_id) - if kernel_id not in self._kernel_buffers: - return - - buffer_info = self._kernel_buffers[kernel_id] - if buffer_info['session_key'] == session_key: - # remove buffer - self._kernel_buffers.pop(kernel_id) - # only return buffer_info if it's a match - return buffer_info - else: - self.stop_buffering(kernel_id) - - def stop_buffering(self, kernel_id): - """Stop buffering kernel messages - - Parameters - ---------- - kernel_id : str - The id of the kernel to stop buffering. - """ - self.log.debug("Clearing buffer for %s", kernel_id) - self._check_kernel_id(kernel_id) + kernel.stop_buffering() + kernel.start_buffering(session_key) - if kernel_id not in self._kernel_buffers: - return - buffer_info = self._kernel_buffers.pop(kernel_id) - # close buffering streams - for stream in buffer_info['channels'].values(): - if not stream.closed(): - stream.on_recv(None) - stream.close() + @gen.coroutine + def _shutdown_all(self): + futures = [self.shutdown_kernel(kid) for kid in self.list_kernel_ids()] + yield gen.multi(futures) - msg_buffer = buffer_info['buffer'] - if msg_buffer: - self.log.info("Discarding %s buffered messages for %s", - len(msg_buffer), buffer_info['session_key']) + def shutdown_all(self): + # Blocking function to call when the notebook server is shutting down + loop = IOLoop.current() + loop.run_sync(self._shutdown_all) + @gen.coroutine def shutdown_kernel(self, kernel_id, now=False): """Shutdown a kernel by kernel_id""" self._check_kernel_id(kernel_id) - kernel = self._kernels[kernel_id] - if kernel._activity_stream: - kernel._activity_stream.close() - kernel._activity_stream = None - self.stop_buffering(kernel_id) - self._kernel_connections.pop(kernel_id, None) + kernel = self._kernels.pop(kernel_id) + self.log.info("Shutting down kernel %s", kernel_id) + yield kernel.shutdown(now=now) + self.last_kernel_activity = utcnow() # Decrease the metric of number of kernels # running for the relevant kernel type by 1 KERNEL_CURRENTLY_RUNNING_TOTAL.labels( - type=self._kernels[kernel_id].kernel_name + type=kernel.kernel_type ).dec() - return super(MappingKernelManager, self).shutdown_kernel(kernel_id, now=now) - @gen.coroutine def restart_kernel(self, kernel_id): - """Restart a kernel by kernel_id""" + """Restart a kernel by kernel_id + + The restarted kernel keeps the same ID and KernelInterface object. + """ self._check_kernel_id(kernel_id) - yield maybe_future(super(MappingKernelManager, self).restart_kernel(kernel_id)) kernel = self.get_kernel(kernel_id) - # return a Future that will resolve when the kernel has successfully restarted - channel = kernel.connect_shell() - future = Future() - - def finish(): - """Common cleanup when restart finishes/fails for any reason.""" - if not channel.closed(): - channel.close() - loop.remove_timeout(timeout) - kernel.remove_restart_callback(on_restart_failed, 'dead') - - def on_reply(msg): - self.log.debug("Kernel info reply received: %s", kernel_id) - finish() - if not future.done(): - future.set_result(msg) - - def on_timeout(): - self.log.warning("Timeout waiting for kernel_info_reply: %s", kernel_id) - finish() - if not future.done(): - future.set_exception(gen.TimeoutError("Timeout waiting for restart")) - - def on_restart_failed(): - self.log.warning("Restarting kernel failed: %s", kernel_id) - finish() - if not future.done(): - future.set_exception(RuntimeError("Restart failed")) - - kernel.add_restart_callback(on_restart_failed, 'dead') - kernel.session.send(channel, "kernel_info_request") - channel.on_recv(on_reply) - loop = IOLoop.current() - timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout) - raise gen.Return(future) + + try: + yield gen.with_timeout( + timedelta(seconds=self.kernel_info_timeout), + kernel.restart(), + ) + except gen.TimeoutError: + self.log.warning("Timeout waiting for kernel_info_reply: %s", + kernel_id) + self._kernels.pop(kernel_id) + # Decrease the metric of number of kernels + # running for the relevant kernel type by 1 + KERNEL_CURRENTLY_RUNNING_TOTAL.labels( + type=kernel.kernel_type + ).dec() + raise gen.TimeoutError("Timeout waiting for restart") def notify_connect(self, kernel_id): """Notice a new connection to a kernel""" - if kernel_id in self._kernel_connections: - self._kernel_connections[kernel_id] += 1 + if kernel_id in self._kernels: + self._kernels[kernel_id].n_connections += 1 def notify_disconnect(self, kernel_id): """Notice a disconnection from a kernel""" - if kernel_id in self._kernel_connections: - self._kernel_connections[kernel_id] -= 1 + if kernel_id in self._kernels: + self._kernels[kernel_id].n_connections -= 1 def kernel_model(self, kernel_id): """Return a JSON-safe dict representing a kernel @@ -367,22 +454,27 @@ def kernel_model(self, kernel_id): model = { "id":kernel_id, - "name": kernel.kernel_name, + "name": kernel.kernel_type, "last_activity": isoformat(kernel.last_activity), "execution_state": kernel.execution_state, - "connections": self._kernel_connections[kernel_id], + "connections": kernel.n_connections, } return model def list_kernels(self): - """Returns a list of kernel_id's of kernels running.""" + """Returns a list of models for kernels running.""" kernels = [] - kernel_ids = super(MappingKernelManager, self).list_kernel_ids() - for kernel_id in kernel_ids: + for kernel_id in self._kernels.keys(): model = self.kernel_model(kernel_id) kernels.append(model) return kernels + def list_kernel_ids(self): + return list(self._kernels.keys()) + + def __contains__(self, kernel_id): + return kernel_id in self._kernels + # override _check_kernel_id to raise 404 instead of KeyError def _check_kernel_id(self, kernel_id): """Check a that a kernel_id exists and raise 404 if not.""" @@ -398,30 +490,19 @@ def start_watching_activity(self, kernel_id): - record execution_state from status messages """ kernel = self._kernels[kernel_id] - # add busy/activity markers: - kernel.execution_state = 'starting' - kernel.last_activity = utcnow() - kernel._activity_stream = kernel.connect_iopub() - session = Session( - config=kernel.session.config, - key=kernel.session.key, - ) - def record_activity(msg_list): + def record_activity(msg, _channel): """Record an IOPub message arriving from a kernel""" self.last_kernel_activity = kernel.last_activity = utcnow() - idents, fed_msg_list = session.feed_identities(msg_list) - msg = session.deserialize(fed_msg_list) - - msg_type = msg['header']['msg_type'] + msg_type = msg.header['msg_type'] if msg_type == 'status': - kernel.execution_state = msg['content']['execution_state'] + kernel.execution_state = msg.content['execution_state'] self.log.debug("activity on %s: %s (%s)", kernel_id, msg_type, kernel.execution_state) else: self.log.debug("activity on %s: %s", kernel_id, msg_type) - kernel._activity_stream.on_recv(record_activity) + kernel.msg_handlers.append(record_activity) def initialize_culler(self): """Start idle culler if 'cull_idle_timeout' is greater than zero. @@ -460,7 +541,7 @@ def cull_kernels(self): def cull_kernel_if_idle(self, kernel_id): kernel = self._kernels[kernel_id] - self.log.debug("kernel_id=%s, kernel_name=%s, last_activity=%s", kernel_id, kernel.kernel_name, kernel.last_activity) + self.log.debug("kernel_id=%s, kernel_name=%s, last_activity=%s", kernel_id, kernel.kernel_type, kernel.last_activity) if kernel.last_activity is not None: dt_now = utcnow() dt_idle = dt_now - kernel.last_activity @@ -473,5 +554,5 @@ def cull_kernel_if_idle(self, kernel_id): if (is_idle_time and is_idle_execute and is_idle_connected): idle_duration = int(dt_idle.total_seconds()) self.log.warning("Culling '%s' kernel '%s' (%s) with %d connections due to %s seconds of inactivity.", - kernel.execution_state, kernel.kernel_name, kernel_id, connections, idle_duration) - self.shutdown_kernel(kernel_id) + kernel.execution_state, kernel.kernel_type, kernel_id, connections, idle_duration) + self.shutdown_kernel(kernel_id, now=True) diff --git a/jupyter_server/services/kernels/ws_serialize.py b/jupyter_server/services/kernels/ws_serialize.py new file mode 100644 index 0000000000..cec99f3426 --- /dev/null +++ b/jupyter_server/services/kernels/ws_serialize.py @@ -0,0 +1,119 @@ +"""Serialize & deserialize Jupyter messages to send over websockets. +""" +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from datetime import datetime, timezone +import json +import struct +import sys + +from dateutil.tz import tzutc + +from ipython_genutils.py3compat import cast_unicode +from jupyter_client.jsonutil import date_default, extract_dates + +def serialize_message(msg, channel): + """Serialize a message from the kernel, using JSON. + + msg is a jupyter_protocol Message object + + Returns a str of JSON if there are no binary buffers, or bytes which may + be sent as a binary websocket message. + """ + d = msg.make_dict() + if channel: + d['channel'] = channel + if msg.buffers: + buf = serialize_binary_message(d, msg.buffers) + return buf + else: + d['buffers'] = [] + smsg = json.dumps(d, default=date_default) + return cast_unicode(smsg) + +def serialize_binary_message(msg_dict, buffers): + """serialize a message as a binary blob + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + + The message serialized to bytes. + + """ + # don't modify msg or buffer list in-place + msg_dict = msg_dict.copy() + buffers = list(buffers) + + if sys.version_info < (3, 4): + buffers = [x.tobytes() for x in buffers] + bmsg = json.dumps(msg_dict, default=date_default).encode('utf8') + buffers.insert(0, bmsg) + nbufs = len(buffers) + offsets = [4 * (nbufs + 1)] + for buf in buffers[:-1]: + offsets.append(offsets[-1] + len(buf)) + offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets) + buffers.insert(0, offsets_buf) + return b''.join(buffers) + +def deserialize_message(msg): + """Deserialize a websocket message, return a dict. + + msg may be either bytes, for a binary websocket message including buffers, + or str, for a pure JSON message. + """ + if isinstance(msg, bytes): + msg = deserialize_binary_message(msg) + else: + msg = json.loads(msg) + + msg['header'] = convert_tz_utc(extract_dates(msg['header'])) + msg['parent_header'] = convert_tz_utc(extract_dates(msg['parent_header'])) + return msg + +def deserialize_binary_message(bmsg): + """deserialize a message from a binary blob + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + + message dictionary + """ + nbufs = struct.unpack('!i', bmsg[:4])[0] + offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)])) + offsets.append(None) + bufs = [] + for start, stop in zip(offsets[:-1], offsets[1:]): + bufs.append(bmsg[start:stop]) + msg = json.loads(bufs[0].decode('utf8')) + + msg['buffers'] = bufs[1:] + return msg + +def convert_tz_utc(obj): + """Replace dateutil tzutc objects with stdlib datetime utc objects""" + if isinstance(obj, dict): + new_obj = {} # don't clobber + for k,v in obj.items(): + new_obj[k] = convert_tz_utc(v) + obj = new_obj + elif isinstance(obj, (list, tuple)): + obj = [ convert_tz_utc(o) for o in obj ] + elif isinstance(obj, datetime) and isinstance(obj.tzinfo, tzutc): + obj = obj.replace(tzinfo=timezone.utc) + return obj diff --git a/jupyter_server/services/kernelspecs/handlers.py b/jupyter_server/services/kernelspecs/handlers.py index 6302e96dfa..4639f4d454 100644 --- a/jupyter_server/services/kernelspecs/handlers.py +++ b/jupyter_server/services/kernelspecs/handlers.py @@ -14,17 +14,17 @@ from tornado import web, gen from ...base.handlers import APIHandler -from ...utils import maybe_future, url_path_join, url_unescape - +from ...utils import maybe_future, url_path_join, url_unescape, quote def kernelspec_model(handler, name, spec_dict, resource_dir): """Load a KernelSpec by name and return the REST API model""" d = { 'name': name, - 'spec': spec_dict, + 'spec': spec_dict.copy(), 'resources': {} } + d['spec']['language'] = d['spec']['language_info']['name'] # Add resource files if they exist resource_dir = resource_dir @@ -33,7 +33,7 @@ def kernelspec_model(handler, name, spec_dict, resource_dir): d['resources'][resource] = url_path_join( handler.base_url, 'kernelspecs', - name, + quote(name, safe=''), resource ) for logo_file in glob.glob(pjoin(resource_dir, 'logo-*')): @@ -42,7 +42,7 @@ def kernelspec_model(handler, name, spec_dict, resource_dir): d['resources'][no_ext] = url_path_join( handler.base_url, 'kernelspecs', - name, + quote(name, safe=''), fname ) return d @@ -58,18 +58,17 @@ class MainKernelSpecHandler(APIHandler): @web.authenticated @gen.coroutine def get(self): - ksm = self.kernel_spec_manager + kf = self.kernel_finder km = self.kernel_manager model = {} model['default'] = km.default_kernel_name model['kernelspecs'] = specs = {} - kspecs = yield maybe_future(ksm.get_all_specs()) - for kernel_name, kernel_info in kspecs.items(): + for kernel_name, kernel_info in kf.find_kernels(): try: if is_kernelspec_model(kernel_info): d = kernel_info else: - d = kernelspec_model(self, kernel_name, kernel_info['spec'], kernel_info['resource_dir']) + d = kernelspec_model(self, kernel_name, kernel_info, kernel_info['resource_dir']) except Exception: self.log.error("Failed to load kernel spec: '%s'", kernel_name, exc_info=True) continue @@ -83,18 +82,18 @@ class KernelSpecHandler(APIHandler): @web.authenticated @gen.coroutine def get(self, kernel_name): - ksm = self.kernel_spec_manager - kernel_name = url_unescape(kernel_name) - try: - spec = yield maybe_future(ksm.get_kernel_spec(kernel_name)) - except KeyError: - raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) - if is_kernelspec_model(spec): - model = spec - else: - model = kernelspec_model(self, kernel_name, spec.to_dict(), spec.resource_dir) - self.set_header("Content-Type", 'application/json') - self.finish(json.dumps(model)) + kf = self.kernel_finder + # TODO: Do we actually want all kernel type names to be case-insensitive? + kernel_name = kernel_name.lower() + for name, info in kf.find_kernels(): + if name == kernel_name: + model = kernelspec_model(self, kernel_name, info, + info['resource_dir']) + self.set_header("Content-Type", 'application/json') + return self.finish(json.dumps(model)) + + raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) + # URL to handler mappings diff --git a/jupyter_server/services/sessions/handlers.py b/jupyter_server/services/sessions/handlers.py index 78072c5771..1ac9970acf 100644 --- a/jupyter_server/services/sessions/handlers.py +++ b/jupyter_server/services/sessions/handlers.py @@ -132,7 +132,7 @@ def patch(self, session_id): changes['kernel_id'] = kernel_id elif model['kernel'].get('name') is not None: kernel_name = model['kernel']['name'] - kernel_id = yield sm.start_kernel_for_session( + kernel_id = sm.start_kernel_for_session( session_id, kernel_name=kernel_name, name=before['name'], path=before['path'], type=before['type']) changes['kernel_id'] = kernel_id diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 50012bc2a1..77e6386ff4 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -89,6 +89,10 @@ def create_session(self, path=None, name=None, type=None, kernel_name=None, kern result = yield maybe_future( self.save_session(session_id, path=path, name=name, type=type, kernel_id=kernel_id) ) + + # Now wait for the kernel to finish starting. + yield self.kernel_manager.get_kernel(kernel_id).client_ready() + # py2-compat raise gen.Return(result) @@ -97,9 +101,8 @@ def start_kernel_for_session(self, session_id, path, name, type, kernel_name): """Start a new kernel for a given session.""" # allow contents manager to specify kernels cwd kernel_path = self.contents_manager.get_kernel_path(path=path) - kernel_id = yield maybe_future( - self.kernel_manager.start_kernel(path=kernel_path, kernel_name=kernel_name) - ) + kernel_id = yield maybe_future(self.kernel_manager.start_launching_kernel( + path=kernel_path, kernel_name=kernel_name,)) # py2-compat raise gen.Return(kernel_id) diff --git a/setup.py b/setup.py index 190ce5c1a6..e38fabe455 100755 --- a/setup.py +++ b/setup.py @@ -82,6 +82,8 @@ 'traitlets>=4.2.1', 'jupyter_core>=4.4.0', 'jupyter_client>=5.3.1', + 'jupyter_kernel_mgmt>=0.4', + 'jupyter_protocol', 'nbformat', 'nbconvert<6', 'ipykernel', # bless IPython kernel for now diff --git a/tests/services/kernels/test_api.py b/tests/services/kernels/test_api.py index cfdc6b80c1..296e5d4b1f 100644 --- a/tests/services/kernels/test_api.py +++ b/tests/services/kernels/test_api.py @@ -7,8 +7,6 @@ import urllib.parse from tornado.escape import url_escape -from jupyter_client.kernelspec import NATIVE_KERNEL_NAME - from jupyter_server.utils import url_path_join from ...conftest import expected_http_error @@ -72,7 +70,7 @@ async def test_main_kernel_handler(fetch): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kernel1 = json.loads(r.body.decode()) @@ -104,7 +102,7 @@ async def test_main_kernel_handler(fetch): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kernel2 = json.loads(r.body.decode()) @@ -145,7 +143,7 @@ async def test_kernel_handler(fetch): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kernel_id = json.loads(r.body.decode())['id'] @@ -200,7 +198,7 @@ async def test_connection(fetch, ws_fetch, http_port, auth_header): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kid = json.loads(r.body.decode())['id'] diff --git a/tests/services/kernelspecs/test_api.py b/tests/services/kernelspecs/test_api.py index 0d3a2ba387..8026b070fb 100644 --- a/tests/services/kernelspecs/test_api.py +++ b/tests/services/kernelspecs/test_api.py @@ -3,8 +3,6 @@ import tornado -from jupyter_client.kernelspec import NATIVE_KERNEL_NAME - from ...conftest import expected_http_error @@ -41,7 +39,7 @@ async def test_list_kernelspecs_bad(fetch, kernelspecs, data_dir): ) model = json.loads(r.body.decode()) assert isinstance(model, dict) - assert model['default'] == NATIVE_KERNEL_NAME + assert model['default'] == 'pyimport/kernel' specs = model['kernelspecs'] assert isinstance(specs, dict) assert len(specs) > 2 @@ -54,16 +52,16 @@ async def test_list_kernelspecs(fetch, kernelspecs): ) model = json.loads(r.body.decode()) assert isinstance(model, dict) - assert model['default'] == NATIVE_KERNEL_NAME + assert model['default'] == 'pyimport/kernel' specs = model['kernelspecs'] assert isinstance(specs, dict) assert len(specs) > 2 def is_sample_kernelspec(s): - return s['name'] == 'sample' and s['spec']['display_name'] == 'Test kernel' + return s['name'] == 'spec/sample' and s['spec']['display_name'] == 'Test kernel' def is_default_kernelspec(s): - return s['name'] == NATIVE_KERNEL_NAME and s['spec']['display_name'].startswith("Python") + return s['name'] == 'pyimport/kernel' and s['spec']['display_name'].startswith("Python") assert any(is_sample_kernelspec(s) for s in specs.values()), specs assert any(is_default_kernelspec(s) for s in specs.values()), specs @@ -71,11 +69,11 @@ def is_default_kernelspec(s): async def test_get_kernelspecs(fetch, kernelspecs): r = await fetch( - 'api', 'kernelspecs', 'Sample', + 'api', 'kernelspecs', 'spec/Sample', method='GET' ) model = json.loads(r.body.decode()) - assert model['name'].lower() == 'sample' + assert model['name'].lower() == 'spec/sample' assert isinstance(model['spec'], dict) assert model['spec']['display_name'] == 'Test kernel' assert isinstance(model['resources'], dict) @@ -83,17 +81,17 @@ async def test_get_kernelspecs(fetch, kernelspecs): async def test_get_kernelspec_spaces(fetch, kernelspecs): r = await fetch( - 'api', 'kernelspecs', 'sample%202', + 'api', 'kernelspecs', 'spec/sample%202', method='GET' ) model = json.loads(r.body.decode()) - assert model['name'].lower() == 'sample 2' + assert model['name'].lower() == 'spec/sample 2' async def test_get_nonexistant_kernelspec(fetch, kernelspecs): with pytest.raises(tornado.httpclient.HTTPClientError) as e: await fetch( - 'api', 'kernelspecs', 'nonexistant', + 'api', 'kernelspecs', 'spec/nonexistant', method='GET' ) assert expected_http_error(e, 404) @@ -101,7 +99,7 @@ async def test_get_nonexistant_kernelspec(fetch, kernelspecs): async def test_get_kernel_resource_file(fetch, kernelspecs): r = await fetch( - 'kernelspecs', 'sAmple', 'resource.txt', + 'kernelspecs', 'spec/sAmple', 'resource.txt', method='GET' ) res = r.body.decode('utf-8') @@ -111,14 +109,14 @@ async def test_get_kernel_resource_file(fetch, kernelspecs): async def test_get_nonexistant_resource(fetch, kernelspecs): with pytest.raises(tornado.httpclient.HTTPClientError) as e: await fetch( - 'kernelspecs', 'nonexistant', 'resource.txt', + 'kernelspecs', 'spec/nonexistant', 'resource.txt', method='GET' ) assert expected_http_error(e, 404) with pytest.raises(tornado.httpclient.HTTPClientError) as e: await fetch( - 'kernelspecs', 'sample', 'nonexistant.txt', + 'kernelspecs', 'spec/sample', 'nonexistant.txt', method='GET' ) - assert expected_http_error(e, 404) \ No newline at end of file + assert expected_http_error(e, 404) diff --git a/tests/services/sessions/test_api.py b/tests/services/sessions/test_api.py index 20d90ded26..965e3e79dd 100644 --- a/tests/services/sessions/test_api.py +++ b/tests/services/sessions/test_api.py @@ -5,7 +5,7 @@ def get_session_model( path, type='notebook', - kernel_name='python', + kernel_name='pyimport/kernel', kernel_id=None ): return { @@ -56,6 +56,12 @@ async def test_create(fetch): method='GET' ) got = json.loads(r.body.decode()) + + # Kernel state may have changed from 'starting' to 'idle' + # so don't assert state. + del got['kernel']['execution_state'] + del newsession['kernel']['execution_state'] + assert got == new_session diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 07947dc549..0dac0c8fac 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -2,24 +2,40 @@ import os -from jupyter_client.session import Session -from jupyter_server.base.zmqhandlers import ( - serialize_binary_message, - deserialize_binary_message, +from jupyter_protocol.messages import Message +from jupyter_server.services.kernels.ws_serialize import ( + serialize_message, deserialize_message ) +def test_serialize_json(): + msg = Message.from_type('data_pub', content={'a': 'b'}) + smsg = serialize_message(msg, 'iopub') + assert isinstance(smsg, str) + def test_serialize_binary(): - s = Session() - msg = s.msg('data_pub', content={'a': 'b'}) - msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ] - bmsg = serialize_binary_message(msg) + msg = Message.from_type('data_pub', content={'a': 'b'}) + msg.buffers = [memoryview(os.urandom(3)) for i in range(3)] + bmsg = serialize_message(msg, 'iopub') assert isinstance(bmsg, bytes) +def test_deserialize_json(): + msg = Message.from_type('data_pub', content={'a': 'b'}) + smsg = serialize_message(msg, 'iopub') + print("Serialised: ", smsg) + msg_dict = msg.make_dict() + msg_dict['channel'] = 'iopub' + msg_dict['buffers'] = [] + + msg2 = deserialize_message(smsg) + assert msg2 == msg_dict def test_deserialize_binary(): - s = Session() - msg = s.msg('data_pub', content={'a': 'b'}) - msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ] - bmsg = serialize_binary_message(msg) - msg2 = deserialize_binary_message(bmsg) - assert msg2 == msg \ No newline at end of file + msg = Message.from_type('data_pub', content={'a': 'b'}) + msg.buffers = [memoryview(os.urandom(3)) for i in range(3)] + bmsg = serialize_message(msg, 'iopub') + msg_dict = msg.make_dict() + msg_dict['channel'] = 'iopub' + msg_dict['buffers'] = msg.buffers + + msg2 = deserialize_message(bmsg) + assert msg2 == msg_dict