diff --git a/docs/source/server-process.md b/docs/source/server-process.md index 44b09c85..92a3a661 100644 --- a/docs/source/server-process.md +++ b/docs/source/server-process.md @@ -169,6 +169,18 @@ Defaults to _True_. (server-process:callable-arguments)= +### `raw_socket_proxy` + +_True_ to proxy only websocket connections into raw stream connections. +_False_ (default) if the proxied server speaks full HTTP. + +If _True_, the proxied server is treated a raw TCP (or unix socket) server that +does not use HTTP. +In this mode, only websockets are handled, and messages are sent to the backend +server as raw stream data. This is similar to running a +[websockify](https://github.com/novnc/websockify) wrapper. +All other HTTP requests return 405. + #### Callable arguments Any time you specify a callable in the config, it can ask for any arguments it needs diff --git a/jupyter_server_proxy/config.py b/jupyter_server_proxy/config.py index ebed718f..954842e6 100644 --- a/jupyter_server_proxy/config.py +++ b/jupyter_server_proxy/config.py @@ -16,6 +16,7 @@ from traitlets.config import Configurable from .handlers import AddSlashHandler, NamedLocalProxyHandler, SuperviseAndProxyHandler +from .rawsocket import RawSocketHandler, SuperviseAndRawSocketHandler try: # Traitlets >= 4.3.3 @@ -43,54 +44,56 @@ "request_headers_override", "rewrite_response", "update_last_activity", + "raw_socket_proxy", ], ) -def _make_namedproxy_handler(sp: ServerProcess): - class _Proxy(NamedLocalProxyHandler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.name = sp.name - self.proxy_base = sp.name - self.absolute_url = sp.absolute_url - self.port = sp.port - self.unix_socket = sp.unix_socket - self.mappath = sp.mappath - self.rewrite_response = sp.rewrite_response - self.update_last_activity = sp.update_last_activity - - def get_request_headers_override(self): - return self._realize_rendered_template(sp.request_headers_override) - - return _Proxy - - -def _make_supervisedproxy_handler(sp: ServerProcess): +def _make_proxy_handler(sp: ServerProcess): """ - Create a SuperviseAndProxyHandler subclass with given parameters + Create an appropriate handler with given parameters """ + if sp.command: + cls = SuperviseAndRawSocketHandler if sp.raw_socket_proxy else SuperviseAndProxyHandler + args = dict(state={}) + elif not (sp.port or isinstance(sp.unix_socket, str)): + warn( + f"Server proxy {sp.name} does not have a command, port " + f"number or unix_socket path. At least one of these is " + f"required." + ) + return + else: + cls = RawSocketHandler if sp.raw_socket_proxy else NamedLocalProxyHandler + args = {} # FIXME: Set 'name' properly - class _Proxy(SuperviseAndProxyHandler): + class _Proxy(cls): + kwargs = args + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.name = sp.name self.command = sp.command self.proxy_base = sp.name self.absolute_url = sp.absolute_url - self.requested_port = sp.port - self.requested_unix_socket = sp.unix_socket + if sp.command: + self.requested_port = sp.port + self.requested_unix_socket = sp.unix_socket + else: + self.port = sp.port + self.unix_socket = sp.unix_socket self.mappath = sp.mappath self.rewrite_response = sp.rewrite_response self.update_last_activity = sp.update_last_activity - def get_env(self): - return self._realize_rendered_template(sp.environment) - def get_request_headers_override(self): return self._realize_rendered_template(sp.request_headers_override) + # these two methods are only used in supervise classes, but do no harm otherwise + def get_env(self): + return self._realize_rendered_template(sp.environment) + def get_timeout(self): return sp.timeout @@ -116,24 +119,14 @@ def make_handlers(base_url, server_processes): """ handlers = [] for sp in server_processes: - if sp.command: - handler = _make_supervisedproxy_handler(sp) - kwargs = dict(state={}) - else: - if not (sp.port or isinstance(sp.unix_socket, str)): - warn( - f"Server proxy {sp.name} does not have a command, port " - f"number or unix_socket path. At least one of these is " - f"required." - ) - continue - handler = _make_namedproxy_handler(sp) - kwargs = {} + handler = _make_proxy_handler(sp) + if not handler: + continue handlers.append( ( ujoin(base_url, sp.name, r"(.*)"), handler, - kwargs, + handler.kwargs ) ) handlers.append((ujoin(base_url, sp.name), AddSlashHandler)) @@ -169,6 +162,7 @@ def make_server_process(name, server_process_config, serverproxy_config): update_last_activity=server_process_config.get( "update_last_activity", True ), + raw_socket_proxy=server_process_config.get("raw_socket_proxy", False), ) @@ -292,6 +286,12 @@ def cats_only(response, path): update_last_activity Will cause the proxy to report activity back to jupyter server. + + raw_socket_proxy + Proxy websocket requests as a raw TCP (or unix socket) stream. + In this mode, only websockets are handled, and messages are sent to the backend, + similar to running a websockify layer (https://github.com/novnc/websockify). + All other HTTP requests return 405 (and thus this will also bypass rewrite_response). """, config=True, ) diff --git a/jupyter_server_proxy/rawsocket.py b/jupyter_server_proxy/rawsocket.py new file mode 100644 index 00000000..a7a6ba14 --- /dev/null +++ b/jupyter_server_proxy/rawsocket.py @@ -0,0 +1,89 @@ +""" +A simple translation layer between tornado websockets and asyncio stream +connections. + +This provides similar functionality to websockify +(https://github.com/novnc/websockify) without needing an extra proxy hop +or process through with all messages pass for translation. +""" + +import asyncio + +from .handlers import NamedLocalProxyHandler, SuperviseAndProxyHandler + +class RawSocketProtocol(asyncio.Protocol): + """ + A protocol handler for the proxied stream connection. + Sends any received blocks directly as websocket messages. + """ + def __init__(self, handler): + self.handler = handler + + def data_received(self, data): + "Send the buffer as a websocket message." + self.handler._record_activity() + # ignore async "semi-synchronous" result, waiting is only needed for control flow and errors + # (see https://github.com/tornadoweb/tornado/blob/bdfc017c66817359158185561cee7878680cd841/tornado/websocket.py#L1073) + self.handler.write_message(data, binary=True) + + def connection_lost(self, exc): + "Close the websocket connection." + self.handler.log.info(f"Raw websocket {self.handler.name} connection lost: {exc}") + self.handler.close() + +class RawSocketHandler(NamedLocalProxyHandler): + """ + HTTP handler that proxies websocket connections into a backend stream. + All other HTTP requests return 405. + """ + def _create_ws_connection(self, proto: asyncio.BaseProtocol): + "Create the appropriate backend asyncio connection" + loop = asyncio.get_running_loop() + if self.unix_socket is not None: + self.log.info(f"RawSocket {self.name} connecting to {self.unix_socket}") + return loop.create_unix_connection(proto, self.unix_socket) + else: + self.log.info(f"RawSocket {self.name} connecting to port {self.port}") + return loop.create_connection(proto, 'localhost', self.port) + + async def proxy(self, port, path): + raise web.HTTPError(405, "this raw_socket_proxy backend only supports websocket connections") + + async def proxy_open(self, host, port, proxied_path=""): + """ + Open the backend connection. host and port are ignored (as they are in + the parent for unix sockets) since they are always passed known values. + """ + transp, proto = await self._create_ws_connection(lambda: RawSocketProtocol(self)) + self.ws_transp = transp + self.ws_proto = proto + self._record_activity() + self.log.info(f"RawSocket {self.name} connected") + + def on_message(self, message): + "Send websocket messages as stream writes, encoding if necessary." + self._record_activity() + if isinstance(message, str): + message = message.encode('utf-8') + self.ws_transp.write(message) # buffered non-blocking. should block (needs new enough tornado) + + def on_ping(self, message): + "No-op" + self._record_activity() + + def on_close(self): + "Close the backend connection." + self.log.info(f"RawSocket {self.name} connection closed") + if hasattr(self, "ws_transp"): + self.ws_transp.close() + +class SuperviseAndRawSocketHandler(SuperviseAndProxyHandler, RawSocketHandler): + async def _http_ready_func(self, p): + # not really HTTP here, just try an empty connection + try: + transp, _ = await self._create_ws_connection(asyncio.Protocol) + except OSError as exc: + self.log.debug(f"RawSocket {self.name} connection check failed: {exc}") + return False + transp.close() + return True diff --git a/tests/resources/jupyter_server_config.py b/tests/resources/jupyter_server_config.py index 97136d2f..d8dee097 100644 --- a/tests/resources/jupyter_server_config.py +++ b/tests/resources/jupyter_server_config.py @@ -127,6 +127,15 @@ def my_env(): "rewrite_response": [cats_only, dog_to_cat], }, "python-proxyto54321-no-command": {"port": 54321}, + "python-rawsocket-tcp": { + "command": [sys.executable, "./tests/resources/rawsocket.py", "{port}"], + "raw_socket_proxy": True + }, + "python-rawsocket-unix": { + "command": [sys.executable, "./tests/resources/rawsocket.py", "{unix_socket}"], + "unix_socket": True, + "raw_socket_proxy": True + }, } c.ServerProxy.non_service_rewrite_response = hello_to_foo diff --git a/tests/resources/rawsocket.py b/tests/resources/rawsocket.py new file mode 100644 index 00000000..23f0c322 --- /dev/null +++ b/tests/resources/rawsocket.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +import os +import socket +import sys + +if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} TCPPORT|SOCKPATH") + sys.exit(1) +where = sys.argv[1] +try: + port = int(where) + family = socket.AF_INET + addr = ('localhost', port) +except ValueError: + family = socket.AF_UNIX + addr = where + +with socket.create_server(addr, family=family) as serv: + while True: + # only handle a single connection at a time + sock, caddr = serv.accept() + while True: + s = sock.recv(1024) + if not s: + break + sock.send(s.swapcase()) + sock.close() diff --git a/tests/test_proxies.py b/tests/test_proxies.py index 8c2de5a0..4c0bb80f 100644 --- a/tests/test_proxies.py +++ b/tests/test_proxies.py @@ -469,3 +469,25 @@ def test_callable_environment_formatting( PORT, TOKEN = a_server_port_and_token r = request_get(PORT, "/python-http-callable-env/test", TOKEN) assert r.code == 200 + + +@pytest.mark.parametrize("rawsocket_type", [ + "tcp", + pytest.param( + "unix", + marks=pytest.mark.skipif( + sys.platform == "win32", reason="Unix socket not supported on Windows" + ), + ), +]) +async def test_server_proxy_rawsocket( + rawsocket_type: str, + a_server_port_and_token: Tuple[int, str] +) -> None: + PORT, TOKEN = a_server_port_and_token + url = f"ws://{LOCALHOST}:{PORT}/python-rawsocket-{rawsocket_type}/?token={TOKEN}" + conn = await websocket_connect(url) + for msg in [b"Hello,", b"world!"]: + await conn.write_message(msg) + res = await conn.read_message() + assert res == msg.swapcase()