Skip to content

Commit

Permalink
Add UNIX socket support to notebook server.
Browse files Browse the repository at this point in the history
  • Loading branch information
kwlzn committed Sep 6, 2019
1 parent 20c2c66 commit 71b2132
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 49 deletions.
2 changes: 2 additions & 0 deletions notebook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
os.path.join(os.path.dirname(__file__), "templates"),
]

DEFAULT_NOTEBOOK_PORT = 8888

del os

from .nbextensions import install_nbextension
Expand Down
19 changes: 12 additions & 7 deletions notebook/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import notebook
from notebook._tz import utcnow
from notebook.i18n import combine_translations
from notebook.utils import is_hidden, url_path_join, url_is_absolute, url_escape
from notebook.utils import is_hidden, url_path_join, url_is_absolute, url_escape, urldecode_unix_socket_path
from notebook.services.security import csp_report_uri

#-----------------------------------------------------------------------------
Expand Down Expand Up @@ -483,13 +483,18 @@ def check_host(self):
# ip_address only accepts unicode on Python 2
host = host.decode('utf8', 'replace')

try:
addr = ipaddress.ip_address(host)
except ValueError:
# Not an IP address: check against hostnames
allow = host in self.settings.get('local_hostnames', ['localhost'])
# UNIX socket handling
check_host = urldecode_unix_socket_path(host)
if check_host.startswith('/') and os.path.exists(check_host):
allow = True
else:
allow = addr.is_loopback
try:
addr = ipaddress.ip_address(host)
except ValueError:
# Not an IP address: check against hostnames
allow = host in self.settings.get('local_hostnames', ['localhost'])
else:
allow = addr.is_loopback

if not allow:
self.log.warning(
Expand Down
191 changes: 156 additions & 35 deletions notebook/notebookapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@
from tornado import web
from tornado.httputil import url_concat
from tornado.log import LogFormatter, app_log, access_log, gen_log
if not sys.platform.startswith('win'):
from tornado.netutil import bind_unix_socket

from notebook import (
DEFAULT_NOTEBOOK_PORT,
DEFAULT_STATIC_FILES_PATH,
DEFAULT_TEMPLATE_PATH_LIST,
__version__,
Expand Down Expand Up @@ -109,7 +112,16 @@
from notebook._sysinfo import get_sys_info

from ._tz import utcnow, utcfromtimestamp
from .utils import url_path_join, check_pid, url_escape, urljoin, pathname2url
from .utils import (
check_pid,
pathname2url,
url_escape,
url_path_join,
urldecode_unix_socket_path,
urlencode_unix_socket,
urlencode_unix_socket_path,
urljoin,
)

#-----------------------------------------------------------------------------
# Module globals
Expand Down Expand Up @@ -213,7 +225,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager,
warnings.warn(_("The `ignore_minified_js` flag is deprecated and will be removed in Notebook 6.0"), DeprecationWarning)

now = utcnow()

root_dir = contents_manager.root_dir
home = py3compat.str_to_unicode(os.path.expanduser('~'), encoding=sys.getfilesystemencoding())
if root_dir.startswith(home + os.path.sep):
Expand Down Expand Up @@ -398,6 +410,7 @@ def start(self):
set_password(config_file=self.config_file)
self.log.info("Wrote hashed password to %s" % self.config_file)


def shutdown_server(server_info, timeout=5, log=None):
"""Shutdown a notebook server in a separate process.
Expand All @@ -410,14 +423,39 @@ def shutdown_server(server_info, timeout=5, log=None):
Returns True if the server was stopped by any means, False if stopping it
failed (on Windows).
"""
from tornado.httpclient import HTTPClient, HTTPRequest
from tornado import gen
from tornado.httpclient import AsyncHTTPClient, HTTPClient, HTTPRequest
from tornado.netutil import bind_unix_socket, Resolver
url = server_info['url']
pid = server_info['pid']
resolver = None

# UNIX Socket handling.
if url.startswith('http+unix://'):
# This library doesn't understand our URI form, but it's just HTTP.
url = url.replace('http+unix://', 'http://')

class UnixSocketResolver(Resolver):
def initialize(self, resolver):
self.resolver = resolver

def close(self):
self.resolver.close()

@gen.coroutine
def resolve(self, host, port, *args, **kwargs):
raise gen.Return([
(socket.AF_UNIX, urldecode_unix_socket_path(host))
])

resolver = UnixSocketResolver(resolver=Resolver())

req = HTTPRequest(url + 'api/shutdown', method='POST', body=b'', headers={
'Authorization': 'token ' + server_info['token']
})
if log: log.debug("POST request to %sapi/shutdown", url)
HTTPClient().fetch(req)
AsyncHTTPClient.configure(None, resolver=resolver)
HTTPClient(AsyncHTTPClient).fetch(req)

# Poll to see if it shut down.
for _ in range(timeout*10):
Expand Down Expand Up @@ -448,13 +486,20 @@ class NbserverStopApp(JupyterApp):
version = __version__
description="Stop currently running notebook server for a given port"

port = Integer(8888, config=True,
help="Port of the server to be killed. Default 8888")
port = Integer(DEFAULT_NOTEBOOK_PORT, config=True,
help="Port of the server to be killed. Default %s" % DEFAULT_NOTEBOOK_PORT)

sock = Unicode(u'', config=True,
help="UNIX socket of the server to be killed.")

def parse_command_line(self, argv=None):
super(NbserverStopApp, self).parse_command_line(argv)
if self.extra_args:
self.port=int(self.extra_args[0])
try:
self.port = int(self.extra_args[0])
except ValueError:
# self.extra_args[0] was not an int, so it must be a string (unix socket).
self.sock = self.extra_args[0]

def shutdown_server(self, server):
return shutdown_server(server, log=self.log)
Expand All @@ -464,16 +509,16 @@ def start(self):
if not servers:
self.exit("There are no running servers")
for server in servers:
if server['port'] == self.port:
print("Shutting down server on port", self.port, "...")
if server.get('sock') == self.sock or server['port'] == self.port:
print("Shutting down server on %s..." % self.sock or self.port)
if not self.shutdown_server(server):
sys.exit("Could not stop server")
return
else:
print("There is currently no server running on port {}".format(self.port), file=sys.stderr)
print("Ports currently in use:", file=sys.stderr)
print("Ports/sockets currently in use:", file=sys.stderr)
for server in servers:
print(" - {}".format(server['port']), file=sys.stderr)
print(" - {}".format(server.get('sock', server['port'])), file=sys.stderr)
self.exit(1)


Expand Down Expand Up @@ -553,6 +598,8 @@ def start(self):
'ip': 'NotebookApp.ip',
'port': 'NotebookApp.port',
'port-retries': 'NotebookApp.port_retries',
'sock': 'NotebookApp.sock',
'sock-umask': 'NotebookApp.sock_umask',
'transport': 'KernelManager.transport',
'keyfile': 'NotebookApp.keyfile',
'certfile': 'NotebookApp.certfile',
Expand Down Expand Up @@ -692,10 +739,18 @@ def _valdate_ip(self, proposal):
or containerized setups for example).""")
)

port = Integer(8888, config=True,
port = Integer(DEFAULT_NOTEBOOK_PORT, config=True,
help=_("The port the notebook server will listen on.")
)

sock = Unicode(u'', config=True,
help=_("The UNIX socket the notebook server will listen on.")
)

sock_umask = Unicode(u'0600', config=True,
help=_("The UNIX socket umask to set on creation (default: 0600).")
)

port_retries = Integer(50, config=True,
help=_("The number of additional ports to try if the specified port is not available.")
)
Expand Down Expand Up @@ -1400,6 +1455,27 @@ def init_webapp(self):
self.log.critical(_("\t$ python -m notebook.auth password"))
sys.exit(1)

# Socket options validation.
if self.sock:
if self.port != DEFAULT_NOTEBOOK_PORT:
self.log.critical(
_('Options --port and --sock are mutually exclusive. Aborting.'),
)
sys.exit(1)

if self.open_browser:
# If we're bound to a UNIX socket, we can't reliably connect from a browser.
self.log.critical(
_('Options --open-browser and --sock are mutually exclusive. Aborting.'),
)
sys.exit(1)

if sys.platform.startswith('win'):
self.log.critical(
_('Option --sock is not supported on Windows, but got value of %s. Aborting.' % self.sock),
)
sys.exit(1)

self.web_app = NotebookWebApplication(
self, self.kernel_manager, self.contents_manager,
self.session_manager, self.kernel_spec_manager,
Expand Down Expand Up @@ -1436,6 +1512,32 @@ def init_webapp(self):
max_body_size=self.max_body_size,
max_buffer_size=self.max_buffer_size)

success = self._bind_http_server()
if not success:
self.log.critical(_('ERROR: the notebook server could not be started because '
'no available port could be found.'))
self.exit(1)

def _bind_http_server(self):
return self._bind_http_server_unix() if self.sock else self._bind_http_server_tcp()

def _bind_http_server_unix(self):
try:
sock = bind_unix_socket(self.sock, mode=int(self.sock_umask.encode(), 8))
self.http_server.add_socket(sock)
except socket.error as e:
if e.errno == errno.EADDRINUSE:
self.log.info(_('The socket %s is already in use.') % self.sock)
return False
elif e.errno in (errno.EACCES, getattr(errno, 'WSAEACCES', errno.EACCES)):
self.log.warning(_("Permission to listen on sock %s denied") % self.sock)
return False
else:
raise
else:
return True

def _bind_http_server_tcp(self):
success = None
for port in random_ports(self.port, self.port_retries+1):
try:
Expand All @@ -1453,39 +1555,45 @@ def init_webapp(self):
self.port = port
success = True
break
if not success:
self.log.critical(_('ERROR: the notebook server could not be started because '
'no available port could be found.'))
self.exit(1)
return success

def _concat_token(self, url):
token = self.token if self._token_generated else '...'
return url_concat(url, {'token': token})

@property
def display_url(self):
if self.custom_display_url:
url = self.custom_display_url
if not url.endswith('/'):
url += '/'
elif self.sock:
url = self._unix_sock_url()
else:
if self.ip in ('', '0.0.0.0'):
ip = "%s" % socket.gethostname()
else:
ip = self.ip
url = self._url(ip)
if self.token:
# Don't log full token if it came from config
token = self.token if self._token_generated else '...'
url = (url_concat(url, {'token': token})
+ '\n or '
+ url_concat(self._url('127.0.0.1'), {'token': token}))
url = self._tcp_url(ip)
if self.token and not self.sock:
url = self._concat_token(url)
url += '\n or %s' % self._concat_token(self._tcp_url('127.0.0.1'))
return url

@property
def connection_url(self):
ip = self.ip if self.ip else 'localhost'
return self._url(ip)
if self.sock:
return self._unix_sock_url()
else:
ip = self.ip if self.ip else 'localhost'
return self._tcp_url(ip)

def _url(self, ip):
def _unix_sock_url(self, token=None):
return '%s%s' % (urlencode_unix_socket(self.sock), self.base_url)

def _tcp_url(self, ip, port=None):
proto = 'https' if self.certfile else 'http'
return "%s://%s:%i%s" % (proto, ip, self.port, self.base_url)
return "%s://%s:%i%s" % (proto, ip, port or self.port, self.base_url)

def init_terminals(self):
if not self.terminals_enabled:
Expand Down Expand Up @@ -1713,6 +1821,7 @@ def server_info(self):
return {'url': self.connection_url,
'hostname': self.ip if self.ip else 'localhost',
'port': self.port,
'sock': self.sock,
'secure': bool(self.certfile),
'base_url': self.base_url,
'token': self.token,
Expand Down Expand Up @@ -1833,19 +1942,31 @@ def start(self):
self.write_server_info_file()
self.write_browser_open_file()

if self.open_browser or self.file_to_run:
if (self.open_browser or self.file_to_run) and not self.sock:
self.launch_browser()

if self.token and self._token_generated:
# log full URL with generated token, so there's a copy/pasteable link
# with auth info.
self.log.critical('\n'.join([
'\n',
'To access the notebook, open this file in a browser:',
' %s' % urljoin('file:', pathname2url(self.browser_open_file)),
'Or copy and paste one of these URLs:',
' %s' % self.display_url,
]))
if self.sock:
self.log.critical('\n'.join([
'\n',
'Notebook is listening on %s' % self.display_url,
'',
(
'UNIX sockets are not browser-connectable, but you can tunnel to '
'the instance via e.g.`ssh -L 8888:%s -N user@this_host` and then '
'opening e.g. %s in a browser.'
) % (self.sock, self._concat_token(self._tcp_url('localhost', 8888)))
]))
else:
self.log.critical('\n'.join([
'\n',
'To access the notebook, open this file in a browser:',
' %s' % urljoin('file:', pathname2url(self.browser_open_file)),
'Or copy and paste one of these URLs:',
' %s' % self.display_url,
]))

self.io_loop = ioloop.IOLoop.current()
if sys.platform.startswith('win'):
Expand Down
Loading

0 comments on commit 71b2132

Please sign in to comment.