Skip to content

Commit

Permalink
Update integration tests to catch 401 errors (#317)
Browse files Browse the repository at this point in the history
Follow-up to #316

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt authored Jun 28, 2024
1 parent f21086e commit cd422fc
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import replicate
from replicate.exceptions import ReplicateError
from replicate.stream import ServerSentEvent

skip_if_no_token = pytest.mark.skipif(
Expand All @@ -21,22 +22,28 @@ async def test_stream(async_flag, record_mode):

events = []

if async_flag:
async for event in await replicate.async_stream(
model,
input=input,
):
events.append(event)
else:
for event in replicate.stream(
model,
input=input,
):
events.append(event)
try:
if async_flag:
async for event in await replicate.async_stream(
model,
input=input,
):
events.append(event)
else:
for event in replicate.stream(
model,
input=input,
):
events.append(event)

assert len(events) > 0
assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events)
assert any(event.event == ServerSentEvent.EventType.DONE for event in events)
assert len(events) > 0
assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events)
assert any(event.event == ServerSentEvent.EventType.DONE for event in events)
except ReplicateError as e:
if e.status == 401:
pytest.skip("Skipping test due to authentication error")
else:
raise e


@skip_if_no_token
Expand All @@ -50,15 +57,21 @@ async def test_stream_prediction(async_flag, record_mode):

events = []

if async_flag:
async for event in replicate.predictions.create(
version=version, input=input, stream=True
).async_stream():
events.append(event)
else:
for event in replicate.predictions.create(
version=version, input=input, stream=True
).stream():
events.append(event)
try:
if async_flag:
async for event in replicate.predictions.create(
version=version, input=input, stream=True
).async_stream():
events.append(event)
else:
for event in replicate.predictions.create(
version=version, input=input, stream=True
).stream():
events.append(event)

assert len(events) > 0
assert len(events) > 0
except ReplicateError as e:
if e.status == 401:
pytest.skip("Skipping test due to authentication error")
else:
raise e

0 comments on commit cd422fc

Please sign in to comment.