Skip to content

Commit

Permalink
Fix bug with wait parameter in replicate.run() (#363)
Browse files Browse the repository at this point in the history
This PR fixes an issue with `replicate.run()` where it would fallback to
polling irregardless of the `wait` parameter. We now skip the waiting if
it looks like we have a prediction with output back.

There are a whole bunch of issues with the `vcr` library we're using at
the moment. I've added tests and have working fixtures but I don't 
think the files are in an ideal state.

Some known issues:

 1. The `vcr` recording mode doesn't work with asyncio, so you need to
    generate the fixtures using the sync api. Commenting out all of the
    `async` & `await` syntax. Then the test will run just fine without.
 2. The path argument passed to `vcr()` decorator is not respected while
     recording. Apparently it does work if you rename the generated file, but
    I've not tried this yet.
  • Loading branch information
aron authored Oct 4, 2024
1 parent 5458c51 commit 08ee31a
Show file tree
Hide file tree
Showing 14 changed files with 3,585 additions and 172 deletions.
5 changes: 5 additions & 0 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,14 @@ def create(
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
"POST",
url,
json=body,
headers=headers,
)

return _json_to_prediction(self._client, resp.json())
Expand All @@ -522,12 +524,15 @@ async def async_create(
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)

headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
"POST",
url,
json=body,
headers=headers,
)

return _json_to_prediction(self._client, resp.json())
Expand Down
4 changes: 3 additions & 1 deletion replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run(
Run a model and wait for its output.
"""

is_blocking = "wait" in params
version, owner, name, version_id = identifier._resolve(ref)

if version_id is not None:
Expand All @@ -57,7 +58,8 @@ def run(
if version and (iterator := _make_output_iterator(version, prediction)):
return iterator

prediction.wait()
if not (is_blocking and prediction.status != "starting"):
prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
interactions:
- request:
body: '{"input": {"prompt": "Please write a haiku about llamas"}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '58'
content-type:
- application/json
host:
- api.replicate.com
prefer:
- wait=10
user-agent:
- replicate-python/0.32.1
method: POST
uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions
response:
body:
string: '{"id":"pw050dtb51rj40cjb2vrcw97b8","model":"meta/meta-llama-3-8b-instruct","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"Please
write a haiku about llamas"},"logs":"","output":["\n\n","Here is"," a ha","iku
about"," llamas",":\n\nF","uzzy,"," gentle eyes","\nSoft","ly munch","ing
on"," the"," grass\n","Peaceful",", quiet"," soul",""],"data_removed":false,"error":null,"status":"processing","created_at":"2024-10-04T18:07:40.328Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/pw050dtb51rj40cjb2vrcw97b8/cancel","get":"https://api.replicate.com/v1/predictions/pw050dtb51rj40cjb2vrcw97b8","stream":"https://streaming-api.svc.rno2.c.replicate.net/v1/streams/3qkvm3bnqhaadgonsecw4twhnzlqi5s4sb5tgts72bo6c5wmjppq"}}'
headers:
CF-Cache-Status:
- DYNAMIC
CF-Ray:
- 8cd71ce48ee406a9-SJC
Connection:
- keep-alive
Content-Length:
- '746'
Content-Type:
- application/json; charset=UTF-8
Date:
- Fri, 04 Oct 2024 18:07:42 GMT
NEL:
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
Preference-Applied:
- wait=10
Report-To:
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=4YyDOuNpYpRJEACcgObeEJEYYwpW6CTpVfvwCyw1qUoBMXc1nsyw8ZGCm5zxwJLpOa44BG7xB2wdbrPosNEdqNt1TqHQpKWhFj2%2FF3MlHZ96Bhrhx17KcPCC7QfbchoPUFo5"}],"group":"cf-nel","max_age":604800}'
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
Vary:
- Accept-Encoding
alt-svc:
- h3=":443"; ma=86400
ratelimit-remaining:
- '599'
ratelimit-reset:
- '1'
replicate-edge-cluster:
- services-aws-us-west-2
replicate-target-cluster:
- coreweave-rno2
status:
code: 201
message: Created
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
interactions:
- request:
body: '{"input": {"prompt": "Please write a haiku about llamas"}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '58'
content-type:
- application/json
host:
- api.replicate.com
prefer:
- wait
user-agent:
- replicate-python/0.32.1
method: POST
uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions
response:
body:
string: '{"id":"7vr9my1z3nrj40cjb2vv9pshw0","model":"meta/meta-llama-3-8b-instruct","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"Please
write a haiku about llamas"},"logs":"","output":["\n\n","Fuzzy",", gentle","
soul\n","Llama","''s soft"," eyes meet"," my"," gaze","\nPeace","ful,"," gentle","
friend",""],"data_removed":false,"error":null,"status":"processing","created_at":"2024-10-04T18:07:37.245Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/7vr9my1z3nrj40cjb2vv9pshw0/cancel","get":"https://api.replicate.com/v1/predictions/7vr9my1z3nrj40cjb2vv9pshw0","stream":"https://streaming-api.svc.rno2.c.replicate.net/v1/streams/yr4xaj6trsqlyhoxrxg3fcl2c2vifsatl7paqk2ltc7tbtrvfb2q"}}'
headers:
CF-Cache-Status:
- DYNAMIC
CF-Ray:
- 8cd71cd159ba06a9-SJC
Connection:
- keep-alive
Content-Length:
- '709'
Content-Type:
- application/json; charset=UTF-8
Date:
- Fri, 04 Oct 2024 18:07:40 GMT
NEL:
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
Preference-Applied:
- wait
Report-To:
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=71TSJZwByfLF51QiGNpTc3HqXXGR%2BTyDJViyA75xfguBlFdel4vpaOp2YGz0lZTndn7Wu8P%2BEMLi%2BjsvNg7QCQxblOezZ9NgGMT%2Fpqxh1gvACXTCOWlzHgg2XPlRxe1WJAyN"}],"group":"cf-nel","max_age":604800}'
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
Vary:
- Accept-Encoding
alt-svc:
- h3=":443"; ma=86400
ratelimit-remaining:
- '599'
ratelimit-reset:
- '1'
replicate-edge-cluster:
- services-aws-us-west-2
replicate-target-cluster:
- coreweave-rno2
status:
code: 201
message: Created
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
interactions:
- request:
body: '{"input": {"prompt": "Please write a haiku about llamas"}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '58'
content-type:
- application/json
host:
- api.replicate.com
prefer:
- wait=10
user-agent:
- replicate-python/0.32.1
method: POST
uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions
response:
body:
string: '{"id":"kr2d2jhqbsrj60cjb2vtmekgh4","model":"meta/meta-llama-3-8b-instruct","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"Please
write a haiku about llamas"},"logs":"","output":["\n\n","Here is"," a ha","iku
about"," llamas",":\n\nF","uzzy,"," gentle eyes","\nL","lama''s"," soft humming","
fills air","\n","Peaceful"," Andean"," charm",""],"data_removed":false,"error":null,"status":"processing","created_at":"2024-10-04T18:07:35.262Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/kr2d2jhqbsrj60cjb2vtmekgh4/cancel","get":"https://api.replicate.com/v1/predictions/kr2d2jhqbsrj60cjb2vtmekgh4","stream":"https://streaming-api.svc.rno2.c.replicate.net/v1/streams/xywaxc2hfosderaab2cbqucgfnlhm5u46ncmwsmqgr3bfhqllmxq"}}'
headers:
CF-Cache-Status:
- DYNAMIC
CF-Ray:
- 8cd71cc4fc3906a9-SJC
Connection:
- keep-alive
Content-Length:
- '749'
Content-Type:
- application/json; charset=UTF-8
Date:
- Fri, 04 Oct 2024 18:07:37 GMT
NEL:
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
Preference-Applied:
- wait=10
Report-To:
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=7BfLUzx88hdyJfG2uImr63gZJVfWvTlkK1HOtlLZVXgPj6XGa9QGGD1TziR3NKOexuxTAyPJctSSCeMBNfASez%2FVmqNYZ48sTT6ST2mjJsGTnbN6E39fykqUN31UNTeinsn4"}],"group":"cf-nel","max_age":604800}'
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
Vary:
- Accept-Encoding
alt-svc:
- h3=":443"; ma=86400
ratelimit-remaining:
- '599'
ratelimit-reset:
- '1'
replicate-edge-cluster:
- services-aws-us-west-2
replicate-target-cluster:
- coreweave-rno2
status:
code: 201
message: Created
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
interactions:
- request:
body: '{"input": {"prompt": "Please write a haiku about llamas"}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '58'
content-type:
- application/json
host:
- api.replicate.com
prefer:
- wait
user-agent:
- replicate-python/0.32.1
method: POST
uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions
response:
body:
string: '{"id":"jp9nrd1g2hrj20cjb2vrb55mkr","model":"meta/meta-llama-3-8b-instruct","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"Please
write a haiku about llamas"},"logs":"","output":["\n\n","Fuzzy",", gentle","
beasts","\nSoft","ly grazing",", quiet"," eyes\n","Llama","''s gentle"," charm",""],"data_removed":false,"error":null,"status":"processing","created_at":"2024-10-04T18:07:33.396Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/jp9nrd1g2hrj20cjb2vrb55mkr/cancel","get":"https://api.replicate.com/v1/predictions/jp9nrd1g2hrj20cjb2vrb55mkr","stream":"https://streaming-api.svc.rno2.c.replicate.net/v1/streams/b4yonjrmynb65tnkucuqc4duawdekslfzexk5itczufef2u36b7a"}}'
headers:
CF-Cache-Status:
- DYNAMIC
CF-Ray:
- 8cd71cb9480406a9-SJC
Connection:
- keep-alive
Content-Length:
- '698'
Content-Type:
- application/json; charset=UTF-8
Date:
- Fri, 04 Oct 2024 18:07:35 GMT
NEL:
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
Preference-Applied:
- wait
Report-To:
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=u8iXtEhO5u6BKquzZ5B1DiM%2B7%2FhDRatREA2SNoe1dFgyyMrWuo9xl8rvi2sN4EjnuxmgK%2BfB8qa%2B1wbo4bChZM6wstlZmAS8P%2F7iccqmK9tm6JbQKT3PUOJcO2FjESZqaY12"}],"group":"cf-nel","max_age":604800}'
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
Vary:
- Accept-Encoding
alt-svc:
- h3=":443"; ma=86400
ratelimit-remaining:
- '599'
ratelimit-reset:
- '1'
replicate-edge-cluster:
- services-aws-us-west-2
replicate-target-cluster:
- coreweave-rno2
status:
code: 201
message: Created
version: 1
Loading

0 comments on commit 08ee31a

Please sign in to comment.