-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add evaluation service module for RAG and Agent (#2070)
- Loading branch information
Showing
29 changed files
with
1,263 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Evaluation.""" | ||
from typing import List | ||
|
||
from dbgpt.core.schema.api import Result | ||
|
||
from ..core.interface.evaluation import EvaluationResult | ||
from ..serve.evaluate.api.schemas import EvaluateServeRequest | ||
from .client import Client, ClientException | ||
|
||
|
||
async def run_evaluation( | ||
client: Client, request: EvaluateServeRequest | ||
) -> List[EvaluationResult]: | ||
"""Run evaluation. | ||
Args: | ||
client (Client): The dbgpt client. | ||
request (EvaluateServeRequest): The Evaluate Request. | ||
""" | ||
try: | ||
res = await client.post("/evaluate/evaluation", request.dict()) | ||
result: Result = res.json() | ||
if result["success"]: | ||
return list(result["data"]) | ||
else: | ||
raise ClientException(status=result["err_code"], reason=result) | ||
except Exception as e: | ||
raise ClientException(f"Failed to run evaluation: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import logging | ||
from functools import cache | ||
from typing import List, Optional | ||
|
||
from fastapi import APIRouter, Depends, HTTPException | ||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer | ||
|
||
from dbgpt.component import ComponentType, SystemApp | ||
from dbgpt.core.interface.evaluation import metric_manage | ||
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory | ||
from dbgpt.rag.evaluation.answer import AnswerRelevancyMetric | ||
from dbgpt.serve.core import Result | ||
from dbgpt.serve.evaluate.api.schemas import EvaluateServeRequest, EvaluateServeResponse | ||
from dbgpt.serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME | ||
from dbgpt.serve.evaluate.service.service import Service | ||
|
||
from ...prompt.service.service import Service as PromptService | ||
|
||
router = APIRouter() | ||
|
||
# Add your API endpoints here | ||
|
||
global_system_app: Optional[SystemApp] = None | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_service() -> Service: | ||
"""Get the service instance""" | ||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) | ||
|
||
|
||
def get_prompt_service() -> PromptService: | ||
return global_system_app.get_component("dbgpt_serve_prompt_service", PromptService) | ||
|
||
|
||
def get_worker_manager() -> WorkerManager: | ||
worker_manager = global_system_app.get_component( | ||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory | ||
).create() | ||
return worker_manager | ||
|
||
|
||
def get_model_controller() -> BaseModelController: | ||
controller = global_system_app.get_component( | ||
ComponentType.MODEL_CONTROLLER, BaseModelController | ||
) | ||
return controller | ||
|
||
|
||
get_bearer_token = HTTPBearer(auto_error=False) | ||
|
||
|
||
@cache | ||
def _parse_api_keys(api_keys: str) -> List[str]: | ||
"""Parse the string api keys to a list | ||
Args: | ||
api_keys (str): The string api keys | ||
Returns: | ||
List[str]: The list of api keys | ||
""" | ||
if not api_keys: | ||
return [] | ||
return [key.strip() for key in api_keys.split(",")] | ||
|
||
|
||
async def check_api_key( | ||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), | ||
service: Service = Depends(get_service), | ||
) -> Optional[str]: | ||
"""Check the api key | ||
If the api key is not set, allow all. | ||
Your can pass the token in you request header like this: | ||
.. code-block:: python | ||
import requests | ||
client_api_key = "your_api_key" | ||
headers = {"Authorization": "Bearer " + client_api_key} | ||
res = requests.get("http://test/hello", headers=headers) | ||
assert res.status_code == 200 | ||
""" | ||
if service.config.api_keys: | ||
api_keys = _parse_api_keys(service.config.api_keys) | ||
if auth is None or (token := auth.credentials) not in api_keys: | ||
raise HTTPException( | ||
status_code=401, | ||
detail={ | ||
"error": { | ||
"message": "", | ||
"type": "invalid_request_error", | ||
"param": None, | ||
"code": "invalid_api_key", | ||
} | ||
}, | ||
) | ||
return token | ||
else: | ||
# api_keys not set; allow all | ||
return None | ||
|
||
|
||
@router.get("/health", dependencies=[Depends(check_api_key)]) | ||
async def health(): | ||
"""Health check endpoint""" | ||
return {"status": "ok"} | ||
|
||
|
||
@router.get("/test_auth", dependencies=[Depends(check_api_key)]) | ||
async def test_auth(): | ||
"""Test auth endpoint""" | ||
return {"status": "ok"} | ||
|
||
|
||
@router.get("/scenes") | ||
async def get_scenes(): | ||
scene_list = [{"recall": "召回评测"}, {"app": "应用评测"}] | ||
|
||
return Result.succ(scene_list) | ||
|
||
|
||
@router.post("/evaluation") | ||
async def evaluation( | ||
request: EvaluateServeRequest, | ||
service: Service = Depends(get_service), | ||
) -> Result: | ||
"""Evaluate results by the scene | ||
Args: | ||
request (EvaluateServeRequest): The request | ||
service (Service): The service | ||
Returns: | ||
ServerResponse: The response | ||
""" | ||
return Result.succ( | ||
await service.run_evaluation( | ||
request.scene_key, | ||
request.scene_value, | ||
request.datasets, | ||
request.context, | ||
request.evaluate_metrics, | ||
) | ||
) | ||
|
||
|
||
def init_endpoints(system_app: SystemApp) -> None: | ||
"""Initialize the endpoints""" | ||
global global_system_app | ||
system_app.register(Service) | ||
global_system_app = system_app |
Oops, something went wrong.