Skip to content

Commit

Permalink
Fix bug where file examples can be corrupted if has multiple extensio…
Browse files Browse the repository at this point in the history
…ns (#4440)

* Fix bug

* Add to changelog

* Add test

* Remove breakpoint

* fix test

* increment version

* update client version req

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
freddyaboulton and abidlabs authored Jun 7, 2023
1 parent e364f81 commit 4a58cce
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 33 deletions.
19 changes: 19 additions & 0 deletions client/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@ No changes to highlight.

## Bug Fixes:

No changes to highlight.

## Breaking Changes:

No changes to highlight.

## Full Changelog:

No changes to highlight.

# 0.2.6

## New Features:

No changes to highlight.

## Bug Fixes:

- Fixed bug file deserialization didn't preserve all file extensions by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4440](https://github.com/gradio-app/gradio/pull/4440)
- Fixed bug where mounted apps could not be called via the client by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4435](https://github.com/gradio-app/gradio/pull/4435)

## Breaking Changes:
Expand Down
6 changes: 2 additions & 4 deletions client/python/gradio_client/serializing.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,9 @@ def _deserialize_single(
root_url + "file=" + filepath,
hf_token=hf_token,
dir=save_dir,
).name
)
else:
file_name = utils.create_tmp_copy_of_file(
filepath, dir=save_dir
).name
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
else:
data = x.get("data")
assert data is not None, f"The 'data' field is missing in {x}"
Expand Down
40 changes: 14 additions & 26 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import mimetypes
import os
import pkgutil
import secrets
import shutil
import tempfile
from concurrent.futures import CancelledError
Expand Down Expand Up @@ -273,40 +274,27 @@ async def get_pred_from_ws(

def download_tmp_copy_of_file(
url_path: str, hf_token: str | None = None, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:
) -> str:
if dir is not None:
os.makedirs(dir, exist_ok=True)
headers = {"Authorization": "Bearer " + hf_token} if hf_token else {}
prefix = Path(url_path).stem
suffix = Path(url_path).suffix
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix=suffix,
dir=dir,
)
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)
file_path = directory / Path(url_path).name

with requests.get(url_path, headers=headers, stream=True) as r, open(
file_obj.name, "wb"
file_path, "wb"
) as f:
shutil.copyfileobj(r.raw, f)
return file_obj
return str(file_path.resolve())


def create_tmp_copy_of_file(
file_path: str, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:
if dir is not None:
os.makedirs(dir, exist_ok=True)
prefix = Path(file_path).stem
suffix = Path(file_path).suffix
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix=suffix,
dir=dir,
)
shutil.copy2(file_path, file_obj.name)
return file_obj
def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str:
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)
dest = directory / Path(file_path).name
shutil.copy2(file_path, dest)
return str(dest.resolve())


def get_mimetype(filename: str) -> str | None:
Expand Down
2 changes: 1 addition & 1 deletion client/python/gradio_client/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.5
0.2.6
2 changes: 1 addition & 1 deletion client/python/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_download_private_file():
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
hf_token = "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" # Intentionally revealing this key for testing purposes
file = utils.download_tmp_copy_of_file(url_path=url_path, hf_token=hf_token)
assert file.name.endswith(".jpg")
assert Path(file).name.endswith(".jpg")


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ aiohttp
altair>=4.2.0
fastapi
ffmpy
gradio_client>=0.2.4
gradio_client>=0.2.6
httpx
huggingface_hub>=0.14.0
Jinja2
Expand Down
25 changes: 25 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -372,3 +373,27 @@ async def test_multiple_file_flagging(tmp_path):

assert len(prediction[0]) == 2
assert all(isinstance(d, dict) for d in prediction[0])


@pytest.mark.asyncio
async def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
file_2 = tmp_path / "file_2"
file_2.mkdir(parents=True)
file_2 = file_2 / "foo.bar.txt"
file_2.write_text("file 2")
io = gr.Interface(
fn=lambda x: x.name,
inputs=gr.File(),
outputs=[gr.File()],
examples=[[str(file_1)], [str(file_2)]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert Path(prediction[0]["name"]).read_text() == "file 1"
assert prediction[0]["orig_name"] == "foo.bar.txt"
prediction = await io.examples_handler.load_from_cache(1)
assert Path(prediction[0]["name"]).read_text() == "file 2"
assert prediction[0]["orig_name"] == "foo.bar.txt"

0 comments on commit 4a58cce

Please sign in to comment.