From c6b18ab80593c56a86014117ac0ec34dbd0229b0 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 21:20:07 -0700 Subject: [PATCH 01/11] auto-enable hf transfer --- libs/infinity_emb/infinity_emb/__init__.py | 39 +++++++++++++------ .../infinity_emb/transformer/__init__.py | 10 ----- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/__init__.py b/libs/infinity_emb/infinity_emb/__init__.py index 3df72736..2f076dfd 100644 --- a/libs/infinity_emb/infinity_emb/__init__.py +++ b/libs/infinity_emb/infinity_emb/__init__.py @@ -1,3 +1,31 @@ +import importlib.metadata +import os + +import huggingface_hub.constants # type: ignore + +### Check if HF_HUB_ENABLE_HF_TRANSFER is set, if not try to enable it +if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + + # Needs to be at the top of the file / before other + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass +huggingface_hub.constants.HF_HUB_DISABLE_PROGRESS_BARS = True + + +from infinity_emb import fastapi_schemas, inference, transformer # noqa: E402 +from infinity_emb.args import EngineArgs # noqa: E402 +from infinity_emb.engine import AsyncEmbeddingEngine # noqa: E402 + +# reexports +from infinity_emb.infinity_server import create_server # noqa: E402 +from infinity_emb.log_handler import logger # noqa: E402 + +__version__ = importlib.metadata.version("infinity_emb") + __all__ = [ "transformer", "inference", @@ -8,14 +36,3 @@ "EngineArgs", "__version__", ] -import importlib.metadata - -from infinity_emb import fastapi_schemas, inference, transformer -from infinity_emb.args import EngineArgs -from infinity_emb.engine import AsyncEmbeddingEngine - -# reexports -from infinity_emb.infinity_server import create_server -from infinity_emb.log_handler import logger - -__version__ = importlib.metadata.version("infinity_emb") diff --git a/libs/infinity_emb/infinity_emb/transformer/__init__.py b/libs/infinity_emb/infinity_emb/transformer/__init__.py index 49e1b773..8b137891 100644 --- a/libs/infinity_emb/infinity_emb/transformer/__init__.py +++ b/libs/infinity_emb/infinity_emb/transformer/__init__.py @@ -1,11 +1 @@ -__all__ = ["InferenceEngine"] -from infinity_emb._optional_imports import CHECK_HF_TRANSFER -from infinity_emb.transformer.utils import InferenceEngine -# place the enabling of hf hub transfer here -if CHECK_HF_TRANSFER.is_available: - # enable hf hub transfer if available - import hf_transfer # type: ignore # noqa - import huggingface_hub.constants # type: ignore[import-untyped] - - huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True From 81d8da01c219dc07b707eb4848cfde121ce0e434 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 21:20:22 -0700 Subject: [PATCH 02/11] enable all INFINITY env variables --- .../infinity_emb/_optional_imports.py | 1 - libs/infinity_emb/infinity_emb/engine.py | 9 +- libs/infinity_emb/infinity_emb/env.py | 141 ++++++++++++++++++ .../infinity_emb/fastapi_schemas/docs.py | 2 +- .../infinity_emb/inference/caching_layer.py | 5 - .../infinity_emb/infinity_server.py | 122 ++++++++------- .../tests/end_to_end/test_authentication.py | 6 +- 7 files changed, 212 insertions(+), 74 deletions(-) create mode 100644 libs/infinity_emb/infinity_emb/env.py diff --git a/libs/infinity_emb/infinity_emb/_optional_imports.py b/libs/infinity_emb/infinity_emb/_optional_imports.py index b22c990d..7f3651ac 100644 --- a/libs/infinity_emb/infinity_emb/_optional_imports.py +++ b/libs/infinity_emb/infinity_emb/_optional_imports.py @@ -55,7 +55,6 @@ def _raise_error(self) -> None: CHECK_DISKCACHE = OptionalImports("diskcache", "cache") CHECK_CTRANSLATE2 = OptionalImports("ctranslate2", "ctranslate2") CHECK_FASTAPI = OptionalImports("fastapi", "server") -CHECK_HF_TRANSFER = OptionalImports("hf_transfer", "hf_transfer") CHECK_ONNXRUNTIME = OptionalImports("optimum.onnxruntime", "optimum") CHECK_OPTIMUM = OptionalImports("optimum", "optimum") CHECK_OPTIMUM_NEURON = OptionalImports("optimum.neuron", "neuronx") diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index 1ea5ec06..86437434 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -224,12 +224,9 @@ def from_args(cls, engine_args_array: Iterable[EngineArgs]) -> "AsyncEngineArray Args: engine_args_array (list[EngineArgs]): EngineArgs object """ - return cls( - engines=tuple( - AsyncEmbeddingEngine.from_args(engine_args) - for engine_args in engine_args_array - ) - ) + engines = map(AsyncEmbeddingEngine.from_args, engine_args_array) + + return cls(engines=tuple(engines)) def __iter__(self) -> Iterator["AsyncEmbeddingEngine"]: return iter(self.engines_dict.values()) diff --git a/libs/infinity_emb/infinity_emb/env.py b/libs/infinity_emb/infinity_emb/env.py new file mode 100644 index 00000000..5db6ccea --- /dev/null +++ b/libs/infinity_emb/infinity_emb/env.py @@ -0,0 +1,141 @@ +# cache +from __future__ import annotations + +import os +from functools import cached_property + + +class __Infinity_EnvManager: + def __init__(self): + for f_name in dir(self): + if isinstance(getattr(type(self), f_name, None), cached_property): + getattr(self, f_name) # pre-cache + + @staticmethod + def _optional_infinity_var(name: str, default: str = ""): + name = name.upper().replace("-", "_") + return os.getenv(f"INFINITY_{name}", default) + + @staticmethod + def _optional_infinity_var_multiple(name: str, default: list[str]) -> list[str]: + value = os.getenv(name) + if not value: + return default + if value.endswith(";"): + value = value[:-1] + return value.split(";") + + @staticmethod + def _to_bool(value: str) -> bool: + return value.lower() in {"true", "1"} + + @staticmethod + def _to_bool_multiple(value: list[str]) -> list[bool]: + return [v.lower() in {"true", "1"} for v in value] + + @staticmethod + def _to_int_multiple(value: list[str]) -> list[int]: + return [int(v) for v in value] + + @cached_property + def api_key(self): + return self._optional_infinity_var("api_key", default="") + + @cached_property + def model_id(self): + return self._optional_infinity_var_multiple( + "model_id", default=["michaelfeil/bge-small-en-v1.5"] + ) + + @cached_property + def served_model_name(self): + return self._optional_infinity_var_multiple("served_model_name", default=[""]) + + @cached_property + def batch_size(self): + return self._to_int_multiple( + self._optional_infinity_var_multiple("batch_size", default=["32"]) + ) + + @cached_property + def revision(self): + return self._optional_infinity_var_multiple("revision", default=[""]) + + @cached_property + def trust_remote_code(self): + return self._to_bool_multiple( + self._optional_infinity_var_multiple("trust_remote_code", default=["true"]) + ) + + @cached_property + def model_warmup(self): + return self._to_bool_multiple( + self._optional_infinity_var_multiple("model_warmup", default=["true"]) + ) + + @cached_property + def vector_disk_cache(self): + return self._to_bool_multiple( + self._optional_infinity_var_multiple("vector_disk_cache", default=["false"]) + ) + + @cached_property + def lengths_via_tokenize(self): + return self._to_bool_multiple( + self._optional_infinity_var_multiple( + "lengths_via_tokenize", default=["false"] + ) + ) + + @cached_property + def compile(self): + return self._to_bool_multiple( + self._optional_infinity_var_multiple("compile", default=["false"]) + ) + + @cached_property + def bettertransformer(self): + return self._to_bool_multiple( + self._optional_infinity_var_multiple("bettertransformer", default=["true"]) + ) + + @cached_property + def preload_only(self): + return self._to_bool( + self._optional_infinity_var("preload_only", default="false") + ) + + @cached_property + def permissive_cors(self): + return self._to_bool( + self._optional_infinity_var("permissive_cors", default="false") + ) + + @cached_property + def url_prefix(self): + return self._optional_infinity_var("url_prefix", default="") + + @cached_property + def port(self): + port = self._optional_infinity_var("port", default="7997") + assert port.isdigit(), "INFINITY_PORT must be a number" + return int(port) + + @cached_property + def host(self): + return self._optional_infinity_var("host", default="0.0.0.0") + + @cached_property + def redirect_slash(self): + route = self._optional_infinity_var("redirect_slash", default="/docs") + assert not route or route.startswith( + "/" + ), "INFINITY_REDIRECT_SLASH must start with /" + return route + + @cached_property + def log_level(self): + return self._optional_infinity_var("log_level", default="info") + + +MANAGER = __Infinity_EnvManager() diff --git a/libs/infinity_emb/infinity_emb/fastapi_schemas/docs.py b/libs/infinity_emb/infinity_emb/fastapi_schemas/docs.py index c47638d1..e6a596c8 100644 --- a/libs/infinity_emb/infinity_emb/fastapi_schemas/docs.py +++ b/libs/infinity_emb/infinity_emb/fastapi_schemas/docs.py @@ -9,7 +9,7 @@ FASTAPI_DESCRIPTION = "" -def startup_message(host: str, port: str, prefix: str) -> str: +def startup_message(host: str, port: int, prefix: str) -> str: from infinity_emb import __version__ return f""" diff --git a/libs/infinity_emb/infinity_emb/inference/caching_layer.py b/libs/infinity_emb/infinity_emb/inference/caching_layer.py index 82b77b17..9d98a0f3 100644 --- a/libs/infinity_emb/infinity_emb/inference/caching_layer.py +++ b/libs/infinity_emb/infinity_emb/inference/caching_layer.py @@ -13,11 +13,6 @@ if CHECK_DISKCACHE.is_available: import diskcache as dc # type: ignore[import-untyped] -INFINITY_CACHE_VECTORS = ( - bool(os.environ.get("INFINITY_CACHE_VECTORS", False)) - and CHECK_DISKCACHE.is_available -) - class Cache: def __init__(self, cache_name: str, shutdown: threading.Event) -> None: diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 4a73156a..9105e7dc 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -1,12 +1,13 @@ import sys import time from contextlib import asynccontextmanager -from typing import Optional +from typing import Any, Optional import infinity_emb from infinity_emb._optional_imports import CHECK_TYPER, CHECK_UVICORN from infinity_emb.args import EngineArgs from infinity_emb.engine import AsyncEmbeddingEngine, AsyncEngineArray +from infinity_emb.env import MANAGER from infinity_emb.fastapi_schemas import docs, errors from infinity_emb.fastapi_schemas.pymodels import ( ClassifyInput, @@ -17,7 +18,6 @@ RerankInput, ReRankResult, ) -from infinity_emb.inference.caching_layer import INFINITY_CACHE_VECTORS from infinity_emb.log_handler import UVICORN_LOG_LEVELS, logger from infinity_emb.primitives import ( Device, @@ -31,18 +31,16 @@ def create_server( *, engine_args_list: list[EngineArgs], - url_prefix: str = "", - doc_extra: dict = {}, - redirect_slash: str = "/docs", - preload_only: bool = False, - permissive_cors: bool = False, - auth_token: Optional[str] = None, + url_prefix: str = MANAGER.url_prefix, + doc_extra: dict[str, Any] = {}, + redirect_slash: str = MANAGER.redirect_slash, + preload_only: bool = MANAGER.preload_only, + permissive_cors: bool = MANAGER.permissive_cors, + api_key: str = MANAGER.api_key, ): """ creates the FastAPI App """ - import os - from fastapi import Depends, FastAPI, HTTPException, responses, status from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -57,8 +55,8 @@ async def lifespan(app: FastAPI): logger.info( docs.startup_message( - host=doc_extra.pop("host", "localhost"), - port=doc_extra.pop("port", "PORT"), + host=doc_extra.pop("host", None), + port=doc_extra.pop("port", None), prefix=url_prefix, ) ) @@ -98,14 +96,13 @@ async def lifespan(app: FastAPI): allow_methods=["*"], allow_headers=["*"], ) - token = auth_token or os.environ.get("INFINITY_API_KEY") - if token: + if api_key: oauth2_scheme = HTTPBearer(auto_error=False) async def validate_token( credential: Optional[HTTPAuthorizationCredentials] = Depends(oauth2_scheme), ): - if credential and credential.credentials != token: + if credential and credential.credentials != api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized", @@ -358,27 +355,28 @@ def __iter__(self): @tp.command("v1") def v1( - model_name_or_path: str = "michaelfeil/bge-small-en-v1.5", - served_model_name: str = "", - batch_size: int = 32, - revision: str = "", - trust_remote_code: bool = True, - url_prefix: str = "", - host: str = "0.0.0.0", - port: int = 7997, - redirect_slash: str = "/docs", - log_level: UVICORN_LOG_LEVELS = UVICORN_LOG_LEVELS.info.name, # type: ignore + model_name_or_path: str = MANAGER.model_id[0], + served_model_name: str = MANAGER.served_model_name[0], + batch_size: int = MANAGER.batch_size[0], + revision: str = MANAGER.revision[0], + trust_remote_code: bool = MANAGER.trust_remote_code[0], + redirect_slash: str = MANAGER.redirect_slash, engine: InferenceEngine = InferenceEngine.default_value(), # type: ignore # noqa - model_warmup: bool = True, - vector_disk_cache: bool = INFINITY_CACHE_VECTORS, + model_warmup: bool = MANAGER.model_warmup[0], + vector_disk_cache: bool = MANAGER.vector_disk_cache[0], device: Device = Device.default_value(), # type: ignore - lengths_via_tokenize: bool = False, + lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0], dtype: Dtype = Dtype.default_value(), # type: ignore pooling_method: PoolingMethod = PoolingMethod.default_value(), # type: ignore - compile: bool = False, - bettertransformer: bool = True, - preload_only: bool = False, - permissive_cors: bool = False, + compile: bool = MANAGER.compile[0], + bettertransformer: bool = MANAGER.bettertransformer[0], + preload_only: bool = MANAGER.preload_only, + permissive_cors: bool = MANAGER.permissive_cors, + api_key: str = MANAGER.api_key, + url_prefix: str = MANAGER.url_prefix, + host: str = MANAGER.host, + port: int = MANAGER.port, + log_level: UVICORN_LOG_LEVELS = MANAGER.log_level, # type: ignore ): """Infinity Embedding API ♾️ cli v1 to start a uvicorn-server instance; MIT License; Copyright (c) 2023-now Michael Feil @@ -390,9 +388,6 @@ def v1( batch_size, int: batch size for forward pass. revision: str: revision of the model. trust_remote_code, bool: trust remote code. - url_prefix, str: prefix for api. typically "". - host, str: host-url, typically either "0.0.0.0" or "127.0.0.1". - port, int: port that you want to expose. redirect_slash, str: redirect to of GET "/". Defaults to "/docs". Empty string to disable. log_level: logging level. For high performance, use "info" or higher levels. Defaults to "info". @@ -400,7 +395,7 @@ def v1( model_warmup, bool: perform model warmup before starting the server. Defaults to True. vector_disk_cache, bool: cache past embeddings in SQL. - Defaults to False or env-INFINITY_CACHE_VECTORS if set + Defaults to False device, Device: device to use for inference. Defaults to Device.auto or "auto" lengths_via_tokenize: bool: schedule by token usage. Defaults to False. dtype, Dtype: data type to use for inference. Defaults to Dtype.auto or "auto" @@ -409,7 +404,17 @@ def v1( use_bettertransformer, bool: use bettertransformer. Defaults to True. preload_only, bool: only preload the model and exit. Defaults to False. permissive_cors, bool: add permissive CORS headers to enable consumption from a browser. Defaults to False. + api_key, str: optional Bearer token for authentication. Defaults to "", which disables authentication. + url_prefix, str: prefix for api. typically "". + host, str: host-url, typically either "0.0.0.0" or "127.0.0.1". + port, int: port that you want to expose. """ + if api_key: + raise ValueError("api_key is not supported in v1") + logger.warning( + "CLI v1 is deprecated and might be removed in the future. Please use CLI v2, by specifying `v2` as the command." + ) + time.sleep(5) v2( model_id=[model_name_or_path], served_model_name=[served_model_name], # type: ignore @@ -433,35 +438,35 @@ def v1( redirect_slash=redirect_slash, log_level=log_level, permissive_cors=permissive_cors, + api_key=api_key, ) @tp.command("v2") def v2( # arguments for engine - model_id: list[str] = [ - "michaelfeil/bge-small-en-v1.5", - ], - served_model_name: list[str] = [""], - batch_size: list[int] = [32], - revision: list[str] = [""], - trust_remote_code: list[bool] = [True], + model_id: list[str] = MANAGER.model_id, + served_model_name: list[str] = MANAGER.served_model_name, + batch_size: list[int] = MANAGER.batch_size, + revision: list[str] = MANAGER.revision, + trust_remote_code: list[bool] = MANAGER.trust_remote_code, engine: list[InferenceEngine] = [InferenceEngine.default_value()], # type: ignore # noqa - model_warmup: list[bool] = [True], - vector_disk_cache: list[bool] = [INFINITY_CACHE_VECTORS], + model_warmup: list[bool] = MANAGER.model_warmup, + vector_disk_cache: list[bool] = MANAGER.vector_disk_cache, device: list[Device] = [Device.default_value()], # type: ignore - lengths_via_tokenize: list[bool] = [False], + lengths_via_tokenize: list[bool] = MANAGER.lengths_via_tokenize, dtype: list[Dtype] = [Dtype.default_value()], # type: ignore pooling_method: list[PoolingMethod] = [PoolingMethod.default_value()], # type: ignore - compile: list[bool] = [False], - bettertransformer: list[bool] = [True], + compile: list[bool] = MANAGER.compile, + bettertransformer: list[bool] = MANAGER.bettertransformer, # arguments for uvicorn / server - preload_only: bool = False, - host: str = "0.0.0.0", - port: int = 7997, - url_prefix: str = "", - redirect_slash: str = "/docs", - log_level: UVICORN_LOG_LEVELS = UVICORN_LOG_LEVELS.info.name, # type: ignore + preload_only: bool = MANAGER.preload_only, + host: str = MANAGER.host, + port: int = MANAGER.port, + url_prefix: str = MANAGER.url_prefix, + redirect_slash: str = MANAGER.redirect_slash, + log_level: UVICORN_LOG_LEVELS = MANAGER.log_level, # type: ignore permissive_cors: bool = False, + api_key: str = "", ): """Infinity Embedding API ♾️ cli v2 to start a uvicorn-server instance; MIT License; Copyright (c) 2023-now Michael Feil @@ -492,6 +497,7 @@ def v2( use_bettertransformer, bool: use bettertransformer. Defaults to True. preload_only, bool: only preload the model and exit. Defaults to False. permissive_cors, bool: add permissive CORS headers to enable consumption from a browser. Defaults to False. + api_key, str: optional Bearer token for authentication. Defaults to "", which disables authentication. """ logger.setLevel(log_level.to_int()) padder = AutoPadding( @@ -523,6 +529,7 @@ def v2( redirect_slash=redirect_slash, preload_only=preload_only, permissive_cors=permissive_cors, + api_key=api_key, ) uvicorn.run(app, host=host, port=port, log_level=log_level.name) @@ -530,13 +537,12 @@ def cli(): if len(sys.argv) == 1 or sys.argv[1] not in ["v1", "v2", "help", "--help"]: for _ in range(3): logger.error( - "WARNING: No command given. Defaulting to `v1`." - "This will be deprecated in the future, and will require usage of a `v1` or `v2`" - "Specify the version of the CLI you want to use." + "WARNING: No command given. Defaulting to `v1`. " + "This will be deprecated in the future, and will require usage of a `v1` or `v2`. " + "Specify the version of the CLI you want to use. " ) time.sleep(1) sys.argv.insert(1, "v1") - print(sys.argv) tp() if __name__ == "__main__": diff --git a/libs/infinity_emb/tests/end_to_end/test_authentication.py b/libs/infinity_emb/tests/end_to_end/test_authentication.py index 58b3e20b..fdd125c7 100644 --- a/libs/infinity_emb/tests/end_to_end/test_authentication.py +++ b/libs/infinity_emb/tests/end_to_end/test_authentication.py @@ -10,7 +10,7 @@ MODEL_NAME = "dummy/model1" MODEL_NAME_2 = "dummy/model2" BATCH_SIZE = 16 -AUTH_TOKEN = "dummy-password" +API_KEY = "dummy-password" app = create_server( url_prefix=PREFIX, @@ -26,7 +26,7 @@ engine=InferenceEngine.debugengine, ), ], - auth_token=AUTH_TOKEN, + api_key=API_KEY, ) @@ -54,7 +54,7 @@ async def test_authentication(client): ], ]: for authenticated in [False, True]: - headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if authenticated else {} + headers = {"Authorization": f"Bearer {API_KEY}"} if authenticated else {} response = await client.post( route, headers=headers, From 40315299edf4e5bd0bd17166bdf9e9d9b00b30e0 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 21:55:00 -0700 Subject: [PATCH 03/11] update env print --- libs/infinity_emb/infinity_emb/env.py | 39 ++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/env.py b/libs/infinity_emb/infinity_emb/env.py index 5db6ccea..fb211f27 100644 --- a/libs/infinity_emb/infinity_emb/env.py +++ b/libs/infinity_emb/infinity_emb/env.py @@ -4,26 +4,47 @@ import os from functools import cached_property - class __Infinity_EnvManager: def __init__(self): + self._debug(f"Loading Infinity ENV variables.\nCONFIG:\n{'-'*10}") for f_name in dir(self): if isinstance(getattr(type(self), f_name, None), cached_property): getattr(self, f_name) # pre-cache - + self._debug(f"{'-'*10}\nENV variables loaded.") + + def _debug(self, message: str): + if "API_KEY" in message: + print("INFINITY_API_KEY=not_shown") + print(f"INFINITY_LOG_LEVEL={self.log_level}") + elif "LOG_LEVEL" in message: + return # recursion + elif self.log_level in {"debug", "trace"}: + print(message) + @staticmethod - def _optional_infinity_var(name: str, default: str = ""): - name = name.upper().replace("-", "_") - return os.getenv(f"INFINITY_{name}", default) + def _to_name(name: str) -> str: + return "INFINITY_" + name.upper().replace("-", "_") + + def _optional_infinity_var(self, name: str, default: str = ""): + name = self._to_name(name) + value = os.getenv(name) + if value is None: + self._debug(f"{name}=`{default}`(default)") + return default + self._debug(f"{name}=`{value}`") + return value - @staticmethod - def _optional_infinity_var_multiple(name: str, default: list[str]) -> list[str]: + def _optional_infinity_var_multiple(self, name: str, default: list[str]) -> list[str]: + name = self._to_name(name) value = os.getenv(name) - if not value: + if value is None: + self._debug(f"{name}=`{';'.join(default)}`(default)") return default if value.endswith(";"): value = value[:-1] - return value.split(";") + value = value.split(";") + self._debug(f"{name}=`{';'.join(value)}`") + return value @staticmethod def _to_bool(value: str) -> bool: From 0eda8f5f684253d1360923fc7fd20b59a783101a Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 21:59:41 -0700 Subject: [PATCH 04/11] fmt --- libs/infinity_emb/infinity_emb/env.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/env.py b/libs/infinity_emb/infinity_emb/env.py index fb211f27..508919d4 100644 --- a/libs/infinity_emb/infinity_emb/env.py +++ b/libs/infinity_emb/infinity_emb/env.py @@ -4,6 +4,7 @@ import os from functools import cached_property + class __Infinity_EnvManager: def __init__(self): self._debug(f"Loading Infinity ENV variables.\nCONFIG:\n{'-'*10}") @@ -11,20 +12,20 @@ def __init__(self): if isinstance(getattr(type(self), f_name, None), cached_property): getattr(self, f_name) # pre-cache self._debug(f"{'-'*10}\nENV variables loaded.") - + def _debug(self, message: str): if "API_KEY" in message: print("INFINITY_API_KEY=not_shown") print(f"INFINITY_LOG_LEVEL={self.log_level}") elif "LOG_LEVEL" in message: - return # recursion + return # recursion elif self.log_level in {"debug", "trace"}: print(message) - + @staticmethod def _to_name(name: str) -> str: return "INFINITY_" + name.upper().replace("-", "_") - + def _optional_infinity_var(self, name: str, default: str = ""): name = self._to_name(name) value = os.getenv(name) @@ -34,7 +35,9 @@ def _optional_infinity_var(self, name: str, default: str = ""): self._debug(f"{name}=`{value}`") return value - def _optional_infinity_var_multiple(self, name: str, default: list[str]) -> list[str]: + def _optional_infinity_var_multiple( + self, name: str, default: list[str] + ) -> list[str]: name = self._to_name(name) value = os.getenv(name) if value is None: From a69d5b7681112dd0a1e7fbc6ac68328ce9dc76b1 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:00:13 -0700 Subject: [PATCH 05/11] fmt2 --- libs/infinity_emb/infinity_emb/env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/env.py b/libs/infinity_emb/infinity_emb/env.py index 508919d4..9902c732 100644 --- a/libs/infinity_emb/infinity_emb/env.py +++ b/libs/infinity_emb/infinity_emb/env.py @@ -45,9 +45,9 @@ def _optional_infinity_var_multiple( return default if value.endswith(";"): value = value[:-1] - value = value.split(";") - self._debug(f"{name}=`{';'.join(value)}`") - return value + value_list = value.split(";") + self._debug(f"{name}=`{';'.join(value_list)}`") + return value_list @staticmethod def _to_bool(value: str) -> bool: From 2c57736d453339d010c282cb6228578af74586d3 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:06:42 -0700 Subject: [PATCH 06/11] update args default --- libs/infinity_emb/infinity_emb/args.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/args.py b/libs/infinity_emb/infinity_emb/args.py index 15eda976..a25d877b 100644 --- a/libs/infinity_emb/infinity_emb/args.py +++ b/libs/infinity_emb/infinity_emb/args.py @@ -1,9 +1,9 @@ -import os import sys from dataclasses import asdict, dataclass from typing import Optional from infinity_emb._optional_imports import CHECK_PYDANTIC +from infinity_emb.env import MANAGER from infinity_emb.primitives import ( Device, Dtype, @@ -41,21 +41,21 @@ class EngineArgs: served_model_name, str: Defaults to readable name of model_name_or_path. """ - model_name_or_path: str = "michaelfeil/bge-small-en-v1.5" - batch_size: int = 32 - revision: Optional[str] = None - trust_remote_code: bool = True + model_name_or_path: str = MANAGER.model_id[0] + batch_size: int = MANAGER.batch_size[0] + revision: Optional[str] = MANAGER.revision[0] + trust_remote_code: bool = MANAGER.trust_remote_code[0] engine: InferenceEngine = InferenceEngine.torch - model_warmup: bool = False + model_warmup: bool = MANAGER.model_warmup[0] vector_disk_cache_path: str = "" device: Device = Device.auto - compile: bool = not os.environ.get("INFINITY_DISABLE_COMPILE", "Disable") - bettertransformer: bool = True + compile: bool = MANAGER.compile[0] + bettertransformer: bool = MANAGER.bettertransformer[0] dtype: Dtype = Dtype.auto pooling_method: PoolingMethod = PoolingMethod.auto - lengths_via_tokenize: bool = False + lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0] embedding_dtype: EmbeddingDtype = EmbeddingDtype.float32 - served_model_name: str = "" + served_model_name: str = MANAGER.served_model_name[0] def __post_init__(self): # convert the following strings to enums From c4b9d7ae25d778e5cc3ed782da124779a8a88f6f Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:30:26 -0700 Subject: [PATCH 07/11] another env update --- libs/infinity_emb/infinity_emb/env.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/libs/infinity_emb/infinity_emb/env.py b/libs/infinity_emb/infinity_emb/env.py index 9902c732..784b56b4 100644 --- a/libs/infinity_emb/infinity_emb/env.py +++ b/libs/infinity_emb/infinity_emb/env.py @@ -4,6 +4,14 @@ import os from functools import cached_property +from infinity_emb.primitives import ( + Device, + Dtype, + EmbeddingDtype, + InferenceEngine, + PoolingMethod, +) + class __Infinity_EnvManager: def __init__(self): @@ -161,5 +169,50 @@ def redirect_slash(self): def log_level(self): return self._optional_infinity_var("log_level", default="info") + @cached_property + def dtype(self) -> list[Dtype]: + return [ + Dtype(v) + for v in self._optional_infinity_var_multiple( + "dtype", default=[Dtype.default_value()] + ) + ] + + @cached_property + def engine(self) -> list[InferenceEngine]: + return [ + InferenceEngine(v) + for v in self._optional_infinity_var_multiple( + "engine", default=[InferenceEngine.default_value()] + ) + ] + + @cached_property + def pooling_method(self) -> list[PoolingMethod]: + return [ + PoolingMethod(v) + for v in self._optional_infinity_var_multiple( + "pooling_method", default=[PoolingMethod.default_value()] + ) + ] + + @cached_property + def device(self) -> list[Device]: + return [ + Device(v) + for v in self._optional_infinity_var_multiple( + "device", default=[Device.default_value()] + ) + ] + + @cached_property + def embedding_dtype(self) -> list[EmbeddingDtype]: + return [ + EmbeddingDtype(v) + for v in self._optional_infinity_var_multiple( + "embedding_dtype", default=[EmbeddingDtype.default_value()] + ) + ] + MANAGER = __Infinity_EnvManager() From 425c8d027237e26b3111bb5fa888b99f25e090da Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:34:14 -0700 Subject: [PATCH 08/11] update torch code --- .../tests/end_to_end/test_torch_reranker.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py b/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py index 721a2147..9dffd2ea 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py @@ -88,3 +88,23 @@ async def test_reranker(client, model_base, helpers): assert len(rdata_results) == len(predictions) for i, pred in enumerate(predictions): assert abs(rdata_results[i]["relevance_score"] - pred["score"]) < 0.01 + + +@pytest.mark.anyio +async def test_reranker_cant_embed_or_classify(client): + documents = [ + "The Eiffel Tower is located in Paris, France", + "The Eiffel Tower is located in the United States.", + "The Eiffel Tower is located in the United Kingdom.", + ] + response = await client.post( + f"{PREFIX}/embeddings", + json={"model": MODEL, "input": documents}, + ) + assert response.status_code == 400 + + response = await client.post( + f"{PREFIX}/classify", + json={"model": MODEL, "input": documents}, + ) + assert response.status_code == 400 \ No newline at end of file From bdc679bdc58c294f7806cc951f3556ed5d319bbb Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:36:52 -0700 Subject: [PATCH 09/11] fmt --- libs/infinity_emb/tests/end_to_end/test_torch_reranker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py b/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py index 9dffd2ea..74cd8dbf 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py @@ -107,4 +107,4 @@ async def test_reranker_cant_embed_or_classify(client): f"{PREFIX}/classify", json={"model": MODEL, "input": documents}, ) - assert response.status_code == 400 \ No newline at end of file + assert response.status_code == 400 From 95aacb62ef8bb6239d1d841602fb36768b776940 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:49:14 -0700 Subject: [PATCH 10/11] add emotions classifier --- README.md | 1 - docs/docs/python_engine.md | 1 - libs/infinity_emb/tests/conftest.py | 1 + .../tests/end_to_end/test_torch_classify.py | 84 +++++++++++++++++++ 4 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 libs/infinity_emb/tests/end_to_end/test_torch_classify.py diff --git a/README.md b/README.md index 4e618eca..43e5b5d3 100644 --- a/README.md +++ b/README.md @@ -171,7 +171,6 @@ engine = AsyncEmbeddingEngine.from_args(engine_args) async def main(): async with engine: predictions, usage = await engine.classify(sentences=sentences) - return predictions, usage # or handle the async start / stop yourself. await engine.astart() predictions, usage = await engine.classify(sentences=sentences) diff --git a/docs/docs/python_engine.md b/docs/docs/python_engine.md index d7d04d79..e717b273 100644 --- a/docs/docs/python_engine.md +++ b/docs/docs/python_engine.md @@ -86,7 +86,6 @@ engine = AsyncEmbeddingEngine.from_args(engine_args) async def main(): async with engine: predictions, usage = await engine.classify(sentences=sentences) - return predictions, usage # or handle the async start / stop yourself. await engine.astart() predictions, usage = await engine.classify(sentences=sentences) diff --git a/libs/infinity_emb/tests/conftest.py b/libs/infinity_emb/tests/conftest.py index 60542b98..c4460778 100644 --- a/libs/infinity_emb/tests/conftest.py +++ b/libs/infinity_emb/tests/conftest.py @@ -7,6 +7,7 @@ pytest.DEFAULT_BERT_MODEL = "michaelfeil/bge-small-en-v1.5" pytest.DEFAULT_RERANKER_MODEL = "BAAI/bge-reranker-base" +pytest.DEFAULT_CLASSIFIER_MODEL = "SamLowe/roberta-base-go_emotions" @pytest.fixture diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_classify.py b/libs/infinity_emb/tests/end_to_end/test_torch_classify.py new file mode 100644 index 00000000..4ebffc32 --- /dev/null +++ b/libs/infinity_emb/tests/end_to_end/test_torch_classify.py @@ -0,0 +1,84 @@ +import pytest +import torch +from asgi_lifespan import LifespanManager +from httpx import AsyncClient +from transformers import pipeline # type: ignore[import-untyped] + +from infinity_emb import create_server +from infinity_emb.args import EngineArgs +from infinity_emb.primitives import Device, InferenceEngine + +PREFIX = "/v1_classify" +MODEL: str = pytest.DEFAULT_CLASSIFIER_MODEL # type: ignore[assignment] +batch_size = 32 if torch.cuda.is_available() else 8 + +app = create_server( + url_prefix=PREFIX, + engine_args_list=[ + EngineArgs( + model_name_or_path=MODEL, + batch_size=batch_size, + engine=InferenceEngine.torch, + device=Device.auto if not torch.backends.mps.is_available() else Device.cpu, + ) + ], +) + + +@pytest.fixture +def model_base() -> pipeline: + return pipeline(model=MODEL, task="text-classification") + + +@pytest.fixture() +async def client(): + async with AsyncClient( + app=app, base_url="http://test", timeout=20 + ) as client, LifespanManager(app): + yield client + + +def test_load_model(model_base): + # this makes sure that the error below is not based on a slow download + # or internal pytorch errors + model_base.predict( + { + "text": "I love fries!", + } + ) + + +@pytest.mark.anyio +async def test_model_route(client): + response = await client.get(f"{PREFIX}/models") + assert response.status_code == 200 + rdata = response.json() + assert "data" in rdata + assert rdata["data"][0].get("id", "") == MODEL + assert isinstance(rdata["data"][0].get("stats"), dict) + + +@pytest.mark.anyio +async def test_classifier(client, model_base): + documents = [ + "I love fries!", + "I hate fries!", + "I am jealous of fries!", + ] + response = await client.post( + f"{PREFIX}/classify", + json={"model": MODEL, "input": documents}, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + # rdata_results = rdata["results"] + + # predictions = [ + # model_base.predict({"text": query, "text_pair": doc}) for doc in documents + # ] + + # assert len(rdata_results) == len(predictions) + # for i, pred in enumerate(predictions): + # assert abs(rdata_results[i]["relevance_score"] - pred["score"]) < 0.01 From 4343ff94dd89e107b5b33071d000e8e5eea28e19 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 30 May 2024 22:49:28 -0700 Subject: [PATCH 11/11] fmt --- libs/infinity_emb/tests/end_to_end/test_torch_classify.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_classify.py b/libs/infinity_emb/tests/end_to_end/test_torch_classify.py index 4ebffc32..6f8ec274 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_classify.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_classify.py @@ -42,10 +42,10 @@ def test_load_model(model_base): # this makes sure that the error below is not based on a slow download # or internal pytorch errors model_base.predict( - { - "text": "I love fries!", - } - ) + { + "text": "I love fries!", + } + ) @pytest.mark.anyio