Skip to content

Commit

Permalink
Cache (#9)
Browse files Browse the repository at this point in the history
* Cache draft

* Added cache stats

* Added test notebook

* Added running statistics calculation for the in-memory cache

* Fixed some typing hints and edge cases

* Minor changes

* Removed test notebook

* Removed lost line

* Minor changes

* Fixed typo

* typo

* Added the possibility to set a cache for each batch to reduce the number of read/writes in case the main cache is on disk

* Added correct stats calculation for the disk cache

* Cache initial

* Nit

* small fix

---------

Co-authored-by: pau-mensa <pau.mopuig@gmail.com>
Co-authored-by: rui <rui@mixedbread.ai>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent 853d8bf commit 1a1797f
Show file tree
Hide file tree
Showing 11 changed files with 427 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ format: ## Run code autoformatters (ruff).
uv run ruff format .

lint: ## Run linters: ruff
uv run ruff check .
uv run ruff check . --fix

test: ## Run tests via pytest
uv run pytest
Expand Down
100 changes: 92 additions & 8 deletions batched/aio/batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import asyncio
import contextlib
from collections.abc import Mapping, Sized
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic

from batched.types import T, U
from batched.utils import batch_iter
from batched.types import AsyncCache, T, U
from batched.utils import batch_iter, batch_iter_by_length, first

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable


@dataclass(order=True)
Expand All @@ -30,6 +31,7 @@ class AsyncBatchItem(Generic[T, U]):
content: T = field(compare=False)
future: asyncio.Future[U] = field(compare=False)
priority: int = field(default=1, compare=True)
_len_fn: Callable[[T], int] = field(default=None, compare=False)

def set_result(self, result: U) -> None:
"""
Expand All @@ -55,6 +57,27 @@ def done(self) -> bool:
"""Check if the item's future is done."""
return self.future.done()

@staticmethod
def _get_len(content: T) -> int:
"""Get the length of the content."""
if content is None:
return 0

if isinstance(content, Sized):
return len(content)

if isinstance(content, Mapping):
value = first(content.values())
return AsyncBatchItem._get_len(value)

return 1

def __len__(self) -> int:
"""Get the length of the content."""
if self._len_fn is None:
self._len_fn = AsyncBatchItem._get_len
return self._len_fn(self.content)


class AsyncBatchGenerator(Generic[T, U]):
"""
Expand All @@ -64,9 +87,11 @@ class AsyncBatchGenerator(Generic[T, U]):
based on the specified batch size and timeout.
Attributes:
cache (AsyncCache[T, U] | None): An optional cache to use for storing results.
_queue (asyncio.PriorityQueue): An asyncio priority queue to store items.
_batch_size (int): The maximum size of each batch.
_timeout (float): The timeout in seconds between batch generation attempts.
_max_batch_length (int | None): Used to count the length of each item and stay within the limit.
Type Parameters:
T: The type of the content in the BatchItem.
Expand All @@ -76,18 +101,31 @@ class AsyncBatchGenerator(Generic[T, U]):
def __init__(
self,
batch_size: int = 32,
*,
timeout_ms: float = 5.0,
cache: AsyncCache[T, U] | None = None,
max_batch_length: int | None = None,
sort_by_priority: bool = False,
) -> None:
"""
Initialize the BatchGenerator.
Args:
batch_size (int): The maximum size of each batch. Defaults to 32.
timeout_ms (float): The timeout in milliseconds between batch generation attempts. Defaults to 5.0.
cache (AsyncCache[T, U] | None): An optional cache to use for storing results.
max_batch_length (int | None): Used to count the length of each item and stay within the limit.
sort_by_priority (bool): Whether to sort the queue by priority. Defaults to False.
"""
self._queue: asyncio.PriorityQueue[AsyncBatchItem[T, U]] = asyncio.PriorityQueue()
self.cache = cache

self._queue: asyncio.PriorityQueue[AsyncBatchItem[T, U]] = (
asyncio.PriorityQueue() if sort_by_priority else asyncio.Queue()
)
self._batch_size = batch_size
self._timeout = timeout_ms / 1000 # Convert to seconds
self._max_batch_length = max_batch_length
self._background_tasks = set()

def __len__(self) -> int:
"""
Expand All @@ -98,15 +136,51 @@ def __len__(self) -> int:
"""
return self._queue.qsize()

def _wrap_set_result(self, item: AsyncBatchItem[T, U]) -> Callable[[U], None]:
"""Wrap the set_result method to store results in the cache."""

def wrapper(original_set_result):
def wrapped_set_result(result: U) -> None:
original_set_result(result)

if self.cache is not None:
task = asyncio.create_task(self.cache.set(item.content, result))
task.add_done_callback(self._background_tasks.discard)
self._background_tasks.add(task)

return wrapped_set_result

return wrapper(item.set_result)

async def _check_cache(self, item: AsyncBatchItem[T, U]) -> AsyncBatchItem[T, U]:
"""Check if the item is in the cache and set the result if it is."""
if not self.cache:
return item

hit = await self.cache.get(item.content)
if hit is not None:
item.set_result(hit)
else:
item.set_result = self._wrap_set_result(item)

return item

async def extend(self, items: list[AsyncBatchItem[T, U]]) -> None:
"""
Add multiple items to the queue.
Args:
items (list[BatchItem[T, U]]): A list of items to add to the queue.
"""
for item in items:
await self._queue.put(item)
if self.cache is None:
for item in items:
await self._queue.put(item)
return

for item in asyncio.as_completed([self._check_cache(item) for item in items]):
result = await item
if not result.done():
await self._queue.put(result)

async def optimal_batches(self) -> AsyncGenerator[list[AsyncBatchItem[T, U]], None]:
"""
Expand All @@ -129,5 +203,15 @@ async def optimal_batches(self) -> AsyncGenerator[list[AsyncBatchItem[T, U]], No
n_batches = max(1, queue_size // self._batch_size)
size_batches = min(self._batch_size * n_batches, queue_size)
batch_items = [self._queue._get() for _ in range(size_batches)] # noqa: SLF001
for batch in batch_iter(batch_items, self._batch_size):
yield batch

if self._max_batch_length:
batch_items = batch_iter_by_length(
batch_items, max_batch_length=self._max_batch_length, batch_size=self._batch_size
)
else:
batch_items = batch_iter(batch_items, self._batch_size)

for batch in batch_items:
filtered = [item for item in batch if not item.done()]
if filtered:
yield filtered
91 changes: 70 additions & 21 deletions batched/aio/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import batched.utils as utils
from batched.aio.batch_generator import AsyncBatchGenerator, AsyncBatchItem
from batched.decorator import _dynamic_batch
from batched.types import BatchFunc, BatchProcessorStats, T, U, _validate_batch_output
from batched.types import AsyncCache, BatchFunc, BatchProcessorStats, PriorityStrategy, T, U, _validate_batch_output


class AsyncBatchProcessor(Generic[T, U]):
Expand All @@ -16,8 +16,9 @@ class AsyncBatchProcessor(Generic[T, U]):
Attributes:
batch_func (Callable[[list[T]], Awaitable[list[U]]]): The function to process batches.
batch_queue (BatchGenerator[T, U]): The generator for creating optimal batches.
batch_queue (AsyncBatchGenerator[T, U]): The generator for creating optimal batches.
small_batch_threshold (int): The threshold for considering a batch as small.
batch_item_cls (type[AsyncBatchItem[T, U]]): The class to use for batch items. Defaults to AsyncBatchItem.
_loop (asyncio.AbstractEventLoop): The event loop for asynchronous operations.
_task (Optional[asyncio.Task]): The task for processing batches.
_stats (BatchProcessorStats): Statistics about the batch processing.
Expand All @@ -29,27 +30,45 @@ class AsyncBatchProcessor(Generic[T, U]):

def __init__(
self,
_func: BatchFunc[T, U],
func: BatchFunc[T, U],
*,
batch_size: int = 32,
timeout_ms: float = 5.0,
small_batch_threshold: int = 8,
cache: AsyncCache[T, U] | None = None,
max_batch_length: int | None = None,
priority_strategy: PriorityStrategy = PriorityStrategy.NONE,
batch_item_cls: type[AsyncBatchItem[T, U]] = AsyncBatchItem[T, U],
):
"""
Initialize the BatchProcessor.
Initialize the AsyncBatchProcessor.
Args:
_func (BatchFunc[T, U]): The function to process batches.
batch_size (int): The maximum size of each batch. Defaults to 32.
timeout_ms (float): The timeout in milliseconds between batch generation attempts. Defaults to 5.0.
small_batch_threshold (int): The threshold to give priority to small batches. Defaults to 8.
cache (AsyncCache[T, U] | None): An optional cache for storing results. Defaults to None.
max_batch_length (int | None): The maximum length of a batch. Defaults to None.
priority_strategy (PriorityStrategy): The strategy to use for prioritizing items.
batch_item_cls (type[AsyncBatchItem[T, U]]): The class to use for batch items. Defaults to AsyncBatchItem.
You can use a custom subclass to add additional attributes to batch items.
"""
self.batch_func = utils.ensure_async(func)
self.batch_queue = AsyncBatchGenerator[T, U](
batch_size=batch_size,
timeout_ms=timeout_ms,
cache=cache,
max_batch_length=max_batch_length,
sort_by_priority=priority_strategy != PriorityStrategy.NONE,
)
self.small_batch_threshold = small_batch_threshold
self.batch_func = utils.ensure_async(_func)
self.batch_queue = AsyncBatchGenerator[T, U](batch_size, timeout_ms)
self.batch_item_cls = batch_item_cls
self.priority_strategy = priority_strategy

self._start_lock = asyncio.Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._task: Optional[asyncio.Task] = None
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task | None = None
self._stats = BatchProcessorStats()

async def _start(self) -> None:
Expand All @@ -63,16 +82,21 @@ async def _start(self) -> None:

def _determine_priority(self, items: list[T]) -> list[int]:
"""
Determine if items should be prioritized based on the batch size.
Determine the priority of items based on batch size and content length.
Args:
items (list[T]): The list of items to prioritize.
Returns:
list[bool]: A list of boolean values indicating the priority of each item.
list[int]: A list of integer values indicating the priority of each item.
"""
priority = 0 if len(items) <= self.small_batch_threshold else 1
return [priority] * len(items)
if self.priority_strategy == PriorityStrategy.NONE or len(items) <= self.small_batch_threshold:
return [0] * len(items)

if self.priority_strategy == PriorityStrategy.LENGTH:
return [len(item) for item in items]

return [1] * len(items)

async def _schedule(self, items: list[T]) -> list[U]:
"""
Expand All @@ -90,7 +114,7 @@ async def _schedule(self, items: list[T]) -> list[U]:
prioritized = self._determine_priority(items)

batch_items = [
AsyncBatchItem[T, U](
self.batch_item_cls(
content=item,
priority=prio,
future=self._loop.create_future(),
Expand Down Expand Up @@ -132,15 +156,16 @@ def stats(self) -> BatchProcessorStats:
Returns:
BatchProcessorStats: The current statistics.
"""
return self._stats.clone(queue_size=len(self.batch_queue))
return self._stats.clone(
queue_size=len(self.batch_queue),
cache_stats=self.batch_queue.cache.stats() if self.batch_queue.cache else None,
)

@overload
async def __call__(self, item: T) -> U:
...
async def __call__(self, item: T) -> U: ...

@overload
async def __call__(self, items: list[T]) -> list[U]:
...
async def __call__(self, items: list[T]) -> list[U]: ...

async def __call__(self, items: Union[T, list[T]]) -> Union[U, list[U]]:
"""
Expand All @@ -157,19 +182,28 @@ async def __call__(self, items: Union[T, list[T]]) -> Union[U, list[U]]:
return (await self._schedule([items]))[0]

def __del__(self):
if self._task is None or not self._loop.is_running():
if self._loop is None or self._loop.is_closed():
return
if self._task is None:
return

self._task.cancel()
self._loop = None

def clear_stats(self) -> None:
self._stats = BatchProcessorStats()


def dynamically(
func: Optional[BatchFunc[T, U]] = None,
/,
*,
batch_size: int = 32,
timeout_ms: float = 5.0,
small_batch_threshold: int = 8,
max_batch_length: int | None = None,
priority_strategy: PriorityStrategy = PriorityStrategy.NONE,
cache: AsyncCache[T, U] | None = None,
batch_item_cls: type[AsyncBatchItem[T, U]] = AsyncBatchItem[T, U],
) -> Callable:
"""
Dynamically batch inputs for processing using asyncio.
Expand All @@ -185,6 +219,12 @@ def dynamically(
batch_size (int): The maximum size of each batch. Defaults to 32.
timeout_ms (float): The timeout in milliseconds between batch generation attempts. Defaults to 5.0.
small_batch_threshold (int): The threshold for considering a batch as small. Defaults to 8.
Only used if sort_by_priority is True.
max_batch_length (int | None): The maximum length of a batch. Defaults to None.
priority_strategy (PriorityStrategy): The strategy to use for prioritizing items.
cache (AsyncCache[T, U] | None): An optional cache for storing results.
batch_item_cls (type[AsyncBatchItem[T, U]]): The class to use for batch items. Defaults to AsyncBatchItem[T, U].
You can use a custom subclass to add additional attributes to batch items.
Returns:
Callable: A decorator that creates an AsyncBatchProcessor for the given function.
Expand Down Expand Up @@ -213,6 +253,15 @@ def sync_process(items: list[str]) -> list[int]:
"""

def make_processor(_func: BatchFunc[T, U]) -> AsyncBatchProcessor[T, U]:
return AsyncBatchProcessor(_func, batch_size, timeout_ms, small_batch_threshold)
return AsyncBatchProcessor(
func=_func,
batch_size=batch_size,
timeout_ms=timeout_ms,
small_batch_threshold=small_batch_threshold,
max_batch_length=max_batch_length,
cache=cache,
priority_strategy=priority_strategy,
batch_item_cls=batch_item_cls,
)

return _dynamic_batch(make_processor, func)
Loading

0 comments on commit 1a1797f

Please sign in to comment.