Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix iterator support for replicate.run() #383

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 54 additions & 23 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from replicate.exceptions import ModelError
from replicate.helpers import transform_output
from replicate.model import Model
from replicate.prediction import Prediction
from replicate.schema import make_schema_backwards_compatible
from replicate.version import Version, Versions

Expand Down Expand Up @@ -59,15 +58,36 @@ def run(
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)

if version and (iterator := _make_output_iterator(version, prediction)):
return iterator
# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, even if
# it is actually processing, the prediction will be in a "starting" state.
#
# We should fix this in the blocking API itself. Predictions that are done should
# be in a terminal state and predictions that are processing should be in state
# "processing".
in_terminal_state = is_blocking and prediction.status != "starting"
if not in_terminal_state:
# Return a "polling" iterator if the model has an output iterator array type.
if version and _has_output_iterator_array_type(version):
return (
transform_output(chunk, client)
for chunk in prediction.output_iterator()
)

if not (is_blocking and prediction.status != "starting"):
prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction)

# Return an iterator for the completed prediction when needed.
if (
version
and _has_output_iterator_array_type(version)
and prediction.output is not None
):
return (transform_output(chunk, client) for chunk in prediction.output)

if use_file_output:
return transform_output(prediction.output, client)

Expand Down Expand Up @@ -108,15 +128,39 @@ async def async_run(
if not version and (owner and name and version_id):
version = await Versions(client, model=(owner, name)).async_get(version_id)

if version and (iterator := _make_async_output_iterator(version, prediction)):
return iterator
# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, even if
# it is actually processing, the prediction will be in a "starting" state.
#
# We should fix this in the blocking API itself. Predictions that are done should
# be in a terminal state and predictions that are processing should be in state
# "processing".
in_terminal_state = is_blocking and prediction.status != "starting"
if not in_terminal_state:
# Return a "polling" iterator if the model has an output iterator array type.
if version and _has_output_iterator_array_type(version):
return (
transform_output(chunk, client)
async for chunk in prediction.async_output_iterator()
)

if not (is_blocking and prediction.status != "starting"):
await prediction.async_wait()

if prediction.status == "failed":
raise ModelError(prediction)

# Return an iterator for completed output if the model has an output iterator array type.
if (
version
and _has_output_iterator_array_type(version)
and prediction.output is not None
):
return (
transform_output(chunk, client)
async for chunk in _make_async_iterator(prediction.output)
)

if use_file_output:
return transform_output(prediction.output, client)

Expand All @@ -133,22 +177,9 @@ def _has_output_iterator_array_type(version: Version) -> bool:
)


def _make_output_iterator(
version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.output_iterator()

return None


def _make_async_output_iterator(
version: Version, prediction: Prediction
) -> Optional[AsyncIterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.async_output_iterator()

return None
async def _make_async_iterator(list: list) -> AsyncIterator:
for item in list:
yield item


__all__: List = []
Loading