Skip to content

Commit

Permalink
Use safehttpx.get() instead of async_get_with_secure_transport() (#…
Browse files Browse the repository at this point in the history
…9795)

* changes

* add changeset

* format

* format

* format

* add changeset

* remove tests

* bump

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Oct 23, 2024
1 parent 5e89b6d commit ff5be45
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 114 deletions.
5 changes: 5 additions & 0 deletions .changeset/green-rings-create.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Use `safehttpx.get()` instead of `async_get_with_secure_transport()`
91 changes: 8 additions & 83 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import mimetypes
import os
import shutil
import socket
import ssl
import subprocess
import tempfile
import warnings
Expand All @@ -24,6 +22,7 @@
import aiofiles
import httpx
import numpy as np
import safehttpx as sh
from gradio_client import utils as client_utils
from PIL import Image, ImageOps, ImageSequence, PngImagePlugin

Expand Down Expand Up @@ -326,84 +325,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return decorator


@lru_cache_async(maxsize=256)
async def async_resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)

ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []


class AsyncSecureTransport(httpx.AsyncHTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()

async def connect(
self,
hostname: str,
port: int,
_timeout: float | None = None,
ssl_context: ssl.SSLContext | None = None,
**_kwargs: Any,
):
loop = asyncio.get_event_loop()
sock = await loop.getaddrinfo(self.verified_ip, port)
sock = socket.socket(sock[0][0], sock[0][1])
await loop.sock_connect(sock, (self.verified_ip, port))
if ssl_context:
sock = ssl_context.wrap_socket(sock, server_hostname=hostname)
return sock


async def async_validate_url(url: str) -> str:
hostname = urlparse(url).hostname
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e

for family, _, _, _, sockaddr in addrinfo:
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address

if not wasm_utils.IS_WASM:
for ip_address in await async_resolve_hostname_google(hostname):
if is_public_ip(ip_address):
return ip_address

raise ValueError(f"Hostname {hostname} failed validation")


async def async_get_with_secure_transport(
url: str, trust_hostname: bool = False
) -> httpx.Response:
if wasm_utils.IS_WASM:
transport = PyodideHttpTransport()
elif trust_hostname:
transport = None
else:
verified_ip = await async_validate_url(url)
transport = AsyncSecureTransport(verified_ip)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, follow_redirects=False)


async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
temp_dir = Path(cache_dir) / hash_url(url)
temp_dir.mkdir(exist_ok=True, parents=True)
Expand All @@ -416,8 +337,8 @@ async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
parsed_url = urlparse(url)
hostname = parsed_url.hostname

response = await async_get_with_secure_transport(
url, trust_hostname=hostname in PUBLIC_HOSTNAME_WHITELIST
response = await sh.get(
url, domain_whitelist=PUBLIC_HOSTNAME_WHITELIST, _transport=async_transport
)

while response.is_redirect:
Expand All @@ -427,7 +348,11 @@ async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
if not redirect_parsed.hostname:
redirect_url = f"{parsed_url.scheme}://{hostname}{redirect_url}"

response = await async_get_with_secure_transport(redirect_url)
response = await sh.get(
redirect_url,
domain_whitelist=PUBLIC_HOSTNAME_WHITELIST,
_transport=async_transport,
)

if response.status_code != 200:
raise Exception(f"Failed to download file. Status code: {response.status_code}")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ python-multipart>=0.0.9,!=0.0.13 # required for fastapi forms. 0.0.13 was yanke
pydub
pyyaml>=5.0,<7.0
ruff>=0.2.2; sys.platform != 'emscripten'
safehttpx>=0.1.1,<1.0
semantic_version~=2.0
starlette>=0.40.0,<1.0; sys.platform != 'emscripten'
tomlkit==0.12.0
Expand Down
31 changes: 0 additions & 31 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,37 +408,6 @@ async def test_json_data_not_moved_to_cache():
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"url",
[
"https://localhost",
"http://127.0.0.1/file/a/b/c",
"http://[::1]",
"https://192.168.0.1",
"http://10.0.0.1?q=a",
"http://192.168.1.250.nip.io",
],
)
async def test_local_urls_fail(url):
with pytest.raises(ValueError, match="failed validation"):
await processing_utils.async_validate_url(url)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"url",
[
"https://google.com",
"https://8.8.8.8/",
"http://93.184.215.14.nip.io/",
"https://huggingface.co/datasets/dylanebert/3dgs/resolve/main/luigi/luigi.ply",
],
)
async def test_public_urls_pass(url):
await processing_utils.async_validate_url(url)


def test_public_request_pass():
tempdir = tempfile.TemporaryDirectory()
file = processing_utils.ssrf_protected_download(
Expand Down

0 comments on commit ff5be45

Please sign in to comment.