From a10d5f57b2e0541869d7647b1a2ef5d1c1ee4b0a Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 19 Dec 2023 13:41:02 +0800 Subject: [PATCH] feat(core): Add API authentication for serve template (#950) --- dbgpt/app/dbgpt_server.py | 3 - .../initialization/db_model_initialization.py | 1 - .../initialization/serve_initialization.py | 6 +- dbgpt/app/prompt/api.py | 46 ---- dbgpt/app/prompt/prompt_manage_db.py | 89 ------ dbgpt/app/prompt/request/request.py | 44 --- dbgpt/app/prompt/request/response.py | 26 -- dbgpt/app/prompt/service.py | 87 ------ dbgpt/{app/prompt => model/llm}/__init__.py | 0 dbgpt/serve/core/config.py | 2 + .../request => serve/core/tests}/__init__.py | 0 dbgpt/serve/core/tests/conftest.py | 59 ++++ dbgpt/serve/prompt/api/endpoints.py | 94 ++++++- dbgpt/serve/prompt/api/schemas.py | 131 ++++----- dbgpt/serve/prompt/config.py | 15 +- dbgpt/serve/prompt/serve.py | 8 +- dbgpt/serve/prompt/service/service.py | 22 +- dbgpt/serve/prompt/tests/__init__.py | 0 dbgpt/serve/prompt/tests/test_endpoints.py | 176 ++++++++++++ dbgpt/serve/prompt/tests/test_models.py | 257 ++++++++++++++++++ dbgpt/serve/prompt/tests/test_service.py | 154 +++++++++++ .../default_serve_template/api/endpoints.py | 88 +++++- .../default_serve_template/api/schemas.py | 6 + .../default_serve_template/config.py | 6 +- .../default_serve_template/serve.py | 8 +- .../default_serve_template/service/service.py | 6 +- .../default_serve_template/tests/__init__.py | 0 .../tests/test_endpoints.py | 124 +++++++++ .../tests/test_models.py | 109 ++++++++ .../tests/test_service.py | 76 ++++++ dbgpt/serve/utils/cli.py | 9 + dbgpt/storage/metadata/_base_dao.py | 3 +- dbgpt/storage/metadata/tests/test_base_dao.py | 22 ++ dbgpt/util/config_utils.py | 9 +- 34 files changed, 1301 insertions(+), 385 deletions(-) delete mode 100644 dbgpt/app/prompt/api.py delete mode 100644 dbgpt/app/prompt/prompt_manage_db.py delete mode 100644 dbgpt/app/prompt/request/request.py delete mode 100644 dbgpt/app/prompt/request/response.py delete mode 100644 dbgpt/app/prompt/service.py rename dbgpt/{app/prompt => model/llm}/__init__.py (100%) rename dbgpt/{app/prompt/request => serve/core/tests}/__init__.py (100%) create mode 100644 dbgpt/serve/core/tests/conftest.py create mode 100644 dbgpt/serve/prompt/tests/__init__.py create mode 100644 dbgpt/serve/prompt/tests/test_endpoints.py create mode 100644 dbgpt/serve/prompt/tests/test_models.py create mode 100644 dbgpt/serve/prompt/tests/test_service.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/tests/__init__.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/tests/test_endpoints.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py create mode 100644 dbgpt/serve/utils/_template_files/default_serve_template/tests/test_service.py diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index c23e24f79..388d014b2 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -77,8 +77,6 @@ def mount_routers(app: FastAPI): """Lazy import to avoid high time cost""" from dbgpt.app.knowledge.api import router as knowledge_router - # from dbgpt.app.prompt.api import router as prompt_router - # prompt has been removed to dbgpt.serve.prompt from dbgpt.app.llm_manage.api import router as llm_manage_api from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1 @@ -93,7 +91,6 @@ def mount_routers(app: FastAPI): app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"]) app.include_router(knowledge_router, tags=["Knowledge"]) - # app.include_router(prompt_router, tags=["Prompt"]) def mount_static_files(app: FastAPI): diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index 9095da3e0..f236ef649 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -7,7 +7,6 @@ from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity -# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity from dbgpt.storage.chat_history.chat_history_db import ( diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 5f132cc2d..50d098ad7 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -3,7 +3,11 @@ def register_serve_apps(system_app: SystemApp): """Register serve apps""" - from dbgpt.serve.prompt.serve import Serve as PromptServe + from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX # Replace old prompt serve + # Set config + system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt") + system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt") + # Register serve app system_app.register(PromptServe, api_prefix="/prompt") diff --git a/dbgpt/app/prompt/api.py b/dbgpt/app/prompt/api.py deleted file mode 100644 index f3d4b6d01..000000000 --- a/dbgpt/app/prompt/api.py +++ /dev/null @@ -1,46 +0,0 @@ -from fastapi import APIRouter - -from dbgpt.app.openapi.api_view_model import Result -from dbgpt.app.prompt.service import PromptManageService -from dbgpt.app.prompt.request.request import PromptManageRequest - -router = APIRouter() - -prompt_manage_service = PromptManageService() - - -@router.post("/prompt/add") -def prompt_add(request: PromptManageRequest): - print(f"/prompt/add params: {request}") - try: - prompt_manage_service.create_prompt(request) - return Result.succ([]) - except Exception as e: - return Result.failed(code="E010X", msg=f"prompt add error {e}") - - -@router.post("/prompt/list") -def prompt_list(request: PromptManageRequest): - print(f"/prompt/list params: {request}") - try: - return Result.succ(prompt_manage_service.get_prompts(request)) - except Exception as e: - return Result.failed(code="E010X", msg=f"prompt list error {e}") - - -@router.post("/prompt/update") -def prompt_update(request: PromptManageRequest): - print(f"/prompt/update params: {request}") - try: - return Result.succ(prompt_manage_service.update_prompt(request)) - except Exception as e: - return Result.failed(code="E010X", msg=f"prompt update error {e}") - - -@router.post("/prompt/delete") -def prompt_delete(request: PromptManageRequest): - print(f"/prompt/delete params: {request}") - try: - return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name)) - except Exception as e: - return Result.failed(code="E010X", msg=f"prompt delete error {e}") diff --git a/dbgpt/app/prompt/prompt_manage_db.py b/dbgpt/app/prompt/prompt_manage_db.py deleted file mode 100644 index 6ba281668..000000000 --- a/dbgpt/app/prompt/prompt_manage_db.py +++ /dev/null @@ -1,89 +0,0 @@ -from datetime import datetime - -from sqlalchemy import Column, Integer, Text, String, DateTime - -from dbgpt.storage.metadata import BaseDao, Model - -from dbgpt._private.config import Config - -from dbgpt.app.prompt.request.request import PromptManageRequest - -CFG = Config() - - -class PromptManageEntity(Model): - __tablename__ = "prompt_manage" - id = Column(Integer, primary_key=True) - chat_scene = Column(String(100)) - sub_chat_scene = Column(String(100)) - prompt_type = Column(String(100)) - prompt_name = Column(String(512)) - content = Column(Text) - user_name = Column(String(128)) - sys_code = Column(String(128), index=True, nullable=True, comment="System code") - gmt_created = Column(DateTime) - gmt_modified = Column(DateTime) - - def __repr__(self): - return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" - - -class PromptManageDao(BaseDao): - def create_prompt(self, prompt: PromptManageRequest): - session = self.get_raw_session() - prompt_manage = PromptManageEntity( - chat_scene=prompt.chat_scene, - sub_chat_scene=prompt.sub_chat_scene, - prompt_type=prompt.prompt_type, - prompt_name=prompt.prompt_name, - content=prompt.content, - user_name=prompt.user_name, - sys_code=prompt.sys_code, - gmt_created=datetime.now(), - gmt_modified=datetime.now(), - ) - session.add(prompt_manage) - session.commit() - session.close() - - def get_prompts(self, query: PromptManageEntity): - session = self.get_raw_session() - prompts = session.query(PromptManageEntity) - if query.chat_scene is not None: - prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) - if query.sub_chat_scene is not None: - prompts = prompts.filter( - PromptManageEntity.sub_chat_scene == query.sub_chat_scene - ) - if query.prompt_type is not None: - prompts = prompts.filter( - PromptManageEntity.prompt_type == query.prompt_type - ) - if query.prompt_type == "private" and query.user_name is not None: - prompts = prompts.filter( - PromptManageEntity.user_name == query.user_name - ) - if query.prompt_name is not None: - prompts = prompts.filter( - PromptManageEntity.prompt_name == query.prompt_name - ) - if query.sys_code is not None: - prompts = prompts.filter(PromptManageEntity.sys_code == query.sys_code) - - prompts = prompts.order_by(PromptManageEntity.gmt_created.desc()) - result = prompts.all() - session.close() - return result - - def update_prompt(self, prompt: PromptManageEntity): - session = self.get_raw_session() - session.merge(prompt) - session.commit() - session.close() - - def delete_prompt(self, prompt: PromptManageEntity): - session = self.get_raw_session() - if prompt: - session.delete(prompt) - session.commit() - session.close() diff --git a/dbgpt/app/prompt/request/request.py b/dbgpt/app/prompt/request/request.py deleted file mode 100644 index 9d7f37deb..000000000 --- a/dbgpt/app/prompt/request/request.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import List - -from dbgpt._private.pydantic import BaseModel -from typing import Optional -from dbgpt._private.pydantic import BaseModel - - -class PromptManageRequest(BaseModel): - """Model for managing prompts.""" - - chat_scene: Optional[str] = None - """ - The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa. - """ - - sub_chat_scene: Optional[str] = None - """ - The sub chat scene. - """ - - prompt_type: Optional[str] = None - """ - The prompt type, either common or private. - """ - - content: Optional[str] = None - """ - The prompt content. - """ - - user_name: Optional[str] = None - """ - The user name. - """ - - sys_code: Optional[str] = None - """ - System code - """ - - prompt_name: Optional[str] = None - """ - The prompt name. - """ diff --git a/dbgpt/app/prompt/request/response.py b/dbgpt/app/prompt/request/response.py deleted file mode 100644 index 8f9221c15..000000000 --- a/dbgpt/app/prompt/request/response.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import List -from dbgpt._private.pydantic import BaseModel - - -class PromptQueryResponse(BaseModel): - id: int = None - """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" - - chat_scene: str = None - - """sub_chat_scene: sub chat scene""" - sub_chat_scene: str = None - - """prompt_type: common or private""" - prompt_type: str = None - - """content: prompt content""" - content: str = None - - """user_name: user name""" - user_name: str = None - - """prompt_name: prompt name""" - prompt_name: str = None - gmt_created: str = None - gmt_modified: str = None diff --git a/dbgpt/app/prompt/service.py b/dbgpt/app/prompt/service.py deleted file mode 100644 index 0eec54cce..000000000 --- a/dbgpt/app/prompt/service.py +++ /dev/null @@ -1,87 +0,0 @@ -from datetime import datetime - -from dbgpt.app.prompt.request.request import PromptManageRequest -from dbgpt.app.prompt.request.response import PromptQueryResponse -from dbgpt.app.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity - -prompt_manage_dao = PromptManageDao() - - -class PromptManageService: - def __init__(self): - pass - - """create prompt""" - - def create_prompt(self, request: PromptManageRequest): - query = PromptManageRequest( - prompt_name=request.prompt_name, - ) - err_sys_str = "" - if query.sys_code: - query.sys_code = request.sys_code - err_sys_str = f" and sys_code: {request.sys_code}" - prompt_name = prompt_manage_dao.get_prompts(query) - if len(prompt_name) > 0: - raise Exception( - f"prompt name: {request.prompt_name}{err_sys_str} have already named" - ) - prompt_manage_dao.create_prompt(request) - return True - - """get prompts""" - - def get_prompts(self, request: PromptManageRequest): - query = PromptManageRequest( - chat_scene=request.chat_scene, - sub_chat_scene=request.sub_chat_scene, - prompt_type=request.prompt_type, - prompt_name=request.prompt_name, - user_name=request.user_name, - sys_code=request.sys_code, - ) - responses = [] - prompts = prompt_manage_dao.get_prompts(query) - for prompt in prompts: - res = PromptQueryResponse() - - res.id = prompt.id - res.chat_scene = prompt.chat_scene - res.sub_chat_scene = prompt.sub_chat_scene - res.prompt_type = prompt.prompt_type - res.content = prompt.content - res.user_name = prompt.user_name - res.prompt_name = prompt.prompt_name - res.gmt_created = prompt.gmt_created - res.gmt_modified = prompt.gmt_modified - responses.append(res) - return responses - - """update prompt""" - - def update_prompt(self, request: PromptManageRequest): - query = PromptManageEntity(prompt_name=request.prompt_name) - prompts = prompt_manage_dao.get_prompts(query) - if len(prompts) != 1: - raise Exception( - f"there are no or more than one space called {request.prompt_name}" - ) - prompt = prompts[0] - prompt.chat_scene = request.chat_scene - prompt.sub_chat_scene = request.sub_chat_scene - prompt.prompt_type = request.prompt_type - prompt.content = request.content - prompt.user_name = request.user_name - prompt.gmt_modified = datetime.now() - return prompt_manage_dao.update_prompt(prompt) - - """delete prompt""" - - def delete_prompt(self, prompt_name: str): - query = PromptManageEntity(prompt_name=prompt_name) - prompts = prompt_manage_dao.get_prompts(query) - if len(prompts) == 0: - raise Exception(f"delete error, no prompt name:{prompt_name} in database ") - # delete prompt - prompt = prompts[0] - return prompt_manage_dao.delete_prompt(prompt) diff --git a/dbgpt/app/prompt/__init__.py b/dbgpt/model/llm/__init__.py similarity index 100% rename from dbgpt/app/prompt/__init__.py rename to dbgpt/model/llm/__init__.py diff --git a/dbgpt/serve/core/config.py b/dbgpt/serve/core/config.py index 0793fc9a5..4fd6d2247 100644 --- a/dbgpt/serve/core/config.py +++ b/dbgpt/serve/core/config.py @@ -16,4 +16,6 @@ def from_app_config(cls, config: AppConfig, config_prefix: str): config_prefix (str): Configuration prefix """ config_dict = config.get_all_by_prefix(config_prefix) + # remove prefix + config_dict = {k[len(config_prefix) :]: v for k, v in config_dict.items()} return cls(**config_dict) diff --git a/dbgpt/app/prompt/request/__init__.py b/dbgpt/serve/core/tests/__init__.py similarity index 100% rename from dbgpt/app/prompt/request/__init__.py rename to dbgpt/serve/core/tests/__init__.py diff --git a/dbgpt/serve/core/tests/conftest.py b/dbgpt/serve/core/tests/conftest.py new file mode 100644 index 000000000..9c8c7c77e --- /dev/null +++ b/dbgpt/serve/core/tests/conftest.py @@ -0,0 +1,59 @@ +import pytest +import pytest_asyncio +from typing import Dict +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.util import AppConfig + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +def create_system_app(param: Dict) -> SystemApp: + app_config = param.get("app_config", {}) + if isinstance(app_config, dict): + app_config = AppConfig(configs=app_config) + elif not isinstance(app_config, AppConfig): + raise RuntimeError("app_config must be AppConfig or dict") + return SystemApp(app, app_config) + + +@pytest_asyncio.fixture +async def asystem_app(request): + param = getattr(request, "param", {}) + return create_system_app(param) + + +@pytest.fixture +def system_app(request): + param = getattr(request, "param", {}) + return create_system_app(param) + + +@pytest_asyncio.fixture +async def client(request, asystem_app: SystemApp): + param = getattr(request, "param", {}) + headers = param.get("headers", {}) + base_url = param.get("base_url", "http://test") + client_api_key = param.get("client_api_key") + routers = param.get("routers", []) + app_caller = param.get("app_caller") + if "api_keys" in param: + del param["api_keys"] + if client_api_key: + headers["Authorization"] = "Bearer " + client_api_key + async with AsyncClient(app=app, base_url=base_url, headers=headers) as client: + for router in routers: + app.include_router(router) + if app_caller: + app_caller(app, asystem_app) + yield client diff --git a/dbgpt/serve/prompt/api/endpoints.py b/dbgpt/serve/prompt/api/endpoints.py index 493ac2f85..c0dbca3df 100644 --- a/dbgpt/serve/prompt/api/endpoints.py +++ b/dbgpt/serve/prompt/api/endpoints.py @@ -1,5 +1,8 @@ from typing import Optional, List -from fastapi import APIRouter, Depends, Query +from functools import cache +from fastapi import APIRouter, Depends, Query, HTTPException +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + from dbgpt.component import SystemApp from dbgpt.serve.core import Result @@ -20,14 +23,79 @@ def get_service() -> Service: return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key } + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + @router.get("/health") async def health(): """Health check endpoint""" return {"status": "ok"} +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + # TODO: Compatible with old API, will be modified in the future -@router.post("/add", response_model=Result[ServerResponse]) +@router.post( + "/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) async def create( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[ServerResponse]: @@ -42,7 +110,11 @@ async def create( return Result.succ(service.create(request)) -@router.post("/update", response_model=Result[ServerResponse]) +@router.post( + "/update", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) async def update( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[ServerResponse]: @@ -57,7 +129,9 @@ async def update( return Result.succ(service.update(request)) -@router.post("/delete", response_model=Result[None]) +@router.post( + "/delete", response_model=Result[None], dependencies=[Depends(check_api_key)] +) async def delete( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[None]: @@ -72,7 +146,11 @@ async def delete( return Result.succ(service.delete(request)) -@router.post("/list", response_model=Result[List[ServerResponse]]) +@router.post( + "/list", + response_model=Result[List[ServerResponse]], + dependencies=[Depends(check_api_key)], +) async def query( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[List[ServerResponse]]: @@ -87,7 +165,11 @@ async def query( return Result.succ(service.get_list(request)) -@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]]) +@router.post( + "/query_page", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) async def query_page( request: ServeRequest, page: Optional[int] = Query(default=1, description="current page"), diff --git a/dbgpt/serve/prompt/api/schemas.py b/dbgpt/serve/prompt/api/schemas.py index a6131dea5..e5c7610d4 100644 --- a/dbgpt/serve/prompt/api/schemas.py +++ b/dbgpt/serve/prompt/api/schemas.py @@ -1,73 +1,78 @@ # Define your Pydantic schemas here from typing import Optional from dbgpt._private.pydantic import BaseModel, Field +from ..config import SERVE_APP_NAME_HUMP class ServeRequest(BaseModel): """Prompt request model""" - chat_scene: Optional[str] = None - """ - The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa. - """ - - sub_chat_scene: Optional[str] = None - """ - The sub chat scene. - """ - - prompt_type: Optional[str] = None - """ - The prompt type, either common or private. - """ - - content: Optional[str] = None - """ - The prompt content. - """ - - user_name: Optional[str] = None - """ - The user name. - """ - - sys_code: Optional[str] = None - """ - System code - """ - - prompt_name: Optional[str] = None - """ - The prompt name. - """ - - -class ServerResponse(BaseModel): + class Config: + title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + + chat_scene: Optional[str] = Field( + None, + description="The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.", + examples=["chat_with_db_execute", "chat_excel", "chat_with_db_qa"], + ) + + sub_chat_scene: Optional[str] = Field( + None, + description="The sub chat scene.", + examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"], + ) + + prompt_type: Optional[str] = Field( + None, + description="The prompt type, either common or private.", + examples=["common", "private"], + ) + prompt_name: Optional[str] = Field( + None, + description="The prompt name.", + examples=["code_assistant", "joker", "data_analysis_expert"], + ) + content: Optional[str] = Field( + None, + description="The prompt content.", + examples=[ + "Write a qsort function in python", + "Tell me a joke about AI", + "You are a data analysis expert.", + ], + ) + + user_name: Optional[str] = Field( + None, + description="The user name.", + examples=["zhangsan", "lisi", "wangwu"], + ) + + sys_code: Optional[str] = Field( + None, + description="The system code.", + examples=["dbgpt", "auth_manager", "data_platform"], + ) + + +class ServerResponse(ServeRequest): """Prompt response model""" - id: int = None - """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" - - chat_scene: str = None - - """sub_chat_scene: sub chat scene""" - sub_chat_scene: str = None - - """prompt_type: common or private""" - prompt_type: str = None - - """content: prompt content""" - content: str = None - - """user_name: user name""" - user_name: str = None - - sys_code: Optional[str] = None - """ - System code - """ - - """prompt_name: prompt name""" - prompt_name: str = None - gmt_created: str = None - gmt_modified: str = None + class Config: + title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + + id: Optional[int] = Field( + None, + description="The prompt id.", + examples=[1, 2, 3], + ) + gmt_created: Optional[str] = Field( + None, + description="The prompt created time.", + examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"], + ) + gmt_modified: Optional[str] = Field( + None, + description="The prompt modified time.", + examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"], + ) diff --git a/dbgpt/serve/prompt/config.py b/dbgpt/serve/prompt/config.py index 6d033e3eb..c304eaf08 100644 --- a/dbgpt/serve/prompt/config.py +++ b/dbgpt/serve/prompt/config.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from typing import Optional +from dataclasses import dataclass, field from dbgpt.serve.core import BaseServeConfig @@ -17,3 +18,15 @@ class ServeConfig(BaseServeConfig): """Parameters for the serve command""" # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) + + default_user: Optional[str] = field( + default=None, + metadata={"help": "Default user name for prompt"}, + ) + default_sys_code: Optional[str] = field( + default=None, + metadata={"help": "Default system code for prompt"}, + ) diff --git a/dbgpt/serve/prompt/serve.py b/dbgpt/serve/prompt/serve.py index 9a19adf7b..74db1a996 100644 --- a/dbgpt/serve/prompt/serve.py +++ b/dbgpt/serve/prompt/serve.py @@ -2,7 +2,13 @@ from dbgpt.component import BaseComponent, SystemApp from .api.endpoints import router, init_endpoints -from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME +from .config import ( + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + APP_NAME, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) class Serve(BaseComponent): diff --git a/dbgpt/serve/prompt/service/service.py b/dbgpt/serve/prompt/service/service.py index 7d1273161..ce90f7f3c 100644 --- a/dbgpt/serve/prompt/service/service.py +++ b/dbgpt/serve/prompt/service/service.py @@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): name = SERVE_SERVICE_COMPONENT_NAME - def __init__(self, system_app: SystemApp): + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): self._system_app = None self._serve_config: ServeConfig = None - self._dao: ServeDao = None + self._dao: ServeDao = dao super().__init__(system_app) def init_app(self, system_app: SystemApp) -> None: @@ -28,7 +28,7 @@ def init_app(self, system_app: SystemApp) -> None: self._serve_config = ServeConfig.from_app_config( system_app.config, SERVE_CONFIG_KEY_PREFIX ) - self._dao = ServeDao(self._serve_config) + self._dao = self._dao or ServeDao(self._serve_config) self._system_app = system_app @property @@ -41,6 +41,22 @@ def config(self) -> ServeConfig: """Returns the internal ServeConfig.""" return self._serve_config + def create(self, request: ServeRequest) -> ServerResponse: + """Create a new Prompt entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + + if not request.user_name: + request.user_name = self.config.default_user + if not request.sys_code: + request.sys_code = self.config.default_sys_code + return super().create(request) + def update(self, request: ServeRequest) -> ServerResponse: """Update a Prompt entity diff --git a/dbgpt/serve/prompt/tests/__init__.py b/dbgpt/serve/prompt/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/prompt/tests/test_endpoints.py b/dbgpt/serve/prompt/tests/test_endpoints.py new file mode 100644 index 000000000..e45a56745 --- /dev/null +++ b/dbgpt/serve/prompt/tests/test_endpoints.py @@ -0,0 +1,176 @@ +import pytest +from httpx import AsyncClient + +from fastapi import FastAPI +from dbgpt.component import SystemApp +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult +from ..config import SERVE_CONFIG_KEY_PREFIX +from ..api.endpoints import router, init_endpoints +from ..api.schemas import ServeRequest, ServerResponse + +from dbgpt.serve.core.tests.conftest import client, asystem_app + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +async def _create_and_validate( + client: AsyncClient, sys_code: str, content: str, expect_id: int = 1, **kwargs +): + req_json = {"sys_code": sys_code, "content": content} + req_json.update(kwargs) + response = await client.post("/add", json=req_json) + assert response.status_code == 200 + json_res = response.json() + assert "success" in json_res and json_res["success"] + assert "data" in json_res and json_res["data"] + data = json_res["data"] + res_obj = ServerResponse(**data) + assert res_obj.id == expect_id + assert res_obj.sys_code == sys_code + assert res_obj.content == content + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_auth(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + await _create_and_validate(client, "test", "test") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + await _create_and_validate(client, "test", "test") + + response = await client.post("/update", json={"id": 1, "content": "test2"}) + assert response.status_code == 200 + json_res = response.json() + assert "success" in json_res and json_res["success"] + assert "data" in json_res and json_res["data"] + data = json_res["data"] + res_obj = ServerResponse(**data) + assert res_obj.id == 1 + assert res_obj.sys_code == "test" + assert res_obj.content == "test2" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + for i in range(10): + await _create_and_validate( + client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}" + ) + response = await client.post("/list", json={"sys_code": "test"}) + assert response.status_code == 200 + json_res = response.json() + assert "success" in json_res and json_res["success"] + assert "data" in json_res and json_res["data"] + data = json_res["data"] + assert len(data) == 10 + res_obj = ServerResponse(**data[0]) + assert res_obj.id == 1 + assert res_obj.sys_code == "test" + assert res_obj.content == "test0" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + for i in range(10): + await _create_and_validate( + client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}" + ) + response = await client.post( + "/query_page", params={"page": 1, "page_size": 5}, json={"sys_code": "test"} + ) + assert response.status_code == 200 + json_res = response.json() + assert "success" in json_res and json_res["success"] + assert "data" in json_res and json_res["data"] + data = json_res["data"] + page_result: PaginationResult = PaginationResult(**data) + assert page_result.total_count == 10 + assert page_result.total_pages == 2 + assert page_result.page == 1 + assert page_result.page_size == 5 + assert len(page_result.items) == 5 diff --git a/dbgpt/serve/prompt/tests/test_models.py b/dbgpt/serve/prompt/tests/test_models.py new file mode 100644 index 000000000..c5ef89f8f --- /dev/null +++ b/dbgpt/serve/prompt/tests/test_models.py @@ -0,0 +1,257 @@ +from typing import List +import pytest +from dbgpt.storage.metadata import db +from ..config import ServeConfig +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity, ServeDao + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + return { + "chat_scene": "chat_data", + "sub_chat_scene": "excel", + "prompt_type": "common", + "prompt_name": "my_prompt_1", + "content": "Write a qsort function in python.", + "user_name": "zhangsan", + "sys_code": "dbgpt", + } + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + with db.session() as session: + db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_1" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None + + +def test_entity_unique_key(default_entity_dict): + ServeEntity.create(**default_entity_dict) + with pytest.raises(Exception): + ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}) + + +def test_entity_get(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_1" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None + + +def test_entity_update(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + entity.update(prompt_name="my_prompt_2") + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_2" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None + + +def test_entity_delete(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + entity.delete() + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity is None + + +def test_entity_all(): + for i in range(10): + ServeEntity.create( + chat_scene="chat_data", + sub_chat_scene="excel", + prompt_type="common", + prompt_name=f"my_prompt_{i}", + content="Write a qsort function in python.", + user_name="zhangsan", + sys_code="dbgpt", + ) + entities = ServeEntity.all() + assert len(entities) == 10 + for entity in entities: + assert entity.chat_scene == "chat_data" + assert entity.sub_chat_scene == "excel" + assert entity.prompt_type == "common" + assert entity.content == "Write a qsort function in python." + assert entity.user_name == "zhangsan" + assert entity.sys_code == "dbgpt" + assert entity.gmt_created is not None + assert entity.gmt_modified is not None + + +def test_dao_create(dao, default_entity_dict): + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + assert res is not None + assert res.id == 1 + assert res.chat_scene == "chat_data" + assert res.sub_chat_scene == "excel" + assert res.prompt_type == "common" + assert res.prompt_name == "my_prompt_1" + assert res.content == "Write a qsort function in python." + assert res.user_name == "zhangsan" + assert res.sys_code == "dbgpt" + + +def test_dao_get_one(dao, default_entity_dict): + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + res: ServerResponse = dao.get_one( + {"prompt_name": "my_prompt_1", "sys_code": "dbgpt"} + ) + assert res is not None + assert res.id == 1 + assert res.chat_scene == "chat_data" + assert res.sub_chat_scene == "excel" + assert res.prompt_type == "common" + assert res.prompt_name == "my_prompt_1" + assert res.content == "Write a qsort function in python." + assert res.user_name == "zhangsan" + assert res.sys_code == "dbgpt" + + +def test_get_dao_get_list(dao): + for i in range(10): + dao.create( + ServeRequest( + chat_scene="chat_data", + sub_chat_scene="excel", + prompt_type="common", + prompt_name=f"my_prompt_{i}", + content="Write a qsort function in python.", + user_name="zhangsan" if i % 2 == 0 else "lisi", + sys_code="dbgpt", + ) + ) + res: List[ServerResponse] = dao.get_list({"sys_code": "dbgpt"}) + assert len(res) == 10 + for i, r in enumerate(res): + assert r.id == i + 1 + assert r.chat_scene == "chat_data" + assert r.sub_chat_scene == "excel" + assert r.prompt_type == "common" + assert r.prompt_name == f"my_prompt_{i}" + assert r.content == "Write a qsort function in python." + assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi" + assert r.sys_code == "dbgpt" + + half_res: List[ServerResponse] = dao.get_list({"user_name": "zhangsan"}) + assert len(half_res) == 5 + + +def test_dao_update(dao, default_entity_dict): + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + res: ServerResponse = dao.update( + {"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}, + ServeRequest(prompt_name="my_prompt_2"), + ) + assert res is not None + assert res.id == 1 + assert res.chat_scene == "chat_data" + assert res.sub_chat_scene == "excel" + assert res.prompt_type == "common" + assert res.prompt_name == "my_prompt_2" + assert res.content == "Write a qsort function in python." + assert res.user_name == "zhangsan" + assert res.sys_code == "dbgpt" + + +def test_dao_delete(dao, default_entity_dict): + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + dao.delete({"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}) + res: ServerResponse = dao.get_one( + {"prompt_name": "my_prompt_1", "sys_code": "dbgpt"} + ) + assert res is None + + +def test_dao_get_list_page(dao): + for i in range(20): + dao.create( + ServeRequest( + chat_scene="chat_data", + sub_chat_scene="excel", + prompt_type="common", + prompt_name=f"my_prompt_{i}", + content="Write a qsort function in python.", + user_name="zhangsan" if i % 2 == 0 else "lisi", + sys_code="dbgpt", + ) + ) + res = dao.get_list_page({"sys_code": "dbgpt"}, page=1, page_size=8) + assert res.total_count == 20 + assert res.total_pages == 3 + assert res.page == 1 + assert res.page_size == 8 + assert len(res.items) == 8 + for i, r in enumerate(res.items): + assert r.id == i + 1 + assert r.chat_scene == "chat_data" + assert r.sub_chat_scene == "excel" + assert r.prompt_type == "common" + assert r.prompt_name == f"my_prompt_{i}" + assert r.content == "Write a qsort function in python." + assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi" + assert r.sys_code == "dbgpt" + + res_half = dao.get_list_page({"user_name": "zhangsan"}, page=2, page_size=8) + assert res_half.total_count == 10 + assert res_half.total_pages == 2 + assert res_half.page == 2 + assert res_half.page_size == 8 + assert len(res_half.items) == 2 + for i, r in enumerate(res_half.items): + assert r.chat_scene == "chat_data" + assert r.sub_chat_scene == "excel" + assert r.prompt_type == "common" + assert r.content == "Write a qsort function in python." + assert r.user_name == "zhangsan" + assert r.sys_code == "dbgpt" diff --git a/dbgpt/serve/prompt/tests/test_service.py b/dbgpt/serve/prompt/tests/test_service.py new file mode 100644 index 000000000..a7c9cc686 --- /dev/null +++ b/dbgpt/serve/prompt/tests/test_service.py @@ -0,0 +1,154 @@ +from typing import List +import pytest +from dbgpt.component import SystemApp +from dbgpt.storage.metadata import db +from dbgpt.serve.core.tests.conftest import system_app + +from ..models.models import ServeEntity +from ..api.schemas import ServeRequest, ServerResponse +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + return { + "chat_scene": "chat_data", + "sub_chat_scene": "excel", + "prompt_type": "common", + "prompt_name": "my_prompt_1", + "content": "Write a qsort function in python.", + "user_name": "zhangsan", + "sys_code": "dbgpt", + } + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +@pytest.mark.parametrize( + "system_app", + [ + { + "app_config": { + "DEBUG": True, + "dbgpt.serve.prompt.default_user": "dbgpt", + "dbgpt.serve.prompt.default_sys_code": "dbgpt", + } + } + ], + indirect=True, +) +def test_config_default_user(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.prompt.default_user") == "dbgpt" + assert service.config is not None + assert service.config.default_user == "dbgpt" + assert service.config.default_sys_code == "dbgpt" + + +def test_service_create(service: Service, default_entity_dict): + entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + with db.session() as session: + db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_1" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None + + +def test_service_update(service: Service, default_entity_dict): + service.create(ServeRequest(**default_entity_dict)) + entity: ServerResponse = service.update(ServeRequest(**default_entity_dict)) + with db.session() as session: + db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_1" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None + + +def test_service_get(service: Service, default_entity_dict): + service.create(ServeRequest(**default_entity_dict)) + entity: ServerResponse = service.get(ServeRequest(**default_entity_dict)) + with db.session() as session: + db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_1" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None + + +def test_service_delete(service: Service, default_entity_dict): + service.create(ServeRequest(**default_entity_dict)) + service.delete(ServeRequest(**default_entity_dict)) + entity: ServerResponse = service.get(ServeRequest(**default_entity_dict)) + assert entity is None + + +def test_service_get_list(service: Service): + for i in range(3): + service.create( + ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"}) + ) + entities: List[ServerResponse] = service.get_list(ServeRequest(sys_code="dbgpt")) + assert len(entities) == 3 + for i, entity in enumerate(entities): + assert entity.sys_code == "dbgpt" + assert entity.prompt_name == f"prompt_{i}" + + +def test_service_get_list_by_page(service: Service): + for i in range(3): + service.create( + ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"}) + ) + res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2) + assert res is not None + assert res.total_count == 3 + assert res.total_pages == 2 + assert len(res.items) == 2 + for i, entity in enumerate(res.items): + assert entity.sys_code == "dbgpt" + assert entity.prompt_name == f"prompt_{i}" diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py b/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py index fc55669cb..198798fa5 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/api/endpoints.py @@ -1,5 +1,8 @@ from typing import Optional, List -from fastapi import APIRouter, Depends, Query +from functools import cache +from fastapi import APIRouter, Depends, Query, HTTPException +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + from dbgpt.component import SystemApp from dbgpt.serve.core import Result @@ -20,13 +23,78 @@ def get_service() -> Service: return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key } + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + @router.get("/health") async def health(): """Health check endpoint""" return {"status": "ok"} -@router.post("/", response_model=Result[ServerResponse]) +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) async def create( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[ServerResponse]: @@ -41,7 +109,9 @@ async def create( return Result.succ(service.create(request)) -@router.put("/", response_model=Result[ServerResponse]) +@router.put( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) async def update( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[ServerResponse]: @@ -56,7 +126,11 @@ async def update( return Result.succ(service.update(request)) -@router.post("/query", response_model=Result[ServerResponse]) +@router.post( + "/query", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) async def query( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[ServerResponse]: @@ -71,7 +145,11 @@ async def query( return Result.succ(service.get(request)) -@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]]) +@router.post( + "/query_page", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) async def query_page( request: ServeRequest, page: Optional[int] = Query(default=1, description="current page"), diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py b/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py index d123aa159..548e78764 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py @@ -1,5 +1,6 @@ # Define your Pydantic schemas here from dbgpt._private.pydantic import BaseModel, Field +from ..config import SERVE_APP_NAME_HUMP class ServeRequest(BaseModel): @@ -7,8 +8,13 @@ class ServeRequest(BaseModel): # TODO define your own fields here + class Config: + title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + class ServerResponse(BaseModel): """{__template_app_name__hump__} response model""" # TODO define your own fields here + class Config: + title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/config.py b/dbgpt/serve/utils/_template_files/default_serve_template/config.py index 25a60a020..271546963 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/config.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/config.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from typing import Optional +from dataclasses import dataclass, field from dbgpt.serve.core import BaseServeConfig @@ -17,3 +18,6 @@ class ServeConfig(BaseServeConfig): """Parameters for the serve command""" # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/serve.py b/dbgpt/serve/utils/_template_files/default_serve_template/serve.py index ef99ce936..1f6bb4811 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/serve.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/serve.py @@ -2,7 +2,13 @@ from dbgpt.component import BaseComponent, SystemApp from .api.endpoints import router, init_endpoints -from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME +from .config import ( + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + APP_NAME, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) class Serve(BaseComponent): diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py b/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py index 89f06b2a2..0e2767023 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/service/service.py @@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): name = SERVE_SERVICE_COMPONENT_NAME - def __init__(self, system_app: SystemApp): + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): self._system_app = None self._serve_config: ServeConfig = None - self._dao: ServeDao = None + self._dao: ServeDao = dao super().__init__(system_app) def init_app(self, system_app: SystemApp) -> None: @@ -28,7 +28,7 @@ def init_app(self, system_app: SystemApp) -> None: self._serve_config = ServeConfig.from_app_config( system_app.config, SERVE_CONFIG_KEY_PREFIX ) - self._dao = ServeDao(self._serve_config) + self._dao = self._dao or ServeDao(self._serve_config) self._system_app = system_app @property diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/tests/__init__.py b/dbgpt/serve/utils/_template_files/default_serve_template/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_endpoints.py b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_endpoints.py new file mode 100644 index 000000000..79f35c7a0 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from httpx import AsyncClient + +from fastapi import FastAPI +from dbgpt.component import SystemApp +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult +from ..config import SERVE_CONFIG_KEY_PREFIX +from ..api.endpoints import router, init_endpoints +from ..api.schemas import ServeRequest, ServerResponse + +from dbgpt.serve.core.tests.conftest import client, asystem_app + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py new file mode 100644 index 000000000..c065909b2 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py @@ -0,0 +1,109 @@ +from typing import List +import pytest +from dbgpt.storage.metadata import db +from ..config import ServeConfig +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity, ServeDao + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + # TODO: implement your test case + with db.session() as session: + db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + assert db_entity.id == entity.id + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity.id == entity.id + # TODO: implement your test case + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + entity.delete() + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity is None + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + assert res is not None + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_service.py b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_service.py new file mode 100644 index 000000000..003286c82 --- /dev/null +++ b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_service.py @@ -0,0 +1,76 @@ +from typing import List +import pytest +from dbgpt.component import SystemApp +from dbgpt.storage.metadata import db +from dbgpt.serve.core.tests.conftest import system_app + +from ..models.models import ServeEntity +from ..api.schemas import ServeRequest, ServerResponse +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/utils/cli.py b/dbgpt/serve/utils/cli.py index eb4ca359f..78cd742b7 100644 --- a/dbgpt/serve/utils/cli.py +++ b/dbgpt/serve/utils/cli.py @@ -62,6 +62,7 @@ def replace_template_variables(content: str, app_name: str): def copy_template_files(src_dir: str, dst_dir: str, app_name: str): for root, dirs, files in os.walk(src_dir): + dirs[:] = [d for d in dirs if not _should_ignore(d)] relative_path = os.path.relpath(root, src_dir) if relative_path == ".": relative_path = "" @@ -70,6 +71,8 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str): os.makedirs(target_dir, exist_ok=True) for file in files: + if _should_ignore(file): + continue try: with open(os.path.join(root, file), "r") as f: content = f.read() @@ -81,3 +84,9 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str): except Exception as e: click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}") raise e + + +def _should_ignore(file_or_dir: str): + """Return True if the given file or directory should be ignored.""" "" + ignore_patterns = [".pyc", "__pycache__"] + return any(pattern in file_or_dir for pattern in ignore_patterns) diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 78ef0c4aa..70f25197d 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -153,7 +153,8 @@ def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES: if entry is None: raise Exception("Invalid request") for key, value in update_request.dict().items(): - setattr(entry, key, value) + if value is not None: + setattr(entry, key, value) session.merge(entry) return self.get_one(self.to_request(entry)) diff --git a/dbgpt/storage/metadata/tests/test_base_dao.py b/dbgpt/storage/metadata/tests/test_base_dao.py index e7563d935..a537188f2 100644 --- a/dbgpt/storage/metadata/tests/test_base_dao.py +++ b/dbgpt/storage/metadata/tests/test_base_dao.py @@ -104,6 +104,28 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_ assert user.age == 35 +def test_update_user_partial( + db: DatabaseManager, User: Type[BaseModel], user_dao, user_req +): + # Create a user + created_user_response = user_dao.create(user_req) + + # Update the user + updated_req = UserRequest(name=user_req.name, password="newpassword") + updated_req.age = None + updated_user = user_dao.update( + query_request={"name": user_req.name}, update_request=updated_req + ) + assert updated_user.id == created_user_response.id + assert updated_user.age == user_req.age + + # Verify that the user is updated in the database + with db.session() as session: + user = session.query(User).get(created_user_response.id) + assert user.age == user_req.age + assert user.password == "newpassword" + + def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req): # Create a user created_user_response = user_dao.create(user_req) diff --git a/dbgpt/util/config_utils.py b/dbgpt/util/config_utils.py index 5d07561ee..fd431f164 100644 --- a/dbgpt/util/config_utils.py +++ b/dbgpt/util/config_utils.py @@ -3,15 +3,18 @@ class AppConfig: - def __init__(self): - self.configs = {} + def __init__(self, configs: Optional[Dict[str, Any]] = None) -> None: + self.configs = configs or {} - def set(self, key: str, value: Any) -> None: + def set(self, key: str, value: Any, overwrite: bool = False) -> None: """Set config value by key Args: key (str): The key of config value (Any): The value of config + overwrite (bool, optional): Whether to overwrite the value if key exists. Defaults to False. """ + if key in self.configs and not overwrite: + raise KeyError(f"Config key {key} already exists") self.configs[key] = value def get(self, key, default: Optional[Any] = None) -> Any: