diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index ab3c069447ab..36f0a3496eb2 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -141,4 +141,4 @@ jobs: --build-arg SQLITE_VERSION=3310100 \ --build-arg SQLITE_YEAR=2020 \ -f old-sqlite.Dockerfile . && - docker run prefect-server-new-sqlite sh -c "prefect server database downgrade --yes -r base && prefect server database upgrade --yes" + docker run prefect-server-new-sqlite sh -c "prefect server database downgrade --yes -r base && prefect server database upgrade --yes" \ No newline at end of file diff --git a/.github/workflows/proxy-test.yaml b/.github/workflows/proxy-test.yaml new file mode 100644 index 000000000000..cf6d92daf04e --- /dev/null +++ b/.github/workflows/proxy-test.yaml @@ -0,0 +1,74 @@ +# This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a +# simple server and a squid proxy server. The proxy server is inaccessible from the host machine, only the proxy +# so we can confirm the proxy is actually working. + +name: Proxy Test +on: + pull_request: + paths: + - .github/workflows/proxy-test.yaml + - scripts/proxy-test/* + - "src/prefect/utilities/proxy.py" + - requirements.txt + - requirements-client.txt + - requirements-dev.txt + push: + branches: + - main + paths: + - .github/workflows/proxy-test.yaml + - scripts/proxy-test/* + - "src/prefect/utilities/proxy.py" + - requirements.txt + - requirements-client.txt + - requirements-dev.txt + +jobs: + proxy-test: + name: Proxy Test + timeout-minutes: 10 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + id: setup_python + with: + python-version: "3.10" + + - name: Create Docker networks + run: | + docker network create internal_net --internal + docker network create external_net + + - name: Start API server container + working-directory: scripts/proxy-test + run: | + docker build -t api-server . + docker run -d --network internal_net --name server api-server + + - name: Start Squid Proxy container + run: | + docker run -d \ + --network internal_net \ + --network external_net \ + -p 3128:3128 \ + -v $(pwd)/scripts/proxy-test/squid.conf:/etc/squid/squid.conf \ + --name proxy \ + ubuntu/squid + + - name: Install Dependencies + run: | + python -m pip install -U uv + uv pip install --upgrade --system . + + - name: Run Proxy Tests + env: + HTTP_PROXY: http://localhost:3128 + HTTPS_PROXY: http://localhost:3128 + run: python scripts/proxy-test/client.py diff --git a/requirements-client.txt b/requirements-client.txt index de5e2b5ab1e5..e5424a1c85c1 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -26,6 +26,7 @@ pydantic_extra_types >= 2.8.2, < 3.0.0 pydantic_settings > 2.2.1 python_dateutil >= 2.8.2, < 3.0.0 python-slugify >= 5.0, < 9.0 +python-socks[asyncio] >= 2.5.3, < 3.0 pyyaml >= 5.4.1, < 7.0.0 rfc3339-validator >= 0.1.4, < 0.2.0 rich >= 11.0, < 14.0 diff --git a/scripts/proxy-test/Dockerfile b/scripts/proxy-test/Dockerfile new file mode 100644 index 000000000000..93b6c4db9107 --- /dev/null +++ b/scripts/proxy-test/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install uv +RUN uv pip install --no-cache-dir --system -r requirements.txt + +COPY server.py . + +CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/scripts/proxy-test/README.md b/scripts/proxy-test/README.md new file mode 100644 index 000000000000..76a8c8fdc55f --- /dev/null +++ b/scripts/proxy-test/README.md @@ -0,0 +1,9 @@ +This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a +simple server and a squid proxy server. The proxy server is inaccessible from the host machine, so we +can confirm the proxy connection is working. + +``` +$ uv pip install -r requirements.txt +$ docker compose up --build +$ python client.py +``` diff --git a/scripts/proxy-test/client.py b/scripts/proxy-test/client.py new file mode 100644 index 000000000000..1ec748a26131 --- /dev/null +++ b/scripts/proxy-test/client.py @@ -0,0 +1,27 @@ +import asyncio +import os + +from prefect.utilities.proxy import websocket_connect + +PROXY_URL = "http://localhost:3128" +WS_SERVER_URL = "ws://server:8000/ws" + + +async def test_websocket_proxy_with_compat(): + """WebSocket through proxy with proxy compatibility code - should work""" + os.environ["HTTP_PROXY"] = "http://localhost:3128" + + async with websocket_connect("ws://server:8000/ws") as websocket: + await websocket.send("Hello!") + response = await websocket.recv() + print("Response: ", response) + assert response == "Server received: Hello!" + + +async def main(): + print("Testing WebSocket through proxy with compatibility code") + await test_websocket_proxy_with_compat() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/proxy-test/docker-compose.yml b/scripts/proxy-test/docker-compose.yml new file mode 100644 index 000000000000..ba20d37b9b6f --- /dev/null +++ b/scripts/proxy-test/docker-compose.yml @@ -0,0 +1,20 @@ +services: + server: + build: . + networks: + - internal_net + + forward_proxy: + image: ubuntu/squid + ports: + - "3128:3128" + volumes: + - ./squid.conf:/etc/squid/squid.conf + networks: + - internal_net + - external_net + +networks: + internal_net: + internal: true + external_net: \ No newline at end of file diff --git a/scripts/proxy-test/requirements.txt b/scripts/proxy-test/requirements.txt new file mode 100644 index 000000000000..247c9f64fc32 --- /dev/null +++ b/scripts/proxy-test/requirements.txt @@ -0,0 +1,6 @@ +fastapi==0.111.1 +uvicorn==0.28.1 +uv==0.5.7 +websockets==13.1 +python-socks==2.5.3 +httpx==0.28.1 \ No newline at end of file diff --git a/scripts/proxy-test/server.py b/scripts/proxy-test/server.py new file mode 100644 index 000000000000..4f4498d04a7a --- /dev/null +++ b/scripts/proxy-test/server.py @@ -0,0 +1,10 @@ +from fastapi import FastAPI, WebSocket + +app = FastAPI() + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + async for data in websocket.iter_text(): + await websocket.send_text(f"Server received: {data}") diff --git a/scripts/proxy-test/squid.conf b/scripts/proxy-test/squid.conf new file mode 100644 index 000000000000..978f770fd236 --- /dev/null +++ b/scripts/proxy-test/squid.conf @@ -0,0 +1,5 @@ +http_port 3128 +acl CONNECT method CONNECT +acl SSL_ports port 443 8000 +http_access allow CONNECT SSL_ports +http_access allow all \ No newline at end of file diff --git a/src/prefect/events/clients.py b/src/prefect/events/clients.py index bd09eb3ab20c..3e09f033d1e2 100644 --- a/src/prefect/events/clients.py +++ b/src/prefect/events/clients.py @@ -21,7 +21,7 @@ from prometheus_client import Counter from typing_extensions import Self from websockets import Subprotocol -from websockets.client import WebSocketClientProtocol, connect +from websockets.client import WebSocketClientProtocol from websockets.exceptions import ( ConnectionClosed, ConnectionClosedError, @@ -37,6 +37,7 @@ PREFECT_DEBUG_MODE, PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, ) +from prefect.utilities.proxy import websocket_connect if TYPE_CHECKING: from prefect.events.filters import EventFilter @@ -265,7 +266,7 @@ def __init__( ) self._events_socket_url = events_in_socket_from_api_url(api_url) - self._connect = connect(self._events_socket_url) + self._connect = websocket_connect(self._events_socket_url) self._websocket = None self._reconnection_attempts = reconnection_attempts self._unconfirmed_events = [] @@ -435,7 +436,7 @@ def __init__( reconnection_attempts=reconnection_attempts, checkpoint_every=checkpoint_every, ) - self._connect = connect( + self._connect = websocket_connect( self._events_socket_url, extra_headers={"Authorization": f"bearer {api_key}"}, ) @@ -494,7 +495,7 @@ def __init__( logger.debug("Connecting to %s", socket_url) - self._connect = connect( + self._connect = websocket_connect( socket_url, subprotocols=[Subprotocol("prefect")], ) diff --git a/src/prefect/utilities/proxy.py b/src/prefect/utilities/proxy.py new file mode 100644 index 000000000000..e70accbb7afb --- /dev/null +++ b/src/prefect/utilities/proxy.py @@ -0,0 +1,58 @@ +import os +from typing import ( + Any, + Generator, +) +from urllib.parse import urlparse + +from python_socks.async_.asyncio import Proxy +from typing_extensions import Self +from websockets.client import WebSocketClientProtocol +from websockets.legacy.client import Connect + + +class WebsocketProxyConnect(Connect): + def __init__(self: Self, uri: str, **kwargs: Any): + # super() is intentionally deferred to the __proxy_connect__ method + # to allow for the proxy to be established before the connection is made + + self.uri = uri + self.__kwargs = kwargs + + u = urlparse(uri) + host = u.hostname + + if u.scheme == "ws": + port = u.port or 80 + proxy_url = os.environ.get("HTTP_PROXY") + elif u.scheme == "wss": + port = u.port or 443 + proxy_url = os.environ.get("HTTPS_PROXY") + kwargs["server_hostname"] = host + else: + raise ValueError( + "Unsupported scheme %s. Expected 'ws' or 'wss'. " % u.scheme + ) + + self.__proxy = Proxy.from_url(proxy_url) if proxy_url else None + self.__host = host + self.__port = port + + async def __proxy_connect__(self: Self) -> WebSocketClientProtocol: + if self.__proxy: + sock = await self.__proxy.connect( + dest_host=self.__host, + dest_port=self.__port, + ) + self.__kwargs["sock"] = sock + + super().__init__(self.uri, **self.__kwargs) + proto = await self.__await_impl__() + return proto + + def __await__(self: Self) -> Generator[Any, None, WebSocketClientProtocol]: + return self.__proxy_connect__().__await__() + + +def websocket_connect(uri: str, **kwargs: Any) -> WebsocketProxyConnect: + return WebsocketProxyConnect(uri, **kwargs) diff --git a/tests/events/client/test_events_client.py b/tests/events/client/test_events_client.py index 9acf2ef82b2b..5b6970ae7e3b 100644 --- a/tests/events/client/test_events_client.py +++ b/tests/events/client/test_events_client.py @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def mock_connect(*args, **kwargs): return MockConnect() - monkeypatch.setattr("prefect.events.clients.connect", mock_connect) + monkeypatch.setattr("prefect.events.clients.websocket_connect", mock_connect) with caplog.at_level(logging.WARNING): with pytest.raises(Exception, match="Connection failed"): diff --git a/tests/utilities/test_proxy.py b/tests/utilities/test_proxy.py new file mode 100644 index 000000000000..4ddecd8fa6c7 --- /dev/null +++ b/tests/utilities/test_proxy.py @@ -0,0 +1,42 @@ +from unittest.mock import Mock + +from prefect.utilities.proxy import WebsocketProxyConnect + + +def test_init_ws_without_proxy(): + client = WebsocketProxyConnect("ws://example.com") + assert client.uri == "ws://example.com" + assert client._WebsocketProxyConnect__host == "example.com" + assert client._WebsocketProxyConnect__port == 80 + assert client._WebsocketProxyConnect__proxy is None + + +def test_init_wss_without_proxy(): + client = WebsocketProxyConnect("wss://example.com") + assert client.uri == "wss://example.com" + assert client._WebsocketProxyConnect__host == "example.com" + assert client._WebsocketProxyConnect__port == 443 + assert "server_hostname" in client._WebsocketProxyConnect__kwargs + assert client._WebsocketProxyConnect__proxy is None + + +def test_init_ws_with_proxy(monkeypatch): + monkeypatch.setenv("HTTP_PROXY", "http://proxy:3128") + mock_proxy = Mock() + monkeypatch.setattr("prefect.utilities.proxy.Proxy", mock_proxy) + + client = WebsocketProxyConnect("ws://example.com") + + mock_proxy.from_url.assert_called_once_with("http://proxy:3128") + assert client._WebsocketProxyConnect__proxy is not None + + +def test_init_wss_with_proxy(monkeypatch): + monkeypatch.setenv("HTTPS_PROXY", "https://proxy:3128") + mock_proxy = Mock() + monkeypatch.setattr("prefect.utilities.proxy.Proxy", mock_proxy) + + client = WebsocketProxyConnect("wss://example.com") + + mock_proxy.from_url.assert_called_once_with("https://proxy:3128") + assert client._WebsocketProxyConnect__proxy is not None