diff --git a/experimental/caching/multiprocessing.py b/experimental/caching/multiprocessing.py index 44729013..584a0e80 100644 --- a/experimental/caching/multiprocessing.py +++ b/experimental/caching/multiprocessing.py @@ -1,5 +1,5 @@ from typing import Iterator -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + import torch.multiprocessing as mp from transformers import AutoTokenizer, AutoModel import torch @@ -88,15 +88,15 @@ def loop_forever(self): pass class TokenizePipeline(BoringPipeline): - def post_init(self, device: str): + def post_init(self): self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") - self.device = device + def working_function(self, item): assert isinstance(item, list) and all(isinstance(i, str) for i in item) try: with torch.inference_mode(): - return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt").to(self.device) + return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt") except Exception as ex: print(ex) return None @@ -109,7 +109,9 @@ def post_init(self, model_device: str): def working_function(self, item): with torch.inference_mode(): - return self.model(**item).last_hidden_state.shape + item = item.to(self.model.device) + output = self.model(**item).last_hidden_state + return output.detach().cpu().shape def main(): mp.set_start_method('spawn') diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index b01796a5..5bbb94a0 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -4,7 +4,6 @@ from infinity_emb.inference import ( BatchHandler, Device, - select_model, ) from infinity_emb.log_handler import logger from infinity_emb.primitives import EmbeddingReturnType, ModelCapabilites @@ -51,19 +50,18 @@ def __init__( self.running = False self._vector_disk_cache_path = vector_disk_cache_path self._model_name_or_path = model_name_or_path + self._model_name_or_pathengine = engine + self._model_warmup = model_warmup self._lengths_via_tokenize = lengths_via_tokenize + if isinstance(engine, str): - engine = InferenceEngine[engine] + self._engine_type = InferenceEngine[engine] + else: + self._engine_type = engine if isinstance(device, str): - device = Device[device] - - self._model, self._min_inference_t = select_model( - model_name_or_path=model_name_or_path, - batch_size=batch_size, - engine=engine, - model_warmup=model_warmup, - device=device, - ) + self.device = Device[device] + else: + self.device = device async def astart(self): """startup engine""" @@ -75,20 +73,23 @@ async def astart(self): ) self.running = True self._batch_handler = BatchHandler( + model_name_or_path=self._model_name_or_path, + engine=self._engine_type, max_batch_size=self.batch_size, - model=self._model, - batch_delay=self._min_inference_t / 2, + model_warmup=self._model_warmup, vector_disk_cache_path=self._vector_disk_cache_path, verbose=logger.level <= 10, lengths_via_tokenize=self._lengths_via_tokenize, + device=self.device, ) - await self._batch_handler.spawn() + await self._batch_handler.astart() async def astop(self): """stop engine""" - self._check_running() + self._assert_running() self.running = False - await self._batch_handler.shutdown() + await self._batch_handler.astop() + self._batch_handler = None async def __aenter__(self): await self.astart() @@ -97,16 +98,17 @@ async def __aexit__(self, *args): await self.astop() def overload_status(self): - self._check_running() + self._assert_running() return self._batch_handler.overload_status() def is_overloaded(self) -> bool: - self._check_running() + self._assert_running() return self._batch_handler.is_overloaded() @property def capabilities(self) -> Set[ModelCapabilites]: - return self._model.capabilities + self._assert_running() + return self._batch_handler.capabilities async def embed( self, sentences: List[str] @@ -125,7 +127,7 @@ async def embed( Usage: """ - self._check_running() + self._assert_running() embeddings, usage = await self._batch_handler.embed(sentences) return embeddings, usage @@ -139,7 +141,7 @@ async def rerank( docs (List[str]): docs to be reranked raw_scores (bool): return raw scores instead of sigmoid """ - self._check_running() + self._assert_running() scores, usage = await self._batch_handler.rerank( query=query, docs=docs, raw_scores=raw_scores ) @@ -156,12 +158,12 @@ async def classify( docs (List[str]): docs to be reranked raw_scores (bool): return raw scores instead of sigmoid """ - self._check_running() + self._assert_running() scores, usage = await self._batch_handler.classify(sentences=sentences) return scores, usage - def _check_running(self): + def _assert_running(self): if not self.running: raise ValueError( "didn't start `AsyncEmbeddingEngine` " diff --git a/libs/infinity_emb/infinity_emb/inference/__init__.py b/libs/infinity_emb/infinity_emb/inference/__init__.py index 7ff36e27..b0c89fae 100644 --- a/libs/infinity_emb/infinity_emb/inference/__init__.py +++ b/libs/infinity_emb/infinity_emb/inference/__init__.py @@ -1,5 +1,4 @@ from infinity_emb.inference.batch_handler import BatchHandler -from infinity_emb.inference.select_model import select_model from infinity_emb.primitives import ( Device, DeviceTypeHint, @@ -15,5 +14,4 @@ "Device", "DeviceTypeHint", "BatchHandler", - "select_model", ] diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 07240e2e..e4117998 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -4,19 +4,24 @@ import threading import time from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Manager, Process, set_start_method from queue import Queue from typing import Any, Dict, List, Sequence, Set, Tuple import numpy as np from infinity_emb.inference.caching_layer import Cache +from infinity_emb.inference.model_worker import ModelWorker from infinity_emb.inference.queue import ( CustomFIFOQueue, ResultKVStoreFuture, ) -from infinity_emb.inference.threading_asyncio import to_thread +from infinity_emb.inference.select_model import ( + get_engine_type_from_config, +) from infinity_emb.log_handler import logger from infinity_emb.primitives import ( + Device, EmbeddingInner, EmbeddingReturnType, EmbeddingSingle, @@ -27,48 +32,58 @@ PredictInner, PredictSingle, PrioritizedQueueItem, + QueueSignalMessages, ReRankInner, ReRankSingle, ) -from infinity_emb.transformer.abstract import BaseTransformer -from infinity_emb.transformer.utils import get_lengths_with_tokenize +from infinity_emb.transformer.utils import ( + CapableEngineType, + InferenceEngine, + get_lengths_with_tokenize, +) + +_mp_manager = None + + +def get_mp_manager(): + global _mp_manager + _mp_manager = Manager() + return _mp_manager class BatchHandler: def __init__( self, - model: BaseTransformer, + model_name_or_path: str, max_batch_size: int, + engine: InferenceEngine = InferenceEngine.torch, + model_warmup: bool = True, max_queue_wait: int = int(os.environ.get("INFINITY_QUEUE_SIZE", 32_000)), - batch_delay: float = 5e-3, vector_disk_cache_path: str = "", verbose=False, lengths_via_tokenize: bool = False, + device: Device = Device.auto, + num_workers: int = 1, + worker_method: str = "thread", ) -> None: """ performs batching around the model. - model: BaseTransformer, implements fn (core|pre|post)_encode max_batch_size: max batch size of the models - max_queue_wait: max items to queue in the batch, default 32_000 sentences - batch_delay: sleep in seconds, wait time for pre/post methods. - Best result: setting to 1/2 the minimal expected - time for core_encode method / "gpu inference". - Dont set it above 1x minimal expected time of interence. - Should not be 0 to not block Python's GIL. vector_disk_cache_path: path to cache vectors on disk. lengths_via_tokenize: if True, use the tokenizer to get the lengths else len() """ - self.model = model + self.model_name_or_path = model_name_or_path self.max_batch_size = max_batch_size + self.model_warmup = model_warmup self.max_queue_wait = max_queue_wait - self._lengths_via_tokenize = lengths_via_tokenize - self._verbose = verbose + self.lengths_via_tokenize = lengths_via_tokenize + self.device = device + self.num_workers = num_workers + self.worker_method = worker_method + self._shutdown = threading.Event() - self._feature_queue: Queue = Queue(6) - self._postprocess_queue: Queue = Queue(4) - self._batch_delay = float(max(1e-4, batch_delay)) - self._threadpool = ThreadPoolExecutor() + self._queue_prio = CustomFIFOQueue() cache = ( Cache( @@ -81,16 +96,52 @@ def __init__( self._result_store = ResultKVStoreFuture(cache) self._ready = False - self._last_inference = time.perf_counter() - - if batch_delay > 0.1: - logger.warn(f"high batch delay of {self._batch_delay}") - if max_batch_size > max_queue_wait * 10: - logger.warn( - f"queue_size={self.max_queue_wait} to small " - f"over batch_size={self.max_batch_size}." - " Consider increasing queue size" - ) + + self.capable_engine: CapableEngineType = get_engine_type_from_config( + model_name_or_path=model_name_or_path, + engine=engine, + ) + self.model_capabilities = self.capable_engine.value.capabilities + + if self.worker_method == "thread": + self._shared_queue_model_out: Queue = Queue() + self._shared_queue_model_in: Queue = Queue() + else: + set_start_method("spawn", force=True) + + mp_manager = get_mp_manager() + self._shared_queue_model_out = mp_manager.Queue() + self._shared_queue_model_in = mp_manager.Queue() + + self.worker_args = dict( + in_queue=self._shared_queue_model_in, + out_queue=self._shared_queue_model_out, + max_batch_size=self.max_batch_size, + model_name_or_path=self.model_name_or_path, + capable_engine=self.capable_engine, + model_warmup=self.model_warmup, + device=self.device, + ) + + def run_model_worker(self): + if self.worker_method == "thread": + with ThreadPoolExecutor(max_workers=self.num_workers) as pool: + tasks = [ + pool.submit(ModelWorker, **self.worker_args) + for _ in range(self.num_workers) + ] + for task in tasks: + task.result() + else: + tasks = [ + Process(target=ModelWorker, kwargs=self.worker_args) + for _ in range(self.num_workers) + ] + for task in tasks: + task.start() + for task in tasks: + task.join() + logger.info("ModelWorker: all workers finished.") async def embed( self, sentences: List[str] @@ -103,11 +154,10 @@ async def embed( Returns: EmbeddingReturnType: list of embedding as 1darray """ - if "embed" not in self.model.capabilities: + if "embed" not in self.model_capabilities: raise ModelNotDeployedError( "the loaded moded cannot fullyfill `embed`." - f"options are {self.model.capabilities} inherited " - f"from model_class={self.model.__class__}" + f"options are {self.model_capabilities}" ) input_sentences = [EmbeddingSingle(s) for s in sentences] @@ -127,11 +177,10 @@ async def rerank( List[float]: list of scores int: token usage """ - if "rerank" not in self.model.capabilities: + if "rerank" not in self.model_capabilities: raise ModelNotDeployedError( "the loaded moded cannot fullyfill `rerank`." - f"options are {self.model.capabilities} inherited " - f"from model_class={self.model.__class__}" + f"options are {self.model_capabilities}" ) rerankables = [ReRankSingle(query=query, document=doc) for doc in docs] scores, usage = await self._schedule(rerankables) @@ -154,11 +203,10 @@ async def classify( Returns: EmbeddingReturnType: embedding as 1darray """ - if "classify" not in self.model.capabilities: + if "classify" not in self.model_capabilities: raise ModelNotDeployedError( "the loaded moded cannot fullyfill `classify`." - f"options are {self.model.capabilities} inherited " - f"from model_class={self.model.__class__}" + f"options are {self.model_capabilities}" ) items = [PredictSingle(sentence=s) for s in sentences] classifications, usage = await self._schedule(items) @@ -190,6 +238,11 @@ async def _schedule( item=inner_item(content=re), # type: ignore ) new_prioqueue.append(item) + + await asyncio.gather( + *[self._result_store.register_item(item.item) for item in new_prioqueue] + ) + await self._queue_prio.extend(new_prioqueue) result = await asyncio.gather( @@ -199,7 +252,7 @@ async def _schedule( @property def capabilities(self) -> Set[ModelCapabilites]: - return self.model.capabilities + return self.model_capabilities def is_overloaded(self) -> bool: """checks if more items can be queued.""" @@ -227,20 +280,20 @@ async def _get_prios_usage( Returns: Tuple[List[int], int]: prios, length """ - if not self._lengths_via_tokenize: + if not self.lengths_via_tokenize: return get_lengths_with_tokenize([it.str_repr() for it in items]) else: - return await to_thread( + # TODO: fix lengths_via_tokenize + return await asyncio.to_thread( get_lengths_with_tokenize, - self._threadpool, _sentences=[it.str_repr() for it in items], - tokenize=self.model.tokenize_lengths, + # tokenize=self.model.tokenize_lengths, # TODO: fix ) def _preprocess_batch(self): """loops and checks if the _core_batch has worked on all items""" self._ready = True - logger.info("ready to batch requests.") + logger.info("Batch_handler: Ready to batch requests.") try: while not self._shutdown.is_set(): # patience: @@ -248,12 +301,12 @@ def _preprocess_batch(self): # - until GPU / _core_batch starts processing the previous item # - or if many items are queued anyhow, so that a good batch # may be popped already. - if not self._feature_queue.empty() and ( - self._feature_queue.full() + if not self._shared_queue_model_in.empty() and ( + self._shared_queue_model_in.full() or (len(self._queue_prio) < self.max_batch_size * 4) ): # add some stochastic delay - time.sleep(self._batch_delay) + time.sleep(2e-4) continue # decision to attemp to pop a batch # -> will happen if a single datapoint is available @@ -267,24 +320,24 @@ def _preprocess_batch(self): # optimal batch has been selected -> # lets tokenize it and move tensors to GPU. for batch in batches: - if self._feature_queue.qsize() > 2: + if self._shared_queue_model_in.qsize() > 2: # add some stochastic delay - time.sleep(self._batch_delay * 2) + time.sleep(2e-4) items_for_pre = [item.content.to_input() for item in batch] - feat = self.model.encode_pre(items_for_pre) - if self._verbose: - logger.debug( - "[📦] batched %s requests, queue remaining: %s", - len(items_for_pre), - len(self._queue_prio), - ) - if self._shutdown.is_set(): - break + + logger.debug( + "[📦] batched %s requests, queue remaining: %s", + len(items_for_pre), + self._shared_queue_model_in.qsize(), + ) + # while-loop just for shutdown while not self._shutdown.is_set(): try: - self._feature_queue.put((feat, batch), timeout=1) + self._shared_queue_model_in.put( + (items_for_pre, batch), timeout=1 + ) break except queue.Full: continue @@ -293,108 +346,64 @@ def _preprocess_batch(self): raise ValueError("_preprocess_batch crashed") self._ready = False - def _core_batch(self): - """waiting for preprocessed batches (on device) - and do the forward pass / `.encode` - """ + async def _queue_finalizer(self): try: - while not self._shutdown.is_set(): + while True: try: - core_batch = self._feature_queue.get(timeout=0.5) - except queue.Empty: - continue - (feat, batch) = core_batch - if self._verbose: - logger.debug("[🏃] Inference on batch_size=%s", len(batch)) - self._last_inference = time.perf_counter() - embed = self.model.encode_core(feat) - - # while-loop just for shutdown - while not self._shutdown.is_set(): - try: - self._postprocess_queue.put((embed, batch), timeout=1) - break - except queue.Full: - continue - self._feature_queue.task_done() - except Exception as ex: - logger.exception(ex) - raise ValueError("_core_batch crashed.") - - async def _postprocess_batch(self): - """collecting forward(.encode) results and put them into the result store""" - # TODO: the ugly asyncio.sleep() could add to 3-8ms of latency worst case - # In constrast, at full batch size, sleep releases cruical CPU at time of - # the forward pass to GPU (after which there is crical time again) - # and not affecting the latency - try: - while not self._shutdown.is_set(): - try: - post_batch = self._postprocess_queue.get_nowait() + batch = self._shared_queue_model_out.get_nowait() except queue.Empty: # instead use async await to get - try: - post_batch = await to_thread( - self._postprocess_queue.get, self._threadpool, timeout=1 - ) - except queue.Empty: - # in case of timeout start again - continue - - if ( - self._postprocess_queue.empty() - and self._last_inference - < time.perf_counter() + self._batch_delay * 2 - ): - # 5 ms, assuming this is below - # 3-50ms for inference on avg. - # give the CPU some time to focus - # on moving the next batch to GPU on the forward pass - # before proceeding - await asyncio.sleep(self._batch_delay) - embed, batch = post_batch - results = self.model.encode_post(embed) - for i, item in enumerate(batch): - item.set_result(results[i]) - await self._result_store.mark_item_ready(item) + batch = await asyncio.to_thread(self._shared_queue_model_out.get) + if batch == QueueSignalMessages.KILL: + break - self._postprocess_queue.task_done() + for item in batch[1]: + await self._result_store.mark_item_ready(item) except Exception as ex: logger.exception(ex) - raise ValueError("Postprocessor crashed") + raise ValueError("_queue_finalizer crashed") async def _delayed_warmup(self): """in case there is no warmup -> perform some warmup.""" await asyncio.sleep(5) if not self._shutdown.is_set(): - logger.debug("Sending a warm up through embedding.") try: - if "embed" in self.model.capabilities: + if "embed" in self.model_capabilities: + logger.debug("Sending a warm up to `embed`.") await self.embed(sentences=["test"] * self.max_batch_size) - if "rerank" in self.model.capabilities: + if "rerank" in self.model_capabilities: + logger.debug("Sending a warm up to `rerank`.") await self.rerank( query="query", docs=["test"] * self.max_batch_size ) - if "classify" in self.model.capabilities: + if "classify" in self.model_capabilities: + logger.debug("Sending a warm up to `classify`.") await self.classify(sentences=["test"] * self.max_batch_size) - except Exception: - pass + except Exception as ex: + logger.exception(ex) - async def spawn(self): + async def astart(self): """set up the resources in batch""" if self._ready: raise ValueError("previous threads are still running.") logger.info("creating batching engine") - self._threadpool.submit(self._preprocess_batch) - self._threadpool.submit(self._core_batch) - asyncio.create_task(self._postprocess_batch()) + self.model_worker = asyncio.create_task( + asyncio.to_thread(self.run_model_worker) + ) + asyncio.create_task(self._queue_finalizer()) + asyncio.create_task(asyncio.to_thread(self._preprocess_batch)) asyncio.create_task(self._delayed_warmup()) - async def shutdown(self): + async def astop(self): """ set the shutdown event and close threadpool. Blocking event, until shutdown. """ + logger.debug("batch handler -> start astop") self._shutdown.set() - with ThreadPoolExecutor() as tp_temp: - await to_thread(self._threadpool.shutdown, tp_temp) + + self._shared_queue_model_out.put(QueueSignalMessages.KILL) + # at last, kill the models, which kill the MP Manager. + self._shared_queue_model_in.put(QueueSignalMessages.KILL) + time.sleep(1) + logger.debug("batch handler <- done astop") diff --git a/libs/infinity_emb/infinity_emb/inference/caching_layer.py b/libs/infinity_emb/infinity_emb/inference/caching_layer.py index 79501d77..9c0fa696 100644 --- a/libs/infinity_emb/infinity_emb/inference/caching_layer.py +++ b/libs/infinity_emb/infinity_emb/inference/caching_layer.py @@ -2,10 +2,8 @@ import os import queue import threading -from concurrent.futures import ThreadPoolExecutor from typing import Any, List, Union -from infinity_emb.inference.threading_asyncio import to_thread from infinity_emb.log_handler import logger from infinity_emb.primitives import EmbeddingReturnType, QueueItemInner from infinity_emb.transformer.utils import infinity_cache_dir @@ -36,18 +34,17 @@ def __init__(self, cache_name: str, shutdown: threading.Event) -> None: logger.info(f"caching vectors under: {dir}") self._cache = dc.Cache(dir, size_limit=2**28) self.is_running = False - self.startup() - def startup(self): + async def _verify_running(self): if not self.is_running: - self._threadpool = ThreadPoolExecutor() - self._threadpool.submit(self._consume_queue) + asyncio.create_task(asyncio.to_thread(self._consume_queue)) @staticmethod def _hash(key: Union[str, Any]) -> str: return str(key) def _consume_queue(self) -> None: + self.is_running = True while not self._shutdown.is_set(): try: item = self._add_q.get(timeout=1) @@ -57,7 +54,7 @@ def _consume_queue(self) -> None: k, v = item self._cache.add(key=self._hash(k), value=v, expire=86400) self._add_q.task_done() - self._threadpool.shutdown(wait=True) + self.is_running = False def _get(self, sentence: str) -> Union[None, EmbeddingReturnType, List[float]]: return self._cache.get(key=self._hash(sentence)) @@ -66,8 +63,9 @@ async def aget(self, item: QueueItemInner, future: asyncio.Future) -> None: """Sets result to item and future, if in cache. If not in cache, sets future to be done when result is set. """ + await self._verify_running() item_as_str = item.content.str_repr() - result = await to_thread(self._get, self._threadpool, item_as_str) + result = await asyncio.to_thread(self._get, item_as_str) if result is not None: # update item with cached result if item.get_result() is None: diff --git a/libs/infinity_emb/infinity_emb/inference/model_worker.py b/libs/infinity_emb/infinity_emb/inference/model_worker.py new file mode 100644 index 00000000..7e0142f2 --- /dev/null +++ b/libs/infinity_emb/infinity_emb/inference/model_worker.py @@ -0,0 +1,154 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Queue as MpQueue +from queue import Empty, Full, Queue +from typing import Any, Union + +from infinity_emb.inference.select_model import ( + select_model, +) +from infinity_emb.log_handler import logger +from infinity_emb.primitives import Device, QueueSignalMessages +from infinity_emb.transformer.utils import ( + CapableEngineType, +) + +AnyQueue = Union[Queue, MpQueue] + + +class ModelWorker: + def __init__( + self, + in_queue: AnyQueue, + out_queue: AnyQueue, + model_name_or_path: str, + max_batch_size: int, + capable_engine: CapableEngineType, + model_warmup: bool, + device: Device, + ) -> None: + self.model, batch_delay = select_model( + model_name_or_path=model_name_or_path, + batch_size=max_batch_size, + capable_engine=capable_engine, + model_warmup=model_warmup, + device=device, + ) + self._batch_delay = float(max(1e-4, batch_delay)) + + if batch_delay > 0.1: + logger.warn(f"high batch delay of {self._batch_delay}") + self.max_batch_size = max_batch_size + + self._shared_in_queue = in_queue + self._feature_queue: Queue = Queue(6) + self._postprocess_queue: Queue = Queue(4) + self._shared_out_queue = out_queue + + self._verbose = logger.level <= 5 + self._last_inference = time.perf_counter() + + # spawn threads + logger.info("creating ModelWorker") + + with ThreadPoolExecutor(max_workers=3) as pool: + tasks = [ + pool.submit( + self._general_batch, + "preprocess", + self._shared_in_queue, + self._feature_queue, + self.model.encode_pre, + ), + pool.submit( + self._general_batch, + "forward", + self._feature_queue, + self._postprocess_queue, + self.model.encode_core, + is_model_fn=True, + ), + pool.submit( + self._general_batch, + "postprocess", + self._postprocess_queue, + self._shared_out_queue, + self.model.encode_post, + is_post_fn=True, + ), + ] + # block until all tasks are done + for future in tasks: + future.result() + + logger.info("stopped ModelWorker") + + def _general_batch( + self, + alias_name: str, + in_queue: AnyQueue, + out_queue: AnyQueue, + batch_fn: Any, + is_model_fn: bool = False, + is_post_fn: bool = False, + ): + try: + while True: + fetched_batch = in_queue.get() + if type(in_queue) == Queue: + in_queue.task_done() + if fetched_batch == QueueSignalMessages.KILL: + self.destruct() + break + + feat, meta = fetched_batch + + if self._verbose: + logger.debug("[🏃] %s on batch_size=%s", alias_name, len(meta)) + + if is_model_fn: + self._last_inference = time.perf_counter() + elif self._last_inference < time.perf_counter() + self._batch_delay: + # 5 ms, assuming this is below + # 3-50ms for inference on avg. + # give the CPU some time to focus + # on moving the next batch to GPU on the forward pass + # before proceeding + time.sleep(self._batch_delay) + + processed = batch_fn(feat) + + if is_post_fn: + for i, item in enumerate(meta): + item.set_result(processed[i]) + processed = None + + out_queue.put((processed, meta)) + + except Exception as ex: + logger.exception(ex) + self.destruct() + raise ValueError(f"{alias_name} crashed.") + + def destruct(self): + """kill all tasks""" + for q in [ + self._shared_in_queue, + self._postprocess_queue, + self._feature_queue, + self._shared_out_queue, + ]: + while not q.empty(): + try: + res = q.get_nowait() + if res == QueueSignalMessages.KILL: + break + if type(q) == Queue: + q.task_done() + except Empty: + pass + for _ in range(10): + try: + q.put_nowait(QueueSignalMessages.KILL) + except Full: + pass diff --git a/libs/infinity_emb/infinity_emb/inference/queue.py b/libs/infinity_emb/infinity_emb/inference/queue.py index c5ccd6c6..8edb1bd3 100644 --- a/libs/infinity_emb/infinity_emb/inference/queue.py +++ b/libs/infinity_emb/infinity_emb/inference/queue.py @@ -94,22 +94,29 @@ def loop(self) -> asyncio.AbstractEventLoop: def __len__(self): return len(self._kv) + async def register_item(self, item: QueueItemInner) -> None: + """wait for future to return""" + uuid = item.get_id() + self._kv[uuid] = self.loop.create_future() + return None + async def wait_for_response(self, item: QueueItemInner) -> EmbeddingReturnType: """wait for future to return""" uuid = item.get_id() - fut = self.loop.create_future() - self._kv[uuid] = fut if self._cache: - asyncio.create_task(self._cache.aget(item, fut)) - await fut + asyncio.create_task(self._cache.aget(item, self._kv[uuid])) + item = await self._kv[uuid] + return item.get_result() async def mark_item_ready(self, item: QueueItemInner) -> None: """mark item as ready. Item.get_result() must be set before calling this""" uuid = item.get_id() + # faster than .pop fut = self._kv[uuid] + del self._kv[uuid] try: - fut.set_result(None) + fut.set_result(item) + # logger.debug(f"marked {uuid} as ready with {len(item.get_result())}") except asyncio.InvalidStateError: pass - del self._kv[uuid] diff --git a/libs/infinity_emb/infinity_emb/inference/select_model.py b/libs/infinity_emb/infinity_emb/inference/select_model.py index d76653ff..07d5b540 100644 --- a/libs/infinity_emb/infinity_emb/inference/select_model.py +++ b/libs/infinity_emb/infinity_emb/inference/select_model.py @@ -8,6 +8,7 @@ ) from infinity_emb.transformer.abstract import BaseCrossEncoder, BaseEmbedder from infinity_emb.transformer.utils import ( + CapableEngineType, EmbedderEngine, InferenceEngine, PredictEngine, @@ -17,7 +18,7 @@ def get_engine_type_from_config( model_name_or_path: str, engine: InferenceEngine -) -> Union[EmbedderEngine, RerankEngine]: +) -> CapableEngineType: if engine in [InferenceEngine.debugengine, InferenceEngine.fastembed]: return EmbedderEngine.from_inference_engine(engine) @@ -50,18 +51,15 @@ def get_engine_type_from_config( def select_model( model_name_or_path: str, batch_size: int, - engine: InferenceEngine = InferenceEngine.torch, + capable_engine: CapableEngineType, model_warmup=True, device: Device = Device.auto, ) -> Tuple[Union[BaseCrossEncoder, BaseEmbedder], float]: logger.info( - f"model=`{model_name_or_path}` selected, using engine=`{engine.value}`" + f"model=`{model_name_or_path}` selected, using engine=`{capable_engine.value}`" f" and device=`{device.value}`" ) - # TODO: add EncoderEngine - unloaded_engine = get_engine_type_from_config(model_name_or_path, engine) - - loaded_engine = unloaded_engine.value(model_name_or_path, device=device.value) + loaded_engine = capable_engine.value(model_name_or_path, device=device.value) min_inference_t = 4e-3 if model_warmup: diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index b5229b39..a490ed67 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -85,7 +85,7 @@ async def _startup(): ) @app.on_event("shutdown") - async def _shutdown(): + async def _astop(): await app.model.astop() @app.get("/ready") diff --git a/libs/infinity_emb/infinity_emb/primitives.py b/libs/infinity_emb/infinity_emb/primitives.py index 6f22cfd5..15a013cb 100644 --- a/libs/infinity_emb/infinity_emb/primitives.py +++ b/libs/infinity_emb/infinity_emb/primitives.py @@ -10,6 +10,11 @@ EmbeddingReturnType = npt.NDArray[Union[np.float32, np.float32]] +class QueueSignalMessages(enum.Enum): + KILL = 1 + WAKE_ON_NOTIFY = 2 + + class Device(enum.Enum): cpu = "cpu" cuda = "cuda" diff --git a/libs/infinity_emb/infinity_emb/transformer/utils.py b/libs/infinity_emb/infinity_emb/transformer/utils.py index 962d7c8d..fb1acee9 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils.py @@ -1,7 +1,7 @@ import os from enum import Enum from pathlib import Path -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, Union from infinity_emb.transformer.classifier.torch import SentenceClassifier from infinity_emb.transformer.crossencoder.torch import ( @@ -75,6 +75,7 @@ def from_inference_engine(engine: InferenceEngine): _types: Dict[str, str] = {e.name: e.name for e in InferenceEngine} InferenceEngineTypeHint = Enum("InferenceEngineTypeHint", _types) # type: ignore +CapableEngineType = Union[EmbedderEngine, RerankEngine, PredictEngine] def length_tokenizer( diff --git a/libs/infinity_emb/poetry.lock b/libs/infinity_emb/poetry.lock index e9598f9d..6790f4a5 100644 --- a/libs/infinity_emb/poetry.lock +++ b/libs/infinity_emb/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2360,6 +2360,20 @@ pytest = ">=5.0" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] +[[package]] +name = "pytest-timeout" +version = "2.2.0" +description = "pytest plugin to abort hanging tests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-timeout-2.2.0.tar.gz", hash = "sha256:3b0b95dabf3cb50bac9ef5ca912fa0cfc286526af17afc806824df20c2f72c90"}, + {file = "pytest_timeout-2.2.0-py3-none-any.whl", hash = "sha256:bde531e096466f49398a59f2dde76fa78429a09a12411466f88a07213e220de2"}, +] + +[package.dependencies] +pytest = ">=5.0.0" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3971,4 +3985,4 @@ torch = ["sentence-transformers", "torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "52bf5a4c6ccda2b33e08cd9107bf8dbd6d0973f5ed92f1d1070f924fef6c8264" +content-hash = "864fc8c7a667c0cdfe819415607c5ba6bf9f1d263a9b6e4fe775a4807fb95721" diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 46340ea1..cd7f925e 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -36,6 +36,7 @@ infinity_emb = "infinity_emb.infinity_server:cli" [tool.poetry.group.test.dependencies] pytest = "^7.0.0" pytest-mock = "*" +pytest-timeout = "*" httpx = "*" asgi_lifespan = "*" anyio = "*" @@ -70,6 +71,8 @@ all=["ctranslate2", "fastapi", "fastembed", "optimum", "orjson", "prometheus-fas markers = [ "performance: tests that measure performance (deselect with '-m \"not performance\"')", ] +[tool.pytest] +timeout = 300 [build-system] requires = ["poetry-core"] diff --git a/libs/infinity_emb/tests/script_live.py b/libs/infinity_emb/tests/script_live.py index fba78a53..8c7b9d53 100644 --- a/libs/infinity_emb/tests/script_live.py +++ b/libs/infinity_emb/tests/script_live.py @@ -45,25 +45,22 @@ def remote(json_data: bytes, iters=1): cosine_sim = np.dot(r, e) / (np.linalg.norm(e) * np.linalg.norm(r)) assert cosine_sim > 0.99 print("Both methods provide the identical output.") - - print("Measuring latency via SentenceTransformers") - latency_st = timeit.timeit("local(sample, iters=3)", number=1, globals=locals()) - print("SentenceTransformers latency: ", latency_st) - model = None - print("Measuring latency via requests") latency_request = timeit.timeit( - "remote(json_d, iters=3)", number=1, globals=locals() + "remote(json_d, iters=5)", number=3, globals=locals() ) - print(f"Request latency: {latency_request}") + print(f"Infinity request latency: {latency_request}") - assert latency_st * 1.1 > latency_request + print("Measuring latency via SentenceTransformers") + latency_st = timeit.timeit("local(sample, iters=5)", number=3, globals=locals()) + print(f"SentenceTransformers latency: {latency_st}") def latency_single(): session = requests.Session() def _post(i): + time.sleep(0.05) json_d = json.dumps({"input": [str(i)], "model": "model"}) s = time.perf_counter() res = session.post(f"{LIVE_URL}/embeddings", data=json_d) @@ -77,4 +74,4 @@ def _post(i): if __name__ == "__main__": - embedding_live_performance() + latency_single() diff --git a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py index f49385e2..63b5fe4a 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py @@ -1,5 +1,6 @@ import asyncio import copy +import logging import random import time from typing import Tuple @@ -12,9 +13,10 @@ from infinity_emb.transformer.embedder.sentence_transformer import ( SentenceTransformerPatched, ) +from infinity_emb.transformer.utils import InferenceEngine BATCH_SIZE = 32 -N_TIMINGS = 3 +N_TIMINGS = 1 LIMIT_SLOWDOWN = 1.25 if torch.cuda.is_available() else 1.35 @@ -23,15 +25,21 @@ async def load_patched_bh() -> Tuple[SentenceTransformerPatched, BatchHandler]: model = SentenceTransformerPatched(pytest.DEFAULT_BERT_MODEL) model.encode(["hello " * 512] * BATCH_SIZE) - bh = BatchHandler(model=model, max_batch_size=BATCH_SIZE) - await bh.spawn() + bh = BatchHandler( + model_name_or_path=str(pytest.DEFAULT_BERT_MODEL), + engine=InferenceEngine.torch, + max_batch_size=BATCH_SIZE, + ) + await bh.astart() return model, bh +@pytest.mark.timeout(300) @pytest.mark.performance @pytest.mark.anyio async def test_batch_performance_raw(get_sts_bechmark_dataset, load_patched_bh): model, bh = load_patched_bh + assert bh.capabilities == {"embed"} try: sentences = [] @@ -53,6 +61,7 @@ async def method_batch_handler(_sentences): ] _ = await asyncio.gather(*tasks) end = time.perf_counter() + logging.info(f"batch_handler: {end - start}") return round(end - start, 4) def method_patched(_sentences): @@ -67,6 +76,7 @@ def method_patched(_sentences): emb.append(model.encode_post(embed)) np.concatenate(emb).tolist() end = time.perf_counter() + logging.info(f"method_patched: {end - start}") return round(end - start, 4) def method_st(_sentences): @@ -74,24 +84,26 @@ def method_st(_sentences): start = time.perf_counter() _ = model.encode(_sentences, batch_size=BATCH_SIZE).tolist() end = time.perf_counter() + logging.info(f"method_st: {end - start}") return round(end - start, 4) # yappi.get_func_stats().print_all() # yappi.stop() method_st(sentences[::10]) await method_batch_handler(sentences[::10]) - time.sleep(2) + time.sleep(0.5) time_batch_handler = np.median( [(await method_batch_handler(sentences)) for _ in range(N_TIMINGS)] ) - time.sleep(2) + return await bh.astop() + time.sleep(0.5) time_st = np.median([method_st(sentences) for _ in range(N_TIMINGS)]) - time.sleep(2) + time.sleep(0.5) time_st_patched = np.median( [method_patched(sentences) for _ in range(N_TIMINGS)] ) - print( + logging.info( f"times are sentence-transformers: {time_st}," " patched-sentence-transformers: " f" {time_st_patched}, batch-handler: {time_batch_handler}" @@ -114,4 +126,4 @@ def method_st(_sentences): ) finally: - await bh.shutdown() + await bh.astop() diff --git a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py index c0bb008b..b00b0c12 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py @@ -4,21 +4,19 @@ import numpy as np import pytest -from infinity_emb.inference import caching_layer +from infinity_emb.inference.caching_layer import Cache from infinity_emb.primitives import EmbeddingInner, EmbeddingSingle +@pytest.mark.timeout(20) @pytest.mark.anyio async def test_cache(): - global INFINITY_CACHE_VECTORS - loop = asyncio.get_event_loop() shutdown = threading.Event() try: - INFINITY_CACHE_VECTORS = True sentence = "dummy" embedding = np.random.random(5).tolist() - c = caching_layer.Cache( + c = Cache( cache_name=f"pytest_{hash((sentence, tuple(embedding)))}", shutdown=shutdown ) @@ -37,5 +35,4 @@ async def test_cache(): assert result is not None np.testing.assert_array_equal(result, embedding) finally: - INFINITY_CACHE_VECTORS = False shutdown.set() diff --git a/libs/infinity_emb/tests/unit_test/test_engine.py b/libs/infinity_emb/tests/unit_test/test_engine.py index ebc28927..2bcea96e 100644 --- a/libs/infinity_emb/tests/unit_test/test_engine.py +++ b/libs/infinity_emb/tests/unit_test/test_engine.py @@ -28,8 +28,9 @@ async def test_async_api_torch(): engine=transformer.InferenceEngine.torch, device="auto", ) - assert engine.capabilities == {"embed"} + async with engine: + assert engine.capabilities == {"embed"} embeddings, usage = await engine.embed(sentences) assert isinstance(embeddings, list) assert isinstance(embeddings[0], np.ndarray) @@ -60,9 +61,8 @@ async def test_async_api_torch_CROSSENCODER(): model_warmup=True, ) - assert engine.capabilities == {"rerank"} - async with engine: + assert engine.capabilities == {"rerank"} rankings, usage = await engine.rerank(query=query, docs=documents) assert usage == sum([len(query) + len(d) for d in documents]) @@ -104,9 +104,9 @@ async def test_async_api_torch_CLASSIFY(): engine="torch", model_warmup=True, ) - assert engine.capabilities == {"classify"} async with engine: + assert engine.capabilities == {"classify"} predictions, usage = await engine.classify(sentences=sentences) assert usage == sum([len(s) for s in sentences]) assert len(predictions) == len(sentences)