Skip to content

Commit

Permalink
TEI rerank microservice async support (#746)
Browse files Browse the repository at this point in the history
* tTEIrerank microservice support async

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lvliang-intel and pre-commit-ci[bot] authored Sep 30, 2024
1 parent bece9f4 commit 9df4b3c
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions comps/reranks/tei/reranking_tei.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import heapq
import json
import os
import re
import time
from typing import Union

import requests
import aiohttp

from comps import (
CustomLogger,
Expand All @@ -27,21 +25,21 @@
RerankingResponseData,
)

logger = CustomLogger("reranking_tgi_gaudi")
logger = CustomLogger("reranking_tei")
logflag = os.getenv("LOGFLAG", False)


@register_microservice(
name="opea_service@reranking_tgi_gaudi",
name="opea_service@reranking_tei",
service_type=ServiceType.RERANK,
endpoint="/v1/reranking",
host="0.0.0.0",
port=8000,
input_datatype=SearchedDoc,
output_datatype=LLMParamsDoc,
)
@register_statistics(names=["opea_service@reranking_tgi_gaudi"])
def reranking(
@register_statistics(names=["opea_service@reranking_tei"])
async def reranking(
input: Union[SearchedDoc, RerankingRequest, ChatCompletionRequest]
) -> Union[LLMParamsDoc, RerankingResponse, ChatCompletionRequest]:
if logflag:
Expand All @@ -58,15 +56,16 @@ def reranking(
query = input.input
data = {"query": query, "texts": docs}
headers = {"Content-Type": "application/json"}
response = requests.post(url, data=json.dumps(data), headers=headers)
response_data = response.json()
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(data), headers=headers) as response:
response_data = await response.json()

for best_response in response_data[: input.top_n]:
reranking_results.append(
{"text": input.retrieved_docs[best_response["index"]].text, "score": best_response["score"]}
)

statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None)
statistics_dict["opea_service@reranking_tei"].append_latency(time.time() - start, None)
if isinstance(input, SearchedDoc):
result = [doc["text"] for doc in reranking_results]
if logflag:
Expand All @@ -92,4 +91,4 @@ def reranking(

if __name__ == "__main__":
tei_reranking_endpoint = os.getenv("TEI_RERANKING_ENDPOINT", "http://localhost:8080")
opea_microservices["opea_service@reranking_tgi_gaudi"].start()
opea_microservices["opea_service@reranking_tei"].start()

0 comments on commit 9df4b3c

Please sign in to comment.