Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support http proxies for websockets #16326

Merged
merged 11 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions .github/workflows/proxy-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a
# simple server and a squid proxy server. The proxy server is inaccessible from the host machine, only the proxy
# so we can confirm the proxy is actually working.

name: Proxy Test
on:
pull_request:
paths:
- .github/workflows/proxy-test.yaml
- scripts/proxy-test/*
- "src/prefect/events/clients.py"
- requirements.txt
- requirements-client.txt
- requirements-dev.txt
push:
branches:
- main
paths:
- .github/workflows/proxy-test.yaml
- scripts/proxy-test/*
- "src/prefect/events/clients.py"
- requirements.txt
- requirements-client.txt
- requirements-dev.txt

jobs:
proxy-test:
name: Proxy Test
timeout-minutes: 10
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v5
id: setup_python
with:
python-version: "3.10"

- name: Create Docker networks
run: |
docker network create internal_net --internal
docker network create external_net
- name: Start API server container
working-directory: scripts/proxy-test
run: |
docker build -t api-server .
docker run -d --network internal_net --name server api-server
- name: Start Squid Proxy container
run: |
docker run -d \
--network internal_net \
--network external_net \
-p 3128:3128 \
-v $(pwd)/scripts/proxy-test/squid.conf:/etc/squid/squid.conf \
--name proxy \
ubuntu/squid
- name: Install Dependencies
run: |
python -m pip install -U uv
uv pip install --upgrade --system .
- name: Run Proxy Tests
env:
HTTP_PROXY: http://localhost:3128
HTTPS_PROXY: http://localhost:3128
run: python scripts/proxy-test/client.py
1 change: 1 addition & 0 deletions requirements-client.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pydantic_extra_types >= 2.8.2, < 3.0.0
pydantic_settings > 2.2.1
python_dateutil >= 2.8.2, < 3.0.0
python-slugify >= 5.0, < 9.0
python-socks[asyncio] >= 2.5.3, < 3.0
pyyaml >= 5.4.1, < 7.0.0
rfc3339-validator >= 0.1.4, < 0.2.0
rich >= 11.0, < 14.0
Expand Down
11 changes: 11 additions & 0 deletions scripts/proxy-test/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install uv
RUN uv pip install --no-cache-dir --system -r requirements.txt

COPY server.py .

CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
9 changes: 9 additions & 0 deletions scripts/proxy-test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a
jakekaplan marked this conversation as resolved.
Show resolved Hide resolved
simple server and a squid proxy server. The proxy server is inaccessible from the host machine, so we
can confirm the proxy connection is working.

```
$ uv pip install -r requirements.txt
$ docker compose up --build
$ python client.py
```
28 changes: 28 additions & 0 deletions scripts/proxy-test/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio
import os

from prefect.events.clients import websocket_connect

PROXY_URL = "http://localhost:3128"
WS_SERVER_URL = "ws://server:8000/ws"


async def test_websocket_proxy_with_compat():
"""WebSocket through proxy with proxy compatibility code - should work"""
os.environ["HTTP_PROXY"] = PROXY_URL

async with websocket_connect(WS_SERVER_URL) as websocket:
message = "Hello!"
await websocket.send(message)
response = await websocket.recv()
print("Response: ", response)
assert response == f"Server received: {message}"


async def main():
print("Testing WebSocket through proxy with compatibility code")
await test_websocket_proxy_with_compat()


if __name__ == "__main__":
asyncio.run(main())
20 changes: 20 additions & 0 deletions scripts/proxy-test/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
services:
server:
build: .
networks:
- internal_net

forward_proxy:
image: ubuntu/squid
ports:
- "3128:3128"
volumes:
- ./squid.conf:/etc/squid/squid.conf
networks:
- internal_net
- external_net

networks:
internal_net:
internal: true
external_net:
6 changes: 6 additions & 0 deletions scripts/proxy-test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi==0.111.1
uvicorn==0.28.1
uv==0.5.7
websockets==13.1
python-socks==2.5.3
httpx==0.28.1
10 changes: 10 additions & 0 deletions scripts/proxy-test/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from fastapi import FastAPI, WebSocket

app = FastAPI()


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
async for data in websocket.iter_text():
await websocket.send_text(f"Server received: {data}")
5 changes: 5 additions & 0 deletions scripts/proxy-test/squid.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
http_port 3128
acl CONNECT method CONNECT
acl SSL_ports port 443 8000
http_access allow CONNECT SSL_ports
http_access allow all
59 changes: 55 additions & 4 deletions src/prefect/events/clients.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
import abc
import asyncio
import os
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generator,
List,
MutableMapping,
Optional,
Tuple,
Type,
cast,
)
from urllib.parse import urlparse
from uuid import UUID

import orjson
import pendulum
from cachetools import TTLCache
from prometheus_client import Counter
from python_socks.async_.asyncio import Proxy
from typing_extensions import Self
from websockets import Subprotocol
from websockets.client import WebSocketClientProtocol, connect
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
ConnectionClosedOK,
)
from websockets.legacy.client import Connect, WebSocketClientProtocol

from prefect.events import Event
from prefect.logging import get_logger
Expand Down Expand Up @@ -80,6 +84,53 @@ def events_out_socket_from_api_url(url: str):
return http_to_ws(url) + "/events/out"


class WebsocketProxyConnect(Connect):
def __init__(self: Self, uri: str, **kwargs: Any):
# super() is intentionally deferred to the _proxy_connect method
# to allow for the socket to be established first

self.uri = uri
self._kwargs = kwargs

u = urlparse(uri)
host = u.hostname

if u.scheme == "ws":
port = u.port or 80
proxy_url = os.environ.get("HTTP_PROXY")
elif u.scheme == "wss":
port = u.port or 443
proxy_url = os.environ.get("HTTPS_PROXY")
kwargs["server_hostname"] = host
else:
raise ValueError(
"Unsupported scheme %s. Expected 'ws' or 'wss'. " % u.scheme
)

self._proxy = Proxy.from_url(proxy_url) if proxy_url else None
self._host = host
self._port = port

async def _proxy_connect(self: Self) -> WebSocketClientProtocol:
if self._proxy:
sock = await self._proxy.connect(
dest_host=self._host,
dest_port=self._port,
)
self._kwargs["sock"] = sock

super().__init__(self.uri, **self._kwargs)
proto = await self.__await_impl__()
return proto

def __await__(self: Self) -> Generator[Any, None, WebSocketClientProtocol]:
return self._proxy_connect().__await__()


def websocket_connect(uri: str, **kwargs: Any) -> WebsocketProxyConnect:
return WebsocketProxyConnect(uri, **kwargs)


def get_events_client(
reconnection_attempts: int = 10,
checkpoint_every: int = 700,
Expand Down Expand Up @@ -265,7 +316,7 @@ def __init__(
)

self._events_socket_url = events_in_socket_from_api_url(api_url)
self._connect = connect(self._events_socket_url)
self._connect = websocket_connect(self._events_socket_url)
self._websocket = None
self._reconnection_attempts = reconnection_attempts
self._unconfirmed_events = []
Expand Down Expand Up @@ -435,7 +486,7 @@ def __init__(
reconnection_attempts=reconnection_attempts,
checkpoint_every=checkpoint_every,
)
self._connect = connect(
self._connect = websocket_connect(
self._events_socket_url,
extra_headers={"Authorization": f"bearer {api_key}"},
)
Expand Down Expand Up @@ -494,7 +545,7 @@ def __init__(

logger.debug("Connecting to %s", socket_url)

self._connect = connect(
self._connect = websocket_connect(
socket_url,
subprotocols=[Subprotocol("prefect")],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/events/client/test_events_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
def mock_connect(*args, **kwargs):
return MockConnect()

monkeypatch.setattr("prefect.events.clients.connect", mock_connect)
monkeypatch.setattr("prefect.events.clients.websocket_connect", mock_connect)

with caplog.at_level(logging.WARNING):
with pytest.raises(Exception, match="Connection failed"):
Expand Down
42 changes: 42 additions & 0 deletions tests/utilities/test_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from unittest.mock import Mock

from prefect.events.clients import WebsocketProxyConnect


def test_init_ws_without_proxy():
client = WebsocketProxyConnect("ws://example.com")
assert client.uri == "ws://example.com"
assert client._host == "example.com"
assert client._port == 80
assert client._proxy is None


def test_init_wss_without_proxy():
client = WebsocketProxyConnect("wss://example.com")
assert client.uri == "wss://example.com"
assert client._host == "example.com"
assert client._port == 443
assert "server_hostname" in client._kwargs
assert client._proxy is None


def test_init_ws_with_proxy(monkeypatch):
monkeypatch.setenv("HTTP_PROXY", "http://proxy:3128")
mock_proxy = Mock()
monkeypatch.setattr("prefect.events.clients.Proxy", mock_proxy)

client = WebsocketProxyConnect("ws://example.com")

mock_proxy.from_url.assert_called_once_with("http://proxy:3128")
assert client._proxy is not None


def test_init_wss_with_proxy(monkeypatch):
monkeypatch.setenv("HTTPS_PROXY", "https://proxy:3128")
mock_proxy = Mock()
monkeypatch.setattr("prefect.events.clients.Proxy", mock_proxy)

client = WebsocketProxyConnect("wss://example.com")

mock_proxy.from_url.assert_called_once_with("https://proxy:3128")
assert client._proxy is not None
Loading