Skip to content

Commit

Permalink
Tighten CORS rules (#7503)
Browse files Browse the repository at this point in the history
* tighten cors rules

* add changeset

* cors policy

* cors

* add changeset

* lint

* changes

* changes

* changes

* logging

* add null

* changes

* changes

* options

* options

* safe changes

* let browser enforce cors

* clean

* route utils

* fix

* fix test

* fix

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Feb 22, 2024
1 parent b186767 commit 84802ee
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/olive-symbols-heal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Tighten CORS rules
57 changes: 57 additions & 0 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass as python_dataclass
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
from urllib.parse import urlparse

import fastapi
import httpx
Expand All @@ -17,6 +18,7 @@
from multipart.multipart import parse_options_header
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart
from starlette.middleware.base import BaseHTTPMiddleware

from gradio import processing_utils, utils
from gradio.data_classes import PredictBody
Expand Down Expand Up @@ -583,3 +585,58 @@ def starts_with_protocol(string: str) -> bool:
"""
pattern = r"^[a-zA-Z][a-zA-Z0-9+\-.]*://"
return re.match(pattern, string) is not None


def get_hostname(url: str) -> str:
"""
Returns the hostname of a given url, or an empty string if the url cannot be parsed.
Examples:
get_hostname("https://www.gradio.app") -> "www.gradio.app"
get_hostname("localhost:7860") -> "localhost"
get_hostname("127.0.0.1") -> "127.0.0.1"
"""
if not url:
return ""
if "://" not in url:
url = "http://" + url
try:
return urlparse(url).hostname or ""
except Exception:
return ""


class CustomCORSMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: fastapi.Request, call_next):
host: str = request.headers.get("host", "")
origin: str = request.headers.get("origin", "")
host_name = get_hostname(host)
origin_name = get_hostname(origin)

# Any of these hosts suggests that the Gradio app is running locally.
# Note: "null" is a special case that happens if a Gradio app is running
# as an embedded web component in a local static webpage.
localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
is_preflight = (
request.method == "OPTIONS"
and "access-control-request-method" in request.headers
)

if host_name in localhost_aliases and origin_name not in localhost_aliases:
allow_origin_header = None
else:
allow_origin_header = origin

if is_preflight:
response = fastapi.Response()
else:
response = await call_next(request)

if allow_origin_header:
response.headers["Access-Control-Allow-Origin"] = allow_origin_header
response.headers[
"Access-Control-Allow-Methods"
] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers[
"Access-Control-Allow-Headers"
] = "Origin, Content-Type, Accept"
return response
9 changes: 2 additions & 7 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import markupsafe
import orjson
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (
FileResponse,
HTMLResponse,
Expand All @@ -55,6 +54,7 @@
from gradio.processing_utils import add_root_url
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
CustomCORSMiddleware,
FileUploadProgress,
FileUploadProgressNotQueuedError,
FileUploadProgressNotTrackedError,
Expand Down Expand Up @@ -196,12 +196,7 @@ def create_app(
app.configure_app(blocks)

if not wasm_utils.IS_WASM:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(CustomCORSMiddleware)

@app.get("/user")
@app.get("/user/")
Expand Down
18 changes: 18 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,24 @@ def test_can_get_config_that_includes_non_pickle_able_objects(self):
response = client.get("/config/")
assert response.is_success

def test_cors_restrictions(self):
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
custom_headers = {
"host": "localhost:7860",
"origin": "https://example.com",
}
file_response = client.get("/config", headers=custom_headers)
assert "access-control-allow-origin" not in file_response.headers
custom_headers = {
"host": "localhost:7860",
"origin": "127.0.0.1",
}
file_response = client.get("/config", headers=custom_headers)
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
io.close()


class TestApp:
def test_create_app(self):
Expand Down

0 comments on commit 84802ee

Please sign in to comment.