From d34129cf08fc73d856d588e6774057222d785cec Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sun, 7 Jan 2024 15:07:22 +0100 Subject: [PATCH 1/9] add poc --- experimental/caching/multiprocessing.py | 12 +- .../infinity_emb/inference/batch_handler.py | 330 +++++++++++------- .../infinity_emb/inference/caching_layer.py | 14 +- libs/infinity_emb/tests/script_live.py | 9 +- 4 files changed, 212 insertions(+), 153 deletions(-) diff --git a/experimental/caching/multiprocessing.py b/experimental/caching/multiprocessing.py index 44729013..584a0e80 100644 --- a/experimental/caching/multiprocessing.py +++ b/experimental/caching/multiprocessing.py @@ -1,5 +1,5 @@ from typing import Iterator -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + import torch.multiprocessing as mp from transformers import AutoTokenizer, AutoModel import torch @@ -88,15 +88,15 @@ def loop_forever(self): pass class TokenizePipeline(BoringPipeline): - def post_init(self, device: str): + def post_init(self): self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") - self.device = device + def working_function(self, item): assert isinstance(item, list) and all(isinstance(i, str) for i in item) try: with torch.inference_mode(): - return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt").to(self.device) + return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt") except Exception as ex: print(ex) return None @@ -109,7 +109,9 @@ def post_init(self, model_device: str): def working_function(self, item): with torch.inference_mode(): - return self.model(**item).last_hidden_state.shape + item = item.to(self.model.device) + output = self.model(**item).last_hidden_state + return output.detach().cpu().shape def main(): mp.set_start_method('spawn') diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 07240e2e..685e99c4 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -3,7 +3,6 @@ import queue import threading import time -from concurrent.futures import ThreadPoolExecutor from queue import Queue from typing import Any, Dict, List, Sequence, Set, Tuple @@ -14,7 +13,6 @@ CustomFIFOQueue, ResultKVStoreFuture, ) -from infinity_emb.inference.threading_asyncio import to_thread from infinity_emb.log_handler import logger from infinity_emb.primitives import ( EmbeddingInner, @@ -34,6 +32,175 @@ from infinity_emb.transformer.utils import get_lengths_with_tokenize +class ModelWorker: + def __init__( + self, + in_queue: CustomFIFOQueue, + out_queue: Queue, + shutdown_event: threading.Event, + batch_delay: float, + model: BaseTransformer, + max_batch_size: int, + verbose: bool = False, + ) -> None: + self.model = model + self._shutdown = shutdown_event + self.max_batch_size = max_batch_size + + self._in_queue = in_queue + self._feature_queue: Queue = Queue(6) + self._postprocess_queue: Queue = Queue(4) + self._out_queue = out_queue + + self._batch_delay = batch_delay + self._verbose = verbose + self._last_inference = time.perf_counter() + + def _preprocess_batch(self): + """loops and checks if the _core_batch has worked on all items""" + self._ready = True + logger.info("ready to batch requests.") + try: + while not self._shutdown.is_set(): + # patience: + # do not pop a batch if self._feature_queue still has an item left + # - until GPU / _core_batch starts processing the previous item + # - or if many items are queued anyhow, so that a good batch + # may be popped already. + if not self._feature_queue.empty() and ( + self._feature_queue.full() + or (len(self._in_queue) < self.max_batch_size * 4) + ): + # add some stochastic delay + time.sleep(self._batch_delay) + continue + # decision to attemp to pop a batch + # -> will happen if a single datapoint is available + + batches = self._in_queue.pop_optimal_batches( + self.max_batch_size, latest_first=False + ) + if not batches: + # not a single sentence available / len=0, wait for more + continue + # optimal batch has been selected -> + # lets tokenize it and move tensors to GPU. + for batch in batches: + if self._feature_queue.qsize() > 2: + # add some stochastic delay + time.sleep(self._batch_delay * 2) + + items_for_pre = [item.content.to_input() for item in batch] + feat = self.model.encode_pre(items_for_pre) + if self._verbose: + logger.debug( + "[📦] batched %s requests, queue remaining: %s", + len(items_for_pre), + len(self._in_queue), + ) + if self._shutdown.is_set(): + break + # while-loop just for shutdown + while not self._shutdown.is_set(): + try: + self._feature_queue.put((feat, batch), timeout=1) + break + except queue.Full: + continue + except Exception as ex: + logger.exception(ex) + raise ValueError("_preprocess_batch crashed") + self._ready = False + + def _core_batch(self): + """waiting for preprocessed batches (on device) + and do the forward pass / `.encode` + """ + try: + while not self._shutdown.is_set(): + try: + core_batch = self._feature_queue.get(timeout=0.5) + except queue.Empty: + continue + (feat, batch) = core_batch + if self._verbose: + logger.debug("[🏃] Inference on batch_size=%s", len(batch)) + self._last_inference = time.perf_counter() + embed = self.model.encode_core(feat) + + # while-loop just for shutdown + while not self._shutdown.is_set(): + try: + self._postprocess_queue.put((embed, batch), timeout=1) + break + except queue.Full: + continue + self._feature_queue.task_done() + except Exception as ex: + logger.exception(ex) + raise ValueError("_core_batch crashed.") + + def _postprocess_batch(self): + """collecting forward(.encode) results and put them into the result store""" + # TODO: the ugly asyncio.sleep() could add to 3-8ms of latency worst case + # In constrast, at full batch size, sleep releases cruical CPU at time of + # the forward pass to GPU (after which there is crical time again) + # and not affecting the latency + try: + while not self._shutdown.is_set(): + try: + post_batch = self._postprocess_queue.get(timeout=0.5) + except queue.Empty: + # instead use async await to get + continue + + if ( + self._postprocess_queue.empty() + and self._last_inference + < time.perf_counter() + self._batch_delay * 2 + ): + # 5 ms, assuming this is below + # 3-50ms for inference on avg. + # give the CPU some time to focus + # on moving the next batch to GPU on the forward pass + # before proceeding + time.sleep(self._batch_delay) + embed, batch = post_batch + results = self.model.encode_post(embed) + for i, item in enumerate(batch): + item.set_result(results[i]) + + while not self._shutdown.is_set(): + try: + self._out_queue.put(batch, timeout=1) + break + except queue.Full: + continue + + self._postprocess_queue.task_done() + except Exception as ex: + logger.exception(ex) + raise ValueError("Postprocessor crashed") + + async def spawn(self): + """set up the resources in batch""" + logger.info("creating batching engine") + self.tasks = [ + asyncio.create_task(asyncio.to_thread(self._preprocess_batch)), + asyncio.create_task(asyncio.to_thread(self._core_batch)), + asyncio.create_task(asyncio.to_thread(self._postprocess_batch)), + ] + + async def shutdown(self): + """ + set the shutdown event. + Blocking event, until shutdown. + """ + self._shutdown.set() + + await asyncio.gather(*self.tasks) + + class BatchHandler: def __init__( self, @@ -63,12 +230,11 @@ def __init__( self.max_batch_size = max_batch_size self.max_queue_wait = max_queue_wait self._lengths_via_tokenize = lengths_via_tokenize - self._verbose = verbose + self._shutdown = threading.Event() - self._feature_queue: Queue = Queue(6) - self._postprocess_queue: Queue = Queue(4) + + self._final_queue: Queue = Queue() self._batch_delay = float(max(1e-4, batch_delay)) - self._threadpool = ThreadPoolExecutor() self._queue_prio = CustomFIFOQueue() cache = ( Cache( @@ -81,7 +247,6 @@ def __init__( self._result_store = ResultKVStoreFuture(cache) self._ready = False - self._last_inference = time.perf_counter() if batch_delay > 0.1: logger.warn(f"high batch delay of {self._batch_delay}") @@ -91,6 +256,15 @@ def __init__( f"over batch_size={self.max_batch_size}." " Consider increasing queue size" ) + self.model_worker = ModelWorker( + in_queue=self._queue_prio, + out_queue=self._final_queue, + shutdown_event=self._shutdown, + batch_delay=self._batch_delay, + model=self.model, + verbose=verbose, + max_batch_size=self.max_batch_size, + ) async def embed( self, sentences: List[str] @@ -230,138 +404,25 @@ async def _get_prios_usage( if not self._lengths_via_tokenize: return get_lengths_with_tokenize([it.str_repr() for it in items]) else: - return await to_thread( + return await asyncio.to_thread( get_lengths_with_tokenize, - self._threadpool, _sentences=[it.str_repr() for it in items], tokenize=self.model.tokenize_lengths, ) - def _preprocess_batch(self): - """loops and checks if the _core_batch has worked on all items""" - self._ready = True - logger.info("ready to batch requests.") - try: - while not self._shutdown.is_set(): - # patience: - # do not pop a batch if self._feature_queue still has an item left - # - until GPU / _core_batch starts processing the previous item - # - or if many items are queued anyhow, so that a good batch - # may be popped already. - if not self._feature_queue.empty() and ( - self._feature_queue.full() - or (len(self._queue_prio) < self.max_batch_size * 4) - ): - # add some stochastic delay - time.sleep(self._batch_delay) - continue - # decision to attemp to pop a batch - # -> will happen if a single datapoint is available - - batches = self._queue_prio.pop_optimal_batches( - self.max_batch_size, latest_first=False - ) - if not batches: - # not a single sentence available / len=0, wait for more - continue - # optimal batch has been selected -> - # lets tokenize it and move tensors to GPU. - for batch in batches: - if self._feature_queue.qsize() > 2: - # add some stochastic delay - time.sleep(self._batch_delay * 2) - - items_for_pre = [item.content.to_input() for item in batch] - feat = self.model.encode_pre(items_for_pre) - if self._verbose: - logger.debug( - "[📦] batched %s requests, queue remaining: %s", - len(items_for_pre), - len(self._queue_prio), - ) - if self._shutdown.is_set(): - break - # while-loop just for shutdown - while not self._shutdown.is_set(): - try: - self._feature_queue.put((feat, batch), timeout=1) - break - except queue.Full: - continue - except Exception as ex: - logger.exception(ex) - raise ValueError("_preprocess_batch crashed") - self._ready = False - - def _core_batch(self): - """waiting for preprocessed batches (on device) - and do the forward pass / `.encode` - """ - try: - while not self._shutdown.is_set(): + async def _queue_finalizer(self): + while not self._shutdown.is_set(): + try: + batch = self._final_queue.get_nowait() + except queue.Empty: + # instead use async await to get try: - core_batch = self._feature_queue.get(timeout=0.5) + batch = await asyncio.to_thread(self._final_queue.get, timeout=1) except queue.Empty: continue - (feat, batch) = core_batch - if self._verbose: - logger.debug("[🏃] Inference on batch_size=%s", len(batch)) - self._last_inference = time.perf_counter() - embed = self.model.encode_core(feat) - - # while-loop just for shutdown - while not self._shutdown.is_set(): - try: - self._postprocess_queue.put((embed, batch), timeout=1) - break - except queue.Full: - continue - self._feature_queue.task_done() - except Exception as ex: - logger.exception(ex) - raise ValueError("_core_batch crashed.") - async def _postprocess_batch(self): - """collecting forward(.encode) results and put them into the result store""" - # TODO: the ugly asyncio.sleep() could add to 3-8ms of latency worst case - # In constrast, at full batch size, sleep releases cruical CPU at time of - # the forward pass to GPU (after which there is crical time again) - # and not affecting the latency - try: - while not self._shutdown.is_set(): - try: - post_batch = self._postprocess_queue.get_nowait() - except queue.Empty: - # instead use async await to get - try: - post_batch = await to_thread( - self._postprocess_queue.get, self._threadpool, timeout=1 - ) - except queue.Empty: - # in case of timeout start again - continue - - if ( - self._postprocess_queue.empty() - and self._last_inference - < time.perf_counter() + self._batch_delay * 2 - ): - # 5 ms, assuming this is below - # 3-50ms for inference on avg. - # give the CPU some time to focus - # on moving the next batch to GPU on the forward pass - # before proceeding - await asyncio.sleep(self._batch_delay) - embed, batch = post_batch - results = self.model.encode_post(embed) - for i, item in enumerate(batch): - item.set_result(results[i]) - await self._result_store.mark_item_ready(item) - - self._postprocess_queue.task_done() - except Exception as ex: - logger.exception(ex) - raise ValueError("Postprocessor crashed") + for item in batch: + await self._result_store.mark_item_ready(item) async def _delayed_warmup(self): """in case there is no warmup -> perform some warmup.""" @@ -385,9 +446,8 @@ async def spawn(self): if self._ready: raise ValueError("previous threads are still running.") logger.info("creating batching engine") - self._threadpool.submit(self._preprocess_batch) - self._threadpool.submit(self._core_batch) - asyncio.create_task(self._postprocess_batch()) + await self.model_worker.spawn() + asyncio.create_task(self._queue_finalizer()) asyncio.create_task(self._delayed_warmup()) async def shutdown(self): @@ -396,5 +456,5 @@ async def shutdown(self): Blocking event, until shutdown. """ self._shutdown.set() - with ThreadPoolExecutor() as tp_temp: - await to_thread(self._threadpool.shutdown, tp_temp) + await self.model_worker.shutdown() + print("all shutdown") diff --git a/libs/infinity_emb/infinity_emb/inference/caching_layer.py b/libs/infinity_emb/infinity_emb/inference/caching_layer.py index 79501d77..871e6b51 100644 --- a/libs/infinity_emb/infinity_emb/inference/caching_layer.py +++ b/libs/infinity_emb/infinity_emb/inference/caching_layer.py @@ -2,10 +2,8 @@ import os import queue import threading -from concurrent.futures import ThreadPoolExecutor from typing import Any, List, Union -from infinity_emb.inference.threading_asyncio import to_thread from infinity_emb.log_handler import logger from infinity_emb.primitives import EmbeddingReturnType, QueueItemInner from infinity_emb.transformer.utils import infinity_cache_dir @@ -36,18 +34,17 @@ def __init__(self, cache_name: str, shutdown: threading.Event) -> None: logger.info(f"caching vectors under: {dir}") self._cache = dc.Cache(dir, size_limit=2**28) self.is_running = False - self.startup() - def startup(self): + async def _verify_running(self): if not self.is_running: - self._threadpool = ThreadPoolExecutor() - self._threadpool.submit(self._consume_queue) + asyncio.create_task(asyncio.to_thread(self._consume_queue)) @staticmethod def _hash(key: Union[str, Any]) -> str: return str(key) def _consume_queue(self) -> None: + self.is_running = True while not self._shutdown.is_set(): try: item = self._add_q.get(timeout=1) @@ -57,7 +54,7 @@ def _consume_queue(self) -> None: k, v = item self._cache.add(key=self._hash(k), value=v, expire=86400) self._add_q.task_done() - self._threadpool.shutdown(wait=True) + self.is_running = False def _get(self, sentence: str) -> Union[None, EmbeddingReturnType, List[float]]: return self._cache.get(key=self._hash(sentence)) @@ -66,8 +63,9 @@ async def aget(self, item: QueueItemInner, future: asyncio.Future) -> None: """Sets result to item and future, if in cache. If not in cache, sets future to be done when result is set. """ + self._verify_running() item_as_str = item.content.str_repr() - result = await to_thread(self._get, self._threadpool, item_as_str) + result = await asyncio.to_thread(self._get, item_as_str) if result is not None: # update item with cached result if item.get_result() is None: diff --git a/libs/infinity_emb/tests/script_live.py b/libs/infinity_emb/tests/script_live.py index fba78a53..ce11d531 100644 --- a/libs/infinity_emb/tests/script_live.py +++ b/libs/infinity_emb/tests/script_live.py @@ -47,23 +47,22 @@ def remote(json_data: bytes, iters=1): print("Both methods provide the identical output.") print("Measuring latency via SentenceTransformers") - latency_st = timeit.timeit("local(sample, iters=3)", number=1, globals=locals()) + latency_st = timeit.timeit("local(sample, iters=1)", number=2, globals=locals()) print("SentenceTransformers latency: ", latency_st) model = None print("Measuring latency via requests") latency_request = timeit.timeit( - "remote(json_d, iters=3)", number=1, globals=locals() + "remote(json_d, iters=1)", number=2, globals=locals() ) - print(f"Request latency: {latency_request}") - - assert latency_st * 1.1 > latency_request + print(f"Infinity request latency: {latency_request}") def latency_single(): session = requests.Session() def _post(i): + time.sleep(0.02) json_d = json.dumps({"input": [str(i)], "model": "model"}) s = time.perf_counter() res = session.post(f"{LIVE_URL}/embeddings", data=json_d) From 15f45e8e406b8bd13b26c073ed211c3e988231e9 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sun, 7 Jan 2024 14:30:06 +0000 Subject: [PATCH 2/9] fixing test --- .../infinity_emb/inference/caching_layer.py | 2 +- libs/infinity_emb/poetry.lock | 18 ++++++++++++++++-- libs/infinity_emb/pyproject.toml | 1 + .../unit_test/inference/test_caching_layer.py | 2 +- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/inference/caching_layer.py b/libs/infinity_emb/infinity_emb/inference/caching_layer.py index 871e6b51..9c0fa696 100644 --- a/libs/infinity_emb/infinity_emb/inference/caching_layer.py +++ b/libs/infinity_emb/infinity_emb/inference/caching_layer.py @@ -63,7 +63,7 @@ async def aget(self, item: QueueItemInner, future: asyncio.Future) -> None: """Sets result to item and future, if in cache. If not in cache, sets future to be done when result is set. """ - self._verify_running() + await self._verify_running() item_as_str = item.content.str_repr() result = await asyncio.to_thread(self._get, item_as_str) if result is not None: diff --git a/libs/infinity_emb/poetry.lock b/libs/infinity_emb/poetry.lock index e9598f9d..6790f4a5 100644 --- a/libs/infinity_emb/poetry.lock +++ b/libs/infinity_emb/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2360,6 +2360,20 @@ pytest = ">=5.0" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] +[[package]] +name = "pytest-timeout" +version = "2.2.0" +description = "pytest plugin to abort hanging tests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-timeout-2.2.0.tar.gz", hash = "sha256:3b0b95dabf3cb50bac9ef5ca912fa0cfc286526af17afc806824df20c2f72c90"}, + {file = "pytest_timeout-2.2.0-py3-none-any.whl", hash = "sha256:bde531e096466f49398a59f2dde76fa78429a09a12411466f88a07213e220de2"}, +] + +[package.dependencies] +pytest = ">=5.0.0" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3971,4 +3985,4 @@ torch = ["sentence-transformers", "torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "52bf5a4c6ccda2b33e08cd9107bf8dbd6d0973f5ed92f1d1070f924fef6c8264" +content-hash = "864fc8c7a667c0cdfe819415607c5ba6bf9f1d263a9b6e4fe775a4807fb95721" diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 46340ea1..f4bf0780 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -36,6 +36,7 @@ infinity_emb = "infinity_emb.infinity_server:cli" [tool.poetry.group.test.dependencies] pytest = "^7.0.0" pytest-mock = "*" +pytest-timeout = "*" httpx = "*" asgi_lifespan = "*" anyio = "*" diff --git a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py index c0bb008b..e04da461 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py @@ -7,7 +7,7 @@ from infinity_emb.inference import caching_layer from infinity_emb.primitives import EmbeddingInner, EmbeddingSingle - +@pytest.mark.timeout(20) @pytest.mark.anyio async def test_cache(): global INFINITY_CACHE_VECTORS From 4dcbe428f8f8415e2cc517717b6f053cfebed624 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sun, 7 Jan 2024 14:36:23 +0000 Subject: [PATCH 3/9] linting --- libs/infinity_emb/pyproject.toml | 2 ++ .../tests/unit_test/inference/test_caching_layer.py | 1 + 2 files changed, 3 insertions(+) diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index f4bf0780..cd7f925e 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -71,6 +71,8 @@ all=["ctranslate2", "fastapi", "fastembed", "optimum", "orjson", "prometheus-fas markers = [ "performance: tests that measure performance (deselect with '-m \"not performance\"')", ] +[tool.pytest] +timeout = 300 [build-system] requires = ["poetry-core"] diff --git a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py index e04da461..eef04802 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py @@ -7,6 +7,7 @@ from infinity_emb.inference import caching_layer from infinity_emb.primitives import EmbeddingInner, EmbeddingSingle + @pytest.mark.timeout(20) @pytest.mark.anyio async def test_cache(): From b72e9519f6937716a67ad29acf41c30883593b19 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sun, 7 Jan 2024 17:13:47 +0000 Subject: [PATCH 4/9] commit modelworker --- libs/infinity_emb/infinity_emb/engine.py | 44 +-- .../infinity_emb/inference/__init__.py | 2 - .../infinity_emb/inference/batch_handler.py | 339 +++++++----------- .../infinity_emb/inference/model_worker.py | 153 ++++++++ .../infinity_emb/inference/select_model.py | 12 +- .../infinity_emb/transformer/utils.py | 3 +- .../unit_test/inference/test_batch_handler.py | 8 +- .../tests/unit_test/test_engine.py | 8 +- 8 files changed, 319 insertions(+), 250 deletions(-) create mode 100644 libs/infinity_emb/infinity_emb/inference/model_worker.py diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index b01796a5..e6b80056 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -4,7 +4,6 @@ from infinity_emb.inference import ( BatchHandler, Device, - select_model, ) from infinity_emb.log_handler import logger from infinity_emb.primitives import EmbeddingReturnType, ModelCapabilites @@ -51,19 +50,18 @@ def __init__( self.running = False self._vector_disk_cache_path = vector_disk_cache_path self._model_name_or_path = model_name_or_path + self._model_name_or_pathengine = engine + self._model_warmup = model_warmup self._lengths_via_tokenize = lengths_via_tokenize + if isinstance(engine, str): - engine = InferenceEngine[engine] + self._engine_type = InferenceEngine[engine] + else: + self._engine_type = engine if isinstance(device, str): - device = Device[device] - - self._model, self._min_inference_t = select_model( - model_name_or_path=model_name_or_path, - batch_size=batch_size, - engine=engine, - model_warmup=model_warmup, - device=device, - ) + self.device = Device[device] + else: + self.device = device async def astart(self): """startup engine""" @@ -75,20 +73,23 @@ async def astart(self): ) self.running = True self._batch_handler = BatchHandler( + model_name_or_path=self._model_name_or_path, + engine=self._engine_type, max_batch_size=self.batch_size, - model=self._model, - batch_delay=self._min_inference_t / 2, + model_warmup=self._model_warmup, vector_disk_cache_path=self._vector_disk_cache_path, verbose=logger.level <= 10, lengths_via_tokenize=self._lengths_via_tokenize, + device=self.device, ) await self._batch_handler.spawn() async def astop(self): """stop engine""" - self._check_running() + self._assert_running() self.running = False await self._batch_handler.shutdown() + self._batch_handler = None async def __aenter__(self): await self.astart() @@ -97,16 +98,17 @@ async def __aexit__(self, *args): await self.astop() def overload_status(self): - self._check_running() + self._assert_running() return self._batch_handler.overload_status() def is_overloaded(self) -> bool: - self._check_running() + self._assert_running() return self._batch_handler.is_overloaded() @property def capabilities(self) -> Set[ModelCapabilites]: - return self._model.capabilities + self._assert_running() + return self._batch_handler.capabilities async def embed( self, sentences: List[str] @@ -125,7 +127,7 @@ async def embed( Usage: """ - self._check_running() + self._assert_running() embeddings, usage = await self._batch_handler.embed(sentences) return embeddings, usage @@ -139,7 +141,7 @@ async def rerank( docs (List[str]): docs to be reranked raw_scores (bool): return raw scores instead of sigmoid """ - self._check_running() + self._assert_running() scores, usage = await self._batch_handler.rerank( query=query, docs=docs, raw_scores=raw_scores ) @@ -156,12 +158,12 @@ async def classify( docs (List[str]): docs to be reranked raw_scores (bool): return raw scores instead of sigmoid """ - self._check_running() + self._assert_running() scores, usage = await self._batch_handler.classify(sentences=sentences) return scores, usage - def _check_running(self): + def _assert_running(self): if not self.running: raise ValueError( "didn't start `AsyncEmbeddingEngine` " diff --git a/libs/infinity_emb/infinity_emb/inference/__init__.py b/libs/infinity_emb/infinity_emb/inference/__init__.py index 7ff36e27..b0c89fae 100644 --- a/libs/infinity_emb/infinity_emb/inference/__init__.py +++ b/libs/infinity_emb/infinity_emb/inference/__init__.py @@ -1,5 +1,4 @@ from infinity_emb.inference.batch_handler import BatchHandler -from infinity_emb.inference.select_model import select_model from infinity_emb.primitives import ( Device, DeviceTypeHint, @@ -15,5 +14,4 @@ "Device", "DeviceTypeHint", "BatchHandler", - "select_model", ] diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 685e99c4..6fd4e975 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -3,18 +3,26 @@ import queue import threading import time + +# from multiprocessing import Process +# from multiprocessing import Queue as MPQueue from queue import Queue from typing import Any, Dict, List, Sequence, Set, Tuple import numpy as np from infinity_emb.inference.caching_layer import Cache +from infinity_emb.inference.model_worker import ModelWorker from infinity_emb.inference.queue import ( CustomFIFOQueue, ResultKVStoreFuture, ) +from infinity_emb.inference.select_model import ( + get_engine_type_from_config, +) from infinity_emb.log_handler import logger from infinity_emb.primitives import ( + Device, EmbeddingInner, EmbeddingReturnType, EmbeddingSingle, @@ -28,213 +36,41 @@ ReRankInner, ReRankSingle, ) -from infinity_emb.transformer.abstract import BaseTransformer -from infinity_emb.transformer.utils import get_lengths_with_tokenize - - -class ModelWorker: - def __init__( - self, - in_queue: CustomFIFOQueue, - out_queue: Queue, - shutdown_event: threading.Event, - batch_delay: float, - model: BaseTransformer, - max_batch_size: int, - verbose: bool = False, - ) -> None: - self.model = model - self._shutdown = shutdown_event - self.max_batch_size = max_batch_size - - self._in_queue = in_queue - self._feature_queue: Queue = Queue(6) - self._postprocess_queue: Queue = Queue(4) - self._out_queue = out_queue - - self._batch_delay = batch_delay - self._verbose = verbose - self._last_inference = time.perf_counter() - - def _preprocess_batch(self): - """loops and checks if the _core_batch has worked on all items""" - self._ready = True - logger.info("ready to batch requests.") - try: - while not self._shutdown.is_set(): - # patience: - # do not pop a batch if self._feature_queue still has an item left - # - until GPU / _core_batch starts processing the previous item - # - or if many items are queued anyhow, so that a good batch - # may be popped already. - if not self._feature_queue.empty() and ( - self._feature_queue.full() - or (len(self._in_queue) < self.max_batch_size * 4) - ): - # add some stochastic delay - time.sleep(self._batch_delay) - continue - # decision to attemp to pop a batch - # -> will happen if a single datapoint is available - - batches = self._in_queue.pop_optimal_batches( - self.max_batch_size, latest_first=False - ) - if not batches: - # not a single sentence available / len=0, wait for more - continue - # optimal batch has been selected -> - # lets tokenize it and move tensors to GPU. - for batch in batches: - if self._feature_queue.qsize() > 2: - # add some stochastic delay - time.sleep(self._batch_delay * 2) - - items_for_pre = [item.content.to_input() for item in batch] - feat = self.model.encode_pre(items_for_pre) - if self._verbose: - logger.debug( - "[📦] batched %s requests, queue remaining: %s", - len(items_for_pre), - len(self._in_queue), - ) - if self._shutdown.is_set(): - break - # while-loop just for shutdown - while not self._shutdown.is_set(): - try: - self._feature_queue.put((feat, batch), timeout=1) - break - except queue.Full: - continue - except Exception as ex: - logger.exception(ex) - raise ValueError("_preprocess_batch crashed") - self._ready = False - - def _core_batch(self): - """waiting for preprocessed batches (on device) - and do the forward pass / `.encode` - """ - try: - while not self._shutdown.is_set(): - try: - core_batch = self._feature_queue.get(timeout=0.5) - except queue.Empty: - continue - (feat, batch) = core_batch - if self._verbose: - logger.debug("[🏃] Inference on batch_size=%s", len(batch)) - self._last_inference = time.perf_counter() - embed = self.model.encode_core(feat) - - # while-loop just for shutdown - while not self._shutdown.is_set(): - try: - self._postprocess_queue.put((embed, batch), timeout=1) - break - except queue.Full: - continue - self._feature_queue.task_done() - except Exception as ex: - logger.exception(ex) - raise ValueError("_core_batch crashed.") - - def _postprocess_batch(self): - """collecting forward(.encode) results and put them into the result store""" - # TODO: the ugly asyncio.sleep() could add to 3-8ms of latency worst case - # In constrast, at full batch size, sleep releases cruical CPU at time of - # the forward pass to GPU (after which there is crical time again) - # and not affecting the latency - try: - while not self._shutdown.is_set(): - try: - post_batch = self._postprocess_queue.get(timeout=0.5) - except queue.Empty: - # instead use async await to get - continue - - if ( - self._postprocess_queue.empty() - and self._last_inference - < time.perf_counter() + self._batch_delay * 2 - ): - # 5 ms, assuming this is below - # 3-50ms for inference on avg. - # give the CPU some time to focus - # on moving the next batch to GPU on the forward pass - # before proceeding - time.sleep(self._batch_delay) - embed, batch = post_batch - results = self.model.encode_post(embed) - for i, item in enumerate(batch): - item.set_result(results[i]) - - while not self._shutdown.is_set(): - try: - self._out_queue.put(batch, timeout=1) - break - except queue.Full: - continue - - self._postprocess_queue.task_done() - except Exception as ex: - logger.exception(ex) - raise ValueError("Postprocessor crashed") - - async def spawn(self): - """set up the resources in batch""" - logger.info("creating batching engine") - self.tasks = [ - asyncio.create_task(asyncio.to_thread(self._preprocess_batch)), - asyncio.create_task(asyncio.to_thread(self._core_batch)), - asyncio.create_task(asyncio.to_thread(self._postprocess_batch)), - ] - - async def shutdown(self): - """ - set the shutdown event. - Blocking event, until shutdown. - """ - self._shutdown.set() - - await asyncio.gather(*self.tasks) +from infinity_emb.transformer.utils import ( + CapableEngineType, + InferenceEngine, + get_lengths_with_tokenize, +) class BatchHandler: def __init__( self, - model: BaseTransformer, + model_name_or_path: str, max_batch_size: int, + engine: InferenceEngine = InferenceEngine.torch, + model_warmup: bool = True, max_queue_wait: int = int(os.environ.get("INFINITY_QUEUE_SIZE", 32_000)), - batch_delay: float = 5e-3, vector_disk_cache_path: str = "", verbose=False, lengths_via_tokenize: bool = False, + device: Device = Device.auto, ) -> None: """ performs batching around the model. - model: BaseTransformer, implements fn (core|pre|post)_encode max_batch_size: max batch size of the models - max_queue_wait: max items to queue in the batch, default 32_000 sentences - batch_delay: sleep in seconds, wait time for pre/post methods. - Best result: setting to 1/2 the minimal expected - time for core_encode method / "gpu inference". - Dont set it above 1x minimal expected time of interence. - Should not be 0 to not block Python's GIL. vector_disk_cache_path: path to cache vectors on disk. lengths_via_tokenize: if True, use the tokenizer to get the lengths else len() """ - self.model = model - self.max_batch_size = max_batch_size + self._verbose = verbose self.max_queue_wait = max_queue_wait + self.max_batch_size = max_batch_size self._lengths_via_tokenize = lengths_via_tokenize self._shutdown = threading.Event() - - self._final_queue: Queue = Queue() - self._batch_delay = float(max(1e-4, batch_delay)) + self._shared_queue_model_out: Queue = Queue() + self._shared_queue_model_in: Queue = Queue() self._queue_prio = CustomFIFOQueue() cache = ( Cache( @@ -248,23 +84,39 @@ def __init__( self._result_store = ResultKVStoreFuture(cache) self._ready = False - if batch_delay > 0.1: - logger.warn(f"high batch delay of {self._batch_delay}") - if max_batch_size > max_queue_wait * 10: - logger.warn( - f"queue_size={self.max_queue_wait} to small " - f"over batch_size={self.max_batch_size}." - " Consider increasing queue size" - ) + capable_engine: CapableEngineType = get_engine_type_from_config( + model_name_or_path=model_name_or_path, + engine=engine, + ) + self.model_capabilities = capable_engine.value.capabilities + self.model_worker = ModelWorker( - in_queue=self._queue_prio, - out_queue=self._final_queue, + in_queue=self._shared_queue_model_in, + out_queue=self._shared_queue_model_out, shutdown_event=self._shutdown, - batch_delay=self._batch_delay, - model=self.model, - verbose=verbose, + verbose=self._verbose, max_batch_size=self.max_batch_size, + model_name_or_path=model_name_or_path, + capable_engine=capable_engine, + model_warmup=model_warmup, + device=device, ) + # else: + # # start a process + # self.model_worker = Process( + # target=ModelWorker, + # kwargs=dict( + # in_queue=self._shared_queue_model_in, + # out_queue=self._shared_queue_model_out, + # shutdown_event=self._shutdown, + # verbose=self._verbose, + # max_batch_size=self.max_batch_size, + # model_name_or_path=model_name_or_path, + # capable_engine=capable_engine, + # model_warmup=model_warmup, + # device=device, + # ), + # ) async def embed( self, sentences: List[str] @@ -277,11 +129,10 @@ async def embed( Returns: EmbeddingReturnType: list of embedding as 1darray """ - if "embed" not in self.model.capabilities: + if "embed" not in self.model_capabilities: raise ModelNotDeployedError( "the loaded moded cannot fullyfill `embed`." - f"options are {self.model.capabilities} inherited " - f"from model_class={self.model.__class__}" + f"options are {self.model_capabilities}" ) input_sentences = [EmbeddingSingle(s) for s in sentences] @@ -301,11 +152,10 @@ async def rerank( List[float]: list of scores int: token usage """ - if "rerank" not in self.model.capabilities: + if "rerank" not in self.model_capabilities: raise ModelNotDeployedError( "the loaded moded cannot fullyfill `rerank`." - f"options are {self.model.capabilities} inherited " - f"from model_class={self.model.__class__}" + f"options are {self.model_capabilities}" ) rerankables = [ReRankSingle(query=query, document=doc) for doc in docs] scores, usage = await self._schedule(rerankables) @@ -328,11 +178,10 @@ async def classify( Returns: EmbeddingReturnType: embedding as 1darray """ - if "classify" not in self.model.capabilities: + if "classify" not in self.model_capabilities: raise ModelNotDeployedError( "the loaded moded cannot fullyfill `classify`." - f"options are {self.model.capabilities} inherited " - f"from model_class={self.model.__class__}" + f"options are {self.model_capabilities}" ) items = [PredictSingle(sentence=s) for s in sentences] classifications, usage = await self._schedule(items) @@ -373,7 +222,7 @@ async def _schedule( @property def capabilities(self) -> Set[ModelCapabilites]: - return self.model.capabilities + return self.model_capabilities def is_overloaded(self) -> bool: """checks if more items can be queued.""" @@ -404,20 +253,81 @@ async def _get_prios_usage( if not self._lengths_via_tokenize: return get_lengths_with_tokenize([it.str_repr() for it in items]) else: + # TODO: fix lengths_via_tokenize return await asyncio.to_thread( get_lengths_with_tokenize, _sentences=[it.str_repr() for it in items], - tokenize=self.model.tokenize_lengths, + # tokenize=self.model.tokenize_lengths, # TODO: fix ) + def _preprocess_batch(self): + """loops and checks if the _core_batch has worked on all items""" + self._ready = True + logger.info("ready to batch requests.") + try: + while not self._shutdown.is_set(): + # patience: + # do not pop a batch if self._feature_queue still has an item left + # - until GPU / _core_batch starts processing the previous item + # - or if many items are queued anyhow, so that a good batch + # may be popped already. + if not self._shared_queue_model_in.empty() and ( + self._shared_queue_model_in.full() + or (len(self._queue_prio) < self.max_batch_size * 4) + ): + # add some stochastic delay + time.sleep(2e-4) + continue + # decision to attemp to pop a batch + # -> will happen if a single datapoint is available + + batches = self._queue_prio.pop_optimal_batches( + self.max_batch_size, latest_first=False + ) + if not batches: + # not a single sentence available / len=0, wait for more + continue + # optimal batch has been selected -> + # lets tokenize it and move tensors to GPU. + for batch in batches: + if self._shared_queue_model_in.qsize() > 2: + # add some stochastic delay + time.sleep(2e-4) + + items_for_pre = [item.content.to_input() for item in batch] + + if self._verbose: + logger.debug( + "[📦] batched %s requests, queue remaining: %s", + len(items_for_pre), + self._shared_queue_model_in.qsize(), + ) + if self._shutdown.is_set(): + break + # while-loop just for shutdown + while not self._shutdown.is_set(): + try: + self._shared_queue_model_in.put( + (items_for_pre, batch), timeout=1 + ) + break + except queue.Full: + continue + except Exception as ex: + logger.exception(ex) + raise ValueError("_preprocess_batch crashed") + self._ready = False + async def _queue_finalizer(self): while not self._shutdown.is_set(): try: - batch = self._final_queue.get_nowait() + _, batch = self._shared_queue_model_out.get_nowait() except queue.Empty: # instead use async await to get try: - batch = await asyncio.to_thread(self._final_queue.get, timeout=1) + _, batch = await asyncio.to_thread( + self._shared_queue_model_out.get, timeout=1 + ) except queue.Empty: continue @@ -430,13 +340,13 @@ async def _delayed_warmup(self): if not self._shutdown.is_set(): logger.debug("Sending a warm up through embedding.") try: - if "embed" in self.model.capabilities: + if "embed" in self.model_capabilities: await self.embed(sentences=["test"] * self.max_batch_size) - if "rerank" in self.model.capabilities: + if "rerank" in self.model_capabilities: await self.rerank( query="query", docs=["test"] * self.max_batch_size ) - if "classify" in self.model.capabilities: + if "classify" in self.model_capabilities: await self.classify(sentences=["test"] * self.max_batch_size) except Exception: pass @@ -448,6 +358,7 @@ async def spawn(self): logger.info("creating batching engine") await self.model_worker.spawn() asyncio.create_task(self._queue_finalizer()) + asyncio.create_task(asyncio.to_thread(self._preprocess_batch)) asyncio.create_task(self._delayed_warmup()) async def shutdown(self): diff --git a/libs/infinity_emb/infinity_emb/inference/model_worker.py b/libs/infinity_emb/infinity_emb/inference/model_worker.py new file mode 100644 index 00000000..59a9dcd7 --- /dev/null +++ b/libs/infinity_emb/infinity_emb/inference/model_worker.py @@ -0,0 +1,153 @@ +import asyncio +import queue +import threading +import time + +# from multiprocessing import Process +# from multiprocessing import Queue as MPQueue +from queue import Queue +from typing import Any + +from infinity_emb.inference.select_model import ( + select_model, +) +from infinity_emb.log_handler import logger +from infinity_emb.primitives import ( + Device, +) +from infinity_emb.transformer.utils import ( + CapableEngineType, +) + + +class ModelWorker: + def __init__( + self, + in_queue: Queue, + out_queue: Queue, + shutdown_event: threading.Event, + model_name_or_path: str, + max_batch_size: int, + capable_engine: CapableEngineType, + model_warmup: bool, + device: Device, + verbose: bool, + ) -> None: + self.model, batch_delay = select_model( + model_name_or_path=model_name_or_path, + batch_size=max_batch_size, + capable_engine=capable_engine, + model_warmup=model_warmup, + device=device, + ) + self._batch_delay = float(max(1e-4, batch_delay)) + + if batch_delay > 0.1: + logger.warn(f"high batch delay of {self._batch_delay}") + self._shutdown = shutdown_event + self.max_batch_size = max_batch_size + + self._shared_in_queue: Queue = in_queue + self._feature_queue: Queue = Queue(6) + self._postprocess_queue: Queue = Queue(4) + self._shared_out_queue = out_queue + + self._verbose = verbose + self._last_inference = time.perf_counter() + + def _general_batch( + self, + alias_name: str, + in_queue: Queue, + out_queue: Queue, + batch_fn: Any, + is_model_fn: bool = False, + is_post_fn: bool = False, + ): + logger.debug(f"starting {alias_name} in ModelWorker") + try: + while not self._shutdown.is_set(): + try: + fetched_batch = in_queue.get(timeout=0.5) + except queue.Empty: + continue + + feat, meta = fetched_batch + + if self._verbose: + logger.debug("[🏃] %s on batch_size=%s", alias_name, len(meta)) + + if is_model_fn: + self._last_inference = time.perf_counter() + elif self._last_inference < time.perf_counter() + self._batch_delay: + # 5 ms, assuming this is below + # 3-50ms for inference on avg. + # give the CPU some time to focus + # on moving the next batch to GPU on the forward pass + # before proceeding + time.sleep(self._batch_delay) + + processed = batch_fn(feat) + + if is_post_fn: + for i, item in enumerate(meta): + item.set_result(processed[i]) + processed = None + + # while-loop just for shutdown + while not self._shutdown.is_set(): + try: + out_queue.put((processed, meta), timeout=1) + break + except queue.Full: + continue + in_queue.task_done() + except Exception as ex: + logger.exception(ex) + self._ready = False + raise ValueError(f"{alias_name} crashed.") + self._ready = False + + async def spawn(self): + """set up the resources in batch""" + logger.info("creating ModelWorker") + self.tasks = [ + asyncio.create_task( + asyncio.to_thread( + self._general_batch, + "preprocess", + self._shared_in_queue, + self._feature_queue, + self.model.encode_pre, + ) + ), + asyncio.create_task( + asyncio.to_thread( + self._general_batch, + "forward", + self._feature_queue, + self._postprocess_queue, + self.model.encode_core, + is_model_fn=True, + ) + ), + asyncio.create_task( + asyncio.to_thread( + self._general_batch, + "postprocess", + self._postprocess_queue, + self._shared_out_queue, + self.model.encode_post, + is_post_fn=True, + ) + ), + ] + + async def shutdown(self): + """ + set the shutdown event. + Blocking event, until shutdown. + """ + self._shutdown.set() + + await asyncio.gather(*self.tasks) diff --git a/libs/infinity_emb/infinity_emb/inference/select_model.py b/libs/infinity_emb/infinity_emb/inference/select_model.py index d76653ff..07d5b540 100644 --- a/libs/infinity_emb/infinity_emb/inference/select_model.py +++ b/libs/infinity_emb/infinity_emb/inference/select_model.py @@ -8,6 +8,7 @@ ) from infinity_emb.transformer.abstract import BaseCrossEncoder, BaseEmbedder from infinity_emb.transformer.utils import ( + CapableEngineType, EmbedderEngine, InferenceEngine, PredictEngine, @@ -17,7 +18,7 @@ def get_engine_type_from_config( model_name_or_path: str, engine: InferenceEngine -) -> Union[EmbedderEngine, RerankEngine]: +) -> CapableEngineType: if engine in [InferenceEngine.debugengine, InferenceEngine.fastembed]: return EmbedderEngine.from_inference_engine(engine) @@ -50,18 +51,15 @@ def get_engine_type_from_config( def select_model( model_name_or_path: str, batch_size: int, - engine: InferenceEngine = InferenceEngine.torch, + capable_engine: CapableEngineType, model_warmup=True, device: Device = Device.auto, ) -> Tuple[Union[BaseCrossEncoder, BaseEmbedder], float]: logger.info( - f"model=`{model_name_or_path}` selected, using engine=`{engine.value}`" + f"model=`{model_name_or_path}` selected, using engine=`{capable_engine.value}`" f" and device=`{device.value}`" ) - # TODO: add EncoderEngine - unloaded_engine = get_engine_type_from_config(model_name_or_path, engine) - - loaded_engine = unloaded_engine.value(model_name_or_path, device=device.value) + loaded_engine = capable_engine.value(model_name_or_path, device=device.value) min_inference_t = 4e-3 if model_warmup: diff --git a/libs/infinity_emb/infinity_emb/transformer/utils.py b/libs/infinity_emb/infinity_emb/transformer/utils.py index 962d7c8d..fb1acee9 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils.py @@ -1,7 +1,7 @@ import os from enum import Enum from pathlib import Path -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, Union from infinity_emb.transformer.classifier.torch import SentenceClassifier from infinity_emb.transformer.crossencoder.torch import ( @@ -75,6 +75,7 @@ def from_inference_engine(engine: InferenceEngine): _types: Dict[str, str] = {e.name: e.name for e in InferenceEngine} InferenceEngineTypeHint = Enum("InferenceEngineTypeHint", _types) # type: ignore +CapableEngineType = Union[EmbedderEngine, RerankEngine, PredictEngine] def length_tokenizer( diff --git a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py index f49385e2..563a2fab 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py @@ -12,6 +12,7 @@ from infinity_emb.transformer.embedder.sentence_transformer import ( SentenceTransformerPatched, ) +from infinity_emb.transformer.utils import InferenceEngine BATCH_SIZE = 32 N_TIMINGS = 3 @@ -23,11 +24,16 @@ async def load_patched_bh() -> Tuple[SentenceTransformerPatched, BatchHandler]: model = SentenceTransformerPatched(pytest.DEFAULT_BERT_MODEL) model.encode(["hello " * 512] * BATCH_SIZE) - bh = BatchHandler(model=model, max_batch_size=BATCH_SIZE) + bh = BatchHandler( + model_name_or_path=str(pytest.DEFAULT_BERT_MODEL), + engine=InferenceEngine.torch, + max_batch_size=BATCH_SIZE, + ) await bh.spawn() return model, bh +@pytest.mark.timeout(300) @pytest.mark.performance @pytest.mark.anyio async def test_batch_performance_raw(get_sts_bechmark_dataset, load_patched_bh): diff --git a/libs/infinity_emb/tests/unit_test/test_engine.py b/libs/infinity_emb/tests/unit_test/test_engine.py index ebc28927..2bcea96e 100644 --- a/libs/infinity_emb/tests/unit_test/test_engine.py +++ b/libs/infinity_emb/tests/unit_test/test_engine.py @@ -28,8 +28,9 @@ async def test_async_api_torch(): engine=transformer.InferenceEngine.torch, device="auto", ) - assert engine.capabilities == {"embed"} + async with engine: + assert engine.capabilities == {"embed"} embeddings, usage = await engine.embed(sentences) assert isinstance(embeddings, list) assert isinstance(embeddings[0], np.ndarray) @@ -60,9 +61,8 @@ async def test_async_api_torch_CROSSENCODER(): model_warmup=True, ) - assert engine.capabilities == {"rerank"} - async with engine: + assert engine.capabilities == {"rerank"} rankings, usage = await engine.rerank(query=query, docs=documents) assert usage == sum([len(query) + len(d) for d in documents]) @@ -104,9 +104,9 @@ async def test_async_api_torch_CLASSIFY(): engine="torch", model_warmup=True, ) - assert engine.capabilities == {"classify"} async with engine: + assert engine.capabilities == {"classify"} predictions, usage = await engine.classify(sentences=sentences) assert usage == sum([len(s) for s in sentences]) assert len(predictions) == len(sentences) From dbce5d13368def966445ed37e517dcd661004da0 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sun, 7 Jan 2024 20:36:48 +0000 Subject: [PATCH 5/9] fix batching error --- libs/infinity_emb/infinity_emb/engine.py | 4 +- .../infinity_emb/inference/batch_handler.py | 52 +++++++++++-------- .../infinity_emb/inference/model_worker.py | 7 ++- .../infinity_emb/inference/queue.py | 18 +++++-- .../infinity_emb/infinity_server.py | 2 +- libs/infinity_emb/infinity_emb/primitives.py | 3 ++ libs/infinity_emb/tests/script_live.py | 2 +- .../unit_test/inference/test_batch_handler.py | 13 +++-- .../unit_test/inference/test_caching_layer.py | 2 +- 9 files changed, 64 insertions(+), 39 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index e6b80056..5bbb94a0 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -82,13 +82,13 @@ async def astart(self): lengths_via_tokenize=self._lengths_via_tokenize, device=self.device, ) - await self._batch_handler.spawn() + await self._batch_handler.astart() async def astop(self): """stop engine""" self._assert_running() self.running = False - await self._batch_handler.shutdown() + await self._batch_handler.astop() self._batch_handler = None async def __aenter__(self): diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 6fd4e975..6ea11847 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -213,8 +213,13 @@ async def _schedule( item=inner_item(content=re), # type: ignore ) new_prioqueue.append(item) + + await asyncio.gather( + *[self._result_store.register_item(item.item) for item in new_prioqueue] + ) + await self._queue_prio.extend(new_prioqueue) - + result = await asyncio.gather( *[self._result_store.wait_for_response(item.item) for item in new_prioqueue] ) @@ -263,7 +268,7 @@ async def _get_prios_usage( def _preprocess_batch(self): """loops and checks if the _core_batch has worked on all items""" self._ready = True - logger.info("ready to batch requests.") + logger.info("Batch_handler: Ready to batch requests.") try: while not self._shutdown.is_set(): # patience: @@ -319,53 +324,58 @@ def _preprocess_batch(self): self._ready = False async def _queue_finalizer(self): - while not self._shutdown.is_set(): - try: - _, batch = self._shared_queue_model_out.get_nowait() - except queue.Empty: - # instead use async await to get + try: + while not self._shutdown.is_set(): try: - _, batch = await asyncio.to_thread( - self._shared_queue_model_out.get, timeout=1 - ) + _, batch = self._shared_queue_model_out.get_nowait() except queue.Empty: - continue - - for item in batch: - await self._result_store.mark_item_ready(item) + # instead use async await to get + try: + _, batch = await asyncio.to_thread( + self._shared_queue_model_out.get, timeout=1 + ) + except queue.Empty: + continue + for item in batch: + await self._result_store.mark_item_ready(item) + except Exception as ex: + logger.exception(ex) + raise ValueError("_queue_finalizer crashed") async def _delayed_warmup(self): """in case there is no warmup -> perform some warmup.""" await asyncio.sleep(5) if not self._shutdown.is_set(): - logger.debug("Sending a warm up through embedding.") try: if "embed" in self.model_capabilities: + logger.debug("Sending a warm up to `embed`.") await self.embed(sentences=["test"] * self.max_batch_size) if "rerank" in self.model_capabilities: + logger.debug("Sending a warm up to `rerank`.") await self.rerank( query="query", docs=["test"] * self.max_batch_size ) if "classify" in self.model_capabilities: + logger.debug("Sending a warm up to `classify`.") await self.classify(sentences=["test"] * self.max_batch_size) - except Exception: - pass + except Exception as ex: + logger.exception(ex) - async def spawn(self): + async def astart(self): """set up the resources in batch""" if self._ready: raise ValueError("previous threads are still running.") logger.info("creating batching engine") - await self.model_worker.spawn() + await self.model_worker.astart() asyncio.create_task(self._queue_finalizer()) asyncio.create_task(asyncio.to_thread(self._preprocess_batch)) asyncio.create_task(self._delayed_warmup()) - async def shutdown(self): + async def astop(self): """ set the shutdown event and close threadpool. Blocking event, until shutdown. """ self._shutdown.set() - await self.model_worker.shutdown() + await self.model_worker.astop() print("all shutdown") diff --git a/libs/infinity_emb/infinity_emb/inference/model_worker.py b/libs/infinity_emb/infinity_emb/inference/model_worker.py index 59a9dcd7..9acb21f5 100644 --- a/libs/infinity_emb/infinity_emb/inference/model_worker.py +++ b/libs/infinity_emb/infinity_emb/inference/model_worker.py @@ -52,7 +52,7 @@ def __init__( self._postprocess_queue: Queue = Queue(4) self._shared_out_queue = out_queue - self._verbose = verbose + self._verbose = logger.level <= 5 or verbose self._last_inference = time.perf_counter() def _general_batch( @@ -108,7 +108,7 @@ def _general_batch( raise ValueError(f"{alias_name} crashed.") self._ready = False - async def spawn(self): + async def astart(self): """set up the resources in batch""" logger.info("creating ModelWorker") self.tasks = [ @@ -143,11 +143,10 @@ async def spawn(self): ), ] - async def shutdown(self): + async def astop(self): """ set the shutdown event. Blocking event, until shutdown. """ self._shutdown.set() - await asyncio.gather(*self.tasks) diff --git a/libs/infinity_emb/infinity_emb/inference/queue.py b/libs/infinity_emb/infinity_emb/inference/queue.py index c5ccd6c6..744fe8c3 100644 --- a/libs/infinity_emb/infinity_emb/inference/queue.py +++ b/libs/infinity_emb/infinity_emb/inference/queue.py @@ -9,6 +9,7 @@ PrioritizedQueueItem, QueueItemInner, ) +from infinity_emb.log_handler import logger class QueueSignal(enum.Enum): @@ -93,15 +94,20 @@ def loop(self) -> asyncio.AbstractEventLoop: def __len__(self): return len(self._kv) - + + async def register_item(self, item: QueueItemInner) -> None: + """wait for future to return""" + uuid = item.get_id() + self._kv[uuid] = self.loop.create_future() + return None + async def wait_for_response(self, item: QueueItemInner) -> EmbeddingReturnType: """wait for future to return""" uuid = item.get_id() - fut = self.loop.create_future() - self._kv[uuid] = fut if self._cache: - asyncio.create_task(self._cache.aget(item, fut)) - await fut + asyncio.create_task(self._cache.aget(item, self._kv[uuid])) + await self._kv[uuid] + return item.get_result() async def mark_item_ready(self, item: QueueItemInner) -> None: @@ -110,6 +116,8 @@ async def mark_item_ready(self, item: QueueItemInner) -> None: fut = self._kv[uuid] try: fut.set_result(None) + logger.debug(f"marked {uuid} as ready") except asyncio.InvalidStateError: pass del self._kv[uuid] + diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index b5229b39..a490ed67 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -85,7 +85,7 @@ async def _startup(): ) @app.on_event("shutdown") - async def _shutdown(): + async def _astop(): await app.model.astop() @app.get("/ready") diff --git a/libs/infinity_emb/infinity_emb/primitives.py b/libs/infinity_emb/infinity_emb/primitives.py index 6f22cfd5..7297719d 100644 --- a/libs/infinity_emb/infinity_emb/primitives.py +++ b/libs/infinity_emb/infinity_emb/primitives.py @@ -9,6 +9,9 @@ EmbeddingReturnType = npt.NDArray[Union[np.float32, np.float32]] +class QueueSignalMessages(enum.Enum): + POISON_KILL = 1 + WAKE_ON_NOTIFY = 2 class Device(enum.Enum): cpu = "cpu" diff --git a/libs/infinity_emb/tests/script_live.py b/libs/infinity_emb/tests/script_live.py index ce11d531..159d13b8 100644 --- a/libs/infinity_emb/tests/script_live.py +++ b/libs/infinity_emb/tests/script_live.py @@ -48,7 +48,7 @@ def remote(json_data: bytes, iters=1): print("Measuring latency via SentenceTransformers") latency_st = timeit.timeit("local(sample, iters=1)", number=2, globals=locals()) - print("SentenceTransformers latency: ", latency_st) + print(f"SentenceTransformers latency: {latency_st}") model = None print("Measuring latency via requests") diff --git a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py index 563a2fab..fa927130 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py @@ -3,7 +3,7 @@ import random import time from typing import Tuple - +import logging import numpy as np import pytest import torch @@ -29,7 +29,7 @@ async def load_patched_bh() -> Tuple[SentenceTransformerPatched, BatchHandler]: engine=InferenceEngine.torch, max_batch_size=BATCH_SIZE, ) - await bh.spawn() + await bh.astart() return model, bh @@ -38,6 +38,8 @@ async def load_patched_bh() -> Tuple[SentenceTransformerPatched, BatchHandler]: @pytest.mark.anyio async def test_batch_performance_raw(get_sts_bechmark_dataset, load_patched_bh): model, bh = load_patched_bh + model: SentenceTransformerPatched = model + bh: BatchHandler = bh assert bh.capabilities == {"embed"} try: sentences = [] @@ -59,6 +61,7 @@ async def method_batch_handler(_sentences): ] _ = await asyncio.gather(*tasks) end = time.perf_counter() + logging.info(f"batch_handler: {end - start}") return round(end - start, 4) def method_patched(_sentences): @@ -73,6 +76,7 @@ def method_patched(_sentences): emb.append(model.encode_post(embed)) np.concatenate(emb).tolist() end = time.perf_counter() + logging.info(f"method_patched: {end - start}") return round(end - start, 4) def method_st(_sentences): @@ -80,6 +84,7 @@ def method_st(_sentences): start = time.perf_counter() _ = model.encode(_sentences, batch_size=BATCH_SIZE).tolist() end = time.perf_counter() + logging.info(f"method_st: {end - start}") return round(end - start, 4) # yappi.get_func_stats().print_all() @@ -97,7 +102,7 @@ def method_st(_sentences): [method_patched(sentences) for _ in range(N_TIMINGS)] ) - print( + logging.info( f"times are sentence-transformers: {time_st}," " patched-sentence-transformers: " f" {time_st_patched}, batch-handler: {time_batch_handler}" @@ -120,4 +125,4 @@ def method_st(_sentences): ) finally: - await bh.shutdown() + await bh.astop() diff --git a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py index eef04802..f99d54b0 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py @@ -11,7 +11,7 @@ @pytest.mark.timeout(20) @pytest.mark.anyio async def test_cache(): - global INFINITY_CACHE_VECTORS + loop = asyncio.get_event_loop() shutdown = threading.Event() From 9a7f5b53aaa729e3af6c1f60395a67c523ae34fb Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Mon, 8 Jan 2024 00:39:29 +0000 Subject: [PATCH 6/9] update current state --- .../infinity_emb/inference/batch_handler.py | 140 +++++++++------- .../infinity_emb/inference/model_worker.py | 158 +++++++++--------- .../infinity_emb/inference/queue.py | 15 +- libs/infinity_emb/infinity_emb/primitives.py | 4 +- libs/infinity_emb/tests/script_live.py | 12 +- .../unit_test/inference/test_batch_handler.py | 15 +- .../unit_test/inference/test_caching_layer.py | 8 +- 7 files changed, 188 insertions(+), 164 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 6ea11847..0de7a4de 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -3,9 +3,8 @@ import queue import threading import time - -# from multiprocessing import Process -# from multiprocessing import Queue as MPQueue +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Manager, Process, set_start_method from queue import Queue from typing import Any, Dict, List, Sequence, Set, Tuple @@ -33,6 +32,7 @@ PredictInner, PredictSingle, PrioritizedQueueItem, + QueueSignalMessages, ReRankInner, ReRankSingle, ) @@ -42,6 +42,14 @@ get_lengths_with_tokenize, ) +_mp_manager = None + + +def get_mp_manager(): + global _mp_manager + _mp_manager = Manager() + return _mp_manager + class BatchHandler: def __init__( @@ -63,14 +71,15 @@ def __init__( vector_disk_cache_path: path to cache vectors on disk. lengths_via_tokenize: if True, use the tokenizer to get the lengths else len() """ - self._verbose = verbose - self.max_queue_wait = max_queue_wait + self.model_name_or_path = model_name_or_path self.max_batch_size = max_batch_size - self._lengths_via_tokenize = lengths_via_tokenize + self.model_warmup = model_warmup + self.max_queue_wait = max_queue_wait + self.lengths_via_tokenize = lengths_via_tokenize + self.device = device self._shutdown = threading.Event() - self._shared_queue_model_out: Queue = Queue() - self._shared_queue_model_in: Queue = Queue() + self._queue_prio = CustomFIFOQueue() cache = ( Cache( @@ -84,39 +93,51 @@ def __init__( self._result_store = ResultKVStoreFuture(cache) self._ready = False - capable_engine: CapableEngineType = get_engine_type_from_config( + self.capable_engine: CapableEngineType = get_engine_type_from_config( model_name_or_path=model_name_or_path, engine=engine, ) - self.model_capabilities = capable_engine.value.capabilities + self.model_capabilities = self.capable_engine.value.capabilities + self._shared_queue_model_out: Queue = Queue() + self._shared_queue_model_in: Queue = Queue() - self.model_worker = ModelWorker( + def _register_model_kwargs(self): + self.worker_args = dict( in_queue=self._shared_queue_model_in, out_queue=self._shared_queue_model_out, - shutdown_event=self._shutdown, - verbose=self._verbose, max_batch_size=self.max_batch_size, - model_name_or_path=model_name_or_path, - capable_engine=capable_engine, - model_warmup=model_warmup, - device=device, + model_name_or_path=self.model_name_or_path, + capable_engine=self.capable_engine, + model_warmup=self.model_warmup, + device=self.device, ) - # else: - # # start a process - # self.model_worker = Process( - # target=ModelWorker, - # kwargs=dict( - # in_queue=self._shared_queue_model_in, - # out_queue=self._shared_queue_model_out, - # shutdown_event=self._shutdown, - # verbose=self._verbose, - # max_batch_size=self.max_batch_size, - # model_name_or_path=model_name_or_path, - # capable_engine=capable_engine, - # model_warmup=model_warmup, - # device=device, - # ), - # ) + + def run_model_worker(self, num_workers=1, start_method="thread"): + if start_method == "thread": + self._register_model_kwargs() + with ThreadPoolExecutor(max_workers=num_workers) as pool: + tasks = [ + pool.submit(ModelWorker, **self.worker_args) + for _ in range(num_workers) + ] + for task in tasks: + task.result() + else: + set_start_method("spawn", force=True) + if 1: + mp_manager = get_mp_manager() + self._shared_queue_model_out = mp_manager.Queue() + self._shared_queue_model_in = mp_manager.Queue() + self._register_model_kwargs() + tasks = [ + Process(target=ModelWorker, kwargs=self.worker_args) + for _ in range(num_workers) + ] + for task in tasks: + task.start() + for task in tasks: + task.join() + logger.info("ModelWorker: all workers finished.") async def embed( self, sentences: List[str] @@ -213,13 +234,13 @@ async def _schedule( item=inner_item(content=re), # type: ignore ) new_prioqueue.append(item) - + await asyncio.gather( *[self._result_store.register_item(item.item) for item in new_prioqueue] ) - + await self._queue_prio.extend(new_prioqueue) - + result = await asyncio.gather( *[self._result_store.wait_for_response(item.item) for item in new_prioqueue] ) @@ -255,7 +276,7 @@ async def _get_prios_usage( Returns: Tuple[List[int], int]: prios, length """ - if not self._lengths_via_tokenize: + if not self.lengths_via_tokenize: return get_lengths_with_tokenize([it.str_repr() for it in items]) else: # TODO: fix lengths_via_tokenize @@ -301,14 +322,12 @@ def _preprocess_batch(self): items_for_pre = [item.content.to_input() for item in batch] - if self._verbose: - logger.debug( - "[📦] batched %s requests, queue remaining: %s", - len(items_for_pre), - self._shared_queue_model_in.qsize(), - ) - if self._shutdown.is_set(): - break + logger.debug( + "[📦] batched %s requests, queue remaining: %s", + len(items_for_pre), + self._shared_queue_model_in.qsize(), + ) + # while-loop just for shutdown while not self._shutdown.is_set(): try: @@ -325,18 +344,16 @@ def _preprocess_batch(self): async def _queue_finalizer(self): try: - while not self._shutdown.is_set(): + while True: try: - _, batch = self._shared_queue_model_out.get_nowait() + batch = self._shared_queue_model_out.get_nowait() except queue.Empty: # instead use async await to get - try: - _, batch = await asyncio.to_thread( - self._shared_queue_model_out.get, timeout=1 - ) - except queue.Empty: - continue - for item in batch: + batch = await asyncio.to_thread(self._shared_queue_model_out.get) + if batch == QueueSignalMessages.KILL: + break + + for item in batch[1]: await self._result_store.mark_item_ready(item) except Exception as ex: logger.exception(ex) @@ -366,7 +383,11 @@ async def astart(self): if self._ready: raise ValueError("previous threads are still running.") logger.info("creating batching engine") - await self.model_worker.astart() + self.model_worker = asyncio.create_task( + asyncio.to_thread(self.run_model_worker) + ) + await asyncio.sleep(5) + asyncio.create_task(self._queue_finalizer()) asyncio.create_task(asyncio.to_thread(self._preprocess_batch)) asyncio.create_task(self._delayed_warmup()) @@ -376,6 +397,11 @@ async def astop(self): set the shutdown event and close threadpool. Blocking event, until shutdown. """ + logger.debug("batch handler -> start astop") self._shutdown.set() - await self.model_worker.astop() - print("all shutdown") + + self._shared_queue_model_out.put(QueueSignalMessages.KILL) + # at last, kill the models, which kill the MP Manager. + self._shared_queue_model_in.put(QueueSignalMessages.KILL) + time.sleep(1) + logger.debug("batch handler <- done astop") diff --git a/libs/infinity_emb/infinity_emb/inference/model_worker.py b/libs/infinity_emb/infinity_emb/inference/model_worker.py index 9acb21f5..7e0142f2 100644 --- a/libs/infinity_emb/infinity_emb/inference/model_worker.py +++ b/libs/infinity_emb/infinity_emb/inference/model_worker.py @@ -1,37 +1,31 @@ -import asyncio -import queue -import threading import time - -# from multiprocessing import Process -# from multiprocessing import Queue as MPQueue -from queue import Queue -from typing import Any +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Queue as MpQueue +from queue import Empty, Full, Queue +from typing import Any, Union from infinity_emb.inference.select_model import ( select_model, ) from infinity_emb.log_handler import logger -from infinity_emb.primitives import ( - Device, -) +from infinity_emb.primitives import Device, QueueSignalMessages from infinity_emb.transformer.utils import ( CapableEngineType, ) +AnyQueue = Union[Queue, MpQueue] + class ModelWorker: def __init__( self, - in_queue: Queue, - out_queue: Queue, - shutdown_event: threading.Event, + in_queue: AnyQueue, + out_queue: AnyQueue, model_name_or_path: str, max_batch_size: int, capable_engine: CapableEngineType, model_warmup: bool, device: Device, - verbose: bool, ) -> None: self.model, batch_delay = select_model( model_name_or_path=model_name_or_path, @@ -44,33 +38,68 @@ def __init__( if batch_delay > 0.1: logger.warn(f"high batch delay of {self._batch_delay}") - self._shutdown = shutdown_event self.max_batch_size = max_batch_size - self._shared_in_queue: Queue = in_queue + self._shared_in_queue = in_queue self._feature_queue: Queue = Queue(6) self._postprocess_queue: Queue = Queue(4) self._shared_out_queue = out_queue - self._verbose = logger.level <= 5 or verbose + self._verbose = logger.level <= 5 self._last_inference = time.perf_counter() + # spawn threads + logger.info("creating ModelWorker") + + with ThreadPoolExecutor(max_workers=3) as pool: + tasks = [ + pool.submit( + self._general_batch, + "preprocess", + self._shared_in_queue, + self._feature_queue, + self.model.encode_pre, + ), + pool.submit( + self._general_batch, + "forward", + self._feature_queue, + self._postprocess_queue, + self.model.encode_core, + is_model_fn=True, + ), + pool.submit( + self._general_batch, + "postprocess", + self._postprocess_queue, + self._shared_out_queue, + self.model.encode_post, + is_post_fn=True, + ), + ] + # block until all tasks are done + for future in tasks: + future.result() + + logger.info("stopped ModelWorker") + def _general_batch( self, alias_name: str, - in_queue: Queue, - out_queue: Queue, + in_queue: AnyQueue, + out_queue: AnyQueue, batch_fn: Any, is_model_fn: bool = False, is_post_fn: bool = False, ): - logger.debug(f"starting {alias_name} in ModelWorker") try: - while not self._shutdown.is_set(): - try: - fetched_batch = in_queue.get(timeout=0.5) - except queue.Empty: - continue + while True: + fetched_batch = in_queue.get() + if type(in_queue) == Queue: + in_queue.task_done() + if fetched_batch == QueueSignalMessages.KILL: + self.destruct() + break feat, meta = fetched_batch @@ -94,59 +123,32 @@ def _general_batch( item.set_result(processed[i]) processed = None - # while-loop just for shutdown - while not self._shutdown.is_set(): - try: - out_queue.put((processed, meta), timeout=1) - break - except queue.Full: - continue - in_queue.task_done() + out_queue.put((processed, meta)) + except Exception as ex: logger.exception(ex) - self._ready = False + self.destruct() raise ValueError(f"{alias_name} crashed.") - self._ready = False - async def astart(self): - """set up the resources in batch""" - logger.info("creating ModelWorker") - self.tasks = [ - asyncio.create_task( - asyncio.to_thread( - self._general_batch, - "preprocess", - self._shared_in_queue, - self._feature_queue, - self.model.encode_pre, - ) - ), - asyncio.create_task( - asyncio.to_thread( - self._general_batch, - "forward", - self._feature_queue, - self._postprocess_queue, - self.model.encode_core, - is_model_fn=True, - ) - ), - asyncio.create_task( - asyncio.to_thread( - self._general_batch, - "postprocess", - self._postprocess_queue, - self._shared_out_queue, - self.model.encode_post, - is_post_fn=True, - ) - ), - ] - - async def astop(self): - """ - set the shutdown event. - Blocking event, until shutdown. - """ - self._shutdown.set() - await asyncio.gather(*self.tasks) + def destruct(self): + """kill all tasks""" + for q in [ + self._shared_in_queue, + self._postprocess_queue, + self._feature_queue, + self._shared_out_queue, + ]: + while not q.empty(): + try: + res = q.get_nowait() + if res == QueueSignalMessages.KILL: + break + if type(q) == Queue: + q.task_done() + except Empty: + pass + for _ in range(10): + try: + q.put_nowait(QueueSignalMessages.KILL) + except Full: + pass diff --git a/libs/infinity_emb/infinity_emb/inference/queue.py b/libs/infinity_emb/infinity_emb/inference/queue.py index 744fe8c3..8edb1bd3 100644 --- a/libs/infinity_emb/infinity_emb/inference/queue.py +++ b/libs/infinity_emb/infinity_emb/inference/queue.py @@ -9,7 +9,6 @@ PrioritizedQueueItem, QueueItemInner, ) -from infinity_emb.log_handler import logger class QueueSignal(enum.Enum): @@ -94,30 +93,30 @@ def loop(self) -> asyncio.AbstractEventLoop: def __len__(self): return len(self._kv) - + async def register_item(self, item: QueueItemInner) -> None: """wait for future to return""" uuid = item.get_id() self._kv[uuid] = self.loop.create_future() return None - + async def wait_for_response(self, item: QueueItemInner) -> EmbeddingReturnType: """wait for future to return""" uuid = item.get_id() if self._cache: asyncio.create_task(self._cache.aget(item, self._kv[uuid])) - await self._kv[uuid] + item = await self._kv[uuid] return item.get_result() async def mark_item_ready(self, item: QueueItemInner) -> None: """mark item as ready. Item.get_result() must be set before calling this""" uuid = item.get_id() + # faster than .pop fut = self._kv[uuid] + del self._kv[uuid] try: - fut.set_result(None) - logger.debug(f"marked {uuid} as ready") + fut.set_result(item) + # logger.debug(f"marked {uuid} as ready with {len(item.get_result())}") except asyncio.InvalidStateError: pass - del self._kv[uuid] - diff --git a/libs/infinity_emb/infinity_emb/primitives.py b/libs/infinity_emb/infinity_emb/primitives.py index 7297719d..15a013cb 100644 --- a/libs/infinity_emb/infinity_emb/primitives.py +++ b/libs/infinity_emb/infinity_emb/primitives.py @@ -9,10 +9,12 @@ EmbeddingReturnType = npt.NDArray[Union[np.float32, np.float32]] + class QueueSignalMessages(enum.Enum): - POISON_KILL = 1 + KILL = 1 WAKE_ON_NOTIFY = 2 + class Device(enum.Enum): cpu = "cpu" cuda = "cuda" diff --git a/libs/infinity_emb/tests/script_live.py b/libs/infinity_emb/tests/script_live.py index 159d13b8..30b7edde 100644 --- a/libs/infinity_emb/tests/script_live.py +++ b/libs/infinity_emb/tests/script_live.py @@ -45,18 +45,16 @@ def remote(json_data: bytes, iters=1): cosine_sim = np.dot(r, e) / (np.linalg.norm(e) * np.linalg.norm(r)) assert cosine_sim > 0.99 print("Both methods provide the identical output.") - - print("Measuring latency via SentenceTransformers") - latency_st = timeit.timeit("local(sample, iters=1)", number=2, globals=locals()) - print(f"SentenceTransformers latency: {latency_st}") - model = None - print("Measuring latency via requests") latency_request = timeit.timeit( - "remote(json_d, iters=1)", number=2, globals=locals() + "remote(json_d, iters=3)", number=2, globals=locals() ) print(f"Infinity request latency: {latency_request}") + print("Measuring latency via SentenceTransformers") + latency_st = timeit.timeit("local(sample, iters=3)", number=2, globals=locals()) + print(f"SentenceTransformers latency: {latency_st}") + def latency_single(): session = requests.Session() diff --git a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py index fa927130..63b5fe4a 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py @@ -1,9 +1,10 @@ import asyncio import copy +import logging import random import time from typing import Tuple -import logging + import numpy as np import pytest import torch @@ -15,7 +16,7 @@ from infinity_emb.transformer.utils import InferenceEngine BATCH_SIZE = 32 -N_TIMINGS = 3 +N_TIMINGS = 1 LIMIT_SLOWDOWN = 1.25 if torch.cuda.is_available() else 1.35 @@ -38,8 +39,7 @@ async def load_patched_bh() -> Tuple[SentenceTransformerPatched, BatchHandler]: @pytest.mark.anyio async def test_batch_performance_raw(get_sts_bechmark_dataset, load_patched_bh): model, bh = load_patched_bh - model: SentenceTransformerPatched = model - bh: BatchHandler = bh + assert bh.capabilities == {"embed"} try: sentences = [] @@ -91,13 +91,14 @@ def method_st(_sentences): # yappi.stop() method_st(sentences[::10]) await method_batch_handler(sentences[::10]) - time.sleep(2) + time.sleep(0.5) time_batch_handler = np.median( [(await method_batch_handler(sentences)) for _ in range(N_TIMINGS)] ) - time.sleep(2) + return await bh.astop() + time.sleep(0.5) time_st = np.median([method_st(sentences) for _ in range(N_TIMINGS)]) - time.sleep(2) + time.sleep(0.5) time_st_patched = np.median( [method_patched(sentences) for _ in range(N_TIMINGS)] ) diff --git a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py index f99d54b0..b00b0c12 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_caching_layer.py @@ -4,22 +4,19 @@ import numpy as np import pytest -from infinity_emb.inference import caching_layer +from infinity_emb.inference.caching_layer import Cache from infinity_emb.primitives import EmbeddingInner, EmbeddingSingle @pytest.mark.timeout(20) @pytest.mark.anyio async def test_cache(): - - loop = asyncio.get_event_loop() shutdown = threading.Event() try: - INFINITY_CACHE_VECTORS = True sentence = "dummy" embedding = np.random.random(5).tolist() - c = caching_layer.Cache( + c = Cache( cache_name=f"pytest_{hash((sentence, tuple(embedding)))}", shutdown=shutdown ) @@ -38,5 +35,4 @@ async def test_cache(): assert result is not None np.testing.assert_array_equal(result, embedding) finally: - INFINITY_CACHE_VECTORS = False shutdown.set() From bd1b5a01cf35883d7ab45e8ab4c40ef25af5a93d Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Mon, 8 Jan 2024 00:45:07 +0000 Subject: [PATCH 7/9] update defaults --- libs/infinity_emb/infinity_emb/inference/batch_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 0de7a4de..1893ea8d 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -112,7 +112,7 @@ def _register_model_kwargs(self): device=self.device, ) - def run_model_worker(self, num_workers=1, start_method="thread"): + def run_model_worker(self, num_workers=2, start_method="threads"): if start_method == "thread": self._register_model_kwargs() with ThreadPoolExecutor(max_workers=num_workers) as pool: From 8790f3f63cd6e39cf2b483be85bab7a9b27569e6 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Mon, 8 Jan 2024 01:00:09 +0000 Subject: [PATCH 8/9] update defaults --- libs/infinity_emb/infinity_emb/inference/batch_handler.py | 2 +- libs/infinity_emb/tests/script_live.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 1893ea8d..8fa539d6 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -112,7 +112,7 @@ def _register_model_kwargs(self): device=self.device, ) - def run_model_worker(self, num_workers=2, start_method="threads"): + def run_model_worker(self, num_workers=4, start_method="process"): if start_method == "thread": self._register_model_kwargs() with ThreadPoolExecutor(max_workers=num_workers) as pool: diff --git a/libs/infinity_emb/tests/script_live.py b/libs/infinity_emb/tests/script_live.py index 30b7edde..1d9a4705 100644 --- a/libs/infinity_emb/tests/script_live.py +++ b/libs/infinity_emb/tests/script_live.py @@ -47,12 +47,12 @@ def remote(json_data: bytes, iters=1): print("Both methods provide the identical output.") print("Measuring latency via requests") latency_request = timeit.timeit( - "remote(json_d, iters=3)", number=2, globals=locals() + "remote(json_d, iters=5)", number=3, globals=locals() ) print(f"Infinity request latency: {latency_request}") print("Measuring latency via SentenceTransformers") - latency_st = timeit.timeit("local(sample, iters=3)", number=2, globals=locals()) + latency_st = timeit.timeit("local(sample, iters=5)", number=3, globals=locals()) print(f"SentenceTransformers latency: {latency_st}") From da5d9852654e59541955ded1bb357dfa25eb0487 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Mon, 8 Jan 2024 01:20:42 +0000 Subject: [PATCH 9/9] improve script and format --- .../infinity_emb/inference/batch_handler.py | 50 ++++++++++--------- libs/infinity_emb/tests/script_live.py | 4 +- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 8fa539d6..e4117998 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -63,6 +63,8 @@ def __init__( verbose=False, lengths_via_tokenize: bool = False, device: Device = Device.auto, + num_workers: int = 1, + worker_method: str = "thread", ) -> None: """ performs batching around the model. @@ -77,6 +79,8 @@ def __init__( self.max_queue_wait = max_queue_wait self.lengths_via_tokenize = lengths_via_tokenize self.device = device + self.num_workers = num_workers + self.worker_method = worker_method self._shutdown = threading.Event() @@ -98,10 +102,17 @@ def __init__( engine=engine, ) self.model_capabilities = self.capable_engine.value.capabilities - self._shared_queue_model_out: Queue = Queue() - self._shared_queue_model_in: Queue = Queue() - def _register_model_kwargs(self): + if self.worker_method == "thread": + self._shared_queue_model_out: Queue = Queue() + self._shared_queue_model_in: Queue = Queue() + else: + set_start_method("spawn", force=True) + + mp_manager = get_mp_manager() + self._shared_queue_model_out = mp_manager.Queue() + self._shared_queue_model_in = mp_manager.Queue() + self.worker_args = dict( in_queue=self._shared_queue_model_in, out_queue=self._shared_queue_model_out, @@ -112,31 +123,24 @@ def _register_model_kwargs(self): device=self.device, ) - def run_model_worker(self, num_workers=4, start_method="process"): - if start_method == "thread": - self._register_model_kwargs() - with ThreadPoolExecutor(max_workers=num_workers) as pool: + def run_model_worker(self): + if self.worker_method == "thread": + with ThreadPoolExecutor(max_workers=self.num_workers) as pool: tasks = [ pool.submit(ModelWorker, **self.worker_args) - for _ in range(num_workers) + for _ in range(self.num_workers) ] for task in tasks: task.result() else: - set_start_method("spawn", force=True) - if 1: - mp_manager = get_mp_manager() - self._shared_queue_model_out = mp_manager.Queue() - self._shared_queue_model_in = mp_manager.Queue() - self._register_model_kwargs() - tasks = [ - Process(target=ModelWorker, kwargs=self.worker_args) - for _ in range(num_workers) - ] - for task in tasks: - task.start() - for task in tasks: - task.join() + tasks = [ + Process(target=ModelWorker, kwargs=self.worker_args) + for _ in range(self.num_workers) + ] + for task in tasks: + task.start() + for task in tasks: + task.join() logger.info("ModelWorker: all workers finished.") async def embed( @@ -386,8 +390,6 @@ async def astart(self): self.model_worker = asyncio.create_task( asyncio.to_thread(self.run_model_worker) ) - await asyncio.sleep(5) - asyncio.create_task(self._queue_finalizer()) asyncio.create_task(asyncio.to_thread(self._preprocess_batch)) asyncio.create_task(self._delayed_warmup()) diff --git a/libs/infinity_emb/tests/script_live.py b/libs/infinity_emb/tests/script_live.py index 1d9a4705..8c7b9d53 100644 --- a/libs/infinity_emb/tests/script_live.py +++ b/libs/infinity_emb/tests/script_live.py @@ -60,7 +60,7 @@ def latency_single(): session = requests.Session() def _post(i): - time.sleep(0.02) + time.sleep(0.05) json_d = json.dumps({"input": [str(i)], "model": "model"}) s = time.perf_counter() res = session.post(f"{LIVE_URL}/embeddings", data=json_d) @@ -74,4 +74,4 @@ def _post(i): if __name__ == "__main__": - embedding_live_performance() + latency_single()