From 8f0f434ff5ed2378b6fa41f245409df561e625af Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 17 May 2023 13:15:44 -0400 Subject: [PATCH 1/9] Fix bug --- gradio/components.py | 60 +++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 60a69131a95c1..521ff7ee53a9e 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -237,13 +237,21 @@ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: sha1.update(data) return sha1.hexdigest() - def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: + def _hash( + self, base64_encoding: str | bytes, chunk_num_blocks: int = 128, encode=True + ): sha1 = hashlib.sha1() for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] - sha1.update(data.encode("utf-8")) + sha1.update(data.encode("utf-8") if encode else data) return sha1.hexdigest() + def hash_bytes(self, bytes: bytes, chunk_num_blocks: int = 128): + return self._hash(bytes, chunk_num_blocks, encode=False) + + def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: + return self._hash(base64_encoding, chunk_num_blocks, encode=False) + def make_temp_copy_if_needed(self, file_path: str) -> str: """Returns a temporary file path for a copy of the given file path if it does not already exist. Otherwise returns the path to the existing temp file.""" @@ -251,9 +259,8 @@ def make_temp_copy_if_needed(self, file_path: str) -> str: temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = client_utils.strip_invalid_filename_characters(Path(file_path).name) - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + name = client_utils.strip_invalid_filename_characters(Path(file_path).name) + full_temp_file_path = str(utils.abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): shutil.copy2(file_path, full_temp_file_path) @@ -292,10 +299,9 @@ def download_temp_copy_if_needed(self, url: str) -> str: temp_dir = self.hash_url(url) temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = client_utils.strip_invalid_filename_characters(Path(url).name) - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + name = client_utils.strip_invalid_filename_characters(Path(url).name) + full_temp_file_path = str(utils.abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): with requests.get(url, stream=True) as r, open( @@ -323,8 +329,7 @@ def base64_to_temp_file_if_needed( file_name = f"file.{guess_extension}" else: file_name = "file" - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = file_name # type: ignore + full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore if not Path(full_temp_file_path).exists(): @@ -335,6 +340,19 @@ def base64_to_temp_file_if_needed( self.temp_files.add(full_temp_file_path) return full_temp_file_path + def pil_to_temp_file(self, img: _Image.Image, dir: str) -> str: + filename = str(Path(dir) / f"{self.hash_base64(img.tobytes())}.png") + img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) + return filename + + def array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: + filename = str(Path(dir) / f"{self.hash_base64(arr.tobytes())}.png") + pil_image = _Image.fromarray( + processing_utils._convert(arr, np.uint8, force_copy=False) + ) + pil_image.save(filename) + return filename + def get_config(self): config = { "label": self.label, @@ -4068,11 +4086,11 @@ def postprocess( base_img_path = base_img base_img = np.array(_Image.open(base_img)) elif isinstance(base_img, np.ndarray): - base_file = processing_utils.save_array_to_file(base_img) - base_img_path = str(utils.abspath(base_file.name)) + base_file = self.array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) + base_img_path = str(utils.abspath(base_file)) elif isinstance(base_img, _Image.Image): - base_file = processing_utils.save_pil_to_file(base_img) - base_img_path = str(utils.abspath(base_file.name)) + base_file = self.pil_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) + base_img_path = str(utils.abspath(base_file)) base_img = np.array(base_img) else: raise ValueError( @@ -4116,8 +4134,10 @@ def hex_to_rgb(value): colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8)) - mask_file = processing_utils.save_pil_to_file(colored_mask_img) - mask_file_path = str(utils.abspath(mask_file.name)) + mask_file = self.pil_to_temp_file( + colored_mask_img, dir=self.DEFAULT_TEMP_DIR + ) + mask_file_path = str(utils.abspath(mask_file)) self.temp_files.add(mask_file_path) sections.append( @@ -4404,12 +4424,12 @@ def postprocess( if isinstance(img, (tuple, list)): img, caption = img if isinstance(img, np.ndarray): - file = processing_utils.save_array_to_file(img) - file_path = str(utils.abspath(file.name)) + file = self.array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR) + file_path = str(utils.abspath(file)) self.temp_files.add(file_path) elif isinstance(img, _Image.Image): - file = processing_utils.save_pil_to_file(img) - file_path = str(utils.abspath(file.name)) + file = self.pil_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR) + file_path = str(utils.abspath(file)) self.temp_files.add(file_path) elif isinstance(img, str): if utils.validate_url(img): From 5d91178a59496d027bc03729c829a65865b18797 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 17 May 2023 13:20:27 -0400 Subject: [PATCH 2/9] Linting --- gradio/components.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 521ff7ee53a9e..f0928e246bec8 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -237,20 +237,18 @@ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: sha1.update(data) return sha1.hexdigest() - def _hash( - self, base64_encoding: str | bytes, chunk_num_blocks: int = 128, encode=True - ): + def _hash(self, base64_encoding: str | bytes, chunk_num_blocks: int = 128): sha1 = hashlib.sha1() for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] - sha1.update(data.encode("utf-8") if encode else data) + sha1.update(data.encode("utf-8") if isinstance(data, str) else data) return sha1.hexdigest() def hash_bytes(self, bytes: bytes, chunk_num_blocks: int = 128): - return self._hash(bytes, chunk_num_blocks, encode=False) + return self._hash(bytes, chunk_num_blocks) def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: - return self._hash(base64_encoding, chunk_num_blocks, encode=False) + return self._hash(base64_encoding, chunk_num_blocks) def make_temp_copy_if_needed(self, file_path: str) -> str: """Returns a temporary file path for a copy of the given file path if it does @@ -341,12 +339,12 @@ def base64_to_temp_file_if_needed( return full_temp_file_path def pil_to_temp_file(self, img: _Image.Image, dir: str) -> str: - filename = str(Path(dir) / f"{self.hash_base64(img.tobytes())}.png") + filename = str(Path(dir) / f"{self.hash_bytes(img.tobytes())}.png") img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) return filename def array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: - filename = str(Path(dir) / f"{self.hash_base64(arr.tobytes())}.png") + filename = str(Path(dir) / f"{self.hash_bytes(arr.tobytes())}.png") pil_image = _Image.fromarray( processing_utils._convert(arr, np.uint8, force_copy=False) ) From 8866de599fa830a1094d907a41d740ece9d2c735 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 17 May 2023 13:37:43 -0400 Subject: [PATCH 3/9] CHANGELOG --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36353b282ab78..5f683b4c2a781 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,9 @@ No changes to highlight. ## Bug Fixes: -No changes to highlight. +- Fixed Gallery/AnnotatedImage components not respecting GRADIO_DEFAULT_DIR variable by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256) +- Fixed Gallery/AnnotatedImage components resaving identical images by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256) +- Fixed Audio/Video/File components creating empty tempfiles on each run by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256) ## Other Changes: From 4134b3970436f30722c277ea26bc419c0b3d21e7 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 17 May 2023 14:51:08 -0400 Subject: [PATCH 4/9] Add tests --- test/test_blocks.py | 55 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/test/test_blocks.py b/test/test_blocks.py index 63b586f33f6c6..01d41a2b37770 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -5,7 +5,9 @@ import os import pathlib import random +import shutil import sys +import tempfile import time import unittest.mock as mock import uuid @@ -15,11 +17,13 @@ from string import capwords from unittest.mock import patch +import gradio_client as grc import pytest import uvicorn import websockets from fastapi.testclient import TestClient from gradio_client import media_data +from PIL import Image import gradio as gr from gradio.events import SelectData @@ -465,6 +469,57 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self): demo.close() +class TestTempFile: + def test_pil_images_hashed(self): + def create_images(n_images): + + a = Image.new("RGB", (512, 512), "red") + b = Image.new("RGB", (512, 512), "green") + c = Image.new("RGB", (512, 512), "blue") + + res = [a, b, c][:n_images] + random.shuffle(res) + return res + + dir_ = tempfile.mkdtemp() + + try: + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": dir_}): + demo = gr.Interface( + create_images, + inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)], + outputs=[gr.Gallery().style(grid=2, preview=True)], + ) + _, url, _ = demo.launch(prevent_thread_lock=True) + client = grc.Client(url) + _ = client.predict(3) + _ = client.predict(3) + # only three files created + assert len(list(pathlib.Path(dir_).glob("**/*"))) == 3 + finally: + demo.close() + shutil.rmtree(dir_) + + def test_no_empty_files(self): + file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + video = str(file_dir / "video_sample.mp4") + dir_ = tempfile.mkdtemp() + + try: + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": dir_}): + demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video()) + _, url, _ = demo.launch(prevent_thread_lock=True) + client = grc.Client(url) + _ = client.predict(video) + _ = client.predict(video) + # During preprocessing we compute the hash based on base64 + # In postprocessing we compute it based on the file + assert len([f for f in pathlib.Path(dir_).glob("**/*") if f.is_file()]) == 2 + finally: + demo.close() + shutil.rmtree(dir_) + + class TestComponentsInBlocks: def test_slider_random_value_config(self): with gr.Blocks() as demo: From 6c864e9da9b50b6d15813cb0be983e9bd7d3e9f6 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 18 May 2023 11:11:41 -0400 Subject: [PATCH 5/9] Update test --- gradio/processing_utils.py | 4 ++++ test/test_processing_utils.py | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 1702df930b4b8..920308f6eddc9 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -2,6 +2,7 @@ import base64 import json +import os import shutil import subprocess import tempfile @@ -532,4 +533,7 @@ def convert_video_to_playable_mp4(video_path: str) -> str: except FFRuntimeError as e: print(f"Error converting video to browser-playable format {str(e)}") output_path = video_path + finally: + # Remove temp_file + os.remove(tmp_file.name) return str(output_path) diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 5530b2a6b76f8..abc2b887e03e0 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -205,9 +205,12 @@ def test_convert_video_to_playable_mp4(self, test_file_dir): shutil.copy( str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name ) - playable_vid = processing_utils.convert_video_to_playable_mp4( - tmp_not_playable_vid.name - ) + with patch("os.remove", wraps=os.remove) as mock_remove: + playable_vid = processing_utils.convert_video_to_playable_mp4( + tmp_not_playable_vid.name + ) + # check tempfile got deleted + assert not Path(mock_remove.call_args[0][0]).exists() assert processing_utils.video_is_playable(playable_vid) @patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception) From 9c1c6e6467afa6e535d2e56127646cd163b19568 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 18 May 2023 17:21:20 -0400 Subject: [PATCH 6/9] Fix remaining components + add tests --- client/python/gradio_client/client.py | 4 +- client/python/test/test_client.py | 14 ++- gradio/components.py | 123 +++++++++++++++++--------- gradio/processing_utils.py | 19 +--- test/conftest.py | 23 +++++ test/test_blocks.py | 105 +++++++++++++++++----- 6 files changed, 198 insertions(+), 90 deletions(-) diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 599e22c6ac72b..9114729d5adba 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -785,8 +785,8 @@ def serialize(self, *data) -> tuple: if t in ["file", "uploadbutton"] ] uploaded_files = self._upload(files) - self._add_uploaded_files_to_data(uploaded_files, list(data)) - + data = list(data) + self._add_uploaded_files_to_data(uploaded_files, data) o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) return o diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 77a582319bec9..28368383ce438 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -252,13 +252,19 @@ def test_upload_file_private_space(self): with patch.object( client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload ) as upload: - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - f.write("Hello from private space!") - - output = client.submit(1, "foo", f.name, api_name="/file_upload").result() + with patch.object( + client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize + ) as serialize: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + f.write("Hello from private space!") + + output = client.submit( + 1, "foo", f.name, api_name="/file_upload" + ).result() with open(output) as f: assert f.read() == "Hello from private space!" upload.assert_called_once() + assert all(f["is_file"] for f in serialize.return_value()) with patch.object( client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload diff --git a/gradio/components.py b/gradio/components.py index f0928e246bec8..af76a36c41195 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -217,14 +217,16 @@ def __init__( if callable(load_fn): self.attach_load_event(load_fn, every) - def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str: + @staticmethod + def hash_file(file_path: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): sha1.update(chunk) return sha1.hexdigest() - def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: + @staticmethod + def hash_url(url: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() remote = urllib.request.urlopen(url) max_file_size = 100 * 1024 * 1024 # 100MB @@ -237,19 +239,20 @@ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: sha1.update(data) return sha1.hexdigest() - def _hash(self, base64_encoding: str | bytes, chunk_num_blocks: int = 128): + @staticmethod + def hash_bytes(bytes: bytes): + sha1 = hashlib.sha1() + sha1.update(bytes) + return sha1.hexdigest() + + @staticmethod + def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] - sha1.update(data.encode("utf-8") if isinstance(data, str) else data) + sha1.update(data.encode("utf-8")) return sha1.hexdigest() - def hash_bytes(self, bytes: bytes, chunk_num_blocks: int = 128): - return self._hash(bytes, chunk_num_blocks) - - def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: - return self._hash(base64_encoding, chunk_num_blocks) - def make_temp_copy_if_needed(self, file_path: str) -> str: """Returns a temporary file path for a copy of the given file path if it does not already exist. Otherwise returns the path to the existing temp file.""" @@ -272,15 +275,14 @@ async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: ) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. temp_dir = Path(upload_dir) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) if file.filename: file_name = Path(file.filename).name - output_file_obj.name = client_utils.strip_invalid_filename_characters( - file_name - ) + name = client_utils.strip_invalid_filename_characters(file_name) + else: + name = f"tmp{secrets.token_hex(5)}" - full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name)) + full_temp_file_path = str(utils.abspath(temp_dir / name)) async with aiofiles.open(full_temp_file_path, "wb") as output_file: while True: @@ -338,19 +340,41 @@ def base64_to_temp_file_if_needed( self.temp_files.add(full_temp_file_path) return full_temp_file_path - def pil_to_temp_file(self, img: _Image.Image, dir: str) -> str: - filename = str(Path(dir) / f"{self.hash_bytes(img.tobytes())}.png") + def pil_to_temp_file(self, img: _Image.Image, dir: str, format=".png") -> str: + temp_dir = Path(dir) / self.hash_bytes(img.tobytes()) + temp_dir.mkdir(exist_ok=True, parents=True) + filename = str(temp_dir / f"image.{format}") img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) return filename - def array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: - filename = str(Path(dir) / f"{self.hash_bytes(arr.tobytes())}.png") + def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: + temp_dir = Path(dir) / self.hash_bytes(arr.tobytes()) + temp_dir.mkdir(exist_ok=True, parents=True) + filename = str(temp_dir / "image.png") pil_image = _Image.fromarray( processing_utils._convert(arr, np.uint8, force_copy=False) ) pil_image.save(filename) return filename + def audio_to_temp_file( + self, data: np.ndarray, sample_rate: int, dir: str, format: str + ): + temp_dir = Path(dir) / self.hash_bytes(data.tobytes()) + temp_dir.mkdir(exist_ok=True, parents=True) + filename = str(temp_dir / f"audio.{format}") + processing_utils.audio_to_file(sample_rate, data, filename, format=format) + return filename + + def file_bytes_to_file(self, data: bytes, dir: str, file_name: str): + path = Path(dir) / self.hash_bytes(data) + path.mkdir(exist_ok=True, parents=True) + path = path / Path(file_name).name + with open(path, "wb") as f: + f.write(data) + f.flush() + return path + def get_config(self): config = { "label": self.label, @@ -1762,6 +1786,9 @@ def update( "__type__": "update", } + ### GPT 4 WEIGHTS !!!! + ### + def _format_image( self, im: _Image.Image | None ) -> np.ndarray | _Image.Image | str | None: @@ -1774,12 +1801,9 @@ def _format_image( elif self.type == "numpy": return np.array(im) elif self.type == "filepath": - file_obj = tempfile.NamedTemporaryFile( - delete=False, - suffix=(f".{fmt.lower()}" if fmt is not None else ".png"), + self.pil_to_temp_file( + im, dir=self.DEFAULT_TEMP_DIR, format=fmt if fmt else ".png" ) - im.save(file_obj.name) - return self.make_temp_copy_if_needed(file_obj.name) else: raise ValueError( "Unknown type: " @@ -2275,8 +2299,7 @@ def srt_to_vtt(srt_file_path, vtt_file_path): # HTML5 only support vtt format if Path(subtitle).suffix == ".srt": temp_file = tempfile.NamedTemporaryFile( - delete=False, - suffix=".vtt", + delete=False, suffix=".vtt", dir=self.DEFAULT_TEMP_DIR ) srt_to_vtt(subtitle, temp_file.name) @@ -2499,7 +2522,9 @@ def tokenize(self, x): # Handle the leave one outs leave_one_out_data = np.copy(data) leave_one_out_data[start:stop] = 0 - file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + file = tempfile.NamedTemporaryFile( + delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR + ) processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name) out_data = client_utils.encode_file_to_base64(file.name) leave_one_out_sets.append(out_data) @@ -2510,7 +2535,9 @@ def tokenize(self, x): token = np.copy(data) token[0:start] = 0 token[stop:] = 0 - file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + file = tempfile.NamedTemporaryFile( + delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR + ) processing_utils.audio_to_file(sample_rate, token, file.name) token_data = client_utils.encode_file_to_base64(file.name) file.close() @@ -2541,7 +2568,7 @@ def get_masked_inputs(self, tokens, binary_mask_matrix): masked_input = np.copy(zero_input) for t, b in zip(token_data, binary_mask_vector): masked_input = masked_input + t * int(b) - file = tempfile.NamedTemporaryFile(delete=False) + file = tempfile.NamedTemporaryFile(delete=False, dir=self.DEFAULT_TEMP_DIR) processing_utils.audio_to_file(sample_rate, masked_input, file.name) masked_data = client_utils.encode_file_to_base64(file.name) file.close() @@ -2562,11 +2589,9 @@ def postprocess(self, y: tuple[int, np.ndarray] | str | None) -> str | dict | No return {"name": y, "data": None, "is_file": True} if isinstance(y, tuple): sample_rate, data = y - file = tempfile.NamedTemporaryFile(suffix=f".{self.format}", delete=False) - processing_utils.audio_to_file( - sample_rate, data, file.name, format=self.format + file_path = self.audio_to_temp_file( + data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format ) - file_path = str(utils.abspath(file.name)) self.temp_files.add(file_path) else: file_path = self.make_temp_copy_if_needed(y) @@ -2736,14 +2761,21 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: ) if self.type == "file": if is_file: - temp_file_path = self.make_temp_copy_if_needed(file_name) - file = tempfile.NamedTemporaryFile(delete=False) - file.name = temp_file_path - file.orig_name = file_name # type: ignore + path = self.make_temp_copy_if_needed(file_name) else: - file = client_utils.decode_base64_to_file(data, file_path=file_name) - file.orig_name = file_name # type: ignore - self.temp_files.add(str(utils.abspath(file.name))) + data, _ = client_utils.decode_base64_to_binary(data) + path = self.file_bytes_to_file( + data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name + ) + path = str(utils.abspath(path)) + self.temp_files.add(path) + + # Creation of tempfiles here + file = tempfile.NamedTemporaryFile( + delete=False, dir=self.DEFAULT_TEMP_DIR + ) + file.name = path + file.orig_name = file_name # type: ignore return file elif ( self.type == "binary" or self.type == "bytes" @@ -2793,13 +2825,14 @@ def postprocess( for file in y ] else: - return { + d = { "orig_name": Path(y).name, "name": self.make_temp_copy_if_needed(y), "size": Path(y).stat().st_size, "data": None, "is_file": True, } + return d def style( self, @@ -3489,7 +3522,9 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: if self.type == "file": if is_file: temp_file_path = self.make_temp_copy_if_needed(file_name) - file = tempfile.NamedTemporaryFile(delete=False) + file = tempfile.NamedTemporaryFile( + delete=False, dir=self.DEFAULT_TEMP_DIR + ) file.name = temp_file_path file.orig_name = file_name # type: ignore else: @@ -4084,7 +4119,7 @@ def postprocess( base_img_path = base_img base_img = np.array(_Image.open(base_img)) elif isinstance(base_img, np.ndarray): - base_file = self.array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) + base_file = self.img_array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) base_img_path = str(utils.abspath(base_file)) elif isinstance(base_img, _Image.Image): base_file = self.pil_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) @@ -4422,7 +4457,7 @@ def postprocess( if isinstance(img, (tuple, list)): img, caption = img if isinstance(img, np.ndarray): - file = self.array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR) + file = self.img_array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR) file_path = str(utils.abspath(file)) self.temp_files.add(file_path) elif isinstance(img, _Image.Image): diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 920308f6eddc9..aaa8f51ffdbdf 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -65,13 +65,6 @@ def encode_plot_to_base64(plt): return "data:image/png;base64," + base64_str -def save_array_to_file(image_array, dir=None): - pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False)) - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj) - return file_obj - - def get_pil_metadata(pil_image): # Copy any text-only metadata metadata = PngImagePlugin.PngInfo() @@ -82,12 +75,6 @@ def get_pil_metadata(pil_image): return metadata -def save_pil_to_file(pil_image, dir=None): - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=get_pil_metadata(pil_image)) - return file_obj - - def encode_pil_to_base64(pil_image): with BytesIO() as output_bytes: pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image)) @@ -520,8 +507,8 @@ def video_is_playable(video_filepath: str) -> bool: def convert_video_to_playable_mp4(video_path: str) -> str: """Convert the video to mp4. If something goes wrong return the original video.""" try: - output_path = Path(video_path).with_suffix(".mp4") with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + output_path = Path(video_path).with_suffix(".mp4") shutil.copy2(video_path, tmp_file.name) # ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4 ff = FFmpeg( @@ -534,6 +521,6 @@ def convert_video_to_playable_mp4(video_path: str) -> str: print(f"Error converting video to browser-playable format {str(e)}") output_path = video_path finally: - # Remove temp_file - os.remove(tmp_file.name) + # Remove temp file + os.remove(tmp_file.name) # type: ignore return str(output_path) diff --git a/test/conftest.py b/test/conftest.py index 247e6e7ad7ab1..c4a86e685dc18 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,7 +1,9 @@ import inspect import pathlib +from contextlib import contextmanager import pytest +from gradio_client import Client import gradio as gr @@ -32,3 +34,24 @@ def io_components(): subclasses.append(subclass) return subclasses + + +@pytest.fixture +def connect(): + @contextmanager + def _connect(demo: gr.Blocks, serialize=True): + _, local_url, _ = demo.launch(prevent_thread_lock=True) + try: + yield Client(local_url, serialize=serialize) + finally: + # A more verbose version of .close() + # because we should set a timeout + # the tests that call .cancel() can get stuck + # waiting for the thread to join + if demo.enable_queue: + demo._queue.close() + demo.is_running = False + demo.server.should_exit = True + demo.server.thread.join(timeout=1) + + return _connect diff --git a/test/test_blocks.py b/test/test_blocks.py index 2f0c0e00feb3e..9bdf3273e6e8b 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -5,9 +5,7 @@ import os import pathlib import random -import shutil import sys -import tempfile import time import unittest.mock as mock import uuid @@ -18,6 +16,7 @@ from unittest.mock import patch import gradio_client as grc +import numpy as np import pytest import uvicorn import websockets @@ -468,9 +467,8 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self): class TestTempFile: - def test_pil_images_hashed(self): + def test_pil_images_hashed(self, tmp_path, connect): def create_images(n_images): - a = Image.new("RGB", (512, 512), "red") b = Image.new("RGB", (512, 512), "green") c = Image.new("RGB", (512, 512), "blue") @@ -479,32 +477,69 @@ def create_images(n_images): random.shuffle(res) return res - dir_ = tempfile.mkdtemp() - - try: - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": dir_}): - demo = gr.Interface( - create_images, - inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)], - outputs=[gr.Gallery().style(grid=2, preview=True)], - ) - _, url, _ = demo.launch(prevent_thread_lock=True) - client = grc.Client(url) + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): + demo = gr.Interface( + create_images, + inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)], + outputs=[gr.Gallery().style(grid=2, preview=True)], + ) + with connect(demo) as client: _ = client.predict(3) _ = client.predict(3) # only three files created - assert len(list(pathlib.Path(dir_).glob("**/*"))) == 3 - finally: - demo.close() - shutil.rmtree(dir_) + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 - def test_no_empty_files(self): + def test_no_empty_image_files(self, tmp_path, connect): + file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + image = str(file_dir / "bus.png") + + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): + demo = gr.Interface( + lambda x: x, + inputs=gr.Image(type="filepath"), + outputs=gr.Image(), + ) + with connect(demo) as client: + _ = client.predict(image) + _ = client.predict(image) + _ = client.predict(image) + # only three files created + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1 + + @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) + def test_file_component_uploads(self, component, tmp_path, connect): + code_file = str(pathlib.Path(__file__)) + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): + demo = gr.Interface(lambda x: x.name, component(), gr.File()) + with connect(demo) as client: + _ = client.predict(code_file) + _ = client.predict(code_file) + # the upload route does not hash the file so 2 files from there + # We create two tempfiles (empty) because API says we return + # preprocess/postprocess will only create one file since we hash + # so 2 + 2 + 1 = 5 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5 + + @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) + def test_file_component_uploads_no_serialize(self, tmp_path, connect): + code_file = str(pathlib.Path(__file__)) + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): + demo = gr.Interface(lambda x: x.name, gr.File(), gr.File()) + with connect(demo, serialize=False) as client: + _ = client.predict(gr.File().serialize(code_file)) + _ = client.predict(gr.File().serialize(code_file)) + # We skip the upload route in this case + # We create two tempfiles (empty) because API says we return + # preprocess/postprocess will only create one file since we hash + # so 2 + 1 = 3 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 + + def test_no_empty_video_files(self, tmp_path): file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") video = str(file_dir / "video_sample.mp4") - dir_ = tempfile.mkdtemp() try: - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": dir_}): + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video()) _, url, _ = demo.launch(prevent_thread_lock=True) client = grc.Client(url) @@ -512,10 +547,32 @@ def test_no_empty_files(self): _ = client.predict(video) # During preprocessing we compute the hash based on base64 # In postprocessing we compute it based on the file - assert len([f for f in pathlib.Path(dir_).glob("**/*") if f.is_file()]) == 2 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 + finally: + demo.close() + + def test_no_empty_audio_files(self, tmp_path): + file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + audio = str(file_dir / "audio_sample.wav") + + def reverse_audio(audio): + sr, data = audio + return (sr, np.flipud(data)) + + try: + with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): + demo = gr.Interface( + fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio() + ) + _, url, _ = demo.launch(prevent_thread_lock=True) + client = grc.Client(url) + _ = client.predict(audio) + client.predict(audio) + # During preprocessing we compute the hash based on base64 + # In postprocessing we compute it based on the file + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 finally: demo.close() - shutil.rmtree(dir_) class TestComponentsInBlocks: From 546d8427ff2b6113d3d4bfb566779edacdc00e17 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 18 May 2023 17:35:26 -0400 Subject: [PATCH 7/9] Fix tests --- gradio/components.py | 21 ++++++++++++--------- test/test_blocks.py | 4 ++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index af76a36c41195..cac6c7e81e639 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3521,16 +3521,19 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: ) if self.type == "file": if is_file: - temp_file_path = self.make_temp_copy_if_needed(file_name) - file = tempfile.NamedTemporaryFile( - delete=False, dir=self.DEFAULT_TEMP_DIR - ) - file.name = temp_file_path - file.orig_name = file_name # type: ignore + path = self.make_temp_copy_if_needed(file_name) else: - file = client_utils.decode_base64_to_file(data, file_path=file_name) - file.orig_name = file_name # type: ignore - self.temp_files.add(str(utils.abspath(file.name))) + data, _ = client_utils.decode_base64_to_binary(data) + path = self.file_bytes_to_file( + data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name + ) + path = str(utils.abspath(path)) + self.temp_files.add(path) + file = tempfile.NamedTemporaryFile( + delete=False, dir=self.DEFAULT_TEMP_DIR + ) + file.name = path + file.orig_name = file_name # type: ignore return file elif self.type == "bytes": if is_file: diff --git a/test/test_blocks.py b/test/test_blocks.py index 9bdf3273e6e8b..d83e85e3e6c44 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -521,10 +521,10 @@ def test_file_component_uploads(self, component, tmp_path, connect): assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5 @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) - def test_file_component_uploads_no_serialize(self, tmp_path, connect): + def test_file_component_uploads_no_serialize(self, component, tmp_path, connect): code_file = str(pathlib.Path(__file__)) with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface(lambda x: x.name, gr.File(), gr.File()) + demo = gr.Interface(lambda x: x.name, component(), gr.File()) with connect(demo, serialize=False) as client: _ = client.predict(gr.File().serialize(code_file)) _ = client.predict(gr.File().serialize(code_file)) From fca28f6ee48d89663f89ed15ab5061177609baff Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 18 May 2023 17:55:36 -0400 Subject: [PATCH 8/9] Fix tests --- gradio/components.py | 6 ++++-- test/test_processing_utils.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 848e22098fbd0..5d3e66724ab13 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -20,7 +20,7 @@ from enum import Enum from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, cast +from typing import TYPE_CHECKING, Any, Callable, Dict import aiofiles import altair as alt @@ -1801,9 +1801,11 @@ def _format_image( elif self.type == "numpy": return np.array(im) elif self.type == "filepath": - self.pil_to_temp_file( + path = self.pil_to_temp_file( im, dir=self.DEFAULT_TEMP_DIR, format=fmt if fmt else ".png" ) + self.temp_files.add(path) + return path else: raise ValueError( "Unknown type: " diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index abc2b887e03e0..4befcb802e5e9 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -11,7 +11,7 @@ from gradio_client import media_data from PIL import Image -from gradio import processing_utils, utils +from gradio import components, processing_utils, utils os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -54,12 +54,12 @@ def test_encode_pil_to_base64(self): output_base64 = processing_utils.encode_pil_to_base64(img) assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE) - def test_save_pil_to_file_keeps_pnginfo(self): + def test_save_pil_to_file_keeps_pnginfo(self, tmp_path): input_img = Image.open("gradio/test_data/test_image.png") input_img = input_img.convert("RGB") input_img.info = {"key1": "value1", "key2": "value2"} - file_obj = processing_utils.save_pil_to_file(input_img) + file_obj = components.Image().pil_to_temp_file(input_img, dir=tmp_path) output_img = Image.open(file_obj) assert output_img.info == input_img.info From d34c69c2967d8d43bc8f07d56f6d4c2487bd6653 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 19 May 2023 12:46:50 -0400 Subject: [PATCH 9/9] Address comments --- gradio/components.py | 20 ++--- gradio/processing_utils.py | 10 ++- test/test_blocks.py | 155 ++++++++++++++++------------------ test/test_processing_utils.py | 35 +++++++- 4 files changed, 120 insertions(+), 100 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 5d3e66724ab13..0b729b2680552 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -340,22 +340,19 @@ def base64_to_temp_file_if_needed( self.temp_files.add(full_temp_file_path) return full_temp_file_path - def pil_to_temp_file(self, img: _Image.Image, dir: str, format=".png") -> str: - temp_dir = Path(dir) / self.hash_bytes(img.tobytes()) + def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str: + bytes_data = processing_utils.encode_pil_to_bytes(img, format) + temp_dir = Path(dir) / self.hash_bytes(bytes_data) temp_dir.mkdir(exist_ok=True, parents=True) filename = str(temp_dir / f"image.{format}") img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) return filename def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: - temp_dir = Path(dir) / self.hash_bytes(arr.tobytes()) - temp_dir.mkdir(exist_ok=True, parents=True) - filename = str(temp_dir / "image.png") pil_image = _Image.fromarray( processing_utils._convert(arr, np.uint8, force_copy=False) ) - pil_image.save(filename) - return filename + return self.pil_to_temp_file(pil_image, dir, format="png") def audio_to_temp_file( self, data: np.ndarray, sample_rate: int, dir: str, format: str @@ -370,9 +367,7 @@ def file_bytes_to_file(self, data: bytes, dir: str, file_name: str): path = Path(dir) / self.hash_bytes(data) path.mkdir(exist_ok=True, parents=True) path = path / Path(file_name).name - with open(path, "wb") as f: - f.write(data) - f.flush() + path.write_bytes(data) return path def get_config(self): @@ -1786,9 +1781,6 @@ def update( "__type__": "update", } - ### GPT 4 WEIGHTS !!!! - ### - def _format_image( self, im: _Image.Image | None ) -> np.ndarray | _Image.Image | str | None: @@ -1802,7 +1794,7 @@ def _format_image( return np.array(im) elif self.type == "filepath": path = self.pil_to_temp_file( - im, dir=self.DEFAULT_TEMP_DIR, format=fmt if fmt else ".png" + im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png" ) self.temp_files.add(path) return path diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index aaa8f51ffdbdf..7789f6b2fee6a 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -75,10 +75,14 @@ def get_pil_metadata(pil_image): return metadata -def encode_pil_to_base64(pil_image): +def encode_pil_to_bytes(pil_image, format="png"): with BytesIO() as output_bytes: - pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image)) - bytes_data = output_bytes.getvalue() + pil_image.save(output_bytes, format, pnginfo=get_pil_metadata(pil_image)) + return output_bytes.getvalue() + + +def encode_pil_to_base64(pil_image): + bytes_data = encode_pil_to_bytes(pil_image) base64_str = str(base64.b64encode(bytes_data), "utf-8") return "data:image/png;base64," + base64_str diff --git a/test/test_blocks.py b/test/test_blocks.py index d83e85e3e6c44..5cae83a4c2016 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -467,91 +467,88 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self): class TestTempFile: - def test_pil_images_hashed(self, tmp_path, connect): - def create_images(n_images): - a = Image.new("RGB", (512, 512), "red") - b = Image.new("RGB", (512, 512), "green") - c = Image.new("RGB", (512, 512), "blue") + def test_pil_images_hashed(self, tmp_path, connect, monkeypatch): + images = [ + Image.new("RGB", (512, 512), color) for color in ("red", "green", "blue") + ] - res = [a, b, c][:n_images] - random.shuffle(res) - return res + def create_images(n_images): + return random.sample(images, n_images) - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface( - create_images, - inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)], - outputs=[gr.Gallery().style(grid=2, preview=True)], - ) - with connect(demo) as client: - _ = client.predict(3) - _ = client.predict(3) - # only three files created - assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface( + create_images, + inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)], + outputs=[gr.Gallery().style(grid=2, preview=True)], + ) + with connect(demo) as client: + _ = client.predict(3) + _ = client.predict(3) + # only three files created + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 - def test_no_empty_image_files(self, tmp_path, connect): + def test_no_empty_image_files(self, tmp_path, connect, monkeypatch): file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") image = str(file_dir / "bus.png") - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface( - lambda x: x, - inputs=gr.Image(type="filepath"), - outputs=gr.Image(), - ) - with connect(demo) as client: - _ = client.predict(image) - _ = client.predict(image) - _ = client.predict(image) - # only three files created - assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1 + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface( + lambda x: x, + inputs=gr.Image(type="filepath"), + outputs=gr.Image(), + ) + with connect(demo) as client: + _ = client.predict(image) + _ = client.predict(image) + _ = client.predict(image) + # only three files created + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1 @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) - def test_file_component_uploads(self, component, tmp_path, connect): + def test_file_component_uploads(self, component, tmp_path, connect, monkeypatch): code_file = str(pathlib.Path(__file__)) - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface(lambda x: x.name, component(), gr.File()) - with connect(demo) as client: - _ = client.predict(code_file) - _ = client.predict(code_file) - # the upload route does not hash the file so 2 files from there - # We create two tempfiles (empty) because API says we return - # preprocess/postprocess will only create one file since we hash - # so 2 + 2 + 1 = 5 - assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5 + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(lambda x: x.name, component(), gr.File()) + with connect(demo) as client: + _ = client.predict(code_file) + _ = client.predict(code_file) + # the upload route does not hash the file so 2 files from there + # We create two tempfiles (empty) because API says we return + # preprocess/postprocess will only create one file since we hash + # so 2 + 2 + 1 = 5 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5 @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) - def test_file_component_uploads_no_serialize(self, component, tmp_path, connect): + def test_file_component_uploads_no_serialize( + self, component, tmp_path, connect, monkeypatch + ): code_file = str(pathlib.Path(__file__)) - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface(lambda x: x.name, component(), gr.File()) - with connect(demo, serialize=False) as client: - _ = client.predict(gr.File().serialize(code_file)) - _ = client.predict(gr.File().serialize(code_file)) - # We skip the upload route in this case - # We create two tempfiles (empty) because API says we return - # preprocess/postprocess will only create one file since we hash - # so 2 + 1 = 3 - assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 - - def test_no_empty_video_files(self, tmp_path): + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(lambda x: x.name, component(), gr.File()) + with connect(demo, serialize=False) as client: + _ = client.predict(gr.File().serialize(code_file)) + _ = client.predict(gr.File().serialize(code_file)) + # We skip the upload route in this case + # We create two tempfiles (empty) because API says we return + # preprocess/postprocess will only create one file since we hash + # so 2 + 1 = 3 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 + + def test_no_empty_video_files(self, tmp_path, monkeypatch, connect): file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") video = str(file_dir / "video_sample.mp4") - - try: - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video()) - _, url, _ = demo.launch(prevent_thread_lock=True) - client = grc.Client(url) - _ = client.predict(video) - _ = client.predict(video) - # During preprocessing we compute the hash based on base64 - # In postprocessing we compute it based on the file - assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 - finally: - demo.close() - - def test_no_empty_audio_files(self, tmp_path): + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video()) + with connect(demo) as client: + _, url, _ = demo.launch(prevent_thread_lock=True) + client = grc.Client(url) + _ = client.predict(video) + _ = client.predict(video) + # During preprocessing we compute the hash based on base64 + # In postprocessing we compute it based on the file + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 + + def test_no_empty_audio_files(self, tmp_path, monkeypatch, connect): file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") audio = str(file_dir / "audio_sample.wav") @@ -559,20 +556,14 @@ def reverse_audio(audio): sr, data = audio return (sr, np.flipud(data)) - try: - with mock.patch.dict(os.environ, {"GRADIO_TEMP_DIR": str(tmp_path)}): - demo = gr.Interface( - fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio() - ) - _, url, _ = demo.launch(prevent_thread_lock=True) - client = grc.Client(url) - _ = client.predict(audio) - client.predict(audio) + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio()) + with connect(demo) as client: + _ = client.predict(audio) + _ = client.predict(audio) # During preprocessing we compute the hash based on base64 # In postprocessing we compute it based on the file assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 - finally: - demo.close() class TestComponentsInBlocks: diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 4befcb802e5e9..e68a7d8ac7af0 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -9,7 +9,7 @@ import numpy as np import pytest from gradio_client import media_data -from PIL import Image +from PIL import Image, ImageCms from gradio import components, processing_utils, utils @@ -64,6 +64,39 @@ def test_save_pil_to_file_keeps_pnginfo(self, tmp_path): assert output_img.info == input_img.info + def test_np_pil_encode_to_the_same(self, tmp_path): + arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8) + pil = Image.fromarray(arr) + comp = components.Image() + assert comp.pil_to_temp_file(pil, dir=tmp_path) == comp.img_array_to_temp_file( + arr, dir=tmp_path + ) + + def test_encode_pil_to_temp_file_metadata_color_profile(self, tmp_path): + # Read image + img = Image.open("gradio/test_data/test_image.png") + img_metadata = Image.open("gradio/test_data/test_image.png") + img_metadata.info = {"key1": "value1", "key2": "value2"} + + # Creating sRGB profile + profile = ImageCms.createProfile("sRGB") + profile2 = ImageCms.ImageCmsProfile(profile) + img.save(tmp_path / "img_color_profile.png", icc_profile=profile2.tobytes()) + img_cp1 = Image.open(str(tmp_path / "img_color_profile.png")) + + # Creating XYZ profile + profile = ImageCms.createProfile("XYZ") + profile2 = ImageCms.ImageCmsProfile(profile) + img.save(tmp_path / "img_color_profile_2.png", icc_profile=profile2.tobytes()) + img_cp2 = Image.open(str(tmp_path / "img_color_profile_2.png")) + + comp = components.Image() + img_path = comp.pil_to_temp_file(img, dir=tmp_path) + img_metadata_path = comp.pil_to_temp_file(img_metadata, dir=tmp_path) + img_cp1_path = comp.pil_to_temp_file(img_cp1, dir=tmp_path) + img_cp2_path = comp.pil_to_temp_file(img_cp2, dir=tmp_path) + assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4 + def test_encode_pil_to_base64_keeps_pnginfo(self): input_img = Image.open("gradio/test_data/test_image.png") input_img = input_img.convert("RGB")