Skip to content

Commit

Permalink
Fix iterator support for replicate.run()
Browse files Browse the repository at this point in the history
Prior to 1.0.0 `replicate.run()` would return an iterator for cog models
that output a type of `Iterator[Any]`. This would poll the
`predictions.get` endpoint for the in progress prediction and
yield any new output.

When implementing the new file interface we introduced two bugs:

1. The iterator didn't convert URLs returned by the model into
   `FileOutput` types making it inconsistent with the non-iterator
   interface. This is controlled by the `use_file_outputs` argument.
2. The iterator was returned without checking if we are using the new
   blocking API introduced by default and controlled by the `wait`
   argument.

This commit fixes these two issues, consistently applying the
`transform_output` function to the output of the iterator as well
as returning the polling iterator (`prediciton.output_iterator`) if
the blocking API has not successfully returned a completed prediction.

The tests have been updated to exercise both of these code paths.
  • Loading branch information
aron committed Oct 24, 2024
1 parent 23bd903 commit 69e3773
Show file tree
Hide file tree
Showing 2 changed files with 492 additions and 219 deletions.
80 changes: 66 additions & 14 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,27 @@ def run(
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)

if version and (iterator := _make_output_iterator(version, prediction)):
return iterator

# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, the
# prediction will be in a "starting" state.
if not (is_blocking and prediction.status != "starting"):
# Return a "polling" iterator if the model has an output iterator array type.
print("polling why?")
if version and (iterator := _make_output_iterator(client, version, prediction)):
print("return iterator", iterator)
return iterator

prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction)

# Return an iterator for the completed prediction when needed.
if version and (iterator := _make_output_iterator(client, version, prediction)):
print("iterator why?")
return iterator

if use_file_output:
return transform_output(prediction.output, client)

Expand Down Expand Up @@ -108,12 +120,25 @@ async def async_run(
if not version and (owner and name and version_id):
version = await Versions(client, model=(owner, name)).async_get(version_id)

if version and (iterator := _make_async_output_iterator(version, prediction)):
return iterator

# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, the
# prediction will be in a "starting" state.
if not (is_blocking and prediction.status != "starting"):
# Return a "polling" iterator if the model has an output iterator array type.
if version and (
iterator := _make_async_output_iterator(client, version, prediction)
):
return iterator

await prediction.async_wait()

# Return an iterator for completed output if the model has an output iterator array type.
if version and (
iterator := _make_async_output_iterator(client, version, prediction)
):
return iterator

if prediction.status == "failed":
raise ModelError(prediction)

Expand All @@ -134,21 +159,48 @@ def _has_output_iterator_array_type(version: Version) -> bool:


def _make_output_iterator(
version: Version, prediction: Prediction
client: "Client", version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.output_iterator()
if not _has_output_iterator_array_type(version):
return None

if prediction.status == "starting":
iterator = prediction.output_iterator()
elif prediction.output is not None:
iterator = iter(prediction.output)
else:
return None

return None
def _iterate(iter: Iterator[Any]) -> Iterator[Any]:
for chunk in iter:
yield transform_output(chunk, client)

return _iterate(iterator)


def _make_async_output_iterator(
version: Version, prediction: Prediction
client: "Client", version: Version, prediction: Prediction
) -> Optional[AsyncIterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.async_output_iterator()
if not _has_output_iterator_array_type(version):
return None

if prediction.status == "starting":
iterator = prediction.async_output_iterator()
elif prediction.output is not None:

async def _list_to_aiter(lst: list) -> AsyncIterator:
for item in lst:
yield item

iterator = _list_to_aiter(prediction.output)
else:
return None

async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator:
async for chunk in iter:
yield transform_output(chunk, client)

return None
return _transform(iterator)


__all__: List = []
Loading

0 comments on commit 69e3773

Please sign in to comment.