diff --git a/replicate/run.py b/replicate/run.py index 3b6bddb..4457b6d 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -59,15 +59,27 @@ def run( if not version and (owner and name and version_id): version = Versions(client, model=(owner, name)).get(version_id) - if version and (iterator := _make_output_iterator(version, prediction)): - return iterator - + # Currently the "Prefer: wait" interface will return a prediction with a status + # of "processing" rather than a terminal state because it returns before the + # prediction has been fully processed. If request exceeds the wait time, the + # prediction will be in a "starting" state. if not (is_blocking and prediction.status != "starting"): + # Return a "polling" iterator if the model has an output iterator array type. + print("polling why?") + if version and (iterator := _make_output_iterator(client, version, prediction)): + print("return iterator", iterator) + return iterator + prediction.wait() if prediction.status == "failed": raise ModelError(prediction) + # Return an iterator for the completed prediction when needed. + if version and (iterator := _make_output_iterator(client, version, prediction)): + print("iterator why?") + return iterator + if use_file_output: return transform_output(prediction.output, client) @@ -108,12 +120,25 @@ async def async_run( if not version and (owner and name and version_id): version = await Versions(client, model=(owner, name)).async_get(version_id) - if version and (iterator := _make_async_output_iterator(version, prediction)): - return iterator - + # Currently the "Prefer: wait" interface will return a prediction with a status + # of "processing" rather than a terminal state because it returns before the + # prediction has been fully processed. If request exceeds the wait time, the + # prediction will be in a "starting" state. if not (is_blocking and prediction.status != "starting"): + # Return a "polling" iterator if the model has an output iterator array type. + if version and ( + iterator := _make_async_output_iterator(client, version, prediction) + ): + return iterator + await prediction.async_wait() + # Return an iterator for completed output if the model has an output iterator array type. + if version and ( + iterator := _make_async_output_iterator(client, version, prediction) + ): + return iterator + if prediction.status == "failed": raise ModelError(prediction) @@ -134,21 +159,48 @@ def _has_output_iterator_array_type(version: Version) -> bool: def _make_output_iterator( - version: Version, prediction: Prediction + client: "Client", version: Version, prediction: Prediction ) -> Optional[Iterator[Any]]: - if _has_output_iterator_array_type(version): - return prediction.output_iterator() + if not _has_output_iterator_array_type(version): + return None + + if prediction.status == "starting": + iterator = prediction.output_iterator() + elif prediction.output is not None: + iterator = iter(prediction.output) + else: + return None - return None + def _iterate(iter: Iterator[Any]) -> Iterator[Any]: + for chunk in iter: + yield transform_output(chunk, client) + + return _iterate(iterator) def _make_async_output_iterator( - version: Version, prediction: Prediction + client: "Client", version: Version, prediction: Prediction ) -> Optional[AsyncIterator[Any]]: - if _has_output_iterator_array_type(version): - return prediction.async_output_iterator() + if not _has_output_iterator_array_type(version): + return None + + if prediction.status == "starting": + iterator = prediction.async_output_iterator() + elif prediction.output is not None: + + async def _list_to_aiter(lst: list) -> AsyncIterator: + for item in lst: + yield item + + iterator = _list_to_aiter(prediction.output) + else: + return None + + async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator: + async for chunk in iter: + yield transform_output(chunk, client) - return None + return _transform(iterator) __all__: List = [] diff --git a/tests/test_run.py b/tests/test_run.py index 8eac091..0c41bb1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,6 @@ import asyncio import sys -from typing import cast +from typing import AsyncIterator, Iterator, Optional, cast import httpx import pytest @@ -48,6 +48,274 @@ async def test_run(async_flag, record_mode): assert output[0].url.startswith("https://") +@pytest.mark.asyncio +async def test_run_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + wait=False, + ), + ) + + output1 = next(stream) + output2 = next(stream) + with pytest.raises(StopIteration): + next(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + +@pytest.mark.asyncio +async def test_async_run_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[FileOutput], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + wait=False, + ), + ) + + output1 = await anext(stream) + output2 = await anext(stream) + with pytest.raises(StopAsyncIteration): + await anext(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + +@pytest.mark.asyncio +async def test_run_blocking_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + "world!", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output1 = next(stream) + output2 = next(stream) + with pytest.raises(StopIteration): + next(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + +@pytest.mark.asyncio +async def test_async_run_blocking_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + "world!", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[FileOutput], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output1 = await anext(stream) + output2 = await anext(stream) + with pytest.raises(StopAsyncIteration): + await anext(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + @pytest.mark.vcr("run__concurrently.yaml") @pytest.mark.asyncio @pytest.mark.skipif( @@ -104,35 +372,17 @@ async def test_run_with_invalid_token(): @pytest.mark.asyncio async def test_run_version_with_invalid_cog_version(mock_replicate_api_token): - def prediction_with_status(status: str) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": "Hello, world!" if status == "succeeded" else None, - "error": None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status("succeeded"), + json=_prediction_with_status("succeeded", "Hello, world!"), ) ) router.route( @@ -141,37 +391,7 @@ def prediction_with_status(status: str) -> dict: ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2022-03-16T00:35:56.210272Z", - "cog_version": "dev", - "openapi_schema": { - "openapi": "3.0.2", - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": {}, - "components": { - "schemas": { - "Input": { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "The text input", - }, - }, - }, - "Output": { - "type": "string", - "title": "Output", - }, - } - }, - }, - }, + json=_version_with_schema(), ) ) router.route(host="api.replicate.com").pass_through() @@ -193,35 +413,17 @@ def prediction_with_status(status: str) -> dict: @pytest.mark.asyncio async def test_run_with_model_error(mock_replicate_api_token): - def prediction_with_status(status: str) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": None, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status("failed"), + json=_prediction_with_status("failed"), ) ) router.route( @@ -230,14 +432,7 @@ def prediction_with_status(status: str) -> dict: ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(host="api.replicate.com").pass_through() @@ -262,37 +457,17 @@ def prediction_with_status(status: str) -> dict: @pytest.mark.asyncio async def test_run_with_file_output(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "https://api.replicate.com/v1/assets/output.txt" ), ) @@ -303,14 +478,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(method="GET", path="/assets/output.txt").mock( @@ -347,31 +515,11 @@ def prediction_with_status( @pytest.mark.asyncio async def test_run_with_file_output_blocking(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") predictions_create_route = router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status( + json=_prediction_with_status( "processing", "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" ), ) @@ -379,7 +527,7 @@ def prediction_with_status( predictions_get_route = router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "https://api.replicate.com/v1/assets/output.txt" ), ) @@ -387,26 +535,14 @@ def prediction_with_status( router.route( method="GET", path="/models/test/example/versions/v1", - ).mock( - return_value=httpx.Response( - 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, - ) - ) + ).mock(return_value=httpx.Response(201, json=_version_with_schema())) client = Client( api_token="test-token", transport=httpx.MockTransport(router.handler) ) client.poll_interval = 0.001 output = cast( - list[FileOutput], + FileOutput, client.run( "test/example:v1", input={ @@ -434,37 +570,17 @@ def prediction_with_status( @pytest.mark.asyncio async def test_run_with_file_output_array(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", [ "https://api.replicate.com/v1/assets/hello.txt", @@ -479,14 +595,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(method="GET", path="/assets/hello.txt").mock( @@ -521,38 +630,103 @@ def prediction_with_status( @pytest.mark.asyncio -async def test_run_with_file_output_data_uri(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", +async def test_run_with_file_output_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "https://api.replicate.com/v1/assets/hello.txt", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "https://api.replicate.com/v1/assets/hello.txt", + "https://api.replicate.com/v1/assets/world.txt", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + "format": "uri", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + router.route(method="GET", path="/assets/hello.txt").mock( + return_value=httpx.Response(200, content=b"Hello,") + ) + router.route(method="GET", path="/assets/world.txt").mock( + return_value=httpx.Response(200, content=b" world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } + use_file_output=True, + wait=False, + ), + ) + + output1 = next(stream) + output2 = next(stream) + + assert output1.url == "https://api.replicate.com/v1/assets/hello.txt" + assert output2.url == "https://api.replicate.com/v1/assets/world.txt" + + assert output1.read() == b"Hello," + assert output2.read() == b" world!" + +@pytest.mark.asyncio +async def test_run_with_file_output_data_uri(mock_replicate_api_token): router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", ), @@ -564,14 +738,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) @@ -600,3 +767,57 @@ def prediction_with_status( assert await output.aread() == b"Hello, world!" async for chunk in output: assert chunk == b"Hello, world!" + + +def _prediction_with_status(status: str, output: str | list[str] | None = None) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + +def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None): + return { + "id": id, + "created_at": "2022-03-16T00:35:56.210272Z", + "cog_version": "dev", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": {}, + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "The text input", + }, + }, + }, + "Output": output_schema + or { + "type": "string", + "title": "Output", + }, + } + }, + }, + }