diff --git a/CHANGELOG.md b/CHANGELOG.md index 768f733a2e030..495df1f5ed234 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ No changes to highlight. ## Bug Fixes: +- Fixed Gallery/AnnotatedImage components not respecting GRADIO_DEFAULT_DIR variable by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256) +- Fixed Gallery/AnnotatedImage components resaving identical images by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256) +- Fixed Audio/Video/File components creating empty tempfiles on each run by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256) - Fixed the behavior of the `run_on_click` parameter in `gr.Examples` by [@abidlabs](https://github.com/abidlabs) in [PR 4258](https://github.com/gradio-app/gradio/pull/4258). - Ensure js client respcts the full root when making requests to the server by [@pngwn](https://github.com/pngwn) in [PR 4271](https://github.com/gradio-app/gradio/pull/4271) diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 599e22c6ac72b..9114729d5adba 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -785,8 +785,8 @@ def serialize(self, *data) -> tuple: if t in ["file", "uploadbutton"] ] uploaded_files = self._upload(files) - self._add_uploaded_files_to_data(uploaded_files, list(data)) - + data = list(data) + self._add_uploaded_files_to_data(uploaded_files, data) o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) return o diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 77a582319bec9..28368383ce438 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -252,13 +252,19 @@ def test_upload_file_private_space(self): with patch.object( client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload ) as upload: - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - f.write("Hello from private space!") - - output = client.submit(1, "foo", f.name, api_name="/file_upload").result() + with patch.object( + client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize + ) as serialize: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + f.write("Hello from private space!") + + output = client.submit( + 1, "foo", f.name, api_name="/file_upload" + ).result() with open(output) as f: assert f.read() == "Hello from private space!" upload.assert_called_once() + assert all(f["is_file"] for f in serialize.return_value()) with patch.object( client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload diff --git a/gradio/components.py b/gradio/components.py index 2f9f03553baf9..0b729b2680552 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -20,7 +20,7 @@ from enum import Enum from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, cast +from typing import TYPE_CHECKING, Any, Callable, Dict import aiofiles import altair as alt @@ -217,14 +217,16 @@ def __init__( if callable(load_fn): self.attach_load_event(load_fn, every) - def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str: + @staticmethod + def hash_file(file_path: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): sha1.update(chunk) return sha1.hexdigest() - def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: + @staticmethod + def hash_url(url: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() remote = urllib.request.urlopen(url) max_file_size = 100 * 1024 * 1024 # 100MB @@ -237,7 +239,14 @@ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: sha1.update(data) return sha1.hexdigest() - def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: + @staticmethod + def hash_bytes(bytes: bytes): + sha1 = hashlib.sha1() + sha1.update(bytes) + return sha1.hexdigest() + + @staticmethod + def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] @@ -251,9 +260,8 @@ def make_temp_copy_if_needed(self, file_path: str) -> str: temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = client_utils.strip_invalid_filename_characters(Path(file_path).name) - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + name = client_utils.strip_invalid_filename_characters(Path(file_path).name) + full_temp_file_path = str(utils.abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): shutil.copy2(file_path, full_temp_file_path) @@ -267,15 +275,14 @@ async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: ) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. temp_dir = Path(upload_dir) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) if file.filename: file_name = Path(file.filename).name - output_file_obj.name = client_utils.strip_invalid_filename_characters( - file_name - ) + name = client_utils.strip_invalid_filename_characters(file_name) + else: + name = f"tmp{secrets.token_hex(5)}" - full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name)) + full_temp_file_path = str(utils.abspath(temp_dir / name)) async with aiofiles.open(full_temp_file_path, "wb") as output_file: while True: @@ -292,10 +299,9 @@ def download_temp_copy_if_needed(self, url: str) -> str: temp_dir = self.hash_url(url) temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = client_utils.strip_invalid_filename_characters(Path(url).name) - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + name = client_utils.strip_invalid_filename_characters(Path(url).name) + full_temp_file_path = str(utils.abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): with requests.get(url, stream=True) as r, open( @@ -323,8 +329,7 @@ def base64_to_temp_file_if_needed( file_name = f"file.{guess_extension}" else: file_name = "file" - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = file_name # type: ignore + full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore if not Path(full_temp_file_path).exists(): @@ -335,6 +340,36 @@ def base64_to_temp_file_if_needed( self.temp_files.add(full_temp_file_path) return full_temp_file_path + def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str: + bytes_data = processing_utils.encode_pil_to_bytes(img, format) + temp_dir = Path(dir) / self.hash_bytes(bytes_data) + temp_dir.mkdir(exist_ok=True, parents=True) + filename = str(temp_dir / f"image.{format}") + img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) + return filename + + def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: + pil_image = _Image.fromarray( + processing_utils._convert(arr, np.uint8, force_copy=False) + ) + return self.pil_to_temp_file(pil_image, dir, format="png") + + def audio_to_temp_file( + self, data: np.ndarray, sample_rate: int, dir: str, format: str + ): + temp_dir = Path(dir) / self.hash_bytes(data.tobytes()) + temp_dir.mkdir(exist_ok=True, parents=True) + filename = str(temp_dir / f"audio.{format}") + processing_utils.audio_to_file(sample_rate, data, filename, format=format) + return filename + + def file_bytes_to_file(self, data: bytes, dir: str, file_name: str): + path = Path(dir) / self.hash_bytes(data) + path.mkdir(exist_ok=True, parents=True) + path = path / Path(file_name).name + path.write_bytes(data) + return path + def get_config(self): config = { "label": self.label, @@ -1758,12 +1793,11 @@ def _format_image( elif self.type == "numpy": return np.array(im) elif self.type == "filepath": - file_obj = tempfile.NamedTemporaryFile( - delete=False, - suffix=(f".{fmt.lower()}" if fmt is not None else ".png"), + path = self.pil_to_temp_file( + im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png" ) - im.save(file_obj.name) - return self.make_temp_copy_if_needed(file_obj.name) + self.temp_files.add(path) + return path else: raise ValueError( "Unknown type: " @@ -2259,8 +2293,7 @@ def srt_to_vtt(srt_file_path, vtt_file_path): # HTML5 only support vtt format if Path(subtitle).suffix == ".srt": temp_file = tempfile.NamedTemporaryFile( - delete=False, - suffix=".vtt", + delete=False, suffix=".vtt", dir=self.DEFAULT_TEMP_DIR ) srt_to_vtt(subtitle, temp_file.name) @@ -2483,7 +2516,9 @@ def tokenize(self, x): # Handle the leave one outs leave_one_out_data = np.copy(data) leave_one_out_data[start:stop] = 0 - file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + file = tempfile.NamedTemporaryFile( + delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR + ) processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name) out_data = client_utils.encode_file_to_base64(file.name) leave_one_out_sets.append(out_data) @@ -2494,7 +2529,9 @@ def tokenize(self, x): token = np.copy(data) token[0:start] = 0 token[stop:] = 0 - file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + file = tempfile.NamedTemporaryFile( + delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR + ) processing_utils.audio_to_file(sample_rate, token, file.name) token_data = client_utils.encode_file_to_base64(file.name) file.close() @@ -2525,7 +2562,7 @@ def get_masked_inputs(self, tokens, binary_mask_matrix): masked_input = np.copy(zero_input) for t, b in zip(token_data, binary_mask_vector): masked_input = masked_input + t * int(b) - file = tempfile.NamedTemporaryFile(delete=False) + file = tempfile.NamedTemporaryFile(delete=False, dir=self.DEFAULT_TEMP_DIR) processing_utils.audio_to_file(sample_rate, masked_input, file.name) masked_data = client_utils.encode_file_to_base64(file.name) file.close() @@ -2546,11 +2583,9 @@ def postprocess(self, y: tuple[int, np.ndarray] | str | None) -> str | dict | No return {"name": y, "data": None, "is_file": True} if isinstance(y, tuple): sample_rate, data = y - file = tempfile.NamedTemporaryFile(suffix=f".{self.format}", delete=False) - processing_utils.audio_to_file( - sample_rate, data, file.name, format=self.format + file_path = self.audio_to_temp_file( + data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format ) - file_path = str(utils.abspath(file.name)) self.temp_files.add(file_path) else: file_path = self.make_temp_copy_if_needed(y) @@ -2720,14 +2755,21 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: ) if self.type == "file": if is_file: - temp_file_path = self.make_temp_copy_if_needed(file_name) - file = tempfile.NamedTemporaryFile(delete=False) - file.name = temp_file_path - file.orig_name = file_name # type: ignore + path = self.make_temp_copy_if_needed(file_name) else: - file = client_utils.decode_base64_to_file(data, file_path=file_name) - file.orig_name = file_name # type: ignore - self.temp_files.add(str(utils.abspath(file.name))) + data, _ = client_utils.decode_base64_to_binary(data) + path = self.file_bytes_to_file( + data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name + ) + path = str(utils.abspath(path)) + self.temp_files.add(path) + + # Creation of tempfiles here + file = tempfile.NamedTemporaryFile( + delete=False, dir=self.DEFAULT_TEMP_DIR + ) + file.name = path + file.orig_name = file_name # type: ignore return file elif ( self.type == "binary" or self.type == "bytes" @@ -2777,13 +2819,14 @@ def postprocess( for file in y ] else: - return { + d = { "orig_name": Path(y).name, "name": self.make_temp_copy_if_needed(y), "size": Path(y).stat().st_size, "data": None, "is_file": True, } + return d def style( self, @@ -3472,14 +3515,19 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: ) if self.type == "file": if is_file: - temp_file_path = self.make_temp_copy_if_needed(file_name) - file = tempfile.NamedTemporaryFile(delete=False) - file.name = temp_file_path - file.orig_name = file_name # type: ignore + path = self.make_temp_copy_if_needed(file_name) else: - file = client_utils.decode_base64_to_file(data, file_path=file_name) - file.orig_name = file_name # type: ignore - self.temp_files.add(str(utils.abspath(file.name))) + data, _ = client_utils.decode_base64_to_binary(data) + path = self.file_bytes_to_file( + data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name + ) + path = str(utils.abspath(path)) + self.temp_files.add(path) + file = tempfile.NamedTemporaryFile( + delete=False, dir=self.DEFAULT_TEMP_DIR + ) + file.name = path + file.orig_name = file_name # type: ignore return file elif self.type == "bytes": if is_file: @@ -4068,11 +4116,11 @@ def postprocess( base_img_path = base_img base_img = np.array(_Image.open(base_img)) elif isinstance(base_img, np.ndarray): - base_file = processing_utils.save_array_to_file(base_img) - base_img_path = str(utils.abspath(base_file.name)) + base_file = self.img_array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) + base_img_path = str(utils.abspath(base_file)) elif isinstance(base_img, _Image.Image): - base_file = processing_utils.save_pil_to_file(base_img) - base_img_path = str(utils.abspath(base_file.name)) + base_file = self.pil_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR) + base_img_path = str(utils.abspath(base_file)) base_img = np.array(base_img) else: raise ValueError( @@ -4116,8 +4164,10 @@ def hex_to_rgb(value): colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8)) - mask_file = processing_utils.save_pil_to_file(colored_mask_img) - mask_file_path = str(utils.abspath(mask_file.name)) + mask_file = self.pil_to_temp_file( + colored_mask_img, dir=self.DEFAULT_TEMP_DIR + ) + mask_file_path = str(utils.abspath(mask_file)) self.temp_files.add(mask_file_path) sections.append( @@ -4404,12 +4454,12 @@ def postprocess( if isinstance(img, (tuple, list)): img, caption = img if isinstance(img, np.ndarray): - file = processing_utils.save_array_to_file(img) - file_path = str(utils.abspath(file.name)) + file = self.img_array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR) + file_path = str(utils.abspath(file)) self.temp_files.add(file_path) elif isinstance(img, _Image.Image): - file = processing_utils.save_pil_to_file(img) - file_path = str(utils.abspath(file.name)) + file = self.pil_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR) + file_path = str(utils.abspath(file)) self.temp_files.add(file_path) elif isinstance(img, str): if utils.validate_url(img): diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 1702df930b4b8..7789f6b2fee6a 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -2,6 +2,7 @@ import base64 import json +import os import shutil import subprocess import tempfile @@ -64,13 +65,6 @@ def encode_plot_to_base64(plt): return "data:image/png;base64," + base64_str -def save_array_to_file(image_array, dir=None): - pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False)) - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj) - return file_obj - - def get_pil_metadata(pil_image): # Copy any text-only metadata metadata = PngImagePlugin.PngInfo() @@ -81,16 +75,14 @@ def get_pil_metadata(pil_image): return metadata -def save_pil_to_file(pil_image, dir=None): - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=get_pil_metadata(pil_image)) - return file_obj +def encode_pil_to_bytes(pil_image, format="png"): + with BytesIO() as output_bytes: + pil_image.save(output_bytes, format, pnginfo=get_pil_metadata(pil_image)) + return output_bytes.getvalue() def encode_pil_to_base64(pil_image): - with BytesIO() as output_bytes: - pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image)) - bytes_data = output_bytes.getvalue() + bytes_data = encode_pil_to_bytes(pil_image) base64_str = str(base64.b64encode(bytes_data), "utf-8") return "data:image/png;base64," + base64_str @@ -519,8 +511,8 @@ def video_is_playable(video_filepath: str) -> bool: def convert_video_to_playable_mp4(video_path: str) -> str: """Convert the video to mp4. If something goes wrong return the original video.""" try: - output_path = Path(video_path).with_suffix(".mp4") with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + output_path = Path(video_path).with_suffix(".mp4") shutil.copy2(video_path, tmp_file.name) # ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4 ff = FFmpeg( @@ -532,4 +524,7 @@ def convert_video_to_playable_mp4(video_path: str) -> str: except FFRuntimeError as e: print(f"Error converting video to browser-playable format {str(e)}") output_path = video_path + finally: + # Remove temp file + os.remove(tmp_file.name) # type: ignore return str(output_path) diff --git a/test/conftest.py b/test/conftest.py index 247e6e7ad7ab1..c4a86e685dc18 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,7 +1,9 @@ import inspect import pathlib +from contextlib import contextmanager import pytest +from gradio_client import Client import gradio as gr @@ -32,3 +34,24 @@ def io_components(): subclasses.append(subclass) return subclasses + + +@pytest.fixture +def connect(): + @contextmanager + def _connect(demo: gr.Blocks, serialize=True): + _, local_url, _ = demo.launch(prevent_thread_lock=True) + try: + yield Client(local_url, serialize=serialize) + finally: + # A more verbose version of .close() + # because we should set a timeout + # the tests that call .cancel() can get stuck + # waiting for the thread to join + if demo.enable_queue: + demo._queue.close() + demo.is_running = False + demo.server.should_exit = True + demo.server.thread.join(timeout=1) + + return _connect diff --git a/test/test_blocks.py b/test/test_blocks.py index 76b237200bdac..5cae83a4c2016 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -15,11 +15,14 @@ from string import capwords from unittest.mock import patch +import gradio_client as grc +import numpy as np import pytest import uvicorn import websockets from fastapi.testclient import TestClient from gradio_client import media_data +from PIL import Image import gradio as gr from gradio.events import SelectData @@ -463,6 +466,106 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self): demo.close() +class TestTempFile: + def test_pil_images_hashed(self, tmp_path, connect, monkeypatch): + images = [ + Image.new("RGB", (512, 512), color) for color in ("red", "green", "blue") + ] + + def create_images(n_images): + return random.sample(images, n_images) + + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface( + create_images, + inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)], + outputs=[gr.Gallery().style(grid=2, preview=True)], + ) + with connect(demo) as client: + _ = client.predict(3) + _ = client.predict(3) + # only three files created + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 + + def test_no_empty_image_files(self, tmp_path, connect, monkeypatch): + file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + image = str(file_dir / "bus.png") + + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface( + lambda x: x, + inputs=gr.Image(type="filepath"), + outputs=gr.Image(), + ) + with connect(demo) as client: + _ = client.predict(image) + _ = client.predict(image) + _ = client.predict(image) + # only three files created + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1 + + @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) + def test_file_component_uploads(self, component, tmp_path, connect, monkeypatch): + code_file = str(pathlib.Path(__file__)) + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(lambda x: x.name, component(), gr.File()) + with connect(demo) as client: + _ = client.predict(code_file) + _ = client.predict(code_file) + # the upload route does not hash the file so 2 files from there + # We create two tempfiles (empty) because API says we return + # preprocess/postprocess will only create one file since we hash + # so 2 + 2 + 1 = 5 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5 + + @pytest.mark.parametrize("component", [gr.UploadButton, gr.File]) + def test_file_component_uploads_no_serialize( + self, component, tmp_path, connect, monkeypatch + ): + code_file = str(pathlib.Path(__file__)) + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(lambda x: x.name, component(), gr.File()) + with connect(demo, serialize=False) as client: + _ = client.predict(gr.File().serialize(code_file)) + _ = client.predict(gr.File().serialize(code_file)) + # We skip the upload route in this case + # We create two tempfiles (empty) because API says we return + # preprocess/postprocess will only create one file since we hash + # so 2 + 1 = 3 + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3 + + def test_no_empty_video_files(self, tmp_path, monkeypatch, connect): + file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + video = str(file_dir / "video_sample.mp4") + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video()) + with connect(demo) as client: + _, url, _ = demo.launch(prevent_thread_lock=True) + client = grc.Client(url) + _ = client.predict(video) + _ = client.predict(video) + # During preprocessing we compute the hash based on base64 + # In postprocessing we compute it based on the file + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 + + def test_no_empty_audio_files(self, tmp_path, monkeypatch, connect): + file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + audio = str(file_dir / "audio_sample.wav") + + def reverse_audio(audio): + sr, data = audio + return (sr, np.flipud(data)) + + monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) + demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio()) + with connect(demo) as client: + _ = client.predict(audio) + _ = client.predict(audio) + # During preprocessing we compute the hash based on base64 + # In postprocessing we compute it based on the file + assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2 + + class TestComponentsInBlocks: def test_slider_random_value_config(self): with gr.Blocks() as demo: diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 5530b2a6b76f8..e68a7d8ac7af0 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -9,9 +9,9 @@ import numpy as np import pytest from gradio_client import media_data -from PIL import Image +from PIL import Image, ImageCms -from gradio import processing_utils, utils +from gradio import components, processing_utils, utils os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -54,16 +54,49 @@ def test_encode_pil_to_base64(self): output_base64 = processing_utils.encode_pil_to_base64(img) assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE) - def test_save_pil_to_file_keeps_pnginfo(self): + def test_save_pil_to_file_keeps_pnginfo(self, tmp_path): input_img = Image.open("gradio/test_data/test_image.png") input_img = input_img.convert("RGB") input_img.info = {"key1": "value1", "key2": "value2"} - file_obj = processing_utils.save_pil_to_file(input_img) + file_obj = components.Image().pil_to_temp_file(input_img, dir=tmp_path) output_img = Image.open(file_obj) assert output_img.info == input_img.info + def test_np_pil_encode_to_the_same(self, tmp_path): + arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8) + pil = Image.fromarray(arr) + comp = components.Image() + assert comp.pil_to_temp_file(pil, dir=tmp_path) == comp.img_array_to_temp_file( + arr, dir=tmp_path + ) + + def test_encode_pil_to_temp_file_metadata_color_profile(self, tmp_path): + # Read image + img = Image.open("gradio/test_data/test_image.png") + img_metadata = Image.open("gradio/test_data/test_image.png") + img_metadata.info = {"key1": "value1", "key2": "value2"} + + # Creating sRGB profile + profile = ImageCms.createProfile("sRGB") + profile2 = ImageCms.ImageCmsProfile(profile) + img.save(tmp_path / "img_color_profile.png", icc_profile=profile2.tobytes()) + img_cp1 = Image.open(str(tmp_path / "img_color_profile.png")) + + # Creating XYZ profile + profile = ImageCms.createProfile("XYZ") + profile2 = ImageCms.ImageCmsProfile(profile) + img.save(tmp_path / "img_color_profile_2.png", icc_profile=profile2.tobytes()) + img_cp2 = Image.open(str(tmp_path / "img_color_profile_2.png")) + + comp = components.Image() + img_path = comp.pil_to_temp_file(img, dir=tmp_path) + img_metadata_path = comp.pil_to_temp_file(img_metadata, dir=tmp_path) + img_cp1_path = comp.pil_to_temp_file(img_cp1, dir=tmp_path) + img_cp2_path = comp.pil_to_temp_file(img_cp2, dir=tmp_path) + assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4 + def test_encode_pil_to_base64_keeps_pnginfo(self): input_img = Image.open("gradio/test_data/test_image.png") input_img = input_img.convert("RGB") @@ -205,9 +238,12 @@ def test_convert_video_to_playable_mp4(self, test_file_dir): shutil.copy( str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name ) - playable_vid = processing_utils.convert_video_to_playable_mp4( - tmp_not_playable_vid.name - ) + with patch("os.remove", wraps=os.remove) as mock_remove: + playable_vid = processing_utils.convert_video_to_playable_mp4( + tmp_not_playable_vid.name + ) + # check tempfile got deleted + assert not Path(mock_remove.call_args[0][0]).exists() assert processing_utils.video_is_playable(playable_vid) @patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)