Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNS resolver on ip check #9150

Merged
merged 14 commits into from
Aug 27, 2024
5 changes: 5 additions & 0 deletions .changeset/solid-chicken-love.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:DNS resolver on ip check
42 changes: 32 additions & 10 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import socket
import subprocess
import tempfile
import urllib.request
import warnings
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse

import aiofiles
import httpx
Expand Down Expand Up @@ -271,7 +272,19 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
return full_temp_file_path


def check_public_url(url: str):
def resolve_with_google_dns(hostname: str) -> str | None:
url = f"https://dns.google/resolve?name={hostname}&type=A"
abidlabs marked this conversation as resolved.
Show resolved Hide resolved

with urllib.request.urlopen(url) as response:
data = json.loads(response.read().decode())

if data.get("Status") == 0 and "Answer" in data:
for answer in data["Answer"]:
if answer["type"] == 1:
return answer["data"]


def get_public_url(url: str) -> str:
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise httpx.RequestError(f"Invalid URL: {url}")
Expand All @@ -289,18 +302,27 @@ def check_public_url(url: str):
if family == socket.AF_INET6:
ip = ip.split("%")[0] # Remove scope ID if present

if not ipaddress.ip_address(ip).is_global:
raise httpx.RequestError(
f"Non-public IP address found: {ip} for URL: {url}"
if ipaddress.ip_address(ip).is_global:
return url

google_resolved_ip = resolve_with_google_dns(hostname)
if google_resolved_ip and ipaddress.ip_address(google_resolved_ip).is_global:
if parsed_url.scheme == "https":
return url
new_parsed = parsed_url._replace(netloc=google_resolved_ip)
if parsed_url.port:
new_parsed = new_parsed._replace(
netloc=f"{google_resolved_ip}:{parsed_url.port}"
)
return urlunparse(new_parsed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we've verified that the url maps to a public domain, why can't we use the original url?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return True
raise httpx.RequestError(f"No public IP address found for URL: {url}")


def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
check_public_url(url)
url = get_public_url(url)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
Expand All @@ -314,7 +336,7 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
open(full_temp_file_path, "wb") as f,
):
for redirect in response.history:
check_public_url(str(redirect.url))
get_public_url(str(redirect.url))

for chunk in response.iter_raw():
f.write(chunk)
Expand All @@ -325,7 +347,7 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file. Uses async httpx."""
check_public_url(url)
url = get_public_url(url)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
Expand All @@ -336,7 +358,7 @@ async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
if not Path(full_temp_file_path).exists():
async with async_client.stream("GET", url, follow_redirects=True) as response:
for redirect in response.history:
check_public_url(str(redirect.url))
get_public_url(str(redirect.url))

async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_raw():
Expand Down
6 changes: 3 additions & 3 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ async def test_json_data_not_moved_to_cache():
],
)
def test_local_urls_fail(url):
with pytest.raises(httpx.RequestError, match="Non-public IP address found"):
processing_utils.check_public_url(url)
with pytest.raises(httpx.RequestError, match="No public IP address found for URL"):
processing_utils.get_public_url(url)


@pytest.mark.parametrize(
Expand All @@ -433,4 +433,4 @@ def test_local_urls_fail(url):
],
)
def test_public_urls_pass(url):
assert processing_utils.check_public_url(url)
assert processing_utils.get_public_url(url)
Loading