diff --git a/CHANGELOG.md b/CHANGELOG.md index e8720b8378dd8..256c876da439f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Make `Blocks.load` behave like other event listeners (allows chaining `then` off of it) [@anentropic](https://github.com/anentropic/) in [PR 4304](https://github.com/gradio-app/gradio/pull/4304) - Respect `interactive=True` in output components of a `gr.Interface` by [@abidlabs](https://github.com/abidlabs) in [PR 4356](https://github.com/gradio-app/gradio/pull/4356). - Remove unused frontend code by [@akx](https://github.com/akx) in [PR 4275](https://github.com/gradio-app/gradio/pull/4275) +- Prevent path traversal in `/file` routes by [@abidlabs](https://github.com/abidlabs) in [PR 4370](https://github.com/gradio-app/gradio/pull/4370). - Do not send HF token to other domains via `/proxy` route by [@abidlabs](https://github.com/abidlabs) in [PR 4368](https://github.com/gradio-app/gradio/pull/4368). ## Other Changes: diff --git a/gradio/routes.py b/gradio/routes.py index 4148a3e69d6f4..8990ce9a28cb7 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -327,49 +327,47 @@ async def file(path_or_url: str, request: fastapi.Request): return RedirectResponse( url=path_or_url, status_code=status.HTTP_302_FOUND ) + abs_path = utils.abspath(path_or_url) + in_blocklist = any( utils.is_in_or_equal(abs_path, blocked_path) for blocked_path in blocks.blocked_paths ) - if in_blocklist or any(part.startswith(".") for part in abs_path.parts): + is_dotfile = any(part.startswith(".") for part in abs_path.parts) + is_dir = abs_path.is_dir() + + if in_blocklist or is_dotfile or is_dir: raise HTTPException(403, f"File not allowed: {path_or_url}.") + if not abs_path.exists(): + raise HTTPException(404, f"File not found: {path_or_url}.") - in_app_dir = utils.abspath(app.cwd) in abs_path.parents + in_app_dir = utils.is_in_or_equal(abs_path, app.cwd) created_by_app = str(abs_path) in set().union(*blocks.temp_file_sets) - in_file_dir = any( + in_allowlist = any( utils.is_in_or_equal(abs_path, allowed_path) for allowed_path in blocks.allowed_paths ) - was_uploaded = utils.abspath(app.uploaded_file_dir) in abs_path.parents - - if in_app_dir or created_by_app or in_file_dir or was_uploaded: - if not abs_path.exists(): - raise HTTPException(404, "File not found") - if abs_path.is_dir(): - raise HTTPException(403) - - range_val = request.headers.get("Range", "").strip() - if range_val.startswith("bytes=") and "-" in range_val: - range_val = range_val[6:] - start, end = range_val.split("-") - if start.isnumeric() and end.isnumeric(): - start = int(start) - end = int(end) - response = ranged_response.RangedFileResponse( - abs_path, - ranged_response.OpenRange(start, end), - dict(request.headers), - stat_result=os.stat(abs_path), - ) - return response - return FileResponse(abs_path, headers={"Accept-Ranges": "bytes"}) + was_uploaded = utils.is_in_or_equal(abs_path, app.uploaded_file_dir) - else: - raise HTTPException( - 403, - f"File cannot be fetched: {path_or_url}. All files must contained within the Gradio python app working directory, or be a temp file created by the Gradio python app.", - ) + if not (in_app_dir or created_by_app or in_allowlist or was_uploaded): + raise HTTPException(403, f"File not allowed: {path_or_url}.") + + range_val = request.headers.get("Range", "").strip() + if range_val.startswith("bytes=") and "-" in range_val: + range_val = range_val[6:] + start, end = range_val.split("-") + if start.isnumeric() and end.isnumeric(): + start = int(start) + end = int(end) + response = ranged_response.RangedFileResponse( + abs_path, + ranged_response.OpenRange(start, end), + dict(request.headers), + stat_result=os.stat(abs_path), + ) + return response + return FileResponse(abs_path, headers={"Accept-Ranges": "bytes"}) @app.get("/file/{path:path}", dependencies=[Depends(login_check)]) async def file_deprecated(path: str, request: fastapi.Request): diff --git a/gradio/utils.py b/gradio/utils.py index 90d8657713b1f..9af39fbc1c6cf 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -844,12 +844,16 @@ def is_in_or_equal(path_1: str | Path, path_2: str | Path): True if path_1 is a descendant (i.e. located within) path_2 or if the paths are the same, returns False otherwise. Parameters: - path_1: str or Path (can be a file or directory) + path_1: str or Path (should be a file) path_2: str or Path (can be a file or directory) """ - return (abspath(path_2) in abspath(path_1).parents) or abspath(path_1) == abspath( - path_2 - ) + path_1, path_2 = abspath(path_1), abspath(path_2) + try: + if str(path_1.relative_to(path_2)).startswith(".."): # prevent path traversal + return False + except ValueError: + return False + return True def get_serializer_name(block: Block) -> str | None: diff --git a/test/test_utils.py b/test/test_utils.py index 911b851aef308..8f08a516242e7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -31,6 +31,7 @@ format_ner_list, get_type_hints, ipython_check, + is_in_or_equal, is_special_typed_parameter, kaggle_check, readme_to_html, @@ -623,3 +624,12 @@ def test_tex2svg_preserves_matplotlib_backend(): ): tex2svg("$$$1+1=2$$$") assert matplotlib.get_backend() == "svg" + + +def test_is_in_or_equal(): + assert is_in_or_equal("files/lion.jpg", "files/lion.jpg") + assert is_in_or_equal("files/lion.jpg", "files") + assert not is_in_or_equal("files", "files/lion.jpg") + assert is_in_or_equal("/home/usr/notes.txt", "/home/usr/") + assert not is_in_or_equal("/home/usr/subdirectory", "/home/usr/notes.txt") + assert not is_in_or_equal("/home/usr/../../etc/notes.txt", "/home/usr/")