Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prefetch #125

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiostream/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
until = stream.until.pipe
zip = stream.zip.pipe
ziplatest = stream.ziplatest.pipe
prefetch = stream.prefetch.pipe
40 changes: 39 additions & 1 deletion aiostream/stream/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import aggregate
from .combine import map, amap, smap

__all__ = ["map", "enumerate", "starmap", "cycle", "chunks"]
__all__ = ["map", "enumerate", "starmap", "cycle", "chunks", "prefetch"]


T = TypeVar("T")
Expand Down Expand Up @@ -122,3 +122,41 @@ async def chunks(source: AsyncIterable[T], n: int) -> AsyncIterator[list[T]]:
async for first in streamer:
xs = select.take(create.preserve(streamer), n - 1)
yield [first] + await aggregate.list(xs)


@pipable_operator
async def prefetch(source: AsyncIterable[T], buffer_size: int = 1) -> AsyncIterator[T]:
"""Prefetch items from an asynchronous sequence into a buffer.

Args:
source: The source async iterable
buffer_size: Size of the prefetch buffer. <= 0 means unlimited buffer.
"""
sentinel = object()
queue: asyncio.Queue[T] = asyncio.Queue(
maxsize=buffer_size if buffer_size > 0 else 0
)

async def _worker():
try:
async with streamcontext(source) as streamer:
async for item in streamer:
await queue.put(item)
finally:
await queue.put(sentinel) # Sentinel value

worker = asyncio.create_task(_worker())

try:
while True:
item = await queue.get()
if item is sentinel: # End of stream
break
yield item

finally:
worker.cancel()
try:
await worker
except asyncio.CancelledError:
pass
63 changes: 63 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,66 @@ async def test_chunks(assert_run, assert_cleanup):
with assert_cleanup():
xs = stream.count(interval=1) | add_resource.pipe(1) | pipe.chunks(3)
await assert_run(xs[:1], [[0, 1, 2]])


@pytest.mark.asyncio
async def test_prefetch(assert_run, assert_cleanup):
# Test basic prefetching
with assert_cleanup():
xs = stream.range(5, interval=1)
ys = xs | pipe.prefetch(1) # Default buffer_size=1
await assert_run(ys, [0, 1, 2, 3, 4])

# Test with custom buffer size
with assert_cleanup():
xs = stream.range(5, interval=1)
ys = xs | pipe.prefetch(1)
await assert_run(ys, [0, 1, 2, 3, 4])

# Test with buffer_size=0
with assert_cleanup():
xs = stream.range(5, interval=1)
ys = xs | pipe.prefetch(0)
await assert_run(ys, [0, 1, 2, 3, 4])

# Test cleanup on early exit
with assert_cleanup():
xs = stream.range(100, interval=1)
ys = xs | pipe.prefetch(buffer_size=1) | pipe.take(3)
await assert_run(ys, [0, 1, 2])

# Test with empty stream
with assert_cleanup():
xs = stream.empty() | pipe.prefetch()
await assert_run(xs, [])

# Test with error propagation
with assert_cleanup():
xs = stream.throw(ValueError()) | pipe.prefetch()
await assert_run(xs, [], ValueError())


@pytest.mark.asyncio
async def test_prefetch_timing(assert_run, assert_cleanup):
async def slow_fetch(x: int) -> int:
await asyncio.sleep(0.1)
return x

async def slow_processor(x: int) -> int:
await asyncio.sleep(0.4)
return x

with assert_cleanup() as loop:
# Without prefetch (sequential):
xs = stream.range(3) | pipe.map(slow_fetch, task_limit=1)
ys = xs | pipe.map(slow_processor, task_limit=1) # Process time
await assert_run(ys, [0, 1, 2])
assert loop.steps == pytest.approx([0.1, 0.4, 0.1, 0.4, 0.1, 0.4])

with assert_cleanup() as loop:
# With prefetch:
xs = stream.range(3) | pipe.map(slow_fetch, task_limit=1)
ys = xs | pipe.prefetch(1) | pipe.map(slow_processor, task_limit=1)
await assert_run(ys, [0, 1, 2])
# instead of taking 0.1 + 0.4 seconds per element, we should now just take max(0.1, 0.4) = 0.4
assert loop.steps == pytest.approx([0.1, 0.1, 0.1, 0.2, 0.4, 0.4])
Loading