From ac9d09bf93953598c754997242e3fd1671caca39 Mon Sep 17 00:00:00 2001 From: kavorite Date: Thu, 12 Dec 2024 23:42:49 +0000 Subject: [PATCH 1/3] feat: prefetch --- aiostream/pipe.py | 1 + aiostream/stream/transform.py | 40 ++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/aiostream/pipe.py b/aiostream/pipe.py index ec872a6..91d26f9 100644 --- a/aiostream/pipe.py +++ b/aiostream/pipe.py @@ -38,3 +38,4 @@ until = stream.until.pipe zip = stream.zip.pipe ziplatest = stream.ziplatest.pipe +prefetch = stream.prefetch.pipe diff --git a/aiostream/stream/transform.py b/aiostream/stream/transform.py index 72ffcad..2faa132 100644 --- a/aiostream/stream/transform.py +++ b/aiostream/stream/transform.py @@ -20,7 +20,7 @@ from . import aggregate from .combine import map, amap, smap -__all__ = ["map", "enumerate", "starmap", "cycle", "chunks"] +__all__ = ["map", "enumerate", "starmap", "cycle", "chunks", "prefetch"] T = TypeVar("T") @@ -122,3 +122,41 @@ async def chunks(source: AsyncIterable[T], n: int) -> AsyncIterator[list[T]]: async for first in streamer: xs = select.take(create.preserve(streamer), n - 1) yield [first] + await aggregate.list(xs) + + +@pipable_operator +async def prefetch(source: AsyncIterable[T], buffer_size: int = 1) -> AsyncIterator[T]: + """Prefetch items from an asynchronous sequence into a buffer. + + Args: + source: The source async iterable + buffer_size: Size of the prefetch buffer. <= 0 means unlimited buffer. + """ + sentinel = object() + queue: asyncio.Queue[T] = asyncio.Queue( + maxsize=buffer_size if buffer_size > 0 else 0 + ) + + async def _worker(): + try: + async with streamcontext(source) as streamer: + async for item in streamer: + await queue.put(item) + finally: + await queue.put(sentinel) # Sentinel value + + worker = asyncio.create_task(_worker()) + + try: + while True: + item = await queue.get() + if item is sentinel: # End of stream + break + yield item + + finally: + worker.cancel() + try: + await worker + except (asyncio.CancelledError, Exception): + pass From 6ad0be84049d6cd83be2968e88674af860460e56 Mon Sep 17 00:00:00 2001 From: kavorite Date: Fri, 13 Dec 2024 00:09:13 +0000 Subject: [PATCH 2/3] remove overly ambitious suppression clause --- aiostream/stream/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiostream/stream/transform.py b/aiostream/stream/transform.py index 2faa132..d05da99 100644 --- a/aiostream/stream/transform.py +++ b/aiostream/stream/transform.py @@ -158,5 +158,5 @@ async def _worker(): worker.cancel() try: await worker - except (asyncio.CancelledError, Exception): + except asyncio.CancelledError: pass From 26eeac1068fd680d8ad529e1a8e332837ff93f77 Mon Sep 17 00:00:00 2001 From: kavorite Date: Thu, 19 Dec 2024 13:04:28 +0000 Subject: [PATCH 3/3] add a prefetch test --- tests/test_transform.py | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/test_transform.py b/tests/test_transform.py index 817fc79..0f9682b 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -68,3 +68,66 @@ async def test_chunks(assert_run, assert_cleanup): with assert_cleanup(): xs = stream.count(interval=1) | add_resource.pipe(1) | pipe.chunks(3) await assert_run(xs[:1], [[0, 1, 2]]) + + +@pytest.mark.asyncio +async def test_prefetch(assert_run, assert_cleanup): + # Test basic prefetching + with assert_cleanup(): + xs = stream.range(5, interval=1) + ys = xs | pipe.prefetch(1) # Default buffer_size=1 + await assert_run(ys, [0, 1, 2, 3, 4]) + + # Test with custom buffer size + with assert_cleanup(): + xs = stream.range(5, interval=1) + ys = xs | pipe.prefetch(1) + await assert_run(ys, [0, 1, 2, 3, 4]) + + # Test with buffer_size=0 + with assert_cleanup(): + xs = stream.range(5, interval=1) + ys = xs | pipe.prefetch(0) + await assert_run(ys, [0, 1, 2, 3, 4]) + + # Test cleanup on early exit + with assert_cleanup(): + xs = stream.range(100, interval=1) + ys = xs | pipe.prefetch(buffer_size=1) | pipe.take(3) + await assert_run(ys, [0, 1, 2]) + + # Test with empty stream + with assert_cleanup(): + xs = stream.empty() | pipe.prefetch() + await assert_run(xs, []) + + # Test with error propagation + with assert_cleanup(): + xs = stream.throw(ValueError()) | pipe.prefetch() + await assert_run(xs, [], ValueError()) + + +@pytest.mark.asyncio +async def test_prefetch_timing(assert_run, assert_cleanup): + async def slow_fetch(x: int) -> int: + await asyncio.sleep(0.1) + return x + + async def slow_processor(x: int) -> int: + await asyncio.sleep(0.4) + return x + + with assert_cleanup() as loop: + # Without prefetch (sequential): + xs = stream.range(3) | pipe.map(slow_fetch, task_limit=1) + ys = xs | pipe.map(slow_processor, task_limit=1) # Process time + await assert_run(ys, [0, 1, 2]) + assert loop.steps == pytest.approx([0.1, 0.4, 0.1, 0.4, 0.1, 0.4]) + + with assert_cleanup() as loop: + # With prefetch: + xs = stream.range(3) | pipe.map(slow_fetch, task_limit=1) + ys = xs | pipe.prefetch(1) | pipe.map(slow_processor, task_limit=1) + await assert_run(ys, [0, 1, 2]) + # instead of taking 0.1 + 0.4 seconds per element, we should now just take max(0.1, 0.4) = 0.4 + assert loop.steps == pytest.approx([0.1, 0.1, 0.1, 0.2, 0.4, 0.4])