Skip to content

Commit

Permalink
stop hook for extensions (#526)
Browse files Browse the repository at this point in the history
* stop hook for extensions

* closes #241
* call a stop_extension method on server shutdown if present

* fix typo

* make extension stop hooks async

* extension stop hook tests

* extension stop hooks feedback

* run_sync

* extension stop hooks extension_apps property
  • Loading branch information
oliver-sanders authored Jul 8, 2021
1 parent f7290dc commit 195ed51
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 29 deletions.
7 changes: 6 additions & 1 deletion docs/source/developers/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,19 @@ The basic structure of an ExtensionApp is shown below:
...
# Change the jinja templating environment
async def stop_extension(self):
...
# Perform any required shut down steps
The ``ExtensionApp`` uses the following methods and properties to connect your extension to the Jupyter server. You do not need to define a ``_load_jupyter_server_extension`` function for these apps. Instead, overwrite the pieces below to add your custom settings, handlers and templates:

Methods

* ``initialize_setting()``: adds custom settings to the Tornado Web Application.
* ``initialize_settings()``: adds custom settings to the Tornado Web Application.
* ``initialize_handlers()``: appends handlers to the Tornado Web Application.
* ``initialize_templates()``: initialize the templating engine (e.g. jinja2) for your frontend.
* ``stop_extension()``: called on server shut down.

Properties

Expand Down
3 changes: 3 additions & 0 deletions jupyter_server/extension/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def start(self):
# Start the server.
self.serverapp.start()

async def stop_extension(self):
"""Cleanup any resources managed by this extension."""

def stop(self):
"""Stop the underlying Jupyter server.
"""
Expand Down
46 changes: 40 additions & 6 deletions jupyter_server/extension/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
import traceback

from tornado.gen import multi

from traitlets.config import LoggingConfigurable

from traitlets import (
Expand Down Expand Up @@ -230,15 +232,17 @@ def link_point(self, point_name, serverapp):

def load_point(self, point_name, serverapp):
point = self.extension_points[point_name]
point.load(serverapp)
return point.load(serverapp)

def link_all_points(self, serverapp):
for point_name in self.extension_points:
self.link_point(point_name, serverapp)

def load_all_points(self, serverapp):
for point_name in self.extension_points:
return [
self.load_point(point_name, serverapp)
for point_name in self.extension_points
]


class ExtensionManager(LoggingConfigurable):
Expand Down Expand Up @@ -290,12 +294,26 @@ def sorted_extensions(self):
"""
)

@property
def extension_apps(self):
"""Return mapping of extension names and sets of ExtensionApp objects.
"""
return {
name: {
point.app
for point in extension.extension_points.values()
if point.app
}
for name, extension in self.extensions.items()
}

@property
def extension_points(self):
extensions = self.extensions
"""Return mapping of extension point names and ExtensionPoint objects.
"""
return {
name: point
for value in extensions.values()
for value in self.extensions.values()
for name, point in value.extension_points.items()
}

Expand Down Expand Up @@ -341,13 +359,22 @@ def link_extension(self, name, serverapp):

def load_extension(self, name, serverapp):
extension = self.extensions.get(name)

if extension.enabled:
try:
extension.load_all_points(serverapp)
self.log.info("{name} | extension was successfully loaded.".format(name=name))
points = extension.load_all_points(serverapp)
except Exception as e:
self.log.debug("".join(traceback.format_exception(*sys.exc_info())))
self.log.warning("{name} | extension failed loading with message: {error}".format(name=name,error=str(e)))
else:
self.log.info("{name} | extension was successfully loaded.".format(name=name))

async def stop_extension(self, name, apps):
"""Call the shutdown hooks in the specified apps."""
for app in apps:
self.log.debug('{} | extension app "{}" stopping'.format(name, app.name))
await app.stop_extension()
self.log.debug('{} | extension app "{}" stopped'.format(name, app.name))

def link_all_extensions(self, serverapp):
"""Link all enabled extensions
Expand All @@ -366,3 +393,10 @@ def load_all_extensions(self, serverapp):
# order.
for name in self.sorted_extensions.keys():
self.load_extension(name, serverapp)

async def stop_all_extensions(self, serverapp):
"""Call the shutdown hooks in all extensions."""
await multi([
self.stop_extension(name, apps)
for name, apps in sorted(dict(self.extension_apps).items())
])
4 changes: 2 additions & 2 deletions jupyter_server/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from jupyter_server.extension import serverextension
from jupyter_server.serverapp import ServerApp
from jupyter_server.utils import url_path_join
from jupyter_server.utils import url_path_join, run_sync
from jupyter_server.services.contents.filemanager import FileContentsManager
from jupyter_server.services.contents.largefilemanager import LargeFileManager

Expand Down Expand Up @@ -284,7 +284,7 @@ def jp_serverapp(
"""Starts a Jupyter Server instance based on the established configuration values."""
app = jp_configurable_serverapp(config=jp_server_config, argv=jp_argv)
yield app
app._cleanup()
run_sync(app._cleanup())


@pytest.fixture
Expand Down
62 changes: 42 additions & 20 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

from jupyter_core.paths import secure_write
from jupyter_server.transutils import trans, _i18n
from jupyter_server.utils import run_sync
from jupyter_server.utils import run_sync_in_loop

# the minimum viable tornado version: needs to be kept in sync with setup.py
MIN_TORNADO = (6, 1, 0)
Expand Down Expand Up @@ -1777,7 +1777,7 @@ def _confirm_exit(self):
self.log.critical(_i18n("Shutting down..."))
# schedule stop on the main thread,
# since this might be called from a signal handler
self.io_loop.add_callback_from_signal(self.io_loop.stop)
self.stop(from_signal=True)
return
print(self.running_server_info())
yes = _i18n('y')
Expand All @@ -1791,7 +1791,7 @@ def _confirm_exit(self):
self.log.critical(_i18n("Shutdown confirmed"))
# schedule stop on the main thread,
# since this might be called from a signal handler
self.io_loop.add_callback_from_signal(self.io_loop.stop)
self.stop(from_signal=True)
return
else:
print(_i18n("No answer for 5s:"), end=' ')
Expand All @@ -1804,7 +1804,7 @@ def _confirm_exit(self):

def _signal_stop(self, sig, frame):
self.log.critical(_i18n("received signal %s, stopping"), sig)
self.io_loop.add_callback_from_signal(self.io_loop.stop)
self.stop(from_signal=True)

def _signal_info(self, sig, frame):
print(self.running_server_info())
Expand Down Expand Up @@ -2086,7 +2086,7 @@ def initialize(self, argv=None, find_extensions=True, new_httpserver=True, start
if new_httpserver:
self.init_httpserver()

def cleanup_kernels(self):
async def cleanup_kernels(self):
"""Shutdown all kernels.
The kernels will shutdown themselves when this process no longer exists,
Expand All @@ -2095,9 +2095,9 @@ def cleanup_kernels(self):
n_kernels = len(self.kernel_manager.list_kernel_ids())
kernel_msg = trans.ngettext('Shutting down %d kernel', 'Shutting down %d kernels', n_kernels)
self.log.info(kernel_msg % n_kernels)
run_sync(self.kernel_manager.shutdown_all())
await run_sync_in_loop(self.kernel_manager.shutdown_all())

def cleanup_terminals(self):
async def cleanup_terminals(self):
"""Shutdown all terminals.
The terminals will shutdown themselves when this process no longer exists,
Expand All @@ -2110,7 +2110,20 @@ def cleanup_terminals(self):
n_terminals = len(terminal_manager.list())
terminal_msg = trans.ngettext('Shutting down %d terminal', 'Shutting down %d terminals', n_terminals)
self.log.info(terminal_msg % n_terminals)
run_sync(terminal_manager.terminate_all())
await run_sync_in_loop(terminal_manager.terminate_all())

async def cleanup_extensions(self):
"""Call shutdown hooks in all extensions."""
n_extensions = len(self.extension_manager.extension_apps)
extension_msg = trans.ngettext(
'Shutting down %d extension',
'Shutting down %d extensions',
n_extensions
)
self.log.info(extension_msg % n_extensions)
await run_sync_in_loop(
self.extension_manager.stop_all_extensions(self)
)

def running_server_info(self, kernel_count=True):
"Return the current working directory and the server url information"
Expand Down Expand Up @@ -2348,14 +2361,15 @@ def start_app(self):
' %s' % self.display_url,
]))

def _cleanup(self):
"""General cleanup of files and kernels created
async def _cleanup(self):
"""General cleanup of files, extensions and kernels created
by this instance ServerApp.
"""
self.remove_server_info_file()
self.remove_browser_open_files()
self.cleanup_kernels()
self.cleanup_terminals()
await self.cleanup_extensions()
await self.cleanup_kernels()
await self.cleanup_terminals()

def start_ioloop(self):
"""Start the IO Loop."""
Expand All @@ -2368,8 +2382,6 @@ def start_ioloop(self):
self.io_loop.start()
except KeyboardInterrupt:
self.log.info(_i18n("Interrupted..."))
finally:
self._cleanup()

def init_ioloop(self):
"""init self.io_loop so that an extension can use it by io_loop.call_later() to create background tasks"""
Expand All @@ -2383,13 +2395,23 @@ def start(self):
self.start_app()
self.start_ioloop()

def stop(self):
def _stop():
async def _stop(self):
"""Cleanup resources and stop the IO Loop."""
await self._cleanup()
self.io_loop.stop()

def stop(self, from_signal=False):
"""Cleanup resources and stop the server."""
if hasattr(self, '_http_server'):
# Stop a server if its set.
if hasattr(self, '_http_server'):
self.http_server.stop()
self.io_loop.stop()
self.io_loop.add_callback(_stop)
self.http_server.stop()
if getattr(self, 'io_loop', None):
# use IOLoop.add_callback because signal.signal must be called
# from main thread
if from_signal:
self.io_loop.add_callback_from_signal(self._stop)
else:
self.io_loop.add_callback(self._stop)


def list_running_servers(runtime_dir=None):
Expand Down
40 changes: 40 additions & 0 deletions jupyter_server/tests/extension/test_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from traitlets.config import Config
from jupyter_server.serverapp import ServerApp
from jupyter_server.utils import run_sync
from .mockextensions.app import MockExtensionApp


Expand Down Expand Up @@ -101,3 +102,42 @@ def test_load_parallel_extensions(monkeypatch, jp_environ):
exts = serverapp.jpserver_extensions
assert exts['jupyter_server.tests.extension.mockextensions.mock1']
assert exts['jupyter_server.tests.extension.mockextensions']


def test_stop_extension(jp_serverapp, caplog):
"""Test the stop_extension method.
This should be fired by ServerApp.cleanup_extensions.
"""
calls = 0

# load extensions (make sure we only have the one extension loaded
jp_serverapp.extension_manager.load_all_extensions(jp_serverapp)
extension_name = 'jupyter_server.tests.extension.mockextensions'
assert list(jp_serverapp.extension_manager.extension_apps) == [
extension_name
]

# add a stop_extension method for the extension app
async def _stop(*args):
nonlocal calls
calls += 1
for apps in jp_serverapp.extension_manager.extension_apps.values():
for app in apps:
if app:
app.stop_extension = _stop

# call cleanup_extensions, check the logging is correct
caplog.clear()
run_sync(jp_serverapp.cleanup_extensions())
assert [
msg
for *_, msg in caplog.record_tuples
] == [
'Shutting down 1 extension',
'{} | extension app "mockextension" stopping'.format(extension_name),
'{} | extension app "mockextension" stopped'.format(extension_name),
]

# check the shutdown method was called once
assert calls == 1
23 changes: 23 additions & 0 deletions jupyter_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,29 @@ def wrapped():
return wrapped()


async def run_sync_in_loop(maybe_async):
"""Runs a function synchronously whether it is an async function or not.
If async, runs maybe_async and blocks until it has executed.
If not async, just returns maybe_async as it is the result of something
that has already executed.
Parameters
----------
maybe_async : async or non-async object
The object to be executed, if it is async.
Returns
-------
result
Whatever the async object returns, or the object itself.
"""
if not inspect.isawaitable(maybe_async):
return maybe_async
return await maybe_async


def urlencode_unix_socket_path(socket_path):
"""Encodes a UNIX socket path string from a socket path for the `http+unix` URI form."""
return socket_path.replace('/', '%2F')
Expand Down

0 comments on commit 195ed51

Please sign in to comment.