Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor -> select_model(functional) #468

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/embed_package/embed/_infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from concurrent.futures import Future
from typing import Collection, Literal, Union

from infinity_emb import EngineArgs, SyncEngineArray # type: ignore
from infinity_emb import EngineArgs, SyncEngineArray
from infinity_emb.infinity_server import AutoPadding

__all__ = ["BatchedInference"]
Expand Down
15 changes: 4 additions & 11 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def __init__(

self.running = False
self._running_sepamore: Optional[Semaphore] = None
self._model_replicas, self._min_inference_t, self._max_inference_t = select_model(
self._engine_args
)
self._model_replicas_functions = select_model(self._engine_args)

@classmethod
def from_args(
Expand All @@ -72,11 +70,7 @@ def from_args(
return engine

def __str__(self) -> str:
return (
f"AsyncEmbeddingEngine(running={self.running}, "
f"inference_time={[self._min_inference_t, self._max_inference_t]}, "
f"{self._engine_args})"
)
return f"AsyncEmbeddingEngine(running={self.running}, " f"{self._engine_args})"

async def astart(self):
"""startup engine"""
Expand All @@ -87,8 +81,7 @@ async def astart(self):
self.running = True
self._batch_handler = BatchHandler(
max_batch_size=self._engine_args.batch_size,
model_replicas=self._model_replicas,
# batch_delay=self._min_inference_t / 2,
model_replicas=self._model_replicas_functions,
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
verbose=logger.level <= 10,
lengths_via_tokenize=self._engine_args.lengths_via_tokenize,
Expand Down Expand Up @@ -124,7 +117,7 @@ def is_running(self) -> bool:

@property
def capabilities(self) -> set[ModelCapabilites]:
return self._model_replicas[0].capabilities
return self._batch_handler.capabilities

@property
def engine_args(self) -> EngineArgs:
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput):
),
Annotated[str, INPUT_STRING],
]
modality: Literal[Modality.text] = Modality.text # type: ignore
modality: Literal[Modality.text] = Modality.text


class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):
Expand All @@ -82,11 +82,11 @@ class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):


class OpenAIEmbeddingInput_Audio(_OpenAIEmbeddingInput_URI):
modality: Literal[Modality.audio] = Modality.audio # type: ignore
modality: Literal[Modality.audio] = Modality.audio


class OpenAIEmbeddingInput_Image(_OpenAIEmbeddingInput_URI):
modality: Literal[Modality.image] = Modality.image # type: ignore
modality: Literal[Modality.image] = Modality.image


def get_modality(obj: dict) -> str:
Expand Down
88 changes: 52 additions & 36 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import Any, Optional, Sequence, Union, TYPE_CHECKING

import numpy as np
from functools import cached_property

from infinity_emb.env import MANAGER
from infinity_emb.inference.caching_layer import Cache
Expand Down Expand Up @@ -39,7 +39,7 @@
from infinity_emb.transformer.vision.utils import resolve_images

if TYPE_CHECKING:
from infinity_emb.transformer.abstract import BaseTypeHint
from infinity_emb.transformer.abstract import CallableReturningBaseTypeHint


QUEUE_TIMEOUT = 0.5
Expand All @@ -64,7 +64,7 @@ def submit(self, *args, **kwargs):
class BatchHandler:
def __init__(
self,
model_replicas: list["BaseTypeHint"],
model_replicas: list["CallableReturningBaseTypeHint"],
max_batch_size: int,
max_queue_wait: int = MANAGER.queue_size,
batch_delay: float = 5e-3,
Expand All @@ -91,6 +91,9 @@ def __init__(

self._max_queue_wait = max_queue_wait
self._lengths_via_tokenize = lengths_via_tokenize
self._max_batch_size = max_batch_size
self._batch_delay = batch_delay
self._verbose = verbose

self._shutdown = threading.Event()
self._threadpool = ThreadPoolExecutor()
Expand All @@ -114,18 +117,8 @@ def __init__(
self._result_store = ResultKVStoreFuture(cache)

# model
self.model_worker = [
ModelWorker(
shutdown=ShutdownReadOnly(self._shutdown),
model=model_replica,
threadpool=ThreadPoolExecutorReadOnly(self._threadpool),
input_q=self._publish_to_model_queue,
output_q=self._result_queue,
verbose=self.batch_delay,
batch_delay=batch_delay,
)
for model_replica in model_replicas
]
self.model_replica_fns = model_replicas
self._capabilities = None

if batch_delay > 0.1:
logger.warning(f"high batch delay of {batch_delay}")
Expand All @@ -136,6 +129,12 @@ def __init__(
" Consider increasing queue size"
)

@cached_property
def _tiktoken_encoding(self):
import tiktoken

return tiktoken.encoding_for_model("gpt-3.5-turbo")

async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a sentence to be embedded. Awaits until embedded.

Expand Down Expand Up @@ -289,10 +288,7 @@ async def audio_embed(
f"Options are {self.capabilities}."
)

items = await resolve_audios(
audios,
getattr(self.model_worker[0]._model, "sampling_rate", -42),
)
items = await resolve_audios(audios, self._extras.get("sampling_rate", -42))
embeddings, usage = await self._schedule(items)
return embeddings, usage

Expand All @@ -319,8 +315,8 @@ async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[lis

@property
def capabilities(self) -> set[ModelCapabilites]:
# TODO: try to remove inheritance here and return upon init.
return self.model_worker[0].capabilities
assert self._capabilities is not None, "Model not loaded"
return self._capabilities

def is_overloaded(self) -> bool:
"""checks if more items can be queued.
Expand Down Expand Up @@ -352,12 +348,11 @@ async def _get_prios_usage(self, items: Sequence[AbstractSingle]) -> tuple[list[
if not self._lengths_via_tokenize:
return get_lengths_with_tokenize([it.str_repr() for it in items])
else:
return await to_thread(
get_lengths_with_tokenize,
self._threadpool,
_sentences=[it.str_repr() for it in items],
tokenize=self.model_worker[0].tokenize_lengths,
)
tokenized = [
len(i)
for i in self._tiktoken_encoding.encode_batch([it.str_repr() for it in items])
]
return tokenized, sum(tokenized)

def _publish_towards_model(
self,
Expand Down Expand Up @@ -452,8 +447,21 @@ async def spawn(self):
ShutdownReadOnly(self._shutdown), self._result_queue, self._threadpool
)
)
for worker in self.model_worker:
worker.spawn()

def get_model_worker(model_replica_fn) -> tuple[set[ModelCapabilites], dict]:
return ModelWorker(
shutdown=ShutdownReadOnly(self._shutdown),
model_fn=model_replica_fn,
threadpool=ThreadPoolExecutorReadOnly(self._threadpool),
input_q=self._publish_to_model_queue,
output_q=self._result_queue,
verbose=self.batch_delay,
batch_delay=self._batch_delay,
).spawn()

self._capabilities, self._extras = get_model_worker(self.model_replica_fns[0])
if len(self.model_replica_fns) > 1:
self._threadpool.map(get_model_worker, self.model_replica_fns[1:])
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved

async def shutdown(self):
"""
Expand All @@ -473,15 +481,15 @@ class ModelWorker:
def __init__(
self,
shutdown: ShutdownReadOnly,
model: "BaseTypeHint",
model_fn: "CallableReturningBaseTypeHint",
threadpool: ThreadPoolExecutorReadOnly,
input_q: Queue,
output_q: Queue,
batch_delay: float = 5e-3,
verbose=False,
) -> None:
self._shutdown = shutdown
self._model = model
self._model_fn = model_fn
self._threadpool = threadpool
self._feature_queue: Queue = Queue(3)
self._postprocess_queue: Queue = Queue(5)
Expand All @@ -492,20 +500,28 @@ def __init__(
self._verbose = verbose
self._ready = False

def spawn(self):
def spawn(self) -> tuple[set[ModelCapabilites], dict]:
if self._ready:
raise ValueError("already spawned")
# start the threads
self._model = self._model_fn()
self._threadpool.submit(self._preprocess_batch)
self._threadpool.submit(self._core_batch)
self._threadpool.submit(self._postprocess_batch)

extras = {}
if hasattr(self._model, "sampling_rate"):
extras["sampling_rate"] = self._model.sampling_rate

return self._model.capabilities, extras

@property
def capabilities(self) -> set[ModelCapabilites]:
return self._model.capabilities
def model(self):
assert self._model is not None, "Model not loaded"
return self._model

def tokenize_lengths(self, *args, **kwargs):
return self._model.tokenize_lengths(*args, **kwargs)
return self.model.tokenize_lengths(*args, **kwargs)

def _preprocess_batch(self):
"""loops and checks if the _core_batch has worked on all items"""
Expand Down Expand Up @@ -560,7 +576,7 @@ def _core_batch(self):
if self._verbose:
logger.debug("[🧠] Inference on batch_size=%s", len(batch))
self._last_inference = time.perf_counter()
embed = self._model.encode_core(feat)
embed = self.model.encode_core(feat)

# while-loop just for shutdown
while not self._shutdown.is_set():
Expand Down
88 changes: 44 additions & 44 deletions libs/infinity_emb/infinity_emb/inference/select_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

import json
from pathlib import Path
from typing import Union
from typing import Union, TYPE_CHECKING


from infinity_emb.args import (
EngineArgs,
)
from infinity_emb.log_handler import logger
from infinity_emb.transformer.abstract import BaseCrossEncoder, BaseEmbedder
from functools import partial
from infinity_emb.transformer.utils import (
AudioEmbedEngine,
EmbedderEngine,
Expand All @@ -19,9 +17,15 @@
RerankEngine,
)

if TYPE_CHECKING:
from infinity_emb.transformer.abstract import BaseTypeHint # , CallableReturningBaseTypeHint
from infinity_emb.args import (
EngineArgs,
)


def get_engine_type_from_config(
engine_args: EngineArgs,
engine_args: "EngineArgs",
) -> Union[EmbedderEngine, RerankEngine, PredictEngine, ImageEmbedEngine, AudioEmbedEngine]:
"""resolved the class of inference engine path from config.json of the repo."""
if engine_args.engine in [InferenceEngine.debugengine]:
Expand Down Expand Up @@ -57,55 +61,51 @@ def get_engine_type_from_config(
return EmbedderEngine.from_inference_engine(engine_args.engine)


def _get_engine_replica(unloaded_engine, engine_args, device_map) -> "BaseTypeHint":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: function lacks type hints for unloaded_engine and device_map parameters

engine_args_copy = engine_args.copy()
engine_args_copy._loading_strategy.device_placement = device_map
loaded_engine = unloaded_engine.value(engine_args=engine_args_copy)

if engine_args.model_warmup:
# size one, warm up warm start timings.
# loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1)
# size one token
min(loaded_engine.warmup(batch_size=1, n_tokens=1)[1] for _ in range(5))
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
emb_per_sec_short, max_inference_temp, log_msg = loaded_engine.warmup(
batch_size=engine_args.batch_size, n_tokens=1
)

logger.info(log_msg)
# now warm up with max_token, max batch size
loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=512)
emb_per_sec, _, log_msg = loaded_engine.warmup(
batch_size=engine_args.batch_size, n_tokens=512
)
logger.info(log_msg)
logger.info(
f"model warmed up, between {emb_per_sec:.2f}-{emb_per_sec_short:.2f}"
f" embeddings/sec at batch_size={engine_args.batch_size}"
)
return loaded_engine


def select_model(
engine_args: EngineArgs,
) -> tuple[list[Union[BaseCrossEncoder, BaseEmbedder]], float, float]:
engine_args: "EngineArgs",
) -> list[partial["BaseTypeHint"]]:
"""based on engine args, fully instantiates the Engine."""
logger.info(
f"model=`{engine_args.model_name_or_path}` selected, "
f"using engine=`{engine_args.engine.value}`"
f" and device=`{engine_args.device.resolve()}`"
)
# engine_args.update_loading_strategy()

unloaded_engine = get_engine_type_from_config(engine_args)

engine_replicas = []
min_inference_t = 4e-3
max_inference_t = 4e-3

# TODO: Can be parallelized
for device_map in engine_args._loading_strategy.device_mapping: # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: type: ignore on device_mapping access should be replaced with proper type annotation

engine_args_copy = engine_args.copy()
engine_args_copy._loading_strategy.device_placement = device_map
loaded_engine = unloaded_engine.value(engine_args=engine_args_copy)

if engine_args.model_warmup:
# size one, warm up warm start timings.
# loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1)
# size one token
min_inference_t = min(
min(loaded_engine.warmup(batch_size=1, n_tokens=1)[1] for _ in range(10)),
min_inference_t,
)
loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1)
emb_per_sec_short, max_inference_temp, log_msg = loaded_engine.warmup(
batch_size=engine_args.batch_size, n_tokens=1
)
max_inference_t = max(max_inference_temp, max_inference_t)

logger.info(log_msg)
# now warm up with max_token, max batch size
loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=512)
emb_per_sec, _, log_msg = loaded_engine.warmup(
batch_size=engine_args.batch_size, n_tokens=512
)
logger.info(log_msg)
logger.info(
f"model warmed up, between {emb_per_sec:.2f}-{emb_per_sec_short:.2f}"
f" embeddings/sec at batch_size={engine_args.batch_size}"
)
engine_replicas.append(loaded_engine)
engine_replicas.append(
partial(_get_engine_replica, unloaded_engine, engine_args, device_map)
)
assert len(engine_replicas) > 0, "No engine replicas were loaded"

return engine_replicas, min_inference_t, max_inference_t
return engine_replicas
Loading