diff --git a/notebook/__init__.py b/notebook/__init__.py index f406f3c6eb..c72096f26b 100644 --- a/notebook/__init__.py +++ b/notebook/__init__.py @@ -20,6 +20,8 @@ os.path.join(os.path.dirname(__file__), "templates"), ] +DEFAULT_NOTEBOOK_PORT = 8888 + del os from .nbextensions import install_nbextension diff --git a/notebook/base/handlers.py b/notebook/base/handlers.py index c05d6ce60c..b8d7a3810c 100755 --- a/notebook/base/handlers.py +++ b/notebook/base/handlers.py @@ -32,7 +32,7 @@ import notebook from notebook._tz import utcnow from notebook.i18n import combine_translations -from notebook.utils import is_hidden, url_path_join, url_is_absolute, url_escape +from notebook.utils import is_hidden, url_path_join, url_is_absolute, url_escape, urldecode_unix_socket_path from notebook.services.security import csp_report_uri #----------------------------------------------------------------------------- @@ -471,13 +471,18 @@ def check_host(self): if host.startswith('[') and host.endswith(']'): host = host[1:-1] - try: - addr = ipaddress.ip_address(host) - except ValueError: - # Not an IP address: check against hostnames - allow = host in self.settings.get('local_hostnames', ['localhost']) + # UNIX socket handling + check_host = urldecode_unix_socket_path(host) + if check_host.startswith('/') and os.path.exists(check_host): + allow = True else: - allow = addr.is_loopback + try: + addr = ipaddress.ip_address(host) + except ValueError: + # Not an IP address: check against hostnames + allow = host in self.settings.get('local_hostnames', ['localhost']) + else: + allow = addr.is_loopback if not allow: self.log.warning( diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py index 4fa3925422..5433f2b2bc 100755 --- a/notebook/notebookapp.py +++ b/notebook/notebookapp.py @@ -26,6 +26,7 @@ import select import signal import socket +import stat import sys import tempfile import threading @@ -66,8 +67,11 @@ from tornado import web from tornado.httputil import url_concat from tornado.log import LogFormatter, app_log, access_log, gen_log +if not sys.platform.startswith('win'): + from tornado.netutil import bind_unix_socket from notebook import ( + DEFAULT_NOTEBOOK_PORT, DEFAULT_STATIC_FILES_PATH, DEFAULT_TEMPLATE_PATH_LIST, __version__, @@ -107,7 +111,18 @@ from notebook._sysinfo import get_sys_info from ._tz import utcnow, utcfromtimestamp -from .utils import url_path_join, check_pid, url_escape, urljoin, pathname2url, run_sync +from .utils import ( + check_pid, + pathname2url, + run_sync, + unix_socket_in_use, + url_escape, + url_path_join, + urldecode_unix_socket_path, + urlencode_unix_socket, + urlencode_unix_socket_path, + urljoin, +) # Check if we can use async kernel management try: @@ -218,7 +233,7 @@ def init_settings(self, jupyter_app, kernel_manager, contents_manager, warnings.warn(_("The `ignore_minified_js` flag is deprecated and will be removed in Notebook 6.0"), DeprecationWarning) now = utcnow() - + root_dir = contents_manager.root_dir home = py3compat.str_to_unicode(os.path.expanduser('~'), encoding=sys.getfilesystemencoding()) if root_dir.startswith(home + os.path.sep): @@ -403,6 +418,7 @@ def start(self): set_password(config_file=self.config_file) self.log.info("Wrote hashed password to %s" % self.config_file) + def shutdown_server(server_info, timeout=5, log=None): """Shutdown a notebook server in a separate process. @@ -415,14 +431,39 @@ def shutdown_server(server_info, timeout=5, log=None): Returns True if the server was stopped by any means, False if stopping it failed (on Windows). """ - from tornado.httpclient import HTTPClient, HTTPRequest + from tornado import gen + from tornado.httpclient import AsyncHTTPClient, HTTPClient, HTTPRequest + from tornado.netutil import bind_unix_socket, Resolver url = server_info['url'] pid = server_info['pid'] + resolver = None + + # UNIX Socket handling. + if url.startswith('http+unix://'): + # This library doesn't understand our URI form, but it's just HTTP. + url = url.replace('http+unix://', 'http://') + + class UnixSocketResolver(Resolver): + def initialize(self, resolver): + self.resolver = resolver + + def close(self): + self.resolver.close() + + @gen.coroutine + def resolve(self, host, port, *args, **kwargs): + raise gen.Return([ + (socket.AF_UNIX, urldecode_unix_socket_path(host)) + ]) + + resolver = UnixSocketResolver(resolver=Resolver()) + req = HTTPRequest(url + 'api/shutdown', method='POST', body=b'', headers={ 'Authorization': 'token ' + server_info['token'] }) if log: log.debug("POST request to %sapi/shutdown", url) - HTTPClient().fetch(req) + AsyncHTTPClient.configure(None, resolver=resolver) + HTTPClient(AsyncHTTPClient).fetch(req) # Poll to see if it shut down. for _ in range(timeout*10): @@ -451,34 +492,65 @@ def shutdown_server(server_info, timeout=5, log=None): class NbserverStopApp(JupyterApp): version = __version__ - description="Stop currently running notebook server for a given port" + description="Stop currently running notebook server." - port = Integer(8888, config=True, - help="Port of the server to be killed. Default 8888") + port = Integer(DEFAULT_NOTEBOOK_PORT, config=True, + help="Port of the server to be killed. Default %s" % DEFAULT_NOTEBOOK_PORT) + + sock = Unicode(u'', config=True, + help="UNIX socket of the server to be killed.") def parse_command_line(self, argv=None): super(NbserverStopApp, self).parse_command_line(argv) if self.extra_args: - self.port=int(self.extra_args[0]) + try: + self.port = int(self.extra_args[0]) + except ValueError: + # self.extra_args[0] was not an int, so it must be a string (unix socket). + self.sock = self.extra_args[0] def shutdown_server(self, server): return shutdown_server(server, log=self.log) + def _shutdown_or_exit(self, target_endpoint, server): + print("Shutting down server on %s..." % target_endpoint) + if not self.shutdown_server(server): + sys.exit("Could not stop server on %s" % target_endpoint) + + @staticmethod + def _maybe_remove_unix_socket(socket_path): + try: + os.unlink(socket_path) + except (OSError, IOError): + pass + def start(self): servers = list(list_running_servers(self.runtime_dir)) if not servers: - self.exit("There are no running servers") + self.exit("There are no running servers (per %s)" % self.runtime_dir) + for server in servers: - if server['port'] == self.port: - print("Shutting down server on port", self.port, "...") - if not self.shutdown_server(server): - sys.exit("Could not stop server") - return + if self.sock: + sock = server.get('sock', None) + if sock and sock == self.sock: + self._shutdown_or_exit(sock, server) + # Attempt to remove the UNIX socket after stopping. + self._maybe_remove_unix_socket(sock) + return + elif self.port: + port = server.get('port', None) + if port == self.port: + self._shutdown_or_exit(port, server) + return else: - print("There is currently no server running on port {}".format(self.port), file=sys.stderr) - print("Ports currently in use:", file=sys.stderr) + current_endpoint = self.sock or self.port + print( + "There is currently no server running on {}".format(current_endpoint), + file=sys.stderr + ) + print("Ports/sockets currently in use:", file=sys.stderr) for server in servers: - print(" - {}".format(server['port']), file=sys.stderr) + print(" - {}".format(server.get('sock') or server['port']), file=sys.stderr) self.exit(1) @@ -558,6 +630,8 @@ def start(self): 'ip': 'NotebookApp.ip', 'port': 'NotebookApp.port', 'port-retries': 'NotebookApp.port_retries', + 'sock': 'NotebookApp.sock', + 'sock-mode': 'NotebookApp.sock_mode', 'transport': 'KernelManager.transport', 'keyfile': 'NotebookApp.keyfile', 'certfile': 'NotebookApp.certfile', @@ -696,7 +770,7 @@ def _default_ip(self): return 'localhost' @validate('ip') - def _valdate_ip(self, proposal): + def _validate_ip(self, proposal): value = proposal['value'] if value == u'*': value = u'' @@ -715,10 +789,40 @@ def _valdate_ip(self, proposal): or containerized setups for example).""") ) - port = Integer(8888, config=True, + port = Integer(DEFAULT_NOTEBOOK_PORT, config=True, help=_("The port the notebook server will listen on.") ) + sock = Unicode(u'', config=True, + help=_("The UNIX socket the notebook server will listen on.") + ) + + sock_mode = Unicode('0600', config=True, + help=_("The permissions mode for UNIX socket creation (default: 0600).") + ) + + @validate('sock_mode') + def _validate_sock_mode(self, proposal): + value = proposal['value'] + try: + converted_value = int(value.encode(), 8) + assert all(( + # Ensure the mode is at least user readable/writable. + bool(converted_value & stat.S_IRUSR), + bool(converted_value & stat.S_IWUSR), + # And isn't out of bounds. + converted_value <= 2 ** 12 + )) + except ValueError: + raise TraitError( + 'invalid --sock-mode value: %s, please specify as e.g. "0600"' % value + ) + except AssertionError: + raise TraitError( + 'invalid --sock-mode value: %s, must have u+rw (0600) at a minimum' % value + ) + return value + port_retries = Integer(50, config=True, help=_("The number of additional ports to try if the specified port is not available.") ) @@ -1469,6 +1573,35 @@ def init_webapp(self): self.log.critical(_("\t$ python -m notebook.auth password")) sys.exit(1) + # Socket options validation. + if self.sock: + if self.port != DEFAULT_NOTEBOOK_PORT: + self.log.critical( + _('Options --port and --sock are mutually exclusive. Aborting.'), + ) + sys.exit(1) + else: + # Reset the default port if we're using a UNIX socket. + self.port = 0 + + if self.open_browser: + # If we're bound to a UNIX socket, we can't reliably connect from a browser. + self.log.info( + _('Ignoring --NotebookApp.open_browser due to --sock being used.'), + ) + + if self.file_to_run: + self.log.critical( + _('Options --NotebookApp.file_to_run and --sock are mutually exclusive.'), + ) + sys.exit(1) + + if sys.platform.startswith('win'): + self.log.critical( + _('Option --sock is not supported on Windows, but got value of %s. Aborting.' % self.sock), + ) + sys.exit(1) + self.web_app = NotebookWebApplication( self, self.kernel_manager, self.contents_manager, self.session_manager, self.kernel_spec_manager, @@ -1505,6 +1638,36 @@ def init_webapp(self): max_body_size=self.max_body_size, max_buffer_size=self.max_buffer_size) + success = self._bind_http_server() + if not success: + self.log.critical(_('ERROR: the notebook server could not be started because ' + 'no available port could be found.')) + self.exit(1) + + def _bind_http_server(self): + return self._bind_http_server_unix() if self.sock else self._bind_http_server_tcp() + + def _bind_http_server_unix(self): + if unix_socket_in_use(self.sock): + self.log.warning(_('The socket %s is already in use.') % self.sock) + return False + + try: + sock = bind_unix_socket(self.sock, mode=int(self.sock_mode.encode(), 8)) + self.http_server.add_socket(sock) + except socket.error as e: + if e.errno == errno.EADDRINUSE: + self.log.warning(_('The socket %s is already in use.') % self.sock) + return False + elif e.errno in (errno.EACCES, getattr(errno, 'WSAEACCES', errno.EACCES)): + self.log.warning(_("Permission to listen on sock %s denied") % self.sock) + return False + else: + raise + else: + return True + + def _bind_http_server_tcp(self): success = None for port in random_ports(self.port, self.port_retries+1): try: @@ -1533,35 +1696,45 @@ def init_webapp(self): self.log.critical(_('ERROR: the notebook server could not be started because ' 'port %i is not available.') % port) self.exit(1) - + return success + + def _concat_token(self, url): + token = self.token if self._token_generated else '...' + return url_concat(url, {'token': token}) + @property def display_url(self): if self.custom_display_url: url = self.custom_display_url if not url.endswith('/'): url += '/' + elif self.sock: + url = self._unix_sock_url() else: if self.ip in ('', '0.0.0.0'): ip = "%s" % socket.gethostname() else: ip = self.ip - url = self._url(ip) - if self.token: - # Don't log full token if it came from config - token = self.token if self._token_generated else '...' - url = (url_concat(url, {'token': token}) - + '\n or ' - + url_concat(self._url('127.0.0.1'), {'token': token})) + url = self._tcp_url(ip) + if self.token and not self.sock: + url = self._concat_token(url) + url += '\n or %s' % self._concat_token(self._tcp_url('127.0.0.1')) return url @property def connection_url(self): - ip = self.ip if self.ip else 'localhost' - return self._url(ip) + if self.sock: + return self._unix_sock_url() + else: + ip = self.ip if self.ip else 'localhost' + return self._tcp_url(ip) + + def _unix_sock_url(self, token=None): + return '%s%s' % (urlencode_unix_socket(self.sock), self.base_url) - def _url(self, ip): + def _tcp_url(self, ip, port=None): proto = 'https' if self.certfile else 'http' - return "%s://%s:%i%s" % (proto, ip, self.port, self.base_url) + return "%s://%s:%i%s" % (proto, ip, port or self.port, self.base_url) def init_terminals(self): if not self.terminals_enabled: @@ -1825,6 +1998,7 @@ def server_info(self): return {'url': self.connection_url, 'hostname': self.ip if self.ip else 'localhost', 'port': self.port, + 'sock': self.sock, 'secure': bool(self.certfile), 'base_url': self.base_url, 'token': self.token, @@ -1954,19 +2128,31 @@ def start(self): self.write_server_info_file() self.write_browser_open_file() - if self.open_browser or self.file_to_run: + if (self.open_browser or self.file_to_run) and not self.sock: self.launch_browser() if self.token and self._token_generated: # log full URL with generated token, so there's a copy/pasteable link # with auth info. - self.log.critical('\n'.join([ - '\n', - 'To access the notebook, open this file in a browser:', - ' %s' % urljoin('file:', pathname2url(self.browser_open_file)), - 'Or copy and paste one of these URLs:', - ' %s' % self.display_url, - ])) + if self.sock: + self.log.critical('\n'.join([ + '\n', + 'Notebook is listening on %s' % self.display_url, + '', + ( + 'UNIX sockets are not browser-connectable, but you can tunnel to ' + 'the instance via e.g.`ssh -L 8888:%s -N user@this_host` and then ' + 'open e.g. %s in a browser.' + ) % (self.sock, self._concat_token(self._tcp_url('localhost', 8888))) + ])) + else: + self.log.critical('\n'.join([ + '\n', + 'To access the notebook, open this file in a browser:', + ' %s' % urljoin('file:', pathname2url(self.browser_open_file)), + 'Or copy and paste one of these URLs:', + ' %s' % self.display_url, + ])) self.io_loop = ioloop.IOLoop.current() if sys.platform.startswith('win'): diff --git a/notebook/tests/launchnotebook.py b/notebook/tests/launchnotebook.py index b873457316..41d0131717 100644 --- a/notebook/tests/launchnotebook.py +++ b/notebook/tests/launchnotebook.py @@ -16,12 +16,13 @@ from unittest.mock import patch import requests +import requests_unixsocket from tornado.ioloop import IOLoop import zmq import jupyter_core.paths from traitlets.config import Config -from ..notebookapp import NotebookApp +from ..notebookapp import NotebookApp, urlencode_unix_socket from ..utils import url_path_join from ipython_genutils.tempdir import TemporaryDirectory @@ -52,7 +53,7 @@ def wait_until_alive(cls): url = cls.base_url() + 'api/contents' for _ in range(int(MAX_WAITTIME/POLL_INTERVAL)): try: - requests.get(url) + cls.fetch_url(url) except Exception as e: if not cls.notebook_thread.is_alive(): raise RuntimeError("The notebook server failed to start") @@ -76,6 +77,10 @@ def auth_headers(cls): headers['Authorization'] = 'token %s' % cls.token return headers + @staticmethod + def fetch_url(url): + return requests.get(url) + @classmethod def request(cls, verb, path, **kwargs): """Send a request to my server @@ -104,7 +109,11 @@ def get_patch_env(cls): @classmethod def get_argv(cls): return [] - + + @classmethod + def get_bind_args(cls): + return dict(port=cls.port) + @classmethod def setup_class(cls): cls.tmp_dir = TemporaryDirectory() @@ -116,7 +125,7 @@ def tmp(*parts): if e.errno != errno.EEXIST: raise return path - + cls.home_dir = tmp('home') data_dir = cls.data_dir = tmp('data') config_dir = cls.config_dir = tmp('config') @@ -144,8 +153,8 @@ def tmp(*parts): started = Event() def start_thread(): try: + bind_args = cls.get_bind_args() app = cls.notebook = NotebookApp( - port=cls.port, port_retries=0, open_browser=False, config_dir=cls.config_dir, @@ -156,6 +165,7 @@ def start_thread(): config=config, allow_root=True, token=cls.token, + **bind_args ) if 'asyncio' in sys.modules: app._init_asyncio_patch() @@ -206,6 +216,25 @@ def base_url(cls): return 'http://localhost:%i%s' % (cls.port, cls.url_prefix) +class UNIXSocketNotebookTestBase(NotebookTestBase): + # Rely on `/tmp` to avoid any Linux socket length max buffer + # issues. Key on PID for process-wise concurrency. + sock = '/tmp/.notebook.%i.sock' % os.getpid() + + @classmethod + def get_bind_args(cls): + return dict(sock=cls.sock) + + @classmethod + def base_url(cls): + return '%s%s' % (urlencode_unix_socket(cls.sock), cls.url_prefix) + + @staticmethod + def fetch_url(url): + with requests_unixsocket.monkeypatch(): + return requests.get(url) + + @contextmanager def assert_http_error(status, msg=None): try: diff --git a/notebook/tests/test_notebookapp.py b/notebook/tests/test_notebookapp.py index f3066d9fb7..d8543feed9 100644 --- a/notebook/tests/test_notebookapp.py +++ b/notebook/tests/test_notebookapp.py @@ -22,7 +22,7 @@ from notebook.auth.security import passwd_check NotebookApp = notebookapp.NotebookApp -from .launchnotebook import NotebookTestBase +from .launchnotebook import NotebookTestBase, UNIXSocketNotebookTestBase def test_help_output(): @@ -189,3 +189,15 @@ def test_list_running_servers(self): servers = list(notebookapp.list_running_servers()) assert len(servers) >= 1 assert self.port in {info['port'] for info in servers} + + +# UNIX sockets aren't available on Windows. +if not sys.platform.startswith('win'): + class NotebookUnixSocketTests(UNIXSocketNotebookTestBase): + def test_run(self): + self.fetch_url(self.base_url() + 'api/contents') + + def test_list_running_sock_servers(self): + servers = list(notebookapp.list_running_servers()) + assert len(servers) >= 1 + assert self.sock in {info['sock'] for info in servers} diff --git a/notebook/tests/test_notebookapp_integration.py b/notebook/tests/test_notebookapp_integration.py new file mode 100644 index 0000000000..9af505342b --- /dev/null +++ b/notebook/tests/test_notebookapp_integration.py @@ -0,0 +1,166 @@ +import os +import stat +import subprocess +import time + +from ipython_genutils.testing.decorators import skip_win32, onlyif +from notebook import DEFAULT_NOTEBOOK_PORT + +from .launchnotebook import UNIXSocketNotebookTestBase +from ..utils import urlencode_unix_socket, urlencode_unix_socket_path + + +@skip_win32 +def test_shutdown_sock_server_integration(): + sock = UNIXSocketNotebookTestBase.sock + url = urlencode_unix_socket(sock).encode() + encoded_sock_path = urlencode_unix_socket_path(sock) + + p = subprocess.Popen( + ['jupyter-notebook', '--sock=%s' % sock, '--sock-mode=0700'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + complete = False + for line in iter(p.stderr.readline, b''): + print(line.decode()) + if url in line: + complete = True + break + + assert complete, 'did not find socket URL in stdout when launching notebook' + + assert encoded_sock_path.encode() in subprocess.check_output(['jupyter-notebook', 'list']) + + # Ensure umask is properly applied. + assert stat.S_IMODE(os.lstat(sock).st_mode) == 0o700 + + try: + subprocess.check_output(['jupyter-notebook', 'stop'], stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + assert 'There is currently no server running on' in e.output.decode() + else: + raise AssertionError('expected stop command to fail due to target mis-match') + + assert encoded_sock_path.encode() in subprocess.check_output(['jupyter-notebook', 'list']) + + subprocess.check_output(['jupyter-notebook', 'stop', sock]) + + assert encoded_sock_path.encode() not in subprocess.check_output(['jupyter-notebook', 'list']) + + p.wait() + + +def test_sock_server_validate_sockmode_type(): + try: + subprocess.check_output( + ['jupyter-notebook', '--sock=/tmp/nonexistent', '--sock-mode=badbadbad'], + stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + assert 'badbadbad' in e.output.decode() + else: + raise AssertionError('expected execution to fail due to validation of --sock-mode param') + + +def test_sock_server_validate_sockmode_accessible(): + try: + subprocess.check_output( + ['jupyter-notebook', '--sock=/tmp/nonexistent', '--sock-mode=0444'], + stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + assert '0444' in e.output.decode() + else: + raise AssertionError('expected execution to fail due to validation of --sock-mode param') + + +def _ensure_stopped(check_msg='There are no running servers'): + try: + subprocess.check_output( + ['jupyter-notebook', 'stop'], + stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + assert check_msg in e.output.decode() + else: + raise AssertionError('expected all servers to be stopped') + + +@onlyif(bool(os.environ.get('RUN_NB_INTEGRATION_TESTS', False)), 'for local testing') +def test_stop_multi_integration(): + """Tests lifecycle behavior for mixed-mode server types w/ default ports. + + Mostly suitable for local dev testing due to reliance on default port binding. + """ + TEST_PORT = '9797' + MSG_TMPL = 'Shutting down server on {}...' + + _ensure_stopped() + + # Default port. + p1 = subprocess.Popen( + ['jupyter-notebook', '--no-browser'] + ) + + # Unix socket. + sock = UNIXSocketNotebookTestBase.sock + p2 = subprocess.Popen( + ['jupyter-notebook', '--sock=%s' % sock] + ) + + # Specified port + p3 = subprocess.Popen( + ['jupyter-notebook', '--no-browser', '--port=%s' % TEST_PORT] + ) + + time.sleep(3) + + assert MSG_TMPL.format(DEFAULT_NOTEBOOK_PORT) in subprocess.check_output( + ['jupyter-notebook', 'stop'] + ).decode() + + _ensure_stopped('There is currently no server running on 8888') + + assert MSG_TMPL.format(sock) in subprocess.check_output( + ['jupyter-notebook', 'stop', sock] + ).decode() + + assert MSG_TMPL.format(TEST_PORT) in subprocess.check_output( + ['jupyter-notebook', 'stop', TEST_PORT] + ).decode() + + _ensure_stopped() + + p1.wait() + p2.wait() + p3.wait() + + +@skip_win32 +def test_launch_socket_collision(): + """Tests UNIX socket in-use detection for lifecycle correctness.""" + sock = UNIXSocketNotebookTestBase.sock + check_msg = 'socket %s is already in use' % sock + + _ensure_stopped() + + # Start a server. + cmd = ['jupyter-notebook', '--sock=%s' % sock] + p1 = subprocess.Popen(cmd) + time.sleep(3) + + # Try to start a server bound to the same UNIX socket. + try: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + assert check_msg in e.output.decode() + else: + raise AssertionError('expected error, instead got %s' % e.output.decode()) + + # Stop the background server, ensure it's stopped and wait on the process to exit. + subprocess.check_call(['jupyter-notebook', 'stop', sock]) + + _ensure_stopped() + + p1.wait() diff --git a/notebook/utils.py b/notebook/utils.py index 9ec10773fb..f0ad0e26c6 100644 --- a/notebook/utils.py +++ b/notebook/utils.py @@ -11,6 +11,7 @@ import errno import inspect import os +import socket import stat import sys from distutils.version import LooseVersion @@ -22,6 +23,7 @@ # in tornado >=5 with Python 3 from tornado.concurrent import Future as TornadoFuture from tornado import gen +import requests_unixsocket from ipython_genutils import py3compat # UF_HIDDEN is a stat flag not defined in the stat module. @@ -367,3 +369,34 @@ def wrapped(): result = asyncio.ensure_future(maybe_async) return result return wrapped() + + +def urlencode_unix_socket_path(socket_path): + """Encodes a UNIX socket path string from a socket path for the `http+unix` URI form.""" + return socket_path.replace('/', '%2F') + + +def urldecode_unix_socket_path(socket_path): + """Decodes a UNIX sock path string from an encoded sock path for the `http+unix` URI form.""" + return socket_path.replace('%2F', '/') + + +def urlencode_unix_socket(socket_path): + """Encodes a UNIX socket URL from a socket path for the `http+unix` URI form.""" + return 'http+unix://%s' % urlencode_unix_socket_path(socket_path) + + +def unix_socket_in_use(socket_path): + """Checks whether a UNIX socket path on disk is in use by attempting to connect to it.""" + if not os.path.exists(socket_path): + return False + + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(socket_path) + except socket.error: + return False + else: + return True + finally: + sock.close() diff --git a/setup.py b/setup.py index a2f07d0023..c970aac746 100755 --- a/setup.py +++ b/setup.py @@ -115,7 +115,8 @@ ], extras_require = { 'test': ['nose', 'coverage', 'requests', 'nose_warnings_filters', - 'nbval', 'nose-exclude', 'selenium', 'pytest', 'pytest-cov'], + 'nbval', 'nose-exclude', 'selenium', 'pytest', 'pytest-cov', + 'requests-unixsocket'], 'test:sys_platform == "win32"': ['nose-exclude'], }, python_requires = '>=3.5',