Skip to content

Commit

Permalink
Add strict parameter to stream.zip (issue #118, PR #119)
Browse files Browse the repository at this point in the history
* Add strict parameter to stream.zip (issue #118)

* Use shortcut for anext called without default

* Add tests for exception passthrough in zip

* Add (failing) test case for early exit from zip

* Exit from non-strict zip as early as possible

Fixes failing test from previous commit.

* Make UNSET an enum

Co-authored-by: Vincent Michel <vxgmichel@gmail.com>

* Make STOP_SENTINEL an enum

Co-authored-by: Vincent Michel <vxgmichel@gmail.com>

* Fix imports for enums

* Move strict condition further up & fix typing

* Update aiostream/stream/combine.py

Fix Pyton 3.8 compat

Co-authored-by: Vincent Michel <vxgmichel@gmail.com>

* Move STOP_SENTINEL construction out of function

* Type inner anext wrapper function

* Improve un-overloaded anext() type signature

* Use ellipsis instead of pass

---------

Co-authored-by: Vincent Michel <vxgmichel@gmail.com>
  • Loading branch information
smheidrich and vxgmichel authored Jul 21, 2024
1 parent 5e0547a commit 30665d9
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 13 deletions.
28 changes: 26 additions & 2 deletions aiostream/aiter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from types import TracebackType

import enum
import warnings
import functools
from typing import (
Expand All @@ -19,6 +20,7 @@
AsyncIterator,
Any,
cast,
overload,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -46,16 +48,37 @@
# Magic method shorcuts


_UnsetType = enum.Enum("_UnsetType", "UNSET")
UNSET = _UnsetType.UNSET


def aiter(obj: AsyncIterable[T]) -> AsyncIterator[T]:
"""Access aiter magic method."""
assert_async_iterable(obj)
return obj.__aiter__()


def anext(obj: AsyncIterator[T]) -> Awaitable[T]:
@overload
def anext(obj: AsyncIterator[T]) -> Awaitable[T]: ...


@overload
def anext(obj: AsyncIterator[T], default: U) -> Awaitable[T | U]: ...


def anext(obj: AsyncIterator[T], default: U | _UnsetType = UNSET) -> Awaitable[T | U]:
"""Access anext magic method."""
assert_async_iterator(obj)
return obj.__anext__()
if default is UNSET:
return obj.__anext__()

async def anext_default_handling_wrapper() -> T | U:
try:
return await obj.__anext__()
except StopAsyncIteration:
return default

return anext_default_handling_wrapper()


# Async / await helper functions
Expand Down Expand Up @@ -109,6 +132,7 @@ def assert_async_iterator(obj: object) -> None:

T = TypeVar("T", covariant=True)
Self = TypeVar("Self", bound="AsyncIteratorContext[Any]")
U = TypeVar("U")


class AsyncIteratorContext(
Expand Down
31 changes: 24 additions & 7 deletions aiostream/stream/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import builtins
import enum

from typing import (
Awaitable,
Expand Down Expand Up @@ -46,8 +47,14 @@ async def chain(*sources: AsyncIterable[T]) -> AsyncIterator[T]:
yield item


_StopSentinelType = enum.Enum("_StopSentinelType", "STOP_SENTINEL")
STOP_SENTINEL = _StopSentinelType.STOP_SENTINEL


@sources_operator
async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]:
async def zip(
*sources: AsyncIterable[T], strict: bool = False
) -> AsyncIterator[tuple[T, ...]]:
"""Combine and forward the elements of several asynchronous sequences.
Each generated value is a tuple of elements, using the same order as
Expand Down Expand Up @@ -76,14 +83,24 @@ async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]:
await stack.enter_async_context(streamcontext(source)) for source in sources
]
# Loop over items
items: list[T]
while True:
try:
coros = builtins.map(anext, streamers)
items = await asyncio.gather(*coros)
except StopAsyncIteration:
break
if strict:
coros = (anext(streamer, STOP_SENTINEL) for streamer in streamers)
_items = await asyncio.gather(*coros)
if all(item == STOP_SENTINEL for item in _items):
break
elif any(item == STOP_SENTINEL for item in _items):
raise ValueError("iterables have different lengths")
# This holds because we've ruled out STOP_SENTINEL above:
items = cast("list[T]", _items)
else:
yield tuple(items)
coros = (anext(streamer) for streamer in streamers)
try:
items = await asyncio.gather(*coros)
except StopAsyncIteration:
break
yield tuple(items)


X = TypeVar("X", contravariant=True)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,50 @@ async def test_zip(assert_run):
expected = [(x,) * 3 for x in range(5)]
await assert_run(ys, expected)

# Exceptions from iterables are propagated
xs = stream.zip(stream.range(2), stream.throw(AttributeError))
with pytest.raises(AttributeError):
await xs

# Empty zip (issue #95)
xs = stream.zip()
await assert_run(xs, [])

# Strict mode (issue #118): Iterable length mismatch raises
xs = stream.zip(stream.range(2), stream.range(1), strict=True)
with pytest.raises(ValueError):
await xs

# Strict mode (issue #118): No raise for matching-length iterables
xs = stream.zip(stream.range(2), stream.range(2), strict=True)
await assert_run(xs, [(0, 0), (1, 1)])

# Strict mode (issue #118): Exceptions from iterables are propagated
xs = stream.zip(stream.range(2), stream.throw(AttributeError), strict=True)
with pytest.raises(AttributeError):
await xs

# Strict mode (issue #118): Non-strict mode works as before
xs = stream.zip(stream.range(2), stream.range(1))
await assert_run(xs, [(0, 0)])

# Strict mode (issue #118): In particular, we stop immediately if any
# one iterable is exhausted, not waiting for the others
slow_iterable_continued_after_sleep = asyncio.Event()

async def fast_iterable():
yield 0
await asyncio.sleep(1)

async def slow_iterable():
yield 0
await asyncio.sleep(2)
slow_iterable_continued_after_sleep.set()

xs = stream.zip(fast_iterable(), slow_iterable())
await assert_run(xs, [(0, 0)])
assert not slow_iterable_continued_after_sleep.is_set()


@pytest.mark.asyncio
async def test_map(assert_run, assert_cleanup):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_introspection_for_sources_operator():
)
assert (
str(inspect.signature(original))
== "(*sources: 'AsyncIterable[T]') -> 'AsyncIterator[tuple[T, ...]]'"
== "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'AsyncIterator[tuple[T, ...]]'"
)

# Check the stream operator
Expand All @@ -251,7 +251,7 @@ def test_introspection_for_sources_operator():
assert stream.zip.raw.__doc__ == original_doc
assert (
str(inspect.signature(stream.zip.raw))
== "(*sources: 'AsyncIterable[T]') -> 'AsyncIterator[tuple[T, ...]]'"
== "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'AsyncIterator[tuple[T, ...]]'"
)

# Check the __call__ method
Expand All @@ -260,7 +260,7 @@ def test_introspection_for_sources_operator():
assert stream.zip.__call__.__doc__ == original_doc
assert (
str(inspect.signature(stream.zip.__call__))
== "(*sources: 'AsyncIterable[T]') -> 'Stream[tuple[T, ...]]'"
== "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'Stream[tuple[T, ...]]'"
)

# Check the pipe method
Expand All @@ -272,5 +272,5 @@ def test_introspection_for_sources_operator():
)
assert (
str(inspect.signature(stream.zip.pipe))
== "(*sources: 'AsyncIterable[T]') -> 'Callable[[AsyncIterable[Any]], Stream[tuple[T, ...]]]'"
== "(*sources: 'AsyncIterable[T]', strict: 'bool' = False) -> 'Callable[[AsyncIterable[Any]], Stream[tuple[T, ...]]]'"
)

0 comments on commit 30665d9

Please sign in to comment.