From ba9a4ebe984f58d38605fd5c5746a5e544e29fea Mon Sep 17 00:00:00 2001 From: daanelson Date: Fri, 3 Mar 2023 10:29:19 -0800 Subject: [PATCH] added timings, tests (#75) Signed-off-by: dan nelson --- replicate/prediction.py | 3 ++ tests/test_prediction.py | 68 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/replicate/prediction.py b/replicate/prediction.py index 54f5db66..fbfcdc2f 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -17,6 +17,9 @@ class Prediction(BaseModel): output: Optional[Any] status: str version: Optional[Version] + started_at: Optional[str] + created_at: Optional[str] + completed_at: Optional[str] def wait(self): """Wait for prediction to finish.""" diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 4d0dcef8..bfe8a3ec 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -48,3 +48,71 @@ def test_cancel(): rsp = responses.post("https://api.replicate.com/v1/predictions/p1/cancel", json={}) prediction.cancel() assert rsp.call_count == 1 + + +@responses.activate +def test_async_timings(): + client = create_client() + version = create_version(client) + + responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher( + { + "version": "v1", + "input": {"text": "hello"}, + "webhook_completed": "https://example.com/webhook", + } + ), + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "source": "api", + "status": "processing", + "input": {"text": "hello"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + responses.get( + "https://api.replicate.com/v1/predictions/p1", + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "succeeded", + "input": {"text": "hello"}, + "output": "hello world", + "error": None, + "logs": "", + }, + ) + + prediction = client.predictions.create( + version=version, + input={"text": "hello"}, + webhook_completed="https://example.com/webhook", + ) + + assert prediction.created_at == "2022-04-26T20:00:40.658234Z" + assert prediction.completed_at == None + assert prediction.output == None + prediction.wait() + assert prediction.created_at == "2022-04-26T20:00:40.658234Z" + assert prediction.completed_at == "2022-04-26T20:02:27.648305Z" + assert prediction.output == "hello world"