From 7778f83ff1bc07ef70e89fd6bec4a388886646bf Mon Sep 17 00:00:00 2001 From: Kevin Bates Date: Thu, 25 Oct 2018 17:17:18 +0100 Subject: [PATCH] Adopt jupyter_kernel_mgmt as kernel management framework - Initial work towards using jupyter_kernel_mgmt in jupyter server - Cherry-pick Notebook PR https://github.com/jupyter/notebook/pull/4837 - Remove jupyter_client, use KernelFinder in Gateway - Minor refactor to preserve MappingKernelManager behavior - Session is no longer a Configurable, removed from classes list. Also removed some of the Gateway classes that shouldn't be there either. - Get gateway functionality working - Fix SessionHandler call to start kernel - Initial support for async kernel management - Plumb launch parameters - Adjust kernel management with recent async updates - Don't get child_watcher if on Windows - Fix gateway kernelspec tests to updated JKM call Also fixed windows testing by increasing delay during cleanup of session and kernel tests - otherwise the temp directory could not be cleaned up, resulting in downstream side-affects. - Require JKM >= 0.5, bump core min release - Remove install of special patch branches for jkm - Merge pytest PR, encode/decode kernel type - Merge/convert missing sessions tests for JKM - Add session and kernel equality methods Co-authored-by: Thomas Kluyver --- appveyor.yml | 5 +- jupyter_server/base/handlers.py | 10 +- jupyter_server/base/zmqhandlers.py | 174 +------ jupyter_server/gateway/handlers.py | 8 +- jupyter_server/gateway/managers.py | 41 +- jupyter_server/kernelspecs/handlers.py | 21 +- jupyter_server/serverapp.py | 61 +-- jupyter_server/services/contents/handlers.py | 2 +- jupyter_server/services/kernels/handlers.py | 293 +++++------ .../services/kernels/kernelmanager.py | 456 +++++++++++------- .../services/kernels/ws_serialize.py | 124 +++++ .../services/kernelspecs/handlers.py | 45 +- jupyter_server/services/sessions/handlers.py | 8 +- .../services/sessions/sessionmanager.py | 6 +- setup.py | 5 +- tests/services/kernels/test_api.py | 10 +- tests/services/kernelspecs/test_api.py | 28 +- tests/services/sessions/test_api.py | 79 ++- tests/services/sessions/test_manager.py | 22 +- tests/test_gateway.py | 2 +- tests/test_serialize.py | 44 +- 21 files changed, 747 insertions(+), 697 deletions(-) create mode 100644 jupyter_server/services/kernels/ws_serialize.py diff --git a/appveyor.yml b/appveyor.yml index 79c6e6dddc..51e5d5f599 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,4 +1,5 @@ # miniconda bootstrap from conda-forge recipe + matrix: fast_finish: true @@ -24,12 +25,12 @@ platform: build: off install: - - cmd: call %CONDA_INSTALL_LOCN%\Scripts\activate.bat + - cmd: call %CONDA_INSTALL_LOCN%\\Scripts\\activate.bat - cmd: set CONDA_PY=%CONDA_PY% - cmd: set CONDA_PY_SPEC=%CONDA_PY_SPEC% - cmd: conda config --set show_channel_urls true - cmd: conda config --add channels conda-forge - - cmd: conda update --yes --quiet conda + - cmd: conda update -y -q conda - cmd: conda info -a - cmd: conda create -y -q -n test-env-%CONDA_PY% python=%CONDA_PY_SPEC% pyzmq tornado jupyter_client nbformat nbconvert ipykernel pip nose - cmd: conda activate test-env-%CONDA_PY% diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 1d9bbb3cd5..a49271ed17 100755 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -245,7 +245,11 @@ def contents_js_source(self): #--------------------------------------------------------------- # Manager objects #--------------------------------------------------------------- - + + @property + def kernel_finder(self): + return self.settings['kernel_finder'] + @property def kernel_manager(self): return self.settings['kernel_manager'] @@ -261,10 +265,6 @@ def session_manager(self): @property def terminal_manager(self): return self.settings['terminal_manager'] - - @property - def kernel_spec_manager(self): - return self.settings['kernel_spec_manager'] @property def config_manager(self): diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index a7dde2cbe5..44d37ae5ac 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -4,81 +4,11 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import json -import struct -import sys -import tornado - from urllib.parse import urlparse -from tornado import gen, ioloop, web -from tornado.websocket import WebSocketHandler - -from jupyter_client.session import Session -from jupyter_client.jsonutil import date_default, extract_dates -from ipython_genutils.py3compat import cast_unicode - -from .handlers import JupyterHandler -from jupyter_server.utils import maybe_future - - -def serialize_binary_message(msg): - """serialize a message as a binary blob - - Header: - - 4 bytes: number of msg parts (nbufs) as 32b int - 4 * nbufs bytes: offset for each buffer as integer as 32b int - - Offsets are from the start of the buffer, including the header. - - Returns - ------- - - The message serialized to bytes. - - """ - # don't modify msg or buffer list in-place - msg = msg.copy() - buffers = list(msg.pop('buffers')) - if sys.version_info < (3, 4): - buffers = [x.tobytes() for x in buffers] - bmsg = json.dumps(msg, default=date_default).encode('utf8') - buffers.insert(0, bmsg) - nbufs = len(buffers) - offsets = [4 * (nbufs + 1)] - for buf in buffers[:-1]: - offsets.append(offsets[-1] + len(buf)) - offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets) - buffers.insert(0, offsets_buf) - return b''.join(buffers) - -def deserialize_binary_message(bmsg): - """deserialize a message from a binary blog - - Header: - - 4 bytes: number of msg parts (nbufs) as 32b int - 4 * nbufs bytes: offset for each buffer as integer as 32b int - - Offsets are from the start of the buffer, including the header. - - Returns - ------- - - message dictionary - """ - nbufs = struct.unpack('!i', bmsg[:4])[0] - offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)])) - offsets.append(None) - bufs = [] - for start, stop in zip(offsets[:-1], offsets[1:]): - bufs.append(bmsg[start:stop]) - msg = json.loads(bufs[0].decode('utf8')) - msg['header'] = extract_dates(msg['header']) - msg['parent_header'] = extract_dates(msg['parent_header']) - msg['buffers'] = bufs[1:] - return msg +from tornado import ioloop +from tornado.iostream import StreamClosedError +from tornado.websocket import WebSocketHandler, WebSocketClosedError # ping interval for keeping websockets alive (30 seconds) WS_PING_INTERVAL = 30000 @@ -188,101 +118,3 @@ def send_ping(self): def on_pong(self, data): self.last_pong = ioloop.IOLoop.current().time() - -class ZMQStreamHandler(WebSocketMixin, WebSocketHandler): - - if tornado.version_info < (4,1): - """Backport send_error from tornado 4.1 to 4.0""" - def send_error(self, *args, **kwargs): - if self.stream is None: - super(WebSocketHandler, self).send_error(*args, **kwargs) - else: - # If we get an uncaught exception during the handshake, - # we have no choice but to abruptly close the connection. - # TODO: for uncaught exceptions after the handshake, - # we can close the connection more gracefully. - self.stream.close() - - - def _reserialize_reply(self, msg_or_list, channel=None): - """Reserialize a reply message using JSON. - - msg_or_list can be an already-deserialized msg dict or the zmq buffer list. - If it is the zmq list, it will be deserialized with self.session. - - This takes the msg list from the ZMQ socket and serializes the result for the websocket. - This method should be used by self._on_zmq_reply to build messages that can - be sent back to the browser. - - """ - if isinstance(msg_or_list, dict): - # already unpacked - msg = msg_or_list - else: - idents, msg_list = self.session.feed_identities(msg_or_list) - msg = self.session.deserialize(msg_list) - if channel: - msg['channel'] = channel - if msg['buffers']: - buf = serialize_binary_message(msg) - return buf - else: - smsg = json.dumps(msg, default=date_default) - return cast_unicode(smsg) - - def _on_zmq_reply(self, stream, msg_list): - # Sometimes this gets triggered when the on_close method is scheduled in the - # eventloop but hasn't been called. - if self.ws_connection is None or stream.closed(): - self.log.warning("zmq message arrived on closed channel") - self.close() - return - channel = getattr(stream, 'channel', None) - try: - msg = self._reserialize_reply(msg_list, channel=channel) - except Exception: - self.log.critical("Malformed message: %r" % msg_list, exc_info=True) - else: - self.write_message(msg, binary=isinstance(msg, bytes)) - - -class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler): - - def set_default_headers(self): - """Undo the set_default_headers in JupyterHandler - - which doesn't make sense for websockets - """ - pass - - def pre_get(self): - """Run before finishing the GET request - - Extend this method to add logic that should fire before - the websocket finishes completing. - """ - # authenticate the request before opening the websocket - if self.get_current_user() is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) - - if self.get_argument('session_id', False): - self.session.session = cast_unicode(self.get_argument('session_id')) - else: - self.log.warning("No session ID specified") - - @gen.coroutine - def get(self, *args, **kwargs): - # pre_get can be a coroutine in subclasses - # assign and yield in two step to avoid tornado 3 issues - res = self.pre_get() - yield maybe_future(res) - res = super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs) - yield maybe_future(res) - - def initialize(self): - self.log.debug("Initializing websocket connection %s", self.request.path) - self.session = Session(config=self.config) - - def get_compression_options(self): - return self.settings.get('websocket_compression_options', None) diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py index a1b76b5536..644b94f4e6 100644 --- a/jupyter_server/gateway/handlers.py +++ b/jupyter_server/gateway/handlers.py @@ -16,7 +16,7 @@ from tornado.escape import url_escape, json_decode, utf8 from ipython_genutils.py3compat import cast_unicode -from jupyter_client.session import Session +from jupyter_protocol.session import Session, new_id_bytes from traitlets.config.configurable import LoggingConfigurable from .managers import GatewayClient @@ -58,7 +58,7 @@ def authenticate(self): def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) - self.session = Session(config=self.config) + self.session = Session(key=new_id_bytes()) self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url) @gen.coroutine @@ -231,8 +231,8 @@ class GatewayResourceHandler(APIHandler): @web.authenticated @gen.coroutine def get(self, kernel_name, path, include_body=True): - ksm = self.kernel_spec_manager - kernel_spec_res = yield ksm.get_kernel_spec_resource(kernel_name, path) + kf = self.kernel_finder + kernel_spec_res = yield kf.get_kernel_spec_resource(kernel_name, path) if kernel_spec_res is None: self.log.warning("Kernelspec resource '{}' for '{}' not found. Gateway may not support" " resource serving.".format(path, kernel_name)) diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index e2f74ce43c..d507910a57 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -1,20 +1,21 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - +import asyncio import os import json +import logging +from jupyter_kernel_mgmt.discovery import KernelFinder from socket import gaierror from tornado import gen, web +from tornado.concurrent import Future from tornado.escape import json_encode, json_decode, url_escape from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError from ..services.kernels.kernelmanager import MappingKernelManager from ..services.sessions.sessionmanager import SessionManager -from jupyter_client.kernelspec import KernelSpecManager -from ..utils import url_path_join - +from ..utils import url_path_join, maybe_future from traitlets import Instance, Unicode, Float, Bool, default, validate, TraitError from traitlets.config import SingletonConfigurable @@ -496,14 +497,16 @@ def shutdown_all(self, now=False): self.remove_kernel(kernel_id) -class GatewayKernelSpecManager(KernelSpecManager): - - def __init__(self, **kwargs): - super(GatewayKernelSpecManager, self).__init__(**kwargs) +class GatewayKernelFinder(KernelFinder): + def __init__(self, parent, providers=[]): + super(GatewayKernelFinder, self).__init__(providers=providers) self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint) self.base_resource_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernelspecs_resource_endpoint) + # Because KernelFinder is not a taitlet/Configurable, we need to simulate a configurable + self.parent = parent + self.log = logging.getLogger(__name__) def _get_kernelspecs_endpoint_url(self, kernel_name=None): """Builds a url for the kernels endpoint @@ -517,9 +520,16 @@ def _get_kernelspecs_endpoint_url(self, kernel_name=None): return self.base_endpoint - @gen.coroutine - def get_all_specs(self): - fetched_kspecs = yield self.list_kernel_specs() + @asyncio.coroutine + def find_kernels(self): + remote_kspecs = yield from self.get_all_specs() + + # convert to list of 2 tuples + for kernel_type, attributes in remote_kspecs.items(): + yield kernel_type, attributes + + async def get_all_specs(self): + fetched_kspecs = await self.list_kernel_specs() # get the default kernel name and compare to that of this server. # If different log a warning and reset the default. However, the @@ -535,16 +545,15 @@ def get_all_specs(self): km.default_kernel_name = remote_default_kernel_name remote_kspecs = fetched_kspecs.get('kernelspecs') - raise gen.Return(remote_kspecs) + return remote_kspecs - @gen.coroutine - def list_kernel_specs(self): + async def list_kernel_specs(self): """Get a list of kernel specs.""" kernel_spec_url = self._get_kernelspecs_endpoint_url() self.log.debug("Request list kernel specs at: %s", kernel_spec_url) - response = yield gateway_request(kernel_spec_url, method='GET') + response = await gateway_request(kernel_spec_url, method='GET') kernel_specs = json_decode(response.body) - raise gen.Return(kernel_specs) + return kernel_specs @gen.coroutine def get_kernel_spec(self, kernel_name, **kwargs): diff --git a/jupyter_server/kernelspecs/handlers.py b/jupyter_server/kernelspecs/handlers.py index 228694b8a5..3e7bc7d491 100644 --- a/jupyter_server/kernelspecs/handlers.py +++ b/jupyter_server/kernelspecs/handlers.py @@ -1,4 +1,5 @@ from tornado import web +from urllib.parse import unquote from ..base.handlers import JupyterHandler from ..services.kernelspecs.handlers import kernel_name_regex @@ -11,19 +12,23 @@ def initialize(self): @web.authenticated def get(self, kernel_name, path, include_body=True): - ksm = self.kernel_spec_manager - try: - self.root = ksm.get_kernel_spec(kernel_name).resource_dir - except KeyError: - raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) - self.log.debug("Serving kernel resource from: %s", self.root) - return web.StaticFileHandler.get(self, path, include_body=include_body) + kf = self.kernel_finder + # TODO: Do we actually want all kernel type names to be case-insensitive? + kernel_name = unquote(kernel_name.lower()) + for name, info in kf.find_kernels(): + if name == kernel_name: + self.root = info['resource_dir'] + self.log.debug("Serving kernel resource from: %s", self.root) + return web.StaticFileHandler.get(self, path, + include_body=include_body) + + raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) @web.authenticated def head(self, kernel_name, path): return self.get(kernel_name, path, include_body=False) + default_handlers = [ (r"/kernelspecs/%s/(?P.*)" % kernel_name_regex, KernelSpecResourceHandler), ] - diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 9ff6d280ca..4b744ee38b 100755 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -75,7 +75,7 @@ from .services.contents.filemanager import FileContentsManager from .services.contents.largefilemanager import LargeFileManager from .services.sessions.sessionmanager import SessionManager -from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient +from .gateway.managers import GatewayKernelFinder, GatewayClient from .auth.login import LoginHandler from .auth.logout import LogoutHandler @@ -87,9 +87,7 @@ JupyterApp, base_flags, base_aliases, ) from jupyter_core.paths import jupyter_config_path -from jupyter_client import KernelManager -from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME -from jupyter_client.session import Session +from jupyter_kernel_mgmt.discovery import KernelFinder from nbformat.sign import NotebookNotary from traitlets import ( Any, Dict, Unicode, Integer, List, Bool, Bytes, Instance, @@ -161,13 +159,13 @@ def load_handlers(name): class ServerWebApplication(web.Application): def __init__(self, jupyter_app, default_services, kernel_manager, contents_manager, - session_manager, kernel_spec_manager, + session_manager, kernel_finder, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options): settings = self.init_settings( jupyter_app, kernel_manager, contents_manager, - session_manager, kernel_spec_manager, config_manager, + session_manager, kernel_finder, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options) handlers = self.init_handlers(default_services, settings) @@ -175,7 +173,7 @@ def __init__(self, jupyter_app, default_services, kernel_manager, contents_manag super(ServerWebApplication, self).__init__(handlers, **settings) def init_settings(self, jupyter_app, kernel_manager, contents_manager, - session_manager, kernel_spec_manager, + session_manager, kernel_finder, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options=None): @@ -250,10 +248,10 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, local_hostnames=jupyter_app.local_hostnames, # managers + kernel_finder=kernel_finder, kernel_manager=kernel_manager, contents_manager=contents_manager, session_manager=session_manager, - kernel_spec_manager=kernel_spec_manager, config_manager=config_manager, # handlers @@ -531,7 +529,6 @@ def start(self): 'ip': 'ServerApp.ip', 'port': 'ServerApp.port', 'port-retries': 'ServerApp.port_retries', - 'transport': 'KernelManager.transport', 'keyfile': 'ServerApp.keyfile', 'certfile': 'ServerApp.certfile', 'client-ca': 'ServerApp.client_ca', @@ -557,9 +554,7 @@ class ServerApp(JupyterApp): flags = flags classes = [ - KernelManager, Session, MappingKernelManager, KernelSpecManager, - ContentsManager, FileContentsManager, NotebookNotary, - GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient, + MappingKernelManager, ContentsManager, FileContentsManager, NotebookNotary, GatewayClient, ] flags = Dict(flags) aliases = Dict(aliases) @@ -1036,6 +1031,12 @@ def template_file_path(self): (shutdown the Jupyter server).""" ) + kernel_providers = List(config=True, + help=_('A list of kernel provider instances. ' + 'If not specified, all installed kernel providers are found ' + 'using entry points.') + ) + contents_manager_class = Type( default_value=LargeFileManager, klass=ContentsManager, @@ -1061,20 +1062,6 @@ def template_file_path(self): help=_('The config manager class to use') ) - kernel_spec_manager = Instance(KernelSpecManager, allow_none=True) - - kernel_spec_manager_class = Type( - default_value=KernelSpecManager, - config=True, - help=""" - The kernel spec manager class to use. Should be a subclass - of `jupyter_client.kernelspec.KernelSpecManager`. - - The Api of KernelSpecManager is provisional and might change - without warning between this version of Jupyter and the next stable one. - """ - ) - login_handler_class = Type( default_value=LoginHandler, klass=web.RequestHandler, @@ -1107,7 +1094,7 @@ def _default_info_file(self): def _default_browser_open_file(self): basename = "jpserver-%s-open.html" % os.getpid() return os.path.join(self.runtime_dir, basename) - + pylab = Unicode('disabled', config=True, help=_(""" DISABLED: use %pylab or %matplotlib in the notebook to enable matplotlib. @@ -1236,6 +1223,7 @@ def parse_command_line(self, argv=None): def init_configurables(self): + # If gateway server is configured, replace appropriate managers to perform redirection. To make # this determination, instantiate the GatewayClient config singleton. self.gateway_config = GatewayClient.instance(parent=self) @@ -1243,16 +1231,17 @@ def init_configurables(self): if self.gateway_config.gateway_enabled: self.kernel_manager_class = 'jupyter_server.gateway.managers.GatewayKernelManager' self.session_manager_class = 'jupyter_server.gateway.managers.GatewaySessionManager' - self.kernel_spec_manager_class = 'jupyter_server.gateway.managers.GatewayKernelSpecManager' + self.kernel_finder = GatewayKernelFinder(parent=self) # no providers here, always go remote + else: + if self.kernel_providers: + self.kernel_finder = KernelFinder(self.kernel_providers) + else: + self.kernel_finder = KernelFinder.from_entrypoints() - self.kernel_spec_manager = self.kernel_spec_manager_class( - parent=self, - ) self.kernel_manager = self.kernel_manager_class( parent=self, log=self.log, - connection_dir=self.runtime_dir, - kernel_spec_manager=self.kernel_spec_manager, + kernel_finder=self.kernel_finder, ) self.contents_manager = self.contents_manager_class( parent=self, @@ -1307,7 +1296,7 @@ def init_webapp(self): self.web_app = ServerWebApplication( self, self.default_services, self.kernel_manager, self.contents_manager, - self.session_manager, self.kernel_spec_manager, + self.session_manager, self.kernel_finder, self.config_manager, self.extra_services, self.log, self.base_url, self.default_url, self.tornado_settings, self.jinja_environment_options, @@ -1496,7 +1485,7 @@ def init_server_extensions(self): Import the module, then call the load_jupyter_server_extension function, if one exists. - + The extension API is experimental, and may change in future releases. """ # Initialize extensions @@ -1680,7 +1669,7 @@ def cleanup_kernels(self): """Shutdown all kernels. The kernels will shutdown themselves when this process no longer exists, - but explicit shutdown allows the KernelManagers to cleanup the connection files. + but explicit shutdown allows the Kernel Providers to cleanup the connection files. """ n_kernels = len(self.kernel_manager.list_kernel_ids()) kernel_msg = trans.ngettext('Shutting down %d kernel', 'Shutting down %d kernels', n_kernels) diff --git a/jupyter_server/services/contents/handlers.py b/jupyter_server/services/contents/handlers.py index 943ba8638d..425ec3d29b 100644 --- a/jupyter_server/services/contents/handlers.py +++ b/jupyter_server/services/contents/handlers.py @@ -11,7 +11,7 @@ from tornado import gen, web from jupyter_server.utils import url_path_join, url_escape, maybe_future -from jupyter_client.jsonutil import date_default +from jupyter_protocol.jsonutil import date_default from jupyter_server.base.handlers import ( JupyterHandler, APIHandler, path_regex, diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index 358798408c..8fbafadb24 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -10,18 +10,20 @@ import logging from textwrap import dedent +import tornado from tornado import gen, web from tornado.concurrent import Future from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler -from jupyter_client import protocol_version as client_protocol_version -from jupyter_client.jsonutil import date_default +from jupyter_protocol.jsonutil import date_default +from jupyter_protocol.messages import Message from ipython_genutils.py3compat import cast_unicode from jupyter_server.utils import url_path_join, url_escape, maybe_future -from ...base.handlers import APIHandler -from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message - +from ...base.handlers import APIHandler, JupyterHandler +from ...base.zmqhandlers import WebSocketMixin +from .ws_serialize import serialize_message, deserialize_message class MainKernelHandler(APIHandler): @@ -45,7 +47,9 @@ def post(self): else: model.setdefault('name', km.default_kernel_name) - kernel_id = yield maybe_future(km.start_kernel(kernel_name=model['name'])) + launch_params = model.get('launch_params', {}) + + kernel_id = yield maybe_future(km.start_kernel(kernel_name=model['name'], launch_params=launch_params)) model = yield maybe_future(km.kernel_model(kernel_id)) location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id)) self.set_header('Location', location) @@ -58,6 +62,7 @@ class KernelHandler(APIHandler): @web.authenticated def get(self, kernel_id): km = self.kernel_manager + km._check_kernel_id(kernel_id) model = km.kernel_model(kernel_id) self.finish(json.dumps(model, default=date_default)) @@ -77,12 +82,21 @@ class KernelActionHandler(APIHandler): def post(self, kernel_id, action): km = self.kernel_manager if action == 'interrupt': - km.interrupt_kernel(kernel_id) - self.set_status(204) - if action == 'restart': + try: + yield maybe_future(km.interrupt_kernel(kernel_id)) + except web.HTTPError: + raise + except Exception as e: + self.log.error("Exception interrupting kernel", exc_info=True) + self.set_status(500) + else: + self.set_status(204) + if action == 'restart': try: yield maybe_future(km.restart_kernel(kernel_id)) + except web.HTTPError: + raise except Exception as e: self.log.error("Exception restarting kernel", exc_info=True) self.set_status(500) @@ -92,7 +106,7 @@ def post(self, kernel_id, action): self.finish() -class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): +class ZMQChannelsHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): '''There is one ZMQChannelsHandler per running kernel and it oversees all the sessions. ''' @@ -119,85 +133,29 @@ def iopub_data_rate_limit(self): def rate_limit_window(self): return self.settings.get('rate_limit_window', 1.0) + @property + def kernel_client(self): + return self.kernel_manager.get_kernel(self.kernel_id).client + def __repr__(self): return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized')) - def create_stream(self): - km = self.kernel_manager - identity = self.session.bsession - for channel in ('shell', 'control', 'iopub', 'stdin'): - meth = getattr(km, 'connect_' + channel) - self.channels[channel] = stream = meth(self.kernel_id, identity=identity) - stream.channel = channel - - def request_kernel_info(self): - """send a request for kernel_info""" - km = self.kernel_manager - kernel = km.get_kernel(self.kernel_id) - try: - # check for previous request - future = kernel._kernel_info_future - except AttributeError: - self.log.debug("Requesting kernel info from %s", self.kernel_id) - # Create a kernel_info channel to query the kernel protocol version. - # This channel will be closed after the kernel_info reply is received. - if self.kernel_info_channel is None: - self.kernel_info_channel = km.connect_shell(self.kernel_id) - self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) - self.session.send(self.kernel_info_channel, "kernel_info_request") - # store the future on the kernel, so only one request is sent - kernel._kernel_info_future = self._kernel_info_future - else: - if not future.done(): - self.log.debug("Waiting for pending kernel_info request") - future.add_done_callback(lambda f: self._finish_kernel_info(f.result())) - return self._kernel_info_future - - def _handle_kernel_info_reply(self, msg): - """process the kernel_info_reply + def set_default_headers(self): + """Undo the set_default_headers in IPythonHandler - enabling msg spec adaptation, if necessary - """ - idents,msg = self.session.feed_identities(msg) - try: - msg = self.session.deserialize(msg) - except: - self.log.error("Bad kernel_info reply", exc_info=True) - self._kernel_info_future.set_result({}) - return - else: - info = msg['content'] - self.log.debug("Received kernel info: %s", info) - if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info: - self.log.error("Kernel info request failed, assuming current %s", info) - info = {} - self._finish_kernel_info(info) - - # close the kernel_info channel, we don't need it anymore - if self.kernel_info_channel: - self.kernel_info_channel.close() - self.kernel_info_channel = None - - def _finish_kernel_info(self, info): - """Finish handling kernel_info reply - - Set up protocol adaptation, if needed, - and signal that connection can continue. + which doesn't make sense for websockets """ - protocol_version = info.get('protocol_version', client_protocol_version) - if protocol_version != client_protocol_version: - self.session.adapt_version = int(protocol_version.split('.')[0]) - self.log.info("Adapting from protocol version {protocol_version} (kernel {kernel_id}) to {client_protocol_version} (client).".format(protocol_version=protocol_version, kernel_id=self.kernel_id, client_protocol_version=client_protocol_version)) - if not self._kernel_info_future.done(): - self._kernel_info_future.set_result(info) + pass + + def get_compression_options(self): + return self.settings.get('websocket_compression_options', None) + + channels = {'shell', 'control', 'iopub', 'stdin'} def initialize(self): super(ZMQChannelsHandler, self).initialize() self.zmq_stream = None - self.channels = {} self.kernel_id = None - self.kernel_info_channel = None - self._kernel_info_future = Future() self._close_future = Future() self.session_key = '' @@ -211,33 +169,27 @@ def initialize(self): # by a delta amount at some point in the future. self._iopub_window_byte_queue = [] + session_id = None + @gen.coroutine def pre_get(self): - # authenticate first - super(ZMQChannelsHandler, self).pre_get() + # authenticate the request before opening the websocket + if self.get_current_user() is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + if self.get_argument('session_id', False): + self.session_id = cast_unicode(self.get_argument('session_id')) + else: + self.log.warning("No session ID specified") + # check session collision: yield self._register_session() - # then request kernel info, waiting up to a certain time before giving up. - # We don't want to wait forever, because browsers don't take it well when - # servers never respond to websocket connection requests. - kernel = self.kernel_manager.get_kernel(self.kernel_id) - self.session.key = kernel.session.key - future = self.request_kernel_info() - - def give_up(): - """Don't wait forever for the kernel to reply""" - if future.done(): - return - self.log.warning("Timeout waiting for kernel_info reply from %s", self.kernel_id) - future.set_result({}) - loop = IOLoop.current() - loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up) - # actually wait for it - yield future @gen.coroutine def get(self, kernel_id): self.kernel_id = cast_unicode(kernel_id, 'ascii') + yield self.pre_get() yield super(ZMQChannelsHandler, self).get(kernel_id=kernel_id) @gen.coroutine @@ -248,57 +200,45 @@ def _register_session(self): This is likely due to a client reconnecting from a lost network connection, where the socket on our side has not been cleaned up yet. """ - self.session_key = '%s:%s' % (self.kernel_id, self.session.session) + self.session_key = '%s:%s' % (self.kernel_id, self.session_id) stale_handler = self._open_sessions.get(self.session_key) if stale_handler: self.log.warning("Replacing stale connection: %s", self.session_key) yield stale_handler.close() self._open_sessions[self.session_key] = self + @gen.coroutine def open(self, kernel_id): super(ZMQChannelsHandler, self).open() km = self.kernel_manager + km._check_kernel_id(kernel_id) km.notify_connect(kernel_id) + kernel = km.get_kernel(kernel_id) + yield from kernel.client_ready() # on new connections, flush the message buffer - buffer_info = km.get_buffer(kernel_id, self.session_key) - if buffer_info and buffer_info['session_key'] == self.session_key: + buffer_key, replay_buffer = kernel.get_buffer() + if buffer_key == self.session_key: self.log.info("Restoring connection for %s", self.session_key) - self.channels = buffer_info['channels'] - replay_buffer = buffer_info['buffer'] if replay_buffer: self.log.info("Replaying %s buffered messages", len(replay_buffer)) - for channel, msg_list in replay_buffer: - stream = self.channels[channel] - self._on_zmq_reply(stream, msg_list) - else: - try: - self.create_stream() - except web.HTTPError as e: - self.log.error("Error opening stream: %s", e) - # WebSockets don't response to traditional error codes so we - # close the connection. - for channel, stream in self.channels.items(): - if not stream.closed(): - stream.close() - self.close() - return + for msg, channel in replay_buffer: + self._on_zmq_msg(msg, channel) - km.add_restart_callback(self.kernel_id, self.on_kernel_restarted) - km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead') + kernel.restarter.add_callback(self.on_kernel_died, 'died') + kernel.restarter.add_callback(self.on_kernel_restarted, 'restarted') + kernel.restarter.add_callback(self.on_restart_failed, 'failed') - for channel, stream in self.channels.items(): - stream.on_recv_stream(self._on_zmq_reply) + kernel.msg_handlers.append(self._on_zmq_msg) def on_message(self, msg): - if not self.channels: + """Received websocket message; forward to kernel""" + if self._close_future.done(): # already closed, ignore the message self.log.debug("Received message on closed websocket %r", msg) return - if isinstance(msg, bytes): - msg = deserialize_binary_message(msg) - else: - msg = json.loads(msg) + + msg = deserialize_message(msg) channel = msg.pop('channel', None) if channel is None: self.log.warning("No channel specified, assuming shell: %s", msg) @@ -311,25 +251,25 @@ def on_message(self, msg): if am and mt not in am: self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt) else: - stream = self.channels[channel] - self.session.send(stream, msg) + self.kernel_client.messaging.send(channel, Message(**msg)) + + def _on_zmq_msg(self, msg: Message, channel): + """Received message from kernel; forward over websocket""" + if self.ws_connection is None: + return - def _on_zmq_reply(self, stream, msg_list): - idents, fed_msg_list = self.session.feed_identities(msg_list) - msg = self.session.deserialize(fed_msg_list) - parent = msg['parent_header'] def write_stderr(error_message): self.log.warning(error_message) - msg = self.session.msg("stream", + stream_msg = Message.from_type("stream", content={"text": error_message + '\n', "name": "stderr"}, - parent=parent - ) - msg['channel'] = 'iopub' - self.write_message(json.dumps(msg, default=date_default)) - channel = getattr(stream, 'channel', None) - msg_type = msg['header']['msg_type'] - - if channel == 'iopub' and msg_type == 'status' and msg['content'].get('execution_state') == 'idle': + ).make_dict() + stream_msg['parent_header'] = msg.parent_header + stream_msg['channel'] = 'iopub' + self.write_message(json.dumps(stream_msg, default=date_default)) + + msg_type = msg.header['msg_type'] + + if channel == 'iopub' and msg_type == 'status' and msg.content.get('execution_state') == 'idle': # reset rate limit counter on status=idle, # to avoid 'Run All' hitting limits prematurely. self._iopub_window_byte_queue = [] @@ -356,7 +296,7 @@ def write_stderr(error_message): # Increment the bytes and message count self._iopub_window_msg_count += 1 if msg_type == 'stream': - byte_count = sum([len(x) for x in msg_list]) + byte_count = len(msg.content['text'].encode('utf-8')) else: byte_count = 0 self._iopub_window_byte_count += byte_count @@ -421,7 +361,14 @@ def write_stderr(error_message): self._iopub_window_byte_count -= byte_count self._iopub_window_byte_queue.pop(-1) return - super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) + + try: + ws_msg = serialize_message(msg, channel=channel) + except Exception: + self.log.critical("Malformed message: %r" % msg, + exc_info=True) + else: + self.write_message(ws_msg, binary=isinstance(ws_msg, bytes)) def close(self): super(ZMQChannelsHandler, self).close() @@ -436,50 +383,44 @@ def on_close(self): km = self.kernel_manager if self.kernel_id in km: km.notify_disconnect(self.kernel_id) - km.remove_restart_callback( - self.kernel_id, self.on_kernel_restarted, - ) - km.remove_restart_callback( - self.kernel_id, self.on_restart_failed, 'dead', - ) + kernel = km.get_kernel(self.kernel_id) + try: + kernel.msg_handlers.remove(self._on_zmq_msg) + except ValueError: + self.log.debug("Message handler not connected") + + kernel.restarter.remove_callback(self.on_kernel_died, 'died') + kernel.restarter.remove_callback(self.on_restart_failed, 'failed') + kernel.restarter.remove_callback(self.on_kernel_restarted, 'restarted') # start buffering instead of closing if this was the last connection - if km._kernel_connections[self.kernel_id] == 0: + if kernel.n_connections == 0: km.start_buffering(self.kernel_id, self.session_key, self.channels) - self._close_future.set_result(None) - return - - # This method can be called twice, once by self.kernel_died and once - # from the WebSocket close event. If the WebSocket connection is - # closed before the ZMQ streams are setup, they could be None. - for channel, stream in self.channels.items(): - if stream is not None and not stream.closed(): - stream.on_recv(None) - stream.close() - self.channels = {} self._close_future.set_result(None) def _send_status_message(self, status): - iopub = self.channels.get('iopub', None) - if iopub and not iopub.closed(): - # flush IOPub before sending a restarting/dead status message - # ensures proper ordering on the IOPub channel - # that all messages from the stopped kernel have been delivered - iopub.flush() - msg = self.session.msg("status", + msg = Message.from_type("status", {'execution_state': status} ) - msg['channel'] = 'iopub' - self.write_message(json.dumps(msg, default=date_default)) + ws_msg = serialize_message(msg, channel='iopub') + return self.write_message(ws_msg, binary=isinstance(ws_msg, bytes)) - def on_kernel_restarted(self): - logging.warn("kernel %s restarted", self.kernel_id) - self._send_status_message('restarting') + def on_kernel_died(self, _data): + logging.warning("kernel %s died, noticed by auto restarter", self.kernel_id) + return self._send_status_message('restarting') + + @gen.coroutine + def on_kernel_restarted(self, _data): + kernel = self.kernel_manager.get_kernel(self.kernel_id) + # Send the status message once the client is connected + yield kernel.client_ready() + logging.warning("kernel %s restarted", self.kernel_id) + return self._send_status_message('starting') - def on_restart_failed(self): + def on_restart_failed(self, _data): logging.error("kernel %s restarted failed!", self.kernel_id) - self._send_status_message('dead') + return self._send_status_message('dead') #----------------------------------------------------------------------------- diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index 250472f2cc..81efececa9 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -9,18 +9,19 @@ from collections import defaultdict from datetime import datetime, timedelta -from functools import partial import os +import uuid from tornado import gen, web -from tornado.concurrent import Future from tornado.ioloop import IOLoop, PeriodicCallback +from tornado.locks import Event -from jupyter_client.session import Session -from jupyter_client.multikernelmanager import MultiKernelManager +from jupyter_kernel_mgmt.client import IOLoopKernelClient +from jupyter_kernel_mgmt.restarter import TornadoKernelRestarter from traitlets import (Any, Bool, Dict, List, Unicode, TraitError, Integer, Float, Instance, default, validate ) +from traitlets.config.configurable import LoggingConfigurable from jupyter_server.utils import maybe_future, to_os_path, exists from jupyter_server._tz import utcnow, isoformat @@ -29,18 +30,153 @@ from jupyter_server.prometheus.metrics import KERNEL_CURRENTLY_RUNNING_TOTAL -class MappingKernelManager(MultiKernelManager): - """A KernelManager that handles - - File mapping - - HTTP error handling - - Kernel message filtering +class KernelInterface(LoggingConfigurable): + """A wrapper around one kernel, including manager, client and restarter. + + A KernelInterface instance persists across kernel restarts, whereas + manager and client objects are recreated. """ + def __init__(self, kernel_type, kernel_finder, connection_info, manager): + super(KernelInterface, self).__init__() + self.kernel_type = kernel_type + self.kernel_finder = kernel_finder + + self.connection_info = connection_info + self.manager = manager + self.n_connections = 0 + self.execution_state = 'starting' + self.last_activity = utcnow() + + self.restarter = TornadoKernelRestarter(self.manager, self.kernel_type, + kernel_finder=self.kernel_finder) + self.restarter.add_callback(self._handle_kernel_died, 'died') + self.restarter.add_callback(self._handle_kernel_restarted, 'restarted') + self.restarter.start() + + # A future that resolves when the client is connected + self.client_connected = self._connect_client() + self._client_connected_evt = Event() + + self.buffer_for_key = None + # TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple + self.buffer = [] + + # Message handlers stored here don't have to be re-added if the kernel + # is restarted. + self.msg_handlers = [] + + @classmethod + @gen.coroutine + def launch(cls, kernel_type, kernel_finder, **kwargs): + connection_info, manager = yield kernel_finder.launch(kernel_type, **kwargs) + raise gen.Return(cls(kernel_type, kernel_finder, connection_info, manager)) + + client = None + + @gen.coroutine + def _connect_client(self): + """Connect a client and wait for it to be ready.""" + self.client = IOLoopKernelClient(self.connection_info, self.manager) + yield self.client.wait_for_ready() + self.client.add_handler(self._msg_received, {'shell', 'iopub', 'stdin'}) + self._client_connected_evt.set() + + def _close_client(self): + if self.client is not None: + self._client_connected_evt.clear() + self.client_connected.cancel() + self.client.close() + self.client = None + + def client_ready(self): + """Return a future which resolves when the client is ready""" + if self.client is None: + return self._client_connected_evt.wait() + else: + return self.client_connected + + def _msg_received(self, msg, channel): + loop = IOLoop.current() + for handler in self.msg_handlers: + loop.add_callback(handler, msg, channel) + + @gen.coroutine + def shutdown(self, now=False): + self.restarter.stop() + + if now or (self.client is None): + yield self.manager.kill() + else: + yield self.client_connected + yield self.client.shutdown_or_terminate() + + self._close_client() + yield self.manager.cleanup() + + @gen.coroutine + def interrupt(self): + yield self.manager.interrupt() + + @gen.coroutine + def _handle_kernel_died(self, data): + """Called when the auto-restarter notices the kernel has died""" + self._close_client() + + @gen.coroutine + def _handle_kernel_restarted(self, data): + """Called when the kernel has been restarted""" + self.manager = data['manager'] + self.connection_info = data['connection_info'] + self.client_connected = self._connect_client() + yield self.client_connected + + @gen.coroutine + def restart(self): + yield self.shutdown() + # The restart will trigger _handle_kernel_restarted() to connect a + # new client. + yield self.restarter.do_restart() + # Resume monitoring the kernel for auto-restart + self.restarter.start() + yield self._client_connected_evt.wait() + + def start_buffering(self, session_key): + # record the session key because only one session can buffer + self.buffer_for_key = session_key + + # forward any future messages to the internal buffer + self.client.add_handler(self._buffer_msg, {'shell', 'iopub', 'stdin'}) + + def _buffer_msg(self, msg, channel): + self.log.debug("Buffering msg on %s", channel) + self.buffer.append((msg, channel)) + + def get_buffer(self): + """Get the buffer for a given kernel, and stop buffering new messages + """ + buffer, key = self.buffer, self.buffer_for_key + self.buffer = [] + self.stop_buffering() + return buffer, key + + def stop_buffering(self): + """Stop buffering kernel messages + """ + self.client.remove_handler(self._buffer_msg) + + if self.buffer: + self.log.info("Discarding %s buffered messages for %s", + len(self.buffer), self.buffer_for_key) + self.buffer = [] + self.buffer_for_key = None - @default('kernel_manager_class') - def _default_kernel_manager_class(self): - return "jupyter_client.ioloop.IOLoopKernelManager" - kernel_argv = List(Unicode()) +class MappingKernelManager(LoggingConfigurable): + """A KernelManager that handles notebook mapping and HTTP error handling""" + + default_kernel_name = Unicode('pyimport/kernel', config=True, + help="The name of the default kernel to start" + ) root_dir = Unicode(config=True) @@ -120,16 +256,24 @@ def _default_kernel_buffers(self): last_kernel_activity = Instance(datetime, help="The last activity on any kernel, including shutting down a kernel") - def __init__(self, **kwargs): - super(MappingKernelManager, self).__init__(**kwargs) - self.last_kernel_activity = utcnow() - allowed_message_types = List(trait=Unicode(), config=True, help="""White list of allowed kernel message types. When the list is empty, all message types are allowed. """ ) + def __init__(self, kernel_finder, **kwargs): + super(MappingKernelManager, self).__init__(**kwargs) + self.last_kernel_activity = utcnow() + self._kernels = {} + self._kernels_starting = {} + self._restarters = {} + self.kernel_finder = kernel_finder + self.initialize_culler() + + def get_kernel(self, kernel_id): + return self._kernels[kernel_id] + #------------------------------------------------------------------------- # Methods for managing kernels and sessions #------------------------------------------------------------------------- @@ -137,7 +281,13 @@ def __init__(self, **kwargs): def _handle_kernel_died(self, kernel_id): """notice that a kernel died""" self.log.warning("Kernel %s died, removing from map.", kernel_id) - self.remove_kernel(kernel_id) + kernel = self._kernels.pop(kernel_id) + kernel.client.close() + kernel.manager.cleanup() + + KERNEL_CURRENTLY_RUNNING_TOTAL.labels( + type=kernel.kernel_type + ).inc() def cwd_for_path(self, path): """Turn API path into absolute OS path.""" @@ -149,7 +299,7 @@ def cwd_for_path(self, path): return os_path @gen.coroutine - def start_kernel(self, kernel_id=None, path=None, **kwargs): + def start_kernel(self, kernel_id=None, path=None, kernel_name=None, **kwargs): """Start a kernel for a session and return its kernel_id. Parameters @@ -166,38 +316,53 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs): an existing kernel is returned, but it may be checked in the future. """ if kernel_id is None: - if path is not None: - kwargs['cwd'] = self.cwd_for_path(path) - kernel_id = yield maybe_future( - super(MappingKernelManager, self).start_kernel(**kwargs) - ) - self._kernel_connections[kernel_id] = 0 - self.start_watching_activity(kernel_id) - self.log.info("Kernel started: %s" % kernel_id) - self.log.debug("Kernel args: %r" % kwargs) - # register callback for failed auto-restart - self.add_restart_callback(kernel_id, - lambda : self._handle_kernel_died(kernel_id), - 'dead', - ) - - # Increase the metric of number of kernels running - # for the relevant kernel type by 1 - KERNEL_CURRENTLY_RUNNING_TOTAL.labels( - type=self._kernels[kernel_id].kernel_name - ).inc() - + kernel_id = yield maybe_future(self.start_launching_kernel(path=path, kernel_name=kernel_name, **kwargs)) + yield self.get_kernel(kernel_id).client_ready() else: self._check_kernel_id(kernel_id) self.log.info("Using existing kernel: %s" % kernel_id) - # Initialize culling if not already - if not self._initialized_culler: - self.initialize_culler() - # py2-compat raise gen.Return(kernel_id) + @gen.coroutine + def start_launching_kernel(self, path=None, kernel_name=None, **kwargs): + """Launch a new kernel, return its kernel ID + + This is a coroutine which starts the process of launching a kernel. + Retrieve the KernelInterface object via the launch class method and call + ``.client_ready()`` to get a future for the rest of the startup & connection. + """ + if path is not None: + kwargs['cwd'] = self.cwd_for_path(path) + + if kernel_name is None: + kernel_name = 'pyimport/kernel' + elif '/' not in kernel_name: + kernel_name = 'spec/' + kernel_name + + kernel = yield KernelInterface.launch(kernel_name, self.kernel_finder, **kwargs) + kernel_id = kernel.manager.kernel_id + if kernel_id is None: # if provider didn't set a kernel_id, let's associate one to this kernel + kernel_id = str(uuid.uuid4()) + self._kernels[kernel_id] = kernel + + self.start_watching_activity(kernel_id) + self.log.info("Kernel started: %s" % kernel_id) + + kernel.restarter.add_callback( + lambda data: self._handle_kernel_died(kernel_id), + 'failed' + ) + + # Increase the metric of number of kernels running + # for the relevant kernel type by 1 + KERNEL_CURRENTLY_RUNNING_TOTAL.labels( + type=self._kernels[kernel_id].kernel_type + ).inc() + + return kernel_id + def start_buffering(self, kernel_id, session_key, channels): """Start buffering messages for a kernel @@ -220,142 +385,81 @@ def start_buffering(self, kernel_id, session_key, channels): self.log.info("Starting buffering for %s", session_key) self._check_kernel_id(kernel_id) + kernel = self._kernels[kernel_id] # clear previous buffering state - self.stop_buffering(kernel_id) - buffer_info = self._kernel_buffers[kernel_id] - # record the session key because only one session can buffer - buffer_info['session_key'] = session_key - # TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple - buffer_info['buffer'] = [] - buffer_info['channels'] = channels - - # forward any future messages to the internal buffer - def buffer_msg(channel, msg_parts): - self.log.debug("Buffering msg on %s:%s", kernel_id, channel) - buffer_info['buffer'].append((channel, msg_parts)) - - for channel, stream in channels.items(): - stream.on_recv(partial(buffer_msg, channel)) - - def get_buffer(self, kernel_id, session_key): - """Get the buffer for a given kernel - - Parameters - ---------- - kernel_id : str - The id of the kernel to stop buffering. - session_key: str, optional - The session_key, if any, that should get the buffer. - If the session_key matches the current buffered session_key, - the buffer will be returned. - """ - self.log.debug("Getting buffer for %s", kernel_id) - if kernel_id not in self._kernel_buffers: - return - - buffer_info = self._kernel_buffers[kernel_id] - if buffer_info['session_key'] == session_key: - # remove buffer - self._kernel_buffers.pop(kernel_id) - # only return buffer_info if it's a match - return buffer_info - else: - self.stop_buffering(kernel_id) - - def stop_buffering(self, kernel_id): - """Stop buffering kernel messages - - Parameters - ---------- - kernel_id : str - The id of the kernel to stop buffering. - """ - self.log.debug("Clearing buffer for %s", kernel_id) - self._check_kernel_id(kernel_id) + kernel.stop_buffering() + kernel.start_buffering(session_key) - if kernel_id not in self._kernel_buffers: - return - buffer_info = self._kernel_buffers.pop(kernel_id) - # close buffering streams - for stream in buffer_info['channels'].values(): - if not stream.closed(): - stream.on_recv(None) - stream.close() + @gen.coroutine + def _shutdown_all(self): + futures = [self.shutdown_kernel(kid) for kid in self.list_kernel_ids()] + yield gen.multi(futures) - msg_buffer = buffer_info['buffer'] - if msg_buffer: - self.log.info("Discarding %s buffered messages for %s", - len(msg_buffer), buffer_info['session_key']) + def shutdown_all(self): + # Blocking function to call when the notebook server is shutting down + loop = IOLoop.current() + loop.run_sync(self._shutdown_all) + @gen.coroutine def shutdown_kernel(self, kernel_id, now=False): """Shutdown a kernel by kernel_id""" self._check_kernel_id(kernel_id) - kernel = self._kernels[kernel_id] - if kernel._activity_stream: - kernel._activity_stream.close() - kernel._activity_stream = None - self.stop_buffering(kernel_id) - self._kernel_connections.pop(kernel_id, None) + kernel = self._kernels.pop(kernel_id) + self.log.info("Shutting down kernel %s", kernel_id) + yield kernel.shutdown(now=now) + self.last_kernel_activity = utcnow() # Decrease the metric of number of kernels # running for the relevant kernel type by 1 KERNEL_CURRENTLY_RUNNING_TOTAL.labels( - type=self._kernels[kernel_id].kernel_name + type=kernel.kernel_type ).dec() - return super(MappingKernelManager, self).shutdown_kernel(kernel_id, now=now) + @gen.coroutine + def interrupt_kernel(self, kernel_id): + """Interrupt a kernel by kernel_id. """ + + self._check_kernel_id(kernel_id) + kernel = self.get_kernel(kernel_id) + + # Don't interrupt a kernel while it's still starting + yield kernel.client_ready() + yield kernel.interrupt() @gen.coroutine def restart_kernel(self, kernel_id): - """Restart a kernel by kernel_id""" + """Restart a kernel by kernel_id + + The restarted kernel keeps the same ID and KernelInterface object. + """ self._check_kernel_id(kernel_id) - yield maybe_future(super(MappingKernelManager, self).restart_kernel(kernel_id)) kernel = self.get_kernel(kernel_id) - # return a Future that will resolve when the kernel has successfully restarted - channel = kernel.connect_shell() - future = Future() - - def finish(): - """Common cleanup when restart finishes/fails for any reason.""" - if not channel.closed(): - channel.close() - loop.remove_timeout(timeout) - kernel.remove_restart_callback(on_restart_failed, 'dead') - - def on_reply(msg): - self.log.debug("Kernel info reply received: %s", kernel_id) - finish() - if not future.done(): - future.set_result(msg) - - def on_timeout(): - self.log.warning("Timeout waiting for kernel_info_reply: %s", kernel_id) - finish() - if not future.done(): - future.set_exception(gen.TimeoutError("Timeout waiting for restart")) - - def on_restart_failed(): - self.log.warning("Restarting kernel failed: %s", kernel_id) - finish() - if not future.done(): - future.set_exception(RuntimeError("Restart failed")) - - kernel.add_restart_callback(on_restart_failed, 'dead') - kernel.session.send(channel, "kernel_info_request") - channel.on_recv(on_reply) - loop = IOLoop.current() - timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout) - raise gen.Return(future) + + try: + yield gen.with_timeout( + timedelta(seconds=self.kernel_info_timeout), + kernel.restart(), + ) + except gen.TimeoutError: + self.log.warning("Timeout waiting for kernel_info_reply: %s", + kernel_id) + self._kernels.pop(kernel_id) + # Decrease the metric of number of kernels + # running for the relevant kernel type by 1 + KERNEL_CURRENTLY_RUNNING_TOTAL.labels( + type=kernel.kernel_type + ).dec() + raise gen.TimeoutError("Timeout waiting for restart") def notify_connect(self, kernel_id): """Notice a new connection to a kernel""" - if kernel_id in self._kernel_connections: - self._kernel_connections[kernel_id] += 1 + if kernel_id in self._kernels: + self._kernels[kernel_id].n_connections += 1 def notify_disconnect(self, kernel_id): """Notice a disconnection from a kernel""" - if kernel_id in self._kernel_connections: - self._kernel_connections[kernel_id] -= 1 + if kernel_id in self._kernels: + self._kernels[kernel_id].n_connections -= 1 def kernel_model(self, kernel_id): """Return a JSON-safe dict representing a kernel @@ -367,22 +471,27 @@ def kernel_model(self, kernel_id): model = { "id":kernel_id, - "name": kernel.kernel_name, + "name": kernel.kernel_type, "last_activity": isoformat(kernel.last_activity), "execution_state": kernel.execution_state, - "connections": self._kernel_connections[kernel_id], + "connections": kernel.n_connections, } return model def list_kernels(self): - """Returns a list of kernel_id's of kernels running.""" + """Returns a list of models for kernels running.""" kernels = [] - kernel_ids = super(MappingKernelManager, self).list_kernel_ids() - for kernel_id in kernel_ids: + for kernel_id in self._kernels.keys(): model = self.kernel_model(kernel_id) kernels.append(model) return kernels + def list_kernel_ids(self): + return list(self._kernels.keys()) + + def __contains__(self, kernel_id): + return kernel_id in self._kernels + # override _check_kernel_id to raise 404 instead of KeyError def _check_kernel_id(self, kernel_id): """Check a that a kernel_id exists and raise 404 if not.""" @@ -398,30 +507,19 @@ def start_watching_activity(self, kernel_id): - record execution_state from status messages """ kernel = self._kernels[kernel_id] - # add busy/activity markers: - kernel.execution_state = 'starting' - kernel.last_activity = utcnow() - kernel._activity_stream = kernel.connect_iopub() - session = Session( - config=kernel.session.config, - key=kernel.session.key, - ) - def record_activity(msg_list): + def record_activity(msg, _channel): """Record an IOPub message arriving from a kernel""" self.last_kernel_activity = kernel.last_activity = utcnow() - idents, fed_msg_list = session.feed_identities(msg_list) - msg = session.deserialize(fed_msg_list) - - msg_type = msg['header']['msg_type'] + msg_type = msg.header['msg_type'] if msg_type == 'status': - kernel.execution_state = msg['content']['execution_state'] + kernel.execution_state = msg.content['execution_state'] self.log.debug("activity on %s: %s (%s)", kernel_id, msg_type, kernel.execution_state) else: self.log.debug("activity on %s: %s", kernel_id, msg_type) - kernel._activity_stream.on_recv(record_activity) + kernel.msg_handlers.append(record_activity) def initialize_culler(self): """Start idle culler if 'cull_idle_timeout' is greater than zero. @@ -460,7 +558,7 @@ def cull_kernels(self): def cull_kernel_if_idle(self, kernel_id): kernel = self._kernels[kernel_id] - self.log.debug("kernel_id=%s, kernel_name=%s, last_activity=%s", kernel_id, kernel.kernel_name, kernel.last_activity) + self.log.debug("kernel_id=%s, kernel_name=%s, last_activity=%s", kernel_id, kernel.kernel_type, kernel.last_activity) if kernel.last_activity is not None: dt_now = utcnow() dt_idle = dt_now - kernel.last_activity @@ -473,5 +571,5 @@ def cull_kernel_if_idle(self, kernel_id): if (is_idle_time and is_idle_execute and is_idle_connected): idle_duration = int(dt_idle.total_seconds()) self.log.warning("Culling '%s' kernel '%s' (%s) with %d connections due to %s seconds of inactivity.", - kernel.execution_state, kernel.kernel_name, kernel_id, connections, idle_duration) - self.shutdown_kernel(kernel_id) + kernel.execution_state, kernel.kernel_type, kernel_id, connections, idle_duration) + self.shutdown_kernel(kernel_id, now=True) diff --git a/jupyter_server/services/kernels/ws_serialize.py b/jupyter_server/services/kernels/ws_serialize.py new file mode 100644 index 0000000000..9af7139dce --- /dev/null +++ b/jupyter_server/services/kernels/ws_serialize.py @@ -0,0 +1,124 @@ +"""Serialize & deserialize Jupyter messages to send over websockets. +""" +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from datetime import datetime, timezone +import json +import struct +import sys + +from dateutil.tz import tzutc + +from ipython_genutils.py3compat import cast_unicode +from jupyter_protocol.jsonutil import date_default, extract_dates + + +def serialize_message(msg, channel): + """Serialize a message from the kernel, using JSON. + + msg is a jupyter_protocol Message object + + Returns a str of JSON if there are no binary buffers, or bytes which may + be sent as a binary websocket message. + """ + d = msg.make_dict() + if channel: + d['channel'] = channel + if msg.buffers: + buf = serialize_binary_message(d, msg.buffers) + return buf + else: + d['buffers'] = [] + smsg = json.dumps(d, default=date_default) + return cast_unicode(smsg) + + +def serialize_binary_message(msg_dict, buffers): + """serialize a message as a binary blob + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + + The message serialized to bytes. + + """ + # don't modify msg or buffer list in-place + msg_dict = msg_dict.copy() + buffers = list(buffers) + + if sys.version_info < (3, 4): + buffers = [x.tobytes() for x in buffers] + bmsg = json.dumps(msg_dict, default=date_default).encode('utf8') + buffers.insert(0, bmsg) + nbufs = len(buffers) + offsets = [4 * (nbufs + 1)] + for buf in buffers[:-1]: + offsets.append(offsets[-1] + len(buf)) + offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets) + buffers.insert(0, offsets_buf) + return b''.join(buffers) + + +def deserialize_message(msg): + """Deserialize a websocket message, return a dict. + + msg may be either bytes, for a binary websocket message including buffers, + or str, for a pure JSON message. + """ + if isinstance(msg, bytes): + msg = deserialize_binary_message(msg) + else: + msg = json.loads(msg) + + msg['header'] = convert_tz_utc(extract_dates(msg['header'])) + msg['parent_header'] = convert_tz_utc(extract_dates(msg['parent_header'])) + return msg + + +def deserialize_binary_message(bmsg): + """deserialize a message from a binary blob + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + + message dictionary + """ + nbufs = struct.unpack('!i', bmsg[:4])[0] + offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)])) + offsets.append(None) + bufs = [] + for start, stop in zip(offsets[:-1], offsets[1:]): + bufs.append(bmsg[start:stop]) + msg = json.loads(bufs[0].decode('utf8')) + + msg['buffers'] = bufs[1:] + return msg + + +def convert_tz_utc(obj): + """Replace dateutil tzutc objects with stdlib datetime utc objects""" + if isinstance(obj, dict): + new_obj = {} # don't clobber + for k,v in obj.items(): + new_obj[k] = convert_tz_utc(v) + obj = new_obj + elif isinstance(obj, (list, tuple)): + obj = [ convert_tz_utc(o) for o in obj ] + elif isinstance(obj, datetime) and isinstance(obj.tzinfo, tzutc): + obj = obj.replace(tzinfo=timezone.utc) + return obj diff --git a/jupyter_server/services/kernelspecs/handlers.py b/jupyter_server/services/kernelspecs/handlers.py index 6302e96dfa..3a3fc3e6ba 100644 --- a/jupyter_server/services/kernelspecs/handlers.py +++ b/jupyter_server/services/kernelspecs/handlers.py @@ -12,19 +12,20 @@ pjoin = os.path.join from tornado import web, gen +from urllib.parse import unquote from ...base.handlers import APIHandler -from ...utils import maybe_future, url_path_join, url_unescape - +from ...utils import maybe_future, url_path_join, url_unescape, quote def kernelspec_model(handler, name, spec_dict, resource_dir): """Load a KernelSpec by name and return the REST API model""" d = { 'name': name, - 'spec': spec_dict, + 'spec': spec_dict.copy(), 'resources': {} } + d['spec']['language'] = d['spec']['language_info']['name'] # Add resource files if they exist resource_dir = resource_dir @@ -33,7 +34,7 @@ def kernelspec_model(handler, name, spec_dict, resource_dir): d['resources'][resource] = url_path_join( handler.base_url, 'kernelspecs', - name, + quote(name, safe=''), resource ) for logo_file in glob.glob(pjoin(resource_dir, 'logo-*')): @@ -42,7 +43,7 @@ def kernelspec_model(handler, name, spec_dict, resource_dir): d['resources'][no_ext] = url_path_join( handler.base_url, 'kernelspecs', - name, + quote(name, safe=''), fname ) return d @@ -58,22 +59,22 @@ class MainKernelSpecHandler(APIHandler): @web.authenticated @gen.coroutine def get(self): - ksm = self.kernel_spec_manager + kf = self.kernel_finder km = self.kernel_manager model = {} model['default'] = km.default_kernel_name model['kernelspecs'] = specs = {} - kspecs = yield maybe_future(ksm.get_all_specs()) - for kernel_name, kernel_info in kspecs.items(): + for kernel_name, kernel_info in kf.find_kernels(): try: if is_kernelspec_model(kernel_info): d = kernel_info else: - d = kernelspec_model(self, kernel_name, kernel_info['spec'], kernel_info['resource_dir']) + d = kernelspec_model(self, kernel_name, kernel_info, kernel_info['resource_dir']) except Exception: self.log.error("Failed to load kernel spec: '%s'", kernel_name, exc_info=True) continue specs[kernel_name] = d + self.set_header("Content-Type", 'application/json') self.finish(json.dumps(model)) @@ -83,18 +84,20 @@ class KernelSpecHandler(APIHandler): @web.authenticated @gen.coroutine def get(self, kernel_name): - ksm = self.kernel_spec_manager - kernel_name = url_unescape(kernel_name) - try: - spec = yield maybe_future(ksm.get_kernel_spec(kernel_name)) - except KeyError: - raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) - if is_kernelspec_model(spec): - model = spec - else: - model = kernelspec_model(self, kernel_name, spec.to_dict(), spec.resource_dir) - self.set_header("Content-Type", 'application/json') - self.finish(json.dumps(model)) + kf = self.kernel_finder + kernel_name = unquote(kernel_name.lower()) + for name, kernel_info in kf.find_kernels(): + if name == kernel_name: + if is_kernelspec_model(kernel_info): + model = kernel_info + else: + model = kernelspec_model(self, kernel_name, kernel_info, + kernel_info['resource_dir']) + self.set_header("Content-Type", 'application/json') + return self.finish(json.dumps(model)) + + raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) + # URL to handler mappings diff --git a/jupyter_server/services/sessions/handlers.py b/jupyter_server/services/sessions/handlers.py index 78072c5771..fa6ea66203 100644 --- a/jupyter_server/services/sessions/handlers.py +++ b/jupyter_server/services/sessions/handlers.py @@ -11,9 +11,9 @@ from tornado import gen, web from ...base.handlers import APIHandler -from jupyter_client.jsonutil import date_default +from jupyter_protocol.jsonutil import date_default from jupyter_server.utils import maybe_future, url_path_join -from jupyter_client.kernelspec import NoSuchKernel +from jupyter_kernel_mgmt.kernelspec import NoSuchKernel class SessionRootHandler(APIHandler): @@ -132,9 +132,9 @@ def patch(self, session_id): changes['kernel_id'] = kernel_id elif model['kernel'].get('name') is not None: kernel_name = model['kernel']['name'] - kernel_id = yield sm.start_kernel_for_session( + kernel_id = yield maybe_future(sm.start_kernel_for_session( session_id, kernel_name=kernel_name, name=before['name'], - path=before['path'], type=before['type']) + path=before['path'], type=before['type'])) changes['kernel_id'] = kernel_id yield maybe_future(sm.update_session(session_id, **changes)) diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 50012bc2a1..f992c03220 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -89,6 +89,7 @@ def create_session(self, path=None, name=None, type=None, kernel_name=None, kern result = yield maybe_future( self.save_session(session_id, path=path, name=name, type=type, kernel_id=kernel_id) ) + # py2-compat raise gen.Return(result) @@ -97,9 +98,8 @@ def start_kernel_for_session(self, session_id, path, name, type, kernel_name): """Start a new kernel for a given session.""" # allow contents manager to specify kernels cwd kernel_path = self.contents_manager.get_kernel_path(path=path) - kernel_id = yield maybe_future( - self.kernel_manager.start_kernel(path=kernel_path, kernel_name=kernel_name) - ) + kernel_id = yield maybe_future(self.kernel_manager.start_kernel( + path=kernel_path, kernel_name=kernel_name,)) # py2-compat raise gen.Return(kernel_id) diff --git a/setup.py b/setup.py index 1e5d636f9f..0acdaa3497 100755 --- a/setup.py +++ b/setup.py @@ -80,8 +80,9 @@ 'pyzmq>=17', 'ipython_genutils', 'traitlets>=4.2.1', - 'jupyter_core>=4.4.0', - 'jupyter_client>=5.3.1', + 'jupyter_core>=4.6.1', + 'jupyter_kernel_mgmt>=0.5', + 'jupyter_protocol', 'nbformat', 'nbconvert<6', 'ipykernel', # bless IPython kernel for now diff --git a/tests/services/kernels/test_api.py b/tests/services/kernels/test_api.py index cfdc6b80c1..296e5d4b1f 100644 --- a/tests/services/kernels/test_api.py +++ b/tests/services/kernels/test_api.py @@ -7,8 +7,6 @@ import urllib.parse from tornado.escape import url_escape -from jupyter_client.kernelspec import NATIVE_KERNEL_NAME - from jupyter_server.utils import url_path_join from ...conftest import expected_http_error @@ -72,7 +70,7 @@ async def test_main_kernel_handler(fetch): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kernel1 = json.loads(r.body.decode()) @@ -104,7 +102,7 @@ async def test_main_kernel_handler(fetch): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kernel2 = json.loads(r.body.decode()) @@ -145,7 +143,7 @@ async def test_kernel_handler(fetch): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kernel_id = json.loads(r.body.decode())['id'] @@ -200,7 +198,7 @@ async def test_connection(fetch, ws_fetch, http_port, auth_header): 'api', 'kernels', method='POST', body=json.dumps({ - 'name': NATIVE_KERNEL_NAME + 'name': 'pyimport/kernel' }) ) kid = json.loads(r.body.decode())['id'] diff --git a/tests/services/kernelspecs/test_api.py b/tests/services/kernelspecs/test_api.py index 0d3a2ba387..49ec590827 100644 --- a/tests/services/kernelspecs/test_api.py +++ b/tests/services/kernelspecs/test_api.py @@ -3,8 +3,6 @@ import tornado -from jupyter_client.kernelspec import NATIVE_KERNEL_NAME - from ...conftest import expected_http_error @@ -41,7 +39,7 @@ async def test_list_kernelspecs_bad(fetch, kernelspecs, data_dir): ) model = json.loads(r.body.decode()) assert isinstance(model, dict) - assert model['default'] == NATIVE_KERNEL_NAME + assert model['default'] == 'pyimport/kernel' specs = model['kernelspecs'] assert isinstance(specs, dict) assert len(specs) > 2 @@ -54,16 +52,16 @@ async def test_list_kernelspecs(fetch, kernelspecs): ) model = json.loads(r.body.decode()) assert isinstance(model, dict) - assert model['default'] == NATIVE_KERNEL_NAME + assert model['default'] == 'pyimport/kernel' specs = model['kernelspecs'] assert isinstance(specs, dict) assert len(specs) > 2 def is_sample_kernelspec(s): - return s['name'] == 'sample' and s['spec']['display_name'] == 'Test kernel' + return s['name'] == 'spec/sample' and s['spec']['display_name'] == 'Test kernel' def is_default_kernelspec(s): - return s['name'] == NATIVE_KERNEL_NAME and s['spec']['display_name'].startswith("Python") + return s['name'] == 'pyimport/kernel' and s['spec']['display_name'].startswith("Python") assert any(is_sample_kernelspec(s) for s in specs.values()), specs assert any(is_default_kernelspec(s) for s in specs.values()), specs @@ -71,11 +69,11 @@ def is_default_kernelspec(s): async def test_get_kernelspecs(fetch, kernelspecs): r = await fetch( - 'api', 'kernelspecs', 'Sample', + 'api', 'kernelspecs', 'spec%2FSample', method='GET' ) model = json.loads(r.body.decode()) - assert model['name'].lower() == 'sample' + assert model['name'].lower() == 'spec/sample' assert isinstance(model['spec'], dict) assert model['spec']['display_name'] == 'Test kernel' assert isinstance(model['resources'], dict) @@ -83,17 +81,17 @@ async def test_get_kernelspecs(fetch, kernelspecs): async def test_get_kernelspec_spaces(fetch, kernelspecs): r = await fetch( - 'api', 'kernelspecs', 'sample%202', + 'api', 'kernelspecs', 'spec%2Fsample%202', method='GET' ) model = json.loads(r.body.decode()) - assert model['name'].lower() == 'sample 2' + assert model['name'].lower() == 'spec/sample 2' async def test_get_nonexistant_kernelspec(fetch, kernelspecs): with pytest.raises(tornado.httpclient.HTTPClientError) as e: await fetch( - 'api', 'kernelspecs', 'nonexistant', + 'api', 'kernelspecs', 'spec/nonexistant', method='GET' ) assert expected_http_error(e, 404) @@ -101,7 +99,7 @@ async def test_get_nonexistant_kernelspec(fetch, kernelspecs): async def test_get_kernel_resource_file(fetch, kernelspecs): r = await fetch( - 'kernelspecs', 'sAmple', 'resource.txt', + 'kernelspecs', 'spec%2FsAmple', 'resource.txt', method='GET' ) res = r.body.decode('utf-8') @@ -111,14 +109,14 @@ async def test_get_kernel_resource_file(fetch, kernelspecs): async def test_get_nonexistant_resource(fetch, kernelspecs): with pytest.raises(tornado.httpclient.HTTPClientError) as e: await fetch( - 'kernelspecs', 'nonexistant', 'resource.txt', + 'kernelspecs', 'spec%2Fnonexistant', 'resource.txt', method='GET' ) assert expected_http_error(e, 404) with pytest.raises(tornado.httpclient.HTTPClientError) as e: await fetch( - 'kernelspecs', 'sample', 'nonexistant.txt', + 'kernelspecs', 'spec%2Fsample', 'nonexistant.txt', method='GET' ) - assert expected_http_error(e, 404) \ No newline at end of file + assert expected_http_error(e, 404) diff --git a/tests/services/sessions/test_api.py b/tests/services/sessions/test_api.py index 31b9ef5d87..181036ac59 100644 --- a/tests/services/sessions/test_api.py +++ b/tests/services/sessions/test_api.py @@ -37,10 +37,10 @@ async def get(self, id): return await self._req(id, method='GET') async def create( - self, - path, - type='notebook', - kernel_name='python', + self, + path, + type='notebook', + kernel_name='pyimport/kernel', kernel_id=None): body = { 'path': path, @@ -58,7 +58,7 @@ def create_deprecated(self, path): 'path': path }, 'kernel': { - 'name': 'python', + 'name': 'pyimport/kernel', 'id': 'foo' } } @@ -96,7 +96,6 @@ async def cleanup(self): time.sleep(0.1) - @pytest.fixture def session_client(root_dir, fetch): subdir = root_dir.joinpath('foo') @@ -114,7 +113,31 @@ def session_client(root_dir, fetch): # Remove subdir shutil.rmtree(str(subdir), ignore_errors=True) - + + +def assert_kernel_equality(actual, expected): + """ Compares kernel models after taking into account that execution_states + may differ from 'starting' to 'idle'. The 'actual' argument is the + current state (which may have an 'idle' status) while the 'expected' + argument is the previous state (which may have a 'starting' status). + """ + actual.pop('execution_state', None) + actual.pop('last_activity', None) + expected.pop('execution_state', None) + expected.pop('last_activity', None) + assert actual == expected + + +def assert_session_equality(actual, expected): + """ Compares session models. `actual` is the most current session, + while `expected` is the target of the comparison. This order + matters when comparing the kernel sub-models. + """ + assert actual['id'] == expected['id'] + assert actual['path'] == expected['path'] + assert actual['type'] == expected['type'] + assert_kernel_equality(actual['kernel'], expected['kernel']) + async def test_create(session_client): # Make sure no sessions exist. @@ -134,13 +157,14 @@ async def test_create(session_client): # Check that the new session appears in list. resp = await session_client.list() sessions = j(resp) - assert sessions == [new_session] + assert len(sessions) == 1 + assert_session_equality(sessions[0], new_session) # Retrieve that session. sid = new_session['id'] resp = await session_client.get(sid) got = j(resp) - assert got == new_session + assert_session_equality(got, new_session) # Need to find a better solution to this. await session_client.cleanup() @@ -170,7 +194,7 @@ async def test_create_deprecated(session_client): assert resp.code == 201 newsession = j(resp) assert newsession['path'] == 'foo/nb1.ipynb' - assert newsession['type'] == 'notebook' + assert newsession['type'] == 'notebook' assert newsession['notebook']['path'] == 'foo/nb1.ipynb' # Need to find a better solution to this. await session_client.cleanup() @@ -183,24 +207,27 @@ async def test_create_with_kernel_id(session_client, fetch): resp = await session_client.create('foo/nb1.ipynb', kernel_id=kernel['id']) assert resp.code == 201 - newsession = j(resp) - assert 'id' in newsession - assert newsession['path'] == 'foo/nb1.ipynb' - assert newsession['kernel']['id'] == kernel['id'] - assert resp.headers['Location'] == '/api/sessions/{0}'.format(newsession['id']) + new_session = j(resp) + assert 'id' in new_session + assert new_session['path'] == 'foo/nb1.ipynb' + assert_kernel_equality(new_session['kernel'], kernel) + assert resp.headers['Location'] == '/api/sessions/{0}'.format(new_session['id']) resp = await session_client.list() sessions = j(resp) - assert sessions == [newsession] + + assert_session_equality(sessions[0], new_session) # Retrieve it - sid = newsession['id'] + sid = new_session['id'] resp = await session_client.get(sid) got = j(resp) - assert got == newsession + assert_session_equality(got, new_session) + # Need to find a better solution to this. await session_client.cleanup() + async def test_delete(session_client): resp = await session_client.create('foo/nb1.ipynb') newsession = j(resp) @@ -219,6 +246,7 @@ async def test_delete(session_client): # Need to find a better solution to this. await session_client.cleanup() + async def test_modify_path(session_client): resp = await session_client.create('foo/nb1.ipynb') newsession = j(resp) @@ -231,6 +259,7 @@ async def test_modify_path(session_client): # Need to find a better solution to this. await session_client.cleanup() + async def test_modify_path_deprecated(session_client): resp = await session_client.create('foo/nb1.ipynb') newsession = j(resp) @@ -243,6 +272,7 @@ async def test_modify_path_deprecated(session_client): # Need to find a better solution to this. await session_client.cleanup() + async def test_modify_type(session_client): resp = await session_client.create('foo/nb1.ipynb') newsession = j(resp) @@ -255,6 +285,7 @@ async def test_modify_type(session_client): # Need to find a better solution to this. await session_client.cleanup() + async def test_modify_kernel_name(session_client, fetch): resp = await session_client.create('foo/nb1.ipynb') before = j(resp) @@ -270,9 +301,9 @@ async def test_modify_kernel_name(session_client, fetch): # check kernel list, to be sure previous kernel was cleaned up resp = await fetch('api/kernels', method='GET') kernel_list = j(resp) - after['kernel'].pop('last_activity') - [ k.pop('last_activity') for k in kernel_list ] - assert kernel_list == [after['kernel']] + + assert_kernel_equality(kernel_list[0], after['kernel']) + # Need to find a better solution to this. await session_client.cleanup() @@ -299,9 +330,7 @@ async def test_modify_kernel_id(session_client, fetch): resp = await fetch('api/kernels', method='GET') kernel_list = j(resp) - kernel.pop('last_activity') - [ k.pop('last_activity') for k in kernel_list ] - assert kernel_list == [kernel] + assert_kernel_equality(kernel_list[0], kernel) # Need to find a better solution to this. - await session_client.cleanup() \ No newline at end of file + await session_client.cleanup() diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index a37a8c7f0a..32a0967c39 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -1,7 +1,9 @@ import pytest -from tornado import web +from tornado import gen, web +from tornado.ioloop import IOLoop +from jupyter_kernel_mgmt.discovery import KernelFinder from jupyter_server.services.sessions.sessionmanager import SessionManager from jupyter_server.services.kernels.kernelmanager import MappingKernelManager from jupyter_server.services.contents.manager import ContentsManager @@ -9,8 +11,12 @@ class DummyKernel(object): - def __init__(self, kernel_name='python'): - self.kernel_name = kernel_name + def __init__(self, kernel_type='python'): + self.kernel_type = kernel_type + + @gen.coroutine + def client_ready(self): + return # Don't wait for anything dummy_date = utcnow() @@ -26,10 +32,10 @@ def __init__(self, *args, **kwargs): def _new_id(self): return next(self.id_letters) - def start_kernel(self, kernel_id=None, path=None, kernel_name='python', **kwargs): - kernel_id = kernel_id or self._new_id() - k = self._kernels[kernel_id] = DummyKernel(kernel_name=kernel_name) - self._kernel_connections[kernel_id] = 0 + def start_launching_kernel(self, path=None, kernel_name='python', **kwargs): + kernel_id = self._new_id() + k = self._kernels[kernel_id] = DummyKernel(kernel_type=kernel_name) + k.n_connections = 0 k.last_activity = dummy_date k.execution_state = 'idle' return kernel_id @@ -41,7 +47,7 @@ def shutdown_kernel(self, kernel_id, now=False): @pytest.fixture def session_manager(): return SessionManager( - kernel_manager=DummyMKM(), + kernel_manager=DummyMKM(kernel_finder=KernelFinder(providers=[])), contents_manager=ContentsManager()) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 4ad4d71a68..bca1398cd0 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -177,7 +177,7 @@ async def test_gateway_class_mappings(init_gateway, serverapp): # Ensure appropriate class mappings are in place. assert serverapp.kernel_manager_class.__name__ == 'GatewayKernelManager' assert serverapp.session_manager_class.__name__ == 'GatewaySessionManager' - assert serverapp.kernel_spec_manager_class.__name__ == 'GatewayKernelSpecManager' + assert serverapp.kernel_finder.__class__.__name__ == 'GatewayKernelFinder' async def test_gateway_get_kernelspecs(init_gateway, fetch): diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 07947dc549..0dac0c8fac 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -2,24 +2,40 @@ import os -from jupyter_client.session import Session -from jupyter_server.base.zmqhandlers import ( - serialize_binary_message, - deserialize_binary_message, +from jupyter_protocol.messages import Message +from jupyter_server.services.kernels.ws_serialize import ( + serialize_message, deserialize_message ) +def test_serialize_json(): + msg = Message.from_type('data_pub', content={'a': 'b'}) + smsg = serialize_message(msg, 'iopub') + assert isinstance(smsg, str) + def test_serialize_binary(): - s = Session() - msg = s.msg('data_pub', content={'a': 'b'}) - msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ] - bmsg = serialize_binary_message(msg) + msg = Message.from_type('data_pub', content={'a': 'b'}) + msg.buffers = [memoryview(os.urandom(3)) for i in range(3)] + bmsg = serialize_message(msg, 'iopub') assert isinstance(bmsg, bytes) +def test_deserialize_json(): + msg = Message.from_type('data_pub', content={'a': 'b'}) + smsg = serialize_message(msg, 'iopub') + print("Serialised: ", smsg) + msg_dict = msg.make_dict() + msg_dict['channel'] = 'iopub' + msg_dict['buffers'] = [] + + msg2 = deserialize_message(smsg) + assert msg2 == msg_dict def test_deserialize_binary(): - s = Session() - msg = s.msg('data_pub', content={'a': 'b'}) - msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ] - bmsg = serialize_binary_message(msg) - msg2 = deserialize_binary_message(bmsg) - assert msg2 == msg \ No newline at end of file + msg = Message.from_type('data_pub', content={'a': 'b'}) + msg.buffers = [memoryview(os.urandom(3)) for i in range(3)] + bmsg = serialize_message(msg, 'iopub') + msg_dict = msg.make_dict() + msg_dict['channel'] = 'iopub' + msg_dict['buffers'] = msg.buffers + + msg2 = deserialize_message(bmsg) + assert msg2 == msg_dict