From ff5be457dc2ed901f3de1493bdbcb80ec341207a Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 23 Oct 2024 12:33:13 -0700 Subject: [PATCH] Use `safehttpx.get()` instead of `async_get_with_secure_transport()` (#9795) * changes * add changeset * format * format * format * add changeset * remove tests * bump --------- Co-authored-by: gradio-pr-bot --- .changeset/green-rings-create.md | 5 ++ gradio/processing_utils.py | 91 +++----------------------------- requirements.txt | 1 + test/test_processing_utils.py | 31 ----------- 4 files changed, 14 insertions(+), 114 deletions(-) create mode 100644 .changeset/green-rings-create.md diff --git a/.changeset/green-rings-create.md b/.changeset/green-rings-create.md new file mode 100644 index 0000000000000..0812653bcebe8 --- /dev/null +++ b/.changeset/green-rings-create.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Use `safehttpx.get()` instead of `async_get_with_secure_transport()` diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 626030f954d5d..332bb0b8d2143 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -9,8 +9,6 @@ import mimetypes import os import shutil -import socket -import ssl import subprocess import tempfile import warnings @@ -24,6 +22,7 @@ import aiofiles import httpx import numpy as np +import safehttpx as sh from gradio_client import utils as client_utils from PIL import Image, ImageOps, ImageSequence, PngImagePlugin @@ -326,84 +325,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]: return decorator -@lru_cache_async(maxsize=256) -async def async_resolve_hostname_google(hostname: str) -> list[str]: - async with httpx.AsyncClient() as client: - try: - response_v4 = await client.get( - f"https://dns.google/resolve?name={hostname}&type=A" - ) - response_v6 = await client.get( - f"https://dns.google/resolve?name={hostname}&type=AAAA" - ) - - ips = [] - for response in [response_v4.json(), response_v6.json()]: - ips.extend([answer["data"] for answer in response.get("Answer", [])]) - return ips - except Exception: - return [] - - -class AsyncSecureTransport(httpx.AsyncHTTPTransport): - def __init__(self, verified_ip: str): - self.verified_ip = verified_ip - super().__init__() - - async def connect( - self, - hostname: str, - port: int, - _timeout: float | None = None, - ssl_context: ssl.SSLContext | None = None, - **_kwargs: Any, - ): - loop = asyncio.get_event_loop() - sock = await loop.getaddrinfo(self.verified_ip, port) - sock = socket.socket(sock[0][0], sock[0][1]) - await loop.sock_connect(sock, (self.verified_ip, port)) - if ssl_context: - sock = ssl_context.wrap_socket(sock, server_hostname=hostname) - return sock - - -async def async_validate_url(url: str) -> str: - hostname = urlparse(url).hostname - if not hostname: - raise ValueError(f"URL {url} does not have a valid hostname") - try: - loop = asyncio.get_event_loop() - addrinfo = await loop.getaddrinfo(hostname, None) - except socket.gaierror as e: - raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e - - for family, _, _, _, sockaddr in addrinfo: - ip_address = sockaddr[0] - if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address): - return ip_address - - if not wasm_utils.IS_WASM: - for ip_address in await async_resolve_hostname_google(hostname): - if is_public_ip(ip_address): - return ip_address - - raise ValueError(f"Hostname {hostname} failed validation") - - -async def async_get_with_secure_transport( - url: str, trust_hostname: bool = False -) -> httpx.Response: - if wasm_utils.IS_WASM: - transport = PyodideHttpTransport() - elif trust_hostname: - transport = None - else: - verified_ip = await async_validate_url(url) - transport = AsyncSecureTransport(verified_ip) - async with httpx.AsyncClient(transport=transport) as client: - return await client.get(url, follow_redirects=False) - - async def async_ssrf_protected_download(url: str, cache_dir: str) -> str: temp_dir = Path(cache_dir) / hash_url(url) temp_dir.mkdir(exist_ok=True, parents=True) @@ -416,8 +337,8 @@ async def async_ssrf_protected_download(url: str, cache_dir: str) -> str: parsed_url = urlparse(url) hostname = parsed_url.hostname - response = await async_get_with_secure_transport( - url, trust_hostname=hostname in PUBLIC_HOSTNAME_WHITELIST + response = await sh.get( + url, domain_whitelist=PUBLIC_HOSTNAME_WHITELIST, _transport=async_transport ) while response.is_redirect: @@ -427,7 +348,11 @@ async def async_ssrf_protected_download(url: str, cache_dir: str) -> str: if not redirect_parsed.hostname: redirect_url = f"{parsed_url.scheme}://{hostname}{redirect_url}" - response = await async_get_with_secure_transport(redirect_url) + response = await sh.get( + redirect_url, + domain_whitelist=PUBLIC_HOSTNAME_WHITELIST, + _transport=async_transport, + ) if response.status_code != 200: raise Exception(f"Failed to download file. Status code: {response.status_code}") diff --git a/requirements.txt b/requirements.txt index 675a0d605b8a3..de88f2b6c24c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ python-multipart>=0.0.9,!=0.0.13 # required for fastapi forms. 0.0.13 was yanke pydub pyyaml>=5.0,<7.0 ruff>=0.2.2; sys.platform != 'emscripten' +safehttpx>=0.1.1,<1.0 semantic_version~=2.0 starlette>=0.40.0,<1.0; sys.platform != 'emscripten' tomlkit==0.12.0 diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 7fd62e4b854ba..5a6fbef6dc649 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -408,37 +408,6 @@ async def test_json_data_not_moved_to_cache(): ) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "url", - [ - "https://localhost", - "http://127.0.0.1/file/a/b/c", - "http://[::1]", - "https://192.168.0.1", - "http://10.0.0.1?q=a", - "http://192.168.1.250.nip.io", - ], -) -async def test_local_urls_fail(url): - with pytest.raises(ValueError, match="failed validation"): - await processing_utils.async_validate_url(url) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "url", - [ - "https://google.com", - "https://8.8.8.8/", - "http://93.184.215.14.nip.io/", - "https://huggingface.co/datasets/dylanebert/3dgs/resolve/main/luigi/luigi.ply", - ], -) -async def test_public_urls_pass(url): - await processing_utils.async_validate_url(url) - - def test_public_request_pass(): tempdir = tempfile.TemporaryDirectory() file = processing_utils.ssrf_protected_download(