Skip to content

Commit

Permalink
added timings, tests (#75)
Browse files Browse the repository at this point in the history
Signed-off-by: dan nelson <dan.nelson8@gmail.com>
  • Loading branch information
daanelson authored Mar 3, 2023
1 parent 6a7739b commit ba9a4eb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
3 changes: 3 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class Prediction(BaseModel):
output: Optional[Any]
status: str
version: Optional[Version]
started_at: Optional[str]
created_at: Optional[str]
completed_at: Optional[str]

def wait(self):
"""Wait for prediction to finish."""
Expand Down
68 changes: 68 additions & 0 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,71 @@ def test_cancel():
rsp = responses.post("https://api.replicate.com/v1/predictions/p1/cancel", json={})
prediction.cancel()
assert rsp.call_count == 1


@responses.activate
def test_async_timings():
client = create_client()
version = create_version(client)

responses.post(
"https://api.replicate.com/v1/predictions",
match=[
matchers.json_params_matcher(
{
"version": "v1",
"input": {"text": "hello"},
"webhook_completed": "https://example.com/webhook",
}
),
],
json={
"id": "p1",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2022-04-26T20:00:40.658234Z",
"source": "api",
"status": "processing",
"input": {"text": "hello"},
"output": None,
"error": None,
"logs": "",
},
)

responses.get(
"https://api.replicate.com/v1/predictions/p1",
json={
"id": "p1",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2022-04-26T20:00:40.658234Z",
"completed_at": "2022-04-26T20:02:27.648305Z",
"source": "api",
"status": "succeeded",
"input": {"text": "hello"},
"output": "hello world",
"error": None,
"logs": "",
},
)

prediction = client.predictions.create(
version=version,
input={"text": "hello"},
webhook_completed="https://example.com/webhook",
)

assert prediction.created_at == "2022-04-26T20:00:40.658234Z"
assert prediction.completed_at == None
assert prediction.output == None
prediction.wait()
assert prediction.created_at == "2022-04-26T20:00:40.658234Z"
assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"
assert prediction.output == "hello world"

0 comments on commit ba9a4eb

Please sign in to comment.