Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers 4.41 bump #257

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def infinity_cache_dir(self) -> Path:

return cache_dir

@cached_property
def queue_size(self) -> int:
return int(self._optional_infinity_var("queue_size", default="32000"))

@cached_property
def permissive_cors(self):
return self._to_bool(
Expand Down
4 changes: 2 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
import queue
import threading
import time
Expand All @@ -9,6 +8,7 @@

import numpy as np

from infinity_emb.env import MANAGER
from infinity_emb.inference.caching_layer import Cache
from infinity_emb.inference.queue import CustomFIFOQueue, ResultKVStoreFuture
from infinity_emb.inference.threading_asyncio import to_thread
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
self,
model: BaseTransformer,
max_batch_size: int,
max_queue_wait: int = int(os.environ.get("INFINITY_QUEUE_SIZE", 32_000)),
max_queue_wait: int = MANAGER.queue_size,
batch_delay: float = 5e-3,
vector_disk_cache_path: str = "",
verbose=False,
Expand Down
21 changes: 7 additions & 14 deletions libs/infinity_emb/infinity_emb/transformer/acceleration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import os
from typing import TYPE_CHECKING

from packaging.version import Version

from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TRANSFORMERS

if CHECK_TRANSFORMERS.is_available:
from transformers import __version__ as transformers_version # type: ignore
from infinity_emb._optional_imports import CHECK_OPTIMUM

if CHECK_OPTIMUM.is_available:
from optimum.bettertransformer import ( # type: ignore[import-untyped]
Expand All @@ -22,17 +17,15 @@
def to_bettertransformer(model: "PreTrainedModel", logger: "Logger"):
if os.environ.get("INFINITY_DISABLE_OPTIMUM", False): # OLD VAR
logger.warning(
"No optimizations via BetterTransformer,"
" it is disabled via env `INFINITY_DISABLE_OPTIMUM` "
"DEPRECATED `INFINITY_DISABLE_OPTIMUM` - setting optimizations via BetterTransformer,"
"INFINITY_DISABLE_OPTIMUM is no longer supported, please use the CLI / ENV for that."
)
return model
CHECK_TRANSFORMERS.mark_required()
if Version(transformers_version) >= Version("4.40.3"):
logger.info(
"Disable optimizations via BetterTransformer, as torch.sdpa ships with transformers >= 4.41.0"
)
return model
if (
hasattr(model.config, "_attn_implementation")
and model.config._attn_implementation != "eager"
):
raise ValueError("BetterTransformer overwrite requires eager attention.")
CHECK_OPTIMUM.mark_required()
logger.info("Adding optimizations via Huggingface optimum. ")
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ def __init__(
engine_args: EngineArgs,
) -> None:
CHECK_TRANSFORMERS.mark_required()
model_kwargs = {}
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"
self._pipe = pipeline(
task="text-classification",
model=engine_args.model_name_or_path,
trust_remote_code=engine_args.trust_remote_code,
device=engine_args.device.resolve(),
top_k=None,
revision=engine_args.revision,
model_kwargs=model_kwargs,
)
if self._pipe.device.type != "cpu": # and engine_args.dtype == "float16":
self._pipe.model = self._pipe.model.half()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@ class CrossEncoderPatched(CrossEncoder, BaseCrossEncoder):
def __init__(self, *, engine_args: EngineArgs):
CHECK_SENTENCE_TRANSFORMERS.mark_required()

model_kwargs = {}
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"

super().__init__(
engine_args.model_name_or_path,
revision=engine_args.revision,
device=engine_args.device.resolve(), # type: ignore
trust_remote_code=engine_args.trust_remote_code,
automodel_args=model_kwargs,
)
self.model.to(self._target_device) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,17 @@ class SentenceTransformerPatched(SentenceTransformer, BaseEmbedder):
def __init__(self, *, engine_args=EngineArgs):
CHECK_TORCH.mark_required()
CHECK_SENTENCE_TRANSFORMERS.mark_required()

model_kwargs = {}
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"

super().__init__(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
device=engine_args.device.resolve(),
model_kwargs=model_kwargs,
)
self.to(self.device)
# make a copy of the tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ def quant_embedding_decorator():
def decorator(func):
@wraps(func)
def wrapper(self: "BaseEmbedder", *args, **kwargs):
# Assume the first argument is the instance of BaseEmbedder or similar
"""
wraps a func called via func(self, *args, **kwargs) -> EmbeddingDtype(similar)

Special:
self has embedding_dtype: EmbeddingDtype
_internal_skip_quanitzation=True skips quantization
"""
skip_quanitzation = kwargs.pop("_internal_skip_quanitzation", False)
embeddings = func(self, *args, **kwargs)
if self.embedding_dtype == EmbeddingDtype.float32 or skip_quanitzation:
Expand Down
Loading
Loading