Skip to content

Commit

Permalink
feat: Add support for TEI API key authentication (#11006)
Browse files Browse the repository at this point in the history
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
  • Loading branch information
kenwoodjw and crazywoola authored Nov 23, 2024
1 parent 16c4158 commit 096c0ad
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
- variable: api_key
label:
en_US: API Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ def _invoke(

server_url = server_url.removesuffix("/")

headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

try:
results = TeiHelper.invoke_rerank(server_url, query, docs)
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)

rerank_documents = []
for result in results:
Expand Down Expand Up @@ -80,7 +85,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
try:
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
if extra_args.model_type != "reranker":
raise CredentialsValidateFailedError("Current model is not a rerank model")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ def __init__(self, model_type: str, max_input_length: int, max_client_batch_size

class TeiHelper:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
def get_tei_extra_parameter(
server_url: str, model_name: str, headers: Optional[dict] = None
) -> TeiModelExtraParameter:
TeiHelper._clean_cache()
with cache_lock:
if model_name not in cache:
cache[model_name] = {
"expires": time() + 300,
"value": TeiHelper._get_tei_extra_parameter(server_url),
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
}
return cache[model_name]["value"]

Expand All @@ -47,7 +49,7 @@ def _clean_cache() -> None:
pass

@staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
"""
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""
Expand All @@ -61,7 +63,7 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
session.mount("https://", HTTPAdapter(max_retries=3))

try:
response = session.get(url, timeout=10)
response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
if response.status_code != 200:
Expand All @@ -86,7 +88,7 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
)

@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
"""
Invoke tokenize endpoint
Expand Down Expand Up @@ -114,15 +116,15 @@ def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
:param server_url: server url
:param texts: texts to tokenize
"""
resp = httpx.post(
f"{server_url}/tokenize",
json={"inputs": texts},
)
url = f"{server_url}/tokenize"
json_data = {"inputs": texts}
resp = httpx.post(url, json=json_data, headers=headers)

resp.raise_for_status()
return resp.json()

@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
"""
Invoke embeddings endpoint
Expand All @@ -147,15 +149,14 @@ def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
:param texts: texts to embed
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f"{server_url}/v1/embeddings",
json={"input": texts},
)
url = f"{server_url}/v1/embeddings"
json_data = {"input": texts}
resp = httpx.post(url, json=json_data, headers=headers)
resp.raise_for_status()
return resp.json()

@staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
"""
Invoke rerank endpoint
Expand All @@ -173,10 +174,7 @@ def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
:param candidates: candidates to rerank
"""
params = {"query": query, "texts": docs, "return_text": True}

response = httpx.post(
server_url + "/rerank",
json=params,
)
url = f"{server_url}/rerank"
response = httpx.post(url, json=params, headers=headers)
response.raise_for_status()
return response.json()
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def _invoke(

server_url = server_url.removesuffix("/")

headers = {"Content-Type": "application/json"}
api_key = credentials["api_key"]
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
Expand All @@ -60,7 +64,7 @@ def _invoke(
used_tokens = 0

# get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)

for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size
Expand Down Expand Up @@ -97,7 +101,7 @@ def _invoke(
used_tokens = 0
for i in _iter:
iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
embeddings = results["data"]
embeddings = [embedding["embedding"] for embedding in embeddings]
batched_embeddings.extend(embeddings)
Expand Down Expand Up @@ -127,7 +131,11 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int

server_url = server_url.removesuffix("/")

batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
}

batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens

Expand All @@ -141,7 +149,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
try:
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
headers = {"Content-Type": "application/json"}

api_key = credentials.get("api_key")

if api_key:
headers["Authorization"] = f"Bearer {api_key}"

extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
print(extra_args)
if extra_args.model_type != "embedding":
raise CredentialsValidateFailedError("Current model is not a embedding model")
Expand Down
1 change: 1 addition & 0 deletions api/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ env =
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
TEI_RERANK_SERVER_URL = http://a.abc.com:11451
TEI_API_KEY = ttttttttttttttt
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
XINFERENCE_CHAT_MODEL_UID = chat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def test_validate_credentials(setup_tei_mock):
model="reranker",
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

model.validate_credentials(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

Expand All @@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
texts=["hello", "world"],
user="abc-123",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def test_validate_credentials(setup_tei_mock):
model="embedding",
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

model.validate_credentials(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

Expand All @@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
query="Who is Kasumi?",
docs=[
Expand Down

0 comments on commit 096c0ad

Please sign in to comment.