diff --git a/fauna/client/client.py b/fauna/client/client.py index 3758315c..4446ea86 100644 --- a/fauna/client/client.py +++ b/fauna/client/client.py @@ -378,11 +378,11 @@ def _query( schema_version=schema_version, ) - def stream(self, fql: StreamToken | Query) -> Iterator[Any]: - return RetryStreamIter(self, fql) + def stream(self, fql: StreamToken | Query) -> "StreamIterator": + return StreamIterator(self, fql) @contextmanager - def _stream(self, fql: StreamToken | Query) -> Iterator[Any]: + def _stream(self, fql: StreamToken | Query): headers = self._headers.copy() headers[_Header.Format] = "tagged" headers[_Header.Authorization] = self._auth.bearer() @@ -615,7 +615,7 @@ def _set_endpoint(self, endpoint): self._endpoint = endpoint -class RetryStreamIter: +class StreamIterator: """A class that mix a ContextManager and an Iterator so we can detected retryable errors.""" def __init__(self, client, fql): @@ -638,7 +638,8 @@ def __iter__(self): def __next__(self): try: - return next(self.stream) + if self.stream is not None: + return next(self.stream) except NetworkError as e: return self._retry() @@ -648,7 +649,8 @@ def _retry(self): return self.__next__() def close(self): - self.stream.close() + if self.stream is not None: + self.stream.close() class QueryIterator: diff --git a/fauna/http/http_client.py b/fauna/http/http_client.py index 205525f4..c9dcfd23 100644 --- a/fauna/http/http_client.py +++ b/fauna/http/http_client.py @@ -1,4 +1,5 @@ import abc +import contextlib from typing import Iterator, Mapping, Any from dataclasses import dataclass @@ -62,6 +63,7 @@ def request( pass @abc.abstractmethod + @contextlib.contextmanager def stream( self, url: str, diff --git a/tests/integration/test_client_with_query_limits.py b/tests/integration/test_client_with_query_limits.py index f5ce09d8..86339aae 100644 --- a/tests/integration/test_client_with_query_limits.py +++ b/tests/integration/test_client_with_query_limits.py @@ -9,7 +9,7 @@ from fauna.errors.errors import ThrottlingError -def query_collection(client: Client) -> QuerySuccess: +def query_collection(client: Client) -> QuerySuccess | None: coll_name = os.environ.get("QUERY_LIMITS_COLL") or "" try: return client.query(fql("${coll}.all().paginate(50)", coll=fql(coll_name))) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6024babd..1e1ec69d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -442,7 +442,7 @@ def test_client_close_stream(subtests, httpx_mock: HTTPXMock): http_client = HTTPXClient(mockClient) c = Client(http_client=http_client) with c.stream(StreamToken("token")) as stream: - next(stream) == 10 + assert next(stream) == 10 stream.close() with pytest.raises(StopIteration): diff --git a/tests/unit/test_httpx_client.py b/tests/unit/test_httpx_client.py index 7451b815..3b2930ac 100644 --- a/tests/unit/test_httpx_client.py +++ b/tests/unit/test_httpx_client.py @@ -37,7 +37,7 @@ def to_json_bytes(obj): with httpx.Client() as mockClient: http_client = HTTPXClient(mockClient) with http_client.stream("http://localhost:8443", {}, {}) as stream: - next(stream) == expected[0] + assert next(stream) == expected[0] stream.close() with pytest.raises(StopIteration):