Skip to content

Commit

Permalink
Get Intermediate Results from Python Client (#3694)
Browse files Browse the repository at this point in the history
* Add status + unit test (flaky) for now

* Install client

* Fix tests

* Lint backend + tests

* Add non-queue test

* Fix name

* Use lock instead

* Add simplify implementation + fix tests

* Restore changes to scripts

* Fix README typo

* Fix CI

* Add intermediate results to python client

* Type check

* Typecheck again

* Catch exception:

* Thinking

* Dont read generator from config

* add no queue test

* Remove unused method

* Fix types

* Remove breakpoint

* Fix code

* Fix test

* Fix tests

* Unpack list

* Add docstring

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
freddyaboulton and abidlabs authored Apr 4, 2023
1 parent c4ad09b commit 9325cba
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 28 deletions.
135 changes: 108 additions & 27 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import json
import re
import threading
import time
import uuid
from concurrent.futures import Future
from concurrent.futures import Future, TimeoutError
from datetime import datetime
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
Expand Down Expand Up @@ -76,7 +77,7 @@ def predict(
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | List[Callable] | None = None,
) -> Future:
) -> Job:
"""
Parameters:
*args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app.
Expand All @@ -90,7 +91,9 @@ def predict(

helper = None
if self.endpoints[inferred_fn_index].use_ws:
helper = Communicator(Lock(), JobStatus())
helper = Communicator(
Lock(), JobStatus(), self.endpoints[inferred_fn_index].deserialize
)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)

Expand Down Expand Up @@ -389,19 +392,14 @@ def _inner(*data):
raise utils.InvalidAPIEndpointError()
inputs = self.serialize(*data)
predictions = _predict(*inputs)
outputs = self.deserialize(*predictions)
if (
len(
[
oct
for oct in self.output_component_types
if not oct == utils.STATE_COMPONENT
]
)
== 1
):
return outputs[0]
return outputs
output = self.deserialize(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
with helper.lock:
if not helper.job.outputs:
helper.job.outputs.append(output)
return output

return _inner

Expand Down Expand Up @@ -461,11 +459,11 @@ def serialize(self, *data) -> Tuple:
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
return tuple([s.serialize(d) for s, d in zip(self.serializers, data)])

def deserialize(self, *data) -> Tuple:
def deserialize(self, *data) -> Tuple | Any:
assert len(data) == len(
self.deserializers
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
return tuple(
outputs = tuple(
[
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
for s, d, oct in zip(
Expand All @@ -474,6 +472,20 @@ def deserialize(self, *data) -> Tuple:
if not oct == utils.STATE_COMPONENT
]
)
if (
len(
[
oct
for oct in self.output_component_types
if not oct == utils.STATE_COMPONENT
]
)
== 1
):
output = outputs[0]
else:
output = outputs
return output

def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
inputs = self.dependency["inputs"]
Expand Down Expand Up @@ -529,31 +541,100 @@ def _use_websocket(self, dependency: Dict) -> bool:

async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.client.ws_url, open_timeout=10, extra_headers=self.client.headers
self.client.ws_url,
open_timeout=10,
extra_headers=self.client.headers,
max_size=1024 * 1024 * 1024,
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)


class Job(Future):
"""A Job is a thin wrapper over the Future class that can be cancelled."""

def __init__(self, future: Future, communicator: Communicator | None = None):
def __init__(
self,
future: Future,
communicator: Communicator | None = None,
):
self.future = future
self.communicator = communicator

def status(self) -> StatusUpdate:
def outputs(self) -> List[Tuple | Any]:
"""Returns a list containing the latest outputs from the Job.
If the endpoint has multiple output components, the list will contain
a tuple of results. Otherwise, it will contain the results without storing them
in tuples.
For endpoints that are queued, this list will contain the final job output even
if that endpoint does not use a generator function.
"""
if not self.communicator:
time = datetime.now()
if self.done():
return []
else:
with self.communicator.lock:
return self.communicator.job.outputs

def result(self, timeout=None):
"""Return the result of the call that the future represents.
Args:
timeout: The number of seconds to wait for the result if the future
isn't done. If None, then there is no limit on the wait time.
Returns:
The result of the call that the future represents.
Raises:
CancelledError: If the future was cancelled.
TimeoutError: If the future didn't finish executing before the given
timeout.
Exception: If the call raised then that exception will be raised.
"""
if self.communicator:
timeout = timeout or float("inf")
if self.future._exception: # type: ignore
raise self.future._exception # type: ignore
with self.communicator.lock:
if self.communicator.job.outputs:
return self.communicator.job.outputs[0]
start = datetime.now()
while True:
if (datetime.now() - start).seconds > timeout:
raise TimeoutError()
if self.future._exception: # type: ignore
raise self.future._exception # type: ignore
with self.communicator.lock:
if self.communicator.job.outputs:
return self.communicator.job.outputs[0]
time.sleep(0.01)
else:
return super().result(timeout=timeout)

def status(self) -> StatusUpdate:
time = datetime.now()
if self.done():
if not self.future._exception: # type: ignore
return StatusUpdate(
code=Status.FINISHED,
rank=0,
queue_size=None,
success=None,
success=True,
time=time,
eta=None,
)
else:
return StatusUpdate(
code=Status.FINISHED,
rank=0,
queue_size=None,
success=False,
time=time,
eta=None,
)
else:
if not self.communicator:
return StatusUpdate(
code=Status.PROCESSING,
rank=0,
Expand All @@ -562,9 +643,9 @@ def status(self) -> StatusUpdate:
time=time,
eta=None,
)
else:
with self.communicator.lock:
return self.communicator.job.latest_status
else:
with self.communicator.lock:
return self.communicator.job.latest_status

def __getattr__(self, name):
"""Forwards any properties to the Future class."""
Expand Down
8 changes: 8 additions & 0 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class Communicator:

lock: Lock
job: JobStatus
deserialize: Callable[..., Tuple]


########################
Expand Down Expand Up @@ -166,6 +167,13 @@ async def get_pred_from_ws(
time=datetime.now(),
eta=resp.get("rank_eta"),
)
output = resp.get("output", {}).get("data", [])
if output and status_update.code != Status.FINISHED:
try:
result = helper.deserialize(*output)
except Exception as e:
result = [e]
helper.job.outputs.append(result)
helper.job.latest_status = status_update
if resp["msg"] == "queue_full":
raise QueueError("Queue is full! Please try again.")
Expand Down
42 changes: 41 additions & 1 deletion client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
import time
from concurrent.futures import TimeoutError
from datetime import datetime, timedelta
from unittest.mock import patch

Expand Down Expand Up @@ -81,14 +82,53 @@ def test_job_status(self):
def test_job_status_queue_disabled(self):
statuses = []
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict("I love the gradio python client", fn_index=0)
job = client.predict("I love the gradio python client", api_name="/classify")
while not job.done():
time.sleep(0.02)
statuses.append(job.status())
statuses.append(job.status())
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)

@pytest.mark.flaky
def test_intermediate_outputs(
self,
):
client = Client(src="gradio/count_generator")
job = client.predict(3, api_name="/count")

while not job.done():
time.sleep(0.1)

assert job.outputs() == [str(i) for i in range(3)]

@pytest.mark.flaky
def test_timeout(self):
with pytest.raises(TimeoutError):
client = Client(src="gradio/count_generator")
job = client.predict(api_name="/sleep")
job.result(timeout=0.05)

@pytest.mark.flaky
def test_timeout_no_queue(self):
with pytest.raises(TimeoutError):
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict(api_name="/sleep")
job.result(timeout=0.1)

@pytest.mark.flaky
def test_raises_exception(self):
with pytest.raises(Exception):
client = Client(src="freddyaboulton/calculator")
job = client.predict("foo", "add", 9, fn_index=0)
job.result()

@pytest.mark.flaky
def test_raises_exception_no_queue(self):
with pytest.raises(Exception):
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict([5], api_name="/sleep")
job.result()

def test_job_output_video(self):
client = Client(src="gradio/video_component")
job = client.predict(
Expand Down

0 comments on commit 9325cba

Please sign in to comment.