diff --git a/notebook/gateway/managers.py b/notebook/gateway/managers.py index 856cea494eb..6f7c9952daf 100644 --- a/notebook/gateway/managers.py +++ b/notebook/gateway/managers.py @@ -454,7 +454,7 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False): self.remove_kernel(kernel_id) @gen.coroutine - def restart_kernel(self, kernel_id, now=False, **kwargs): + def restart_kernel(self, kernel_id, channels=None, now=False, **kwargs): """Restart a kernel by its kernel uuid. Parameters diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index 73da737b150..75ebeb3cbb7 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -78,9 +78,8 @@ def post(self, kernel_id, action): yield maybe_future(km.interrupt_kernel(kernel_id)) self.set_status(204) if action == 'restart': - try: - yield maybe_future(km.restart_kernel(kernel_id)) + yield maybe_future(km.restart_kernel(kernel_id, km.channels)) except Exception as e: self.log.error("Exception restarting kernel", exc_info=True) self.set_status(500) @@ -121,12 +120,71 @@ def __repr__(self): return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized')) def create_stream(self): + self.log.debug("Create stream") km = self.kernel_manager identity = self.session.bsession for channel in ('shell', 'control', 'iopub', 'stdin'): meth = getattr(km, 'connect_' + channel) self.channels[channel] = stream = meth(self.kernel_id, identity=identity) stream.channel = channel + + shell_channel = self.channels['shell'] + iopub_channel = self.channels['iopub'] + + future = Future() + info_future = Future() + iopub_future = Future() + + def finish(): + """Common cleanup""" + loop.remove_timeout(timeout) + loop.remove_timeout(nudge_handle) + iopub_channel.stop_on_recv() + shell_channel.stop_on_recv() + + def on_shell_reply(msg): + if not info_future.done(): + self.log.debug("Nudge: shell info reply received: %s", self.kernel_id) + shell_channel.stop_on_recv() + self.log.debug("Nudge: resolving shell future") + info_future.set_result(msg) + if iopub_future.done(): + finish() + self.log.debug("Nudge: resolving main future in shell handler") + future.set_result(info_future.result()) + + def on_iopub(msg): + if not iopub_future.done(): + self.log.debug("Nudge: first IOPub received: %s", self.kernel_id) + iopub_channel.stop_on_recv() + self.log.debug("Nudge: resolving iopub future") + iopub_future.set_result(None) + if info_future.done(): + finish() + self.log.debug("Nudge: resolving main future in iopub handler") + future.set_result(info_future.result()) + + def on_timeout(): + self.log.warning("Nudge: Timeout waiting for kernel_info_reply: %s", self.kernel_id) + finish() + if not future.done(): + future.set_exception(TimeoutError("Timeout waiting for nudge")) + + iopub_channel.on_recv(on_iopub) + shell_channel.on_recv(on_shell_reply) + loop = IOLoop.current() + + # Nudge the kernel with kernel info requests until we get an IOPub message + def nudge(): + self.log.debug("Nudge") + if not future.done(): + self.log.debug("nudging") + self.session.send(shell_channel, "kernel_info_request") + nudge_handle = loop.call_later(0.5, nudge) + nudge_handle = loop.call_later(0, nudge) + + timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout) + return future def request_kernel_info(self): """send a request for kernel_info""" @@ -193,6 +251,7 @@ def initialize(self): super().initialize() self.zmq_stream = None self.channels = {} + self.kernel_manager.channels = self.channels self.kernel_id = None self.kernel_info_channel = None self._kernel_info_future = Future() @@ -253,6 +312,7 @@ def _register_session(self): yield stale_handler.close() self._open_sessions[self.session_key] = self + @gen.coroutine def open(self, kernel_id): super().open() km = self.kernel_manager @@ -269,9 +329,11 @@ def open(self, kernel_id): for channel, msg_list in replay_buffer: stream = self.channels[channel] self._on_zmq_reply(stream, msg_list) + connected = Future() + connected.set_result(None) else: try: - self.create_stream() + connected = self.create_stream() except web.HTTPError as e: self.log.error("Error opening stream: %s", e) # WebSockets don't response to traditional error codes so we @@ -285,8 +347,13 @@ def open(self, kernel_id): km.add_restart_callback(self.kernel_id, self.on_kernel_restarted) km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead') - for channel, stream in self.channels.items(): - stream.on_recv_stream(self._on_zmq_reply) + def subscribe(value): + for channel, stream in self.channels.items(): + stream.on_recv_stream(self._on_zmq_reply) + + connected.add_done_callback(subscribe) + + return connected def on_message(self, msg): if not self.channels: diff --git a/notebook/services/kernels/kernelmanager.py b/notebook/services/kernels/kernelmanager.py index 61cbbe58f52..2cef17a76db 100644 --- a/notebook/services/kernels/kernelmanager.py +++ b/notebook/services/kernels/kernelmanager.py @@ -304,33 +304,59 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False): return self.pinned_superclass.shutdown_kernel(self, kernel_id, now=now, restart=restart) - async def restart_kernel(self, kernel_id, now=False): + async def restart_kernel(self, kernel_id, channels, now=False): """Restart a kernel by kernel_id""" self._check_kernel_id(kernel_id) await maybe_future(self.pinned_superclass.restart_kernel(self, kernel_id, now=now)) kernel = self.get_kernel(kernel_id) # return a Future that will resolve when the kernel has successfully restarted - channel = kernel.connect_shell() + shell_channel = self.channels['shell'] + iopub_channel = self.channels['iopub'] + + session = Session( + config=kernel.session.config, + key=kernel.session.key, + ) + future = Future() + info_future = Future() + iopub_future = Future() def finish(): - """Common cleanup when restart finishes/fails for any reason.""" - if not channel.closed(): - channel.close() + """Common cleanup""" loop.remove_timeout(timeout) + loop.remove_timeout(nudge_handle) + iopub_channel.stop_on_recv() + shell_channel.stop_on_recv() kernel.remove_restart_callback(on_restart_failed, 'dead') - def on_reply(msg): - self.log.debug("Kernel info reply received: %s", kernel_id) - finish() - if not future.done(): - future.set_result(msg) + def on_shell_reply(msg): + if not info_future.done(): + self.log.debug("Nudge: shell info reply received: %s", kernel_id) + shell_channel.stop_on_recv() + self.log.debug("Nudge: resolving shell future") + info_future.set_result(msg) + if iopub_future.done(): + finish() + self.log.debug("Nudge: resolving main future in shell handler") + future.set_result(info_future.result()) + + def on_iopub(msg): + if not iopub_future.done(): + self.log.debug("Nudge: first IOPub received: %s", kernel_id) + iopub_channel.stop_on_recv() + self.log.debug("Nudge: resolving iopub future") + iopub_future.set_result(None) + if info_future.done(): + finish() + self.log.debug("Nudge: resolving main future in iopub handler") + future.set_result(info_future.result()) def on_timeout(): - self.log.warning("Timeout waiting for kernel_info_reply: %s", kernel_id) + self.log.warning("Nudge: Timeout waiting for kernel_info_reply: %s", kernel_id) finish() if not future.done(): - future.set_exception(TimeoutError("Timeout waiting for restart")) + future.set_exception(TimeoutError("Timeout waiting for nudge")) def on_restart_failed(): self.log.warning("Restarting kernel failed: %s", kernel_id) @@ -339,10 +365,20 @@ def on_restart_failed(): future.set_exception(RuntimeError("Restart failed")) kernel.add_restart_callback(on_restart_failed, 'dead') - kernel.session.send(channel, "kernel_info_request") - channel.on_recv(on_reply) + + iopub_channel.on_recv(on_iopub) + shell_channel.on_recv(on_shell_reply) loop = IOLoop.current() - timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout) + + # Nudge the kernel with kernel info requests until we get an IOPub message + def nudge(): + self.log.debug("Nudge") + if not future.done(): + self.log.debug("nudging") + session.send(shell_channel, "kernel_info_request") + nudge_handle = loop.call_later(0.5, nudge) + nudge_handle = loop.call_later(0, nudge) + return future def notify_connect(self, kernel_id):