Skip to content

Commit

Permalink
Add sources_operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel committed Apr 27, 2024
1 parent 8e55bcd commit a0d0790
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 11 deletions.
185 changes: 185 additions & 0 deletions aiostream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,31 @@ def pipe(
...


B = TypeVar("B", contravariant=True)
U = TypeVar("U", covariant=True)


class SourcesCallable(Protocol[B, U]):
def __call__(
self,
*sources: AsyncIterable[B],
) -> AsyncIterator[U]:
...


class SourcesOperatorType(Protocol[A, T]):
def __call__(self, *sources: AsyncIterable[A]) -> Stream[T]:
...

def raw(self, *sources: AsyncIterable[A]) -> AsyncIterator[T]:
...

def pipe(
self, *sources: AsyncIterable[A]
) -> Callable[[AsyncIterable[A]], Stream[T]]:
...


# Operator decorator


Expand Down Expand Up @@ -554,3 +579,163 @@ def pipe(
"PipableOperatorType[X, P, T]",
type(name, bases, attrs),
)


def sources_operator(
func: SourcesCallable[X, T],
) -> SourcesOperatorType[X, T]:
"""Create a pipable stream operator from an asynchronous generator
(or any function returning an asynchronous iterable).
Decorator usage::
@sources_operator
async def chain(*sources):
for source in sources:
async with streamcontext(source) as streamer:
async for item in streamer:
yield item
Positional arguments are expected to be the asynchronous iteratables.
Keyword arguments are not supported at the moment.
When used in a pipable context, the asynchronous iterable injected by
the pipe operator is used as the first argument.
The return value is a dynamically created class.
It has the same name, module and doc as the original function.
A new stream is created by simply instanciating the operator::
empty_chained = chain()
single_chained = chain(random())
multiple_chained = chain(stream.just(0.0), stream.just(1.0), random())
The original function is called at instanciation to check that
signature match. The source is also checked for asynchronous iteration.
The operator also have a pipe class method that can be used along
with the piping synthax::
just_zero = stream.just(0.0)
multiple_chained = just_zero | chain.pipe(stream.just(1.0, random())
This is strictly equivalent to the previous example.
Other methods are available:
- `original`: the original function as a static method
- `raw`: same as original but add extra checking
The raw method is useful to create new operators from existing ones::
@chain_operator
def chain_twice(*sources):
return chain.raw(*sources, *sources)
"""

# First check for classmethod instance, to avoid more confusing errors later on
if isinstance(func, classmethod):
raise ValueError(
"An operator cannot be created from a class method, "
"since the decorated function becomes an operator class"
)

# Gather data
bases = (Stream,)
name = func.__name__ # type: ignore
module = func.__module__
extra_doc = func.__doc__
doc = extra_doc or f"Regular {name} stream operator."

# Extract signature
signature = inspect.signature(func)
parameters = list(signature.parameters.values())
if parameters and parameters[0].name in ("self", "cls"):
raise ValueError(
"An operator cannot be created from a method, "
"since the decorated function becomes an operator class"
)
if not parameters or parameters[0].kind not in (inspect.Parameter.VAR_POSITIONAL,):
raise ValueError(
"The first argument of a sources operator must be a variadic positional argument."
)

# Injected parameters
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
cls_parameter = inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD)

# Wrapped static method
original = func
original.__qualname__ = name + ".original" # type: ignore

# Raw static method
def raw(*sources: AsyncIterable[X]) -> AsyncIterator[T]:
for source in sources:
assert_async_iterable(source)
return func(*sources)

# Custonize raw method
raw.__signature__ = signature # type: ignore[attr-defined]
raw.__qualname__ = name + ".raw"
raw.__module__ = module
raw.__doc__ = doc

# Init method
def init(self: BaseStream[T], *sources: AsyncIterable[X]) -> None:
for source in sources:
assert_async_iterable(source)
factory = functools.partial(raw, *sources)
return BaseStream.__init__(self, factory)

# Customize init signature
new_parameters = [self_parameter] + parameters
init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]

# Customize init method
init.__qualname__ = name + ".__init__"
init.__name__ = "__init__"
init.__module__ = module
init.__doc__ = f"Initialize the {name} stream."

# Pipe class method
def pipe(
cls: SourcesOperatorType[X, T],
/,
*sources: AsyncIterable[X],
) -> Callable[[AsyncIterable[X]], Stream[T]]:
return lambda source: cls(source, *sources)

# Customize pipe signature
if parameters and parameters[0].kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
new_parameters = [cls_parameter] + parameters[1:]
else:
new_parameters = [cls_parameter] + parameters
pipe.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]

# Customize pipe method
pipe.__qualname__ = name + ".pipe"
pipe.__module__ = module
pipe.__doc__ = f'Pipable "{name}" stream operator.'
if extra_doc:
pipe.__doc__ += "\n\n " + extra_doc

# Gather attributes
attrs = {
"__init__": init,
"__module__": module,
"__doc__": doc,
"raw": staticmethod(raw),
"original": staticmethod(original),
"pipe": classmethod(pipe), # type: ignore[arg-type]
}

# Create operator class
return cast(
"SourcesOperatorType[X, T]",
type(name, bases, attrs),
)
22 changes: 11 additions & 11 deletions aiostream/stream/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import ParamSpec

from ..aiter_utils import AsyncExitStack, anext
from ..core import streamcontext, pipable_operator
from ..core import streamcontext, pipable_operator, sources_operator

from . import create
from . import select
Expand Down Expand Up @@ -48,9 +48,9 @@ async def chain(
yield item


@pipable_operator
@sources_operator
async def zip(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
*sources: AsyncIterable[T],
) -> AsyncIterator[tuple[T, ...]]:
"""Combine and forward the elements of several asynchronous sequences.
Expand All @@ -61,7 +61,9 @@ async def zip(
Note: the different sequences are awaited in parrallel, so that their
waiting times don't add up.
"""
sources = source, *more_sources
# No sources
if not sources:
return

# One sources
if len(sources) == 1:
Expand Down Expand Up @@ -211,18 +213,15 @@ def map(
return smap.raw(source, sync_func, *more_sources)


@pipable_operator
def merge(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
) -> AsyncIterator[T]:
@sources_operator
def merge(*sources: AsyncIterable[T]) -> AsyncIterator[T]:
"""Merge several asynchronous sequences together.
All the sequences are iterated simultaneously and their elements
are forwarded as soon as they're available. The generation continues
until all the sequences are exhausted.
"""
sources = [source, *more_sources]
source_stream: AsyncIterable[AsyncIterable[T]] = create.iterate.raw(sources)
source_stream: AsyncIterator[AsyncIterable[T]] = create.iterate.raw(sources)
return advanced.flatten.raw(source_stream)


Expand Down Expand Up @@ -263,7 +262,8 @@ def func(x: T, *_: object) -> dict[int, T]:
new_sources = [smap.raw(source, make_func(i)) for i, source in enumerate(sources)]

# Merge the sources
merged = merge.raw(*new_sources)
# TODO: sources_operator causes type inference to fail here:
merged: AsyncIterator[dict[int, T]] = merge.raw(*new_sources) # type: ignore[assignment, arg-type]

# Accumulate the current state in a dict
accumulated = aggregate.accumulate.raw(merged, lambda x, e: {**x, **e})
Expand Down
20 changes: 20 additions & 0 deletions tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ async def test_zip(assert_run, event_loop):
expected = [(x,) * 3 for x in range(5)]
await assert_run(ys, expected)

# Empty zip (issue #95)

with event_loop.assert_cleanup():
xs = stream.zip()
await assert_run(xs, [])


@pytest.mark.asyncio
async def test_map(assert_run, event_loop):
Expand Down Expand Up @@ -160,6 +166,12 @@ async def agen2():
xs = stream.merge(agen1(), agen2()) | pipe.delay(1) | pipe.take(1)
await assert_run(xs, [1])

# Empty merge (issue #95)

with event_loop.assert_cleanup():
xs = stream.merge()
await assert_run(xs, [])


@pytest.mark.asyncio
async def test_ziplatest(assert_run, event_loop):
Expand All @@ -176,3 +188,11 @@ async def test_ziplatest(assert_run, event_loop):
zs = stream.ziplatest(xs, ys, partial=False)
await assert_run(zs, [(0, 1), (2, 1), (2, 3), (4, 3)])
assert event_loop.steps == [1, 1, 1, 1]

# Empty ziplatest (issue #95)
# This not supported yet due to the `sources_operator` decorator
# not supporting keyword arguments.
#
# with event_loop.assert_cleanup():
# xs = stream.ziplatest()
# await assert_run(xs, [])

0 comments on commit a0d0790

Please sign in to comment.