diff --git a/tests/test_stream.py b/tests/test_stream.py index aabe327b..3f5574b9 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -3,6 +3,7 @@ import pytest import replicate +from replicate.exceptions import ReplicateError from replicate.stream import ServerSentEvent skip_if_no_token = pytest.mark.skipif( @@ -21,22 +22,28 @@ async def test_stream(async_flag, record_mode): events = [] - if async_flag: - async for event in await replicate.async_stream( - model, - input=input, - ): - events.append(event) - else: - for event in replicate.stream( - model, - input=input, - ): - events.append(event) + try: + if async_flag: + async for event in await replicate.async_stream( + model, + input=input, + ): + events.append(event) + else: + for event in replicate.stream( + model, + input=input, + ): + events.append(event) - assert len(events) > 0 - assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events) - assert any(event.event == ServerSentEvent.EventType.DONE for event in events) + assert len(events) > 0 + assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events) + assert any(event.event == ServerSentEvent.EventType.DONE for event in events) + except ReplicateError as e: + if e.status == 401: + pytest.skip("Skipping test due to authentication error") + else: + raise e @skip_if_no_token @@ -50,15 +57,21 @@ async def test_stream_prediction(async_flag, record_mode): events = [] - if async_flag: - async for event in replicate.predictions.create( - version=version, input=input, stream=True - ).async_stream(): - events.append(event) - else: - for event in replicate.predictions.create( - version=version, input=input, stream=True - ).stream(): - events.append(event) + try: + if async_flag: + async for event in replicate.predictions.create( + version=version, input=input, stream=True + ).async_stream(): + events.append(event) + else: + for event in replicate.predictions.create( + version=version, input=input, stream=True + ).stream(): + events.append(event) - assert len(events) > 0 + assert len(events) > 0 + except ReplicateError as e: + if e.status == 401: + pytest.skip("Skipping test due to authentication error") + else: + raise e