Skip to content

Commit

Permalink
Merge pull request #447 from dylex/main
Browse files Browse the repository at this point in the history
Add `raw_socket_proxy` to directly proxy websockets to TCP/unix sockets
  • Loading branch information
yuvipanda committed Jun 27, 2024
2 parents 887c3d1 + b068325 commit 52b0dec
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 41 deletions.
12 changes: 12 additions & 0 deletions docs/source/server-process.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ Defaults to _True_.

(server-process:callable-arguments)=

### `raw_socket_proxy`

_True_ to proxy only websocket connections into raw stream connections.
_False_ (default) if the proxied server speaks full HTTP.

If _True_, the proxied server is treated a raw TCP (or unix socket) server that
does not use HTTP.
In this mode, only websockets are handled, and messages are sent to the backend
server as raw stream data. This is similar to running a
[websockify](https://github.com/novnc/websockify) wrapper.
All other HTTP requests return 405.

#### Callable arguments

Any time you specify a callable in the config, it can ask for any arguments it needs
Expand Down
82 changes: 41 additions & 41 deletions jupyter_server_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from traitlets.config import Configurable

from .handlers import AddSlashHandler, NamedLocalProxyHandler, SuperviseAndProxyHandler
from .rawsocket import RawSocketHandler, SuperviseAndRawSocketHandler

try:
# Traitlets >= 4.3.3
Expand Down Expand Up @@ -43,54 +44,56 @@
"request_headers_override",
"rewrite_response",
"update_last_activity",
"raw_socket_proxy",
],
)


def _make_namedproxy_handler(sp: ServerProcess):
class _Proxy(NamedLocalProxyHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = sp.name
self.proxy_base = sp.name
self.absolute_url = sp.absolute_url
self.port = sp.port
self.unix_socket = sp.unix_socket
self.mappath = sp.mappath
self.rewrite_response = sp.rewrite_response
self.update_last_activity = sp.update_last_activity

def get_request_headers_override(self):
return self._realize_rendered_template(sp.request_headers_override)

return _Proxy


def _make_supervisedproxy_handler(sp: ServerProcess):
def _make_proxy_handler(sp: ServerProcess):
"""
Create a SuperviseAndProxyHandler subclass with given parameters
Create an appropriate handler with given parameters
"""
if sp.command:
cls = SuperviseAndRawSocketHandler if sp.raw_socket_proxy else SuperviseAndProxyHandler
args = dict(state={})
elif not (sp.port or isinstance(sp.unix_socket, str)):
warn(
f"Server proxy {sp.name} does not have a command, port "
f"number or unix_socket path. At least one of these is "
f"required."
)
return
else:
cls = RawSocketHandler if sp.raw_socket_proxy else NamedLocalProxyHandler
args = {}

# FIXME: Set 'name' properly
class _Proxy(SuperviseAndProxyHandler):
class _Proxy(cls):
kwargs = args

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = sp.name
self.command = sp.command
self.proxy_base = sp.name
self.absolute_url = sp.absolute_url
self.requested_port = sp.port
self.requested_unix_socket = sp.unix_socket
if sp.command:
self.requested_port = sp.port
self.requested_unix_socket = sp.unix_socket
else:
self.port = sp.port
self.unix_socket = sp.unix_socket
self.mappath = sp.mappath
self.rewrite_response = sp.rewrite_response
self.update_last_activity = sp.update_last_activity

def get_env(self):
return self._realize_rendered_template(sp.environment)

def get_request_headers_override(self):
return self._realize_rendered_template(sp.request_headers_override)

# these two methods are only used in supervise classes, but do no harm otherwise
def get_env(self):
return self._realize_rendered_template(sp.environment)

def get_timeout(self):
return sp.timeout

Expand All @@ -116,24 +119,14 @@ def make_handlers(base_url, server_processes):
"""
handlers = []
for sp in server_processes:
if sp.command:
handler = _make_supervisedproxy_handler(sp)
kwargs = dict(state={})
else:
if not (sp.port or isinstance(sp.unix_socket, str)):
warn(
f"Server proxy {sp.name} does not have a command, port "
f"number or unix_socket path. At least one of these is "
f"required."
)
continue
handler = _make_namedproxy_handler(sp)
kwargs = {}
handler = _make_proxy_handler(sp)
if not handler:
continue
handlers.append(
(
ujoin(base_url, sp.name, r"(.*)"),
handler,
kwargs,
handler.kwargs
)
)
handlers.append((ujoin(base_url, sp.name), AddSlashHandler))
Expand Down Expand Up @@ -169,6 +162,7 @@ def make_server_process(name, server_process_config, serverproxy_config):
update_last_activity=server_process_config.get(
"update_last_activity", True
),
raw_socket_proxy=server_process_config.get("raw_socket_proxy", False),
)


Expand Down Expand Up @@ -292,6 +286,12 @@ def cats_only(response, path):
update_last_activity
Will cause the proxy to report activity back to jupyter server.
raw_socket_proxy
Proxy websocket requests as a raw TCP (or unix socket) stream.
In this mode, only websockets are handled, and messages are sent to the backend,
similar to running a websockify layer (https://github.com/novnc/websockify).
All other HTTP requests return 405 (and thus this will also bypass rewrite_response).
""",
config=True,
)
Expand Down
89 changes: 89 additions & 0 deletions jupyter_server_proxy/rawsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
A simple translation layer between tornado websockets and asyncio stream
connections.
This provides similar functionality to websockify
(https://github.com/novnc/websockify) without needing an extra proxy hop
or process through with all messages pass for translation.
"""

import asyncio

from .handlers import NamedLocalProxyHandler, SuperviseAndProxyHandler

class RawSocketProtocol(asyncio.Protocol):
"""
A protocol handler for the proxied stream connection.
Sends any received blocks directly as websocket messages.
"""
def __init__(self, handler):
self.handler = handler

def data_received(self, data):
"Send the buffer as a websocket message."
self.handler._record_activity()
# ignore async "semi-synchronous" result, waiting is only needed for control flow and errors
# (see https://github.com/tornadoweb/tornado/blob/bdfc017c66817359158185561cee7878680cd841/tornado/websocket.py#L1073)
self.handler.write_message(data, binary=True)

def connection_lost(self, exc):
"Close the websocket connection."
self.handler.log.info(f"Raw websocket {self.handler.name} connection lost: {exc}")
self.handler.close()

class RawSocketHandler(NamedLocalProxyHandler):
"""
HTTP handler that proxies websocket connections into a backend stream.
All other HTTP requests return 405.
"""
def _create_ws_connection(self, proto: asyncio.BaseProtocol):
"Create the appropriate backend asyncio connection"
loop = asyncio.get_running_loop()
if self.unix_socket is not None:
self.log.info(f"RawSocket {self.name} connecting to {self.unix_socket}")
return loop.create_unix_connection(proto, self.unix_socket)
else:
self.log.info(f"RawSocket {self.name} connecting to port {self.port}")
return loop.create_connection(proto, 'localhost', self.port)

async def proxy(self, port, path):
raise web.HTTPError(405, "this raw_socket_proxy backend only supports websocket connections")

async def proxy_open(self, host, port, proxied_path=""):
"""
Open the backend connection. host and port are ignored (as they are in
the parent for unix sockets) since they are always passed known values.
"""
transp, proto = await self._create_ws_connection(lambda: RawSocketProtocol(self))
self.ws_transp = transp
self.ws_proto = proto
self._record_activity()
self.log.info(f"RawSocket {self.name} connected")

def on_message(self, message):
"Send websocket messages as stream writes, encoding if necessary."
self._record_activity()
if isinstance(message, str):
message = message.encode('utf-8')
self.ws_transp.write(message) # buffered non-blocking. should block (needs new enough tornado)

def on_ping(self, message):
"No-op"
self._record_activity()

def on_close(self):
"Close the backend connection."
self.log.info(f"RawSocket {self.name} connection closed")
if hasattr(self, "ws_transp"):
self.ws_transp.close()

class SuperviseAndRawSocketHandler(SuperviseAndProxyHandler, RawSocketHandler):
async def _http_ready_func(self, p):
# not really HTTP here, just try an empty connection
try:
transp, _ = await self._create_ws_connection(asyncio.Protocol)
except OSError as exc:
self.log.debug(f"RawSocket {self.name} connection check failed: {exc}")
return False
transp.close()
return True
9 changes: 9 additions & 0 deletions tests/resources/jupyter_server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def my_env():
"rewrite_response": [cats_only, dog_to_cat],
},
"python-proxyto54321-no-command": {"port": 54321},
"python-rawsocket-tcp": {
"command": [sys.executable, "./tests/resources/rawsocket.py", "{port}"],
"raw_socket_proxy": True
},
"python-rawsocket-unix": {
"command": [sys.executable, "./tests/resources/rawsocket.py", "{unix_socket}"],
"unix_socket": True,
"raw_socket_proxy": True
},
}

c.ServerProxy.non_service_rewrite_response = hello_to_foo
Expand Down
28 changes: 28 additions & 0 deletions tests/resources/rawsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python

import os
import socket
import sys

if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} TCPPORT|SOCKPATH")
sys.exit(1)
where = sys.argv[1]
try:
port = int(where)
family = socket.AF_INET
addr = ('localhost', port)
except ValueError:
family = socket.AF_UNIX
addr = where

with socket.create_server(addr, family=family) as serv:
while True:
# only handle a single connection at a time
sock, caddr = serv.accept()
while True:
s = sock.recv(1024)
if not s:
break
sock.send(s.swapcase())
sock.close()
22 changes: 22 additions & 0 deletions tests/test_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,25 @@ def test_callable_environment_formatting(
PORT, TOKEN = a_server_port_and_token
r = request_get(PORT, "/python-http-callable-env/test", TOKEN)
assert r.code == 200


@pytest.mark.parametrize("rawsocket_type", [
"tcp",
pytest.param(
"unix",
marks=pytest.mark.skipif(
sys.platform == "win32", reason="Unix socket not supported on Windows"
),
),
])
async def test_server_proxy_rawsocket(
rawsocket_type: str,
a_server_port_and_token: Tuple[int, str]
) -> None:
PORT, TOKEN = a_server_port_and_token
url = f"ws://{LOCALHOST}:{PORT}/python-rawsocket-{rawsocket_type}/?token={TOKEN}"
conn = await websocket_connect(url)
for msg in [b"Hello,", b"world!"]:
await conn.write_message(msg)
res = await conn.read_message()
assert res == msg.swapcase()

0 comments on commit 52b0dec

Please sign in to comment.