Skip to content

Commit

Permalink
add support for websocket proxys
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Dec 12, 2024
1 parent e69b87c commit b299653
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ jobs:
--build-arg SQLITE_VERSION=3310100 \
--build-arg SQLITE_YEAR=2020 \
-f old-sqlite.Dockerfile . &&
docker run prefect-server-new-sqlite sh -c "prefect server database downgrade --yes -r base && prefect server database upgrade --yes"
docker run prefect-server-new-sqlite sh -c "prefect server database downgrade --yes -r base && prefect server database upgrade --yes"
74 changes: 74 additions & 0 deletions .github/workflows/proxy-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a
# simple server and a squid proxy server. The proxy server is inaccessible from the host machine, only the proxy
# so we can confirm the proxy is actually working.

name: Proxy Test
on:
pull_request:
paths:
- .github/workflows/proxy-test.yaml
- scripts/proxy-test/*
- "src/prefect/utilities/proxy.py"
- requirements.txt
- requirements-client.txt
- requirements-dev.txt
push:
branches:
- main
paths:
- .github/workflows/proxy-test.yaml
- scripts/proxy-test/*
- "src/prefect/utilities/proxy.py"
- requirements.txt
- requirements-client.txt
- requirements-dev.txt

jobs:
proxy-test:
name: Proxy Test
timeout-minutes: 10
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v5
id: setup_python
with:
python-version: "3.10"

- name: Create Docker networks
run: |
docker network create internal_net --internal
docker network create external_net
- name: Start API server container
working-directory: scripts/proxy-test
run: |
docker build -t api-server .
docker run -d --network internal_net --name server api-server
- name: Start Squid Proxy container
run: |
docker run -d \
--network internal_net \
--network external_net \
-p 3128:3128 \
-v $(pwd)/scripts/proxy-test/squid.conf:/etc/squid/squid.conf \
--name proxy \
ubuntu/squid
- name: Install Dependencies
run: |
python -m pip install -U uv
uv pip install --upgrade --system .
- name: Run Proxy Tests
env:
HTTP_PROXY: http://localhost:3128
HTTPS_PROXY: http://localhost:3128
run: python scripts/proxy-test/client.py
1 change: 1 addition & 0 deletions requirements-client.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pydantic_extra_types >= 2.8.2, < 3.0.0
pydantic_settings > 2.2.1
python_dateutil >= 2.8.2, < 3.0.0
python-slugify >= 5.0, < 9.0
python-socks[asyncio] >= 2.5.3, < 3.0
pyyaml >= 5.4.1, < 7.0.0
rfc3339-validator >= 0.1.4, < 0.2.0
rich >= 11.0, < 14.0
Expand Down
11 changes: 11 additions & 0 deletions scripts/proxy-test/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install uv
RUN uv pip install --no-cache-dir --system -r requirements.txt

COPY server.py .

CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
9 changes: 9 additions & 0 deletions scripts/proxy-test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a
simple server and a squid proxy server. The proxy server is inaccessible from the host machine, so we
can confirm the proxy connection is working.

```
$ uv pip install -r requirements.txt
$ docker compose up --build
$ python client.py
```
27 changes: 27 additions & 0 deletions scripts/proxy-test/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
import os

from prefect.utilities.proxy import websocket_connect

PROXY_URL = "http://localhost:3128"
WS_SERVER_URL = "ws://server:8000/ws"


async def test_websocket_proxy_with_compat():
"""WebSocket through proxy with proxy compatibility code - should work"""
os.environ["HTTP_PROXY"] = "http://localhost:3128"

async with websocket_connect("ws://server:8000/ws") as websocket:
await websocket.send("Hello!")
response = await websocket.recv()
print("Response: ", response)
assert response == "Server received: Hello!"


async def main():
print("Testing WebSocket through proxy with compatibility code")
await test_websocket_proxy_with_compat()


if __name__ == "__main__":
asyncio.run(main())
20 changes: 20 additions & 0 deletions scripts/proxy-test/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
services:
server:
build: .
networks:
- internal_net

forward_proxy:
image: ubuntu/squid
ports:
- "3128:3128"
volumes:
- ./squid.conf:/etc/squid/squid.conf
networks:
- internal_net
- external_net

networks:
internal_net:
internal: true
external_net:
6 changes: 6 additions & 0 deletions scripts/proxy-test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi==0.111.1
uvicorn==0.28.1
uv==0.5.7
websockets==13.1
python-socks==2.5.3
httpx==0.28.1
10 changes: 10 additions & 0 deletions scripts/proxy-test/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from fastapi import FastAPI, WebSocket

app = FastAPI()


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
async for data in websocket.iter_text():
await websocket.send_text(f"Server received: {data}")
5 changes: 5 additions & 0 deletions scripts/proxy-test/squid.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
http_port 3128
acl CONNECT method CONNECT
acl SSL_ports port 443 8000
http_access allow CONNECT SSL_ports
http_access allow all
9 changes: 5 additions & 4 deletions src/prefect/events/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from prometheus_client import Counter
from typing_extensions import Self
from websockets import Subprotocol
from websockets.client import WebSocketClientProtocol, connect
from websockets.client import WebSocketClientProtocol
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
Expand All @@ -37,6 +37,7 @@
PREFECT_DEBUG_MODE,
PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
)
from prefect.utilities.proxy import websocket_connect

if TYPE_CHECKING:
from prefect.events.filters import EventFilter
Expand Down Expand Up @@ -265,7 +266,7 @@ def __init__(
)

self._events_socket_url = events_in_socket_from_api_url(api_url)
self._connect = connect(self._events_socket_url)
self._connect = websocket_connect(self._events_socket_url)
self._websocket = None
self._reconnection_attempts = reconnection_attempts
self._unconfirmed_events = []
Expand Down Expand Up @@ -435,7 +436,7 @@ def __init__(
reconnection_attempts=reconnection_attempts,
checkpoint_every=checkpoint_every,
)
self._connect = connect(
self._connect = websocket_connect(
self._events_socket_url,
extra_headers={"Authorization": f"bearer {api_key}"},
)
Expand Down Expand Up @@ -494,7 +495,7 @@ def __init__(

logger.debug("Connecting to %s", socket_url)

self._connect = connect(
self._connect = websocket_connect(
socket_url,
subprotocols=[Subprotocol("prefect")],
)
Expand Down
58 changes: 58 additions & 0 deletions src/prefect/utilities/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
from typing import (
Any,
Generator,
)
from urllib.parse import urlparse

from python_socks.async_.asyncio import Proxy
from typing_extensions import Self
from websockets.client import WebSocketClientProtocol
from websockets.legacy.client import Connect


class WebsocketProxyConnect(Connect):
def __init__(self: Self, uri: str, **kwargs: Any):
# super() is intentionally deferred to the __proxy_connect__ method
# to allow for the proxy to be established before the connection is made

self.uri = uri
self.__kwargs = kwargs

u = urlparse(uri)
host = u.hostname

if u.scheme == "ws":
port = u.port or 80
proxy_url = os.environ.get("HTTP_PROXY")
elif u.scheme == "wss":
port = u.port or 443
proxy_url = os.environ.get("HTTPS_PROXY")
kwargs["server_hostname"] = host
else:
raise ValueError(
"Unsupported scheme %s. Expected 'ws' or 'wss'. " % u.scheme
)

self.__proxy = Proxy.from_url(proxy_url) if proxy_url else None
self.__host = host
self.__port = port

async def __proxy_connect__(self: Self) -> WebSocketClientProtocol:
if self.__proxy:
sock = await self.__proxy.connect(
dest_host=self.__host,
dest_port=self.__port,
)
self.__kwargs["sock"] = sock

super().__init__(self.uri, **self.__kwargs)
proto = await self.__await_impl__()
return proto

def __await__(self: Self) -> Generator[Any, None, WebSocketClientProtocol]:
return self.__proxy_connect__().__await__()


def websocket_connect(uri: str, **kwargs: Any) -> WebsocketProxyConnect:
return WebsocketProxyConnect(uri, **kwargs)
2 changes: 1 addition & 1 deletion tests/events/client/test_events_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
def mock_connect(*args, **kwargs):
return MockConnect()

monkeypatch.setattr("prefect.events.clients.connect", mock_connect)
monkeypatch.setattr("prefect.events.clients.websocket_connect", mock_connect)

with caplog.at_level(logging.WARNING):
with pytest.raises(Exception, match="Connection failed"):
Expand Down
42 changes: 42 additions & 0 deletions tests/utilities/test_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from unittest.mock import Mock

from prefect.utilities.proxy import WebsocketProxyConnect


def test_init_ws_without_proxy():
client = WebsocketProxyConnect("ws://example.com")
assert client.uri == "ws://example.com"
assert client._WebsocketProxyConnect__host == "example.com"
assert client._WebsocketProxyConnect__port == 80
assert client._WebsocketProxyConnect__proxy is None


def test_init_wss_without_proxy():
client = WebsocketProxyConnect("wss://example.com")
assert client.uri == "wss://example.com"
assert client._WebsocketProxyConnect__host == "example.com"
assert client._WebsocketProxyConnect__port == 443
assert "server_hostname" in client._WebsocketProxyConnect__kwargs
assert client._WebsocketProxyConnect__proxy is None


def test_init_ws_with_proxy(monkeypatch):
monkeypatch.setenv("HTTP_PROXY", "http://proxy:3128")
mock_proxy = Mock()
monkeypatch.setattr("prefect.utilities.proxy.Proxy", mock_proxy)

client = WebsocketProxyConnect("ws://example.com")

mock_proxy.from_url.assert_called_once_with("http://proxy:3128")
assert client._WebsocketProxyConnect__proxy is not None


def test_init_wss_with_proxy(monkeypatch):
monkeypatch.setenv("HTTPS_PROXY", "https://proxy:3128")
mock_proxy = Mock()
monkeypatch.setattr("prefect.utilities.proxy.Proxy", mock_proxy)

client = WebsocketProxyConnect("wss://example.com")

mock_proxy.from_url.assert_called_once_with("https://proxy:3128")
assert client._WebsocketProxyConnect__proxy is not None

0 comments on commit b299653

Please sign in to comment.