From a0d07901a6d41ec6baf81f44aeda69ec2a47bad6 Mon Sep 17 00:00:00 2001 From: Vincent Michel Date: Fri, 26 Apr 2024 18:43:20 +0200 Subject: [PATCH] Add sources_operator --- aiostream/core.py | 185 ++++++++++++++++++++++++++++++++++++ aiostream/stream/combine.py | 22 ++--- tests/test_combine.py | 20 ++++ 3 files changed, 216 insertions(+), 11 deletions(-) diff --git a/aiostream/core.py b/aiostream/core.py index 9bf4fdc..af71fe1 100644 --- a/aiostream/core.py +++ b/aiostream/core.py @@ -282,6 +282,31 @@ def pipe( ... +B = TypeVar("B", contravariant=True) +U = TypeVar("U", covariant=True) + + +class SourcesCallable(Protocol[B, U]): + def __call__( + self, + *sources: AsyncIterable[B], + ) -> AsyncIterator[U]: + ... + + +class SourcesOperatorType(Protocol[A, T]): + def __call__(self, *sources: AsyncIterable[A]) -> Stream[T]: + ... + + def raw(self, *sources: AsyncIterable[A]) -> AsyncIterator[T]: + ... + + def pipe( + self, *sources: AsyncIterable[A] + ) -> Callable[[AsyncIterable[A]], Stream[T]]: + ... + + # Operator decorator @@ -554,3 +579,163 @@ def pipe( "PipableOperatorType[X, P, T]", type(name, bases, attrs), ) + + +def sources_operator( + func: SourcesCallable[X, T], +) -> SourcesOperatorType[X, T]: + """Create a pipable stream operator from an asynchronous generator + (or any function returning an asynchronous iterable). + + Decorator usage:: + + @sources_operator + async def chain(*sources): + for source in sources: + async with streamcontext(source) as streamer: + async for item in streamer: + yield item + + Positional arguments are expected to be the asynchronous iteratables. + + Keyword arguments are not supported at the moment. + + When used in a pipable context, the asynchronous iterable injected by + the pipe operator is used as the first argument. + + The return value is a dynamically created class. + It has the same name, module and doc as the original function. + + A new stream is created by simply instanciating the operator:: + + empty_chained = chain() + single_chained = chain(random()) + multiple_chained = chain(stream.just(0.0), stream.just(1.0), random()) + + The original function is called at instanciation to check that + signature match. The source is also checked for asynchronous iteration. + + The operator also have a pipe class method that can be used along + with the piping synthax:: + + just_zero = stream.just(0.0) + multiple_chained = just_zero | chain.pipe(stream.just(1.0, random()) + + This is strictly equivalent to the previous example. + + Other methods are available: + + - `original`: the original function as a static method + - `raw`: same as original but add extra checking + + The raw method is useful to create new operators from existing ones:: + + @chain_operator + def chain_twice(*sources): + return chain.raw(*sources, *sources) + """ + + # First check for classmethod instance, to avoid more confusing errors later on + if isinstance(func, classmethod): + raise ValueError( + "An operator cannot be created from a class method, " + "since the decorated function becomes an operator class" + ) + + # Gather data + bases = (Stream,) + name = func.__name__ # type: ignore + module = func.__module__ + extra_doc = func.__doc__ + doc = extra_doc or f"Regular {name} stream operator." + + # Extract signature + signature = inspect.signature(func) + parameters = list(signature.parameters.values()) + if parameters and parameters[0].name in ("self", "cls"): + raise ValueError( + "An operator cannot be created from a method, " + "since the decorated function becomes an operator class" + ) + if not parameters or parameters[0].kind not in (inspect.Parameter.VAR_POSITIONAL,): + raise ValueError( + "The first argument of a sources operator must be a variadic positional argument." + ) + + # Injected parameters + self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) + cls_parameter = inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD) + + # Wrapped static method + original = func + original.__qualname__ = name + ".original" # type: ignore + + # Raw static method + def raw(*sources: AsyncIterable[X]) -> AsyncIterator[T]: + for source in sources: + assert_async_iterable(source) + return func(*sources) + + # Custonize raw method + raw.__signature__ = signature # type: ignore[attr-defined] + raw.__qualname__ = name + ".raw" + raw.__module__ = module + raw.__doc__ = doc + + # Init method + def init(self: BaseStream[T], *sources: AsyncIterable[X]) -> None: + for source in sources: + assert_async_iterable(source) + factory = functools.partial(raw, *sources) + return BaseStream.__init__(self, factory) + + # Customize init signature + new_parameters = [self_parameter] + parameters + init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined] + + # Customize init method + init.__qualname__ = name + ".__init__" + init.__name__ = "__init__" + init.__module__ = module + init.__doc__ = f"Initialize the {name} stream." + + # Pipe class method + def pipe( + cls: SourcesOperatorType[X, T], + /, + *sources: AsyncIterable[X], + ) -> Callable[[AsyncIterable[X]], Stream[T]]: + return lambda source: cls(source, *sources) + + # Customize pipe signature + if parameters and parameters[0].kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + new_parameters = [cls_parameter] + parameters[1:] + else: + new_parameters = [cls_parameter] + parameters + pipe.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined] + + # Customize pipe method + pipe.__qualname__ = name + ".pipe" + pipe.__module__ = module + pipe.__doc__ = f'Pipable "{name}" stream operator.' + if extra_doc: + pipe.__doc__ += "\n\n " + extra_doc + + # Gather attributes + attrs = { + "__init__": init, + "__module__": module, + "__doc__": doc, + "raw": staticmethod(raw), + "original": staticmethod(original), + "pipe": classmethod(pipe), # type: ignore[arg-type] + } + + # Create operator class + return cast( + "SourcesOperatorType[X, T]", + type(name, bases, attrs), + ) diff --git a/aiostream/stream/combine.py b/aiostream/stream/combine.py index a782a73..db02231 100644 --- a/aiostream/stream/combine.py +++ b/aiostream/stream/combine.py @@ -16,7 +16,7 @@ from typing_extensions import ParamSpec from ..aiter_utils import AsyncExitStack, anext -from ..core import streamcontext, pipable_operator +from ..core import streamcontext, pipable_operator, sources_operator from . import create from . import select @@ -48,9 +48,9 @@ async def chain( yield item -@pipable_operator +@sources_operator async def zip( - source: AsyncIterable[T], *more_sources: AsyncIterable[T] + *sources: AsyncIterable[T], ) -> AsyncIterator[tuple[T, ...]]: """Combine and forward the elements of several asynchronous sequences. @@ -61,7 +61,9 @@ async def zip( Note: the different sequences are awaited in parrallel, so that their waiting times don't add up. """ - sources = source, *more_sources + # No sources + if not sources: + return # One sources if len(sources) == 1: @@ -211,18 +213,15 @@ def map( return smap.raw(source, sync_func, *more_sources) -@pipable_operator -def merge( - source: AsyncIterable[T], *more_sources: AsyncIterable[T] -) -> AsyncIterator[T]: +@sources_operator +def merge(*sources: AsyncIterable[T]) -> AsyncIterator[T]: """Merge several asynchronous sequences together. All the sequences are iterated simultaneously and their elements are forwarded as soon as they're available. The generation continues until all the sequences are exhausted. """ - sources = [source, *more_sources] - source_stream: AsyncIterable[AsyncIterable[T]] = create.iterate.raw(sources) + source_stream: AsyncIterator[AsyncIterable[T]] = create.iterate.raw(sources) return advanced.flatten.raw(source_stream) @@ -263,7 +262,8 @@ def func(x: T, *_: object) -> dict[int, T]: new_sources = [smap.raw(source, make_func(i)) for i, source in enumerate(sources)] # Merge the sources - merged = merge.raw(*new_sources) + # TODO: sources_operator causes type inference to fail here: + merged: AsyncIterator[dict[int, T]] = merge.raw(*new_sources) # type: ignore[assignment, arg-type] # Accumulate the current state in a dict accumulated = aggregate.accumulate.raw(merged, lambda x, e: {**x, **e}) diff --git a/tests/test_combine.py b/tests/test_combine.py index c21afda..19a05ef 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -27,6 +27,12 @@ async def test_zip(assert_run, event_loop): expected = [(x,) * 3 for x in range(5)] await assert_run(ys, expected) + # Empty zip (issue #95) + + with event_loop.assert_cleanup(): + xs = stream.zip() + await assert_run(xs, []) + @pytest.mark.asyncio async def test_map(assert_run, event_loop): @@ -160,6 +166,12 @@ async def agen2(): xs = stream.merge(agen1(), agen2()) | pipe.delay(1) | pipe.take(1) await assert_run(xs, [1]) + # Empty merge (issue #95) + + with event_loop.assert_cleanup(): + xs = stream.merge() + await assert_run(xs, []) + @pytest.mark.asyncio async def test_ziplatest(assert_run, event_loop): @@ -176,3 +188,11 @@ async def test_ziplatest(assert_run, event_loop): zs = stream.ziplatest(xs, ys, partial=False) await assert_run(zs, [(0, 1), (2, 1), (2, 3), (4, 3)]) assert event_loop.steps == [1, 1, 1, 1] + + # Empty ziplatest (issue #95) + # This not supported yet due to the `sources_operator` decorator + # not supporting keyword arguments. + # + # with event_loop.assert_cleanup(): + # xs = stream.ziplatest() + # await assert_run(xs, [])