diff --git a/Makefile b/Makefile index 2e5f4bff4..c1029f3e9 100644 --- a/Makefile +++ b/Makefile @@ -38,21 +38,41 @@ fmt: setup ## Format Python code # https://github.com/PyCQA/isort # $(VENV_BIN)/isort . $(VENV_BIN)/isort dbgpt/agent/ + $(VENV_BIN)/isort dbgpt/app/ + $(VENV_BIN)/isort dbgpt/cli/ + $(VENV_BIN)/isort dbgpt/configs/ $(VENV_BIN)/isort dbgpt/core/ + $(VENV_BIN)/isort dbgpt/datasource/ + $(VENV_BIN)/isort dbgpt/model/ + # TODO: $(VENV_BIN)/isort dbgpt/serve $(VENV_BIN)/isort dbgpt/serve/core/ $(VENV_BIN)/isort dbgpt/serve/agent/ $(VENV_BIN)/isort dbgpt/serve/conversation/ $(VENV_BIN)/isort dbgpt/serve/utils/_template_files + $(VENV_BIN)/isort dbgpt/storage/ + $(VENV_BIN)/isort dbgpt/train/ + $(VENV_BIN)/isort dbgpt/util/ + $(VENV_BIN)/isort dbgpt/vis/ + $(VENV_BIN)/isort dbgpt/__init__.py + $(VENV_BIN)/isort dbgpt/component.py $(VENV_BIN)/isort --extend-skip="examples/notebook" examples # https://github.com/psf/black $(VENV_BIN)/black --extend-exclude="examples/notebook" . # TODO: Use blackdoc to format Python doctests. # https://blackdoc.readthedocs.io/en/latest/ # $(VENV_BIN)/blackdoc . - $(VENV_BIN)/blackdoc dbgpt/core/ $(VENV_BIN)/blackdoc dbgpt/agent/ + $(VENV_BIN)/blackdoc dbgpt/app/ + $(VENV_BIN)/blackdoc dbgpt/cli/ + $(VENV_BIN)/blackdoc dbgpt/configs/ + $(VENV_BIN)/blackdoc dbgpt/core/ + $(VENV_BIN)/blackdoc dbgpt/datasource/ $(VENV_BIN)/blackdoc dbgpt/model/ $(VENV_BIN)/blackdoc dbgpt/serve/ + # TODO: $(VENV_BIN)/blackdoc dbgpt/storage/ + $(VENV_BIN)/blackdoc dbgpt/train/ + $(VENV_BIN)/blackdoc dbgpt/util/ + $(VENV_BIN)/blackdoc dbgpt/vis/ $(VENV_BIN)/blackdoc examples # TODO: Type checking of Python code. # https://github.com/python/mypy diff --git a/dbgpt/__init__.py b/dbgpt/__init__.py index a23c1ec9e..1660c5ddc 100644 --- a/dbgpt/__init__.py +++ b/dbgpt/__init__.py @@ -1,5 +1,4 @@ -from dbgpt.component import SystemApp, BaseComponent - +from dbgpt.component import BaseComponent, SystemApp __ALL__ = ["SystemApp", "BaseComponent"] diff --git a/dbgpt/app/_cli.py b/dbgpt/app/_cli.py index f48e57111..b0d0cb294 100644 --- a/dbgpt/app/_cli.py +++ b/dbgpt/app/_cli.py @@ -1,11 +1,13 @@ +import functools +import os from typing import Optional + import click -import os -import functools + from dbgpt.app.base import WebServerParameters from dbgpt.configs.model_config import LOGDIR -from dbgpt.util.parameter_utils import EnvArgumentParser from dbgpt.util.command_utils import _run_current_with_daemon, _stop_service +from dbgpt.util.parameter_utils import EnvArgumentParser @click.command(name="webserver") @@ -117,8 +119,8 @@ def migrate(alembic_ini_path: str, script_location: str, message: str): def upgrade(alembic_ini_path: str, script_location: str, sql_output: str): """Upgrade database to target version""" from dbgpt.util._db_migration_utils import ( - upgrade_database, generate_sql_for_upgrade, + upgrade_database, ) alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) @@ -210,9 +212,8 @@ def clean( @add_migration_options def list(alembic_ini_path: str, script_location: str): """List all versions in the migration history, marking the current one""" - from alembic.script import ScriptDirectory - from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) @@ -269,12 +270,12 @@ def show(alembic_ini_path: str, script_location: str, revision: str): def _get_migration_config( alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None ): - from dbgpt.storage.metadata.db_manager import db as db_manager - from dbgpt.util._db_migration_utils import create_alembic_config + from dbgpt.app.base import _initialize_db # Import all models to make sure they are registered with SQLAlchemy. from dbgpt.app.initialization.db_model_initialization import _MODELS - from dbgpt.app.base import _initialize_db + from dbgpt.storage.metadata.db_manager import db as db_manager + from dbgpt.util._db_migration_utils import create_alembic_config # initialize db default_meta_data_path = _initialize_db() diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index fde70098e..54c6f5122 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -1,16 +1,15 @@ -import signal +import logging import os -import threading +import signal import sys -import logging -from typing import Optional +import threading from dataclasses import dataclass, field +from typing import Optional from dbgpt._private.config import Config from dbgpt.component import SystemApp from dbgpt.util.parameter_utils import BaseParameters - ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -35,7 +34,6 @@ def server_init(param: "WebServerParameters", system_app: SystemApp): from dbgpt.agent.plugin.commands.command_mange import CommandRegistry # logger.info(f"args: {args}") - # init config cfg = Config() cfg.SYSTEM_APP = system_app @@ -100,13 +98,12 @@ def _migration_db_storage(param: "WebServerParameters"): """Migration the db storage.""" # Import all models to make sure they are registered with SQLAlchemy. from dbgpt.app.initialization.db_model_initialization import _MODELS - from dbgpt.configs.model_config import PILOT_PATH default_meta_data_path = os.path.join(PILOT_PATH, "meta_data") if not param.disable_alembic_upgrade: - from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade from dbgpt.storage.metadata.db_manager import db + from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade # try to create all tables try: @@ -123,9 +120,11 @@ def _initialize_db( Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`. """ + from urllib.parse import quote + from urllib.parse import quote_plus as urlquote + from dbgpt.configs.model_config import PILOT_PATH from dbgpt.storage.metadata.db_manager import initialize_db - from urllib.parse import quote_plus as urlquote, quote CFG = Config() db_name = CFG.LOCAL_DB_NAME @@ -170,8 +169,8 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F Raises: Exception: Raise exception if database operation failed """ - from sqlalchemy import create_engine, DDL - from sqlalchemy.exc import SQLAlchemyError, OperationalError + from sqlalchemy import DDL, create_engine + from sqlalchemy.exc import OperationalError, SQLAlchemyError if not try_to_create_db: logger.info(f"Skipping creation of database {db_name}") diff --git a/dbgpt/app/chat_adapter.py b/dbgpt/app/chat_adapter.py index c1cb192b1..d0c5e84dc 100644 --- a/dbgpt/app/chat_adapter.py +++ b/dbgpt/app/chat_adapter.py @@ -6,9 +6,10 @@ # -*- coding: utf-8 -*- from functools import cache -from typing import List, Dict, Tuple -from dbgpt.model.conversation import Conversation, get_conv_template +from typing import Dict, List, Tuple + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.model.conversation import Conversation, get_conv_template class BaseChatAdpter: diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 8025a4e09..0d6eabde8 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -2,12 +2,11 @@ import logging -from dbgpt.component import SystemApp from dbgpt._private.config import Config +from dbgpt.app.base import WebServerParameters +from dbgpt.component import SystemApp from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR from dbgpt.util.executor_utils import DefaultExecutorFactory -from dbgpt.app.base import WebServerParameters - logger = logging.getLogger(__name__) @@ -21,9 +20,9 @@ def initialize_components( embedding_model_path: str, ): # Lazy import to avoid high time cost - from dbgpt.model.cluster.controller.controller import controller from dbgpt.app.initialization.embedding_component import _initialize_embedding_model from dbgpt.app.initialization.serve_initialization import register_serve_apps + from dbgpt.model.cluster.controller.controller import controller # Register global default executor factory first system_app.register(DefaultExecutorFactory) @@ -47,6 +46,7 @@ def initialize_components( ) _initialize_model_cache(system_app) _initialize_awel(system_app, param) + _initialize_openapi(system_app) # Register serve apps register_serve_apps(system_app, CFG) @@ -65,8 +65,8 @@ def _initialize_model_cache(system_app: SystemApp): def _initialize_awel(system_app: SystemApp, param: WebServerParameters): - from dbgpt.core.awel import initialize_awel from dbgpt.configs.model_config import _DAG_DEFINITION_DIR + from dbgpt.core.awel import initialize_awel # Add default dag definition dir dag_dirs = [_DAG_DEFINITION_DIR] @@ -75,3 +75,9 @@ def _initialize_awel(system_app: SystemApp, param: WebServerParameters): dag_dirs = [x.strip() for x in dag_dirs] initialize_awel(system_app, dag_dirs) + + +def _initialize_openapi(system_app: SystemApp): + from dbgpt.app.openapi.api_v1.editor.service import EditorService + + system_app.register(EditorService) diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index ff7598c01..b4328d257 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -1,46 +1,45 @@ -import os import argparse +import os import sys from typing import List ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from dbgpt.configs.model_config import ( - LLM_MODEL_CONFIG, - EMBEDDING_MODEL_CONFIG, - LOGDIR, - ROOT_PATH, -) -from dbgpt._private.config import Config -from dbgpt.component import SystemApp +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.docs import get_swagger_ui_html + +# fastapi import time cost about 0.05s +from fastapi.staticfiles import StaticFiles +from dbgpt._private.config import Config from dbgpt.app.base import ( - server_init, - _migration_db_storage, WebServerParameters, _create_model_start_listener, + _migration_db_storage, + server_init, ) # initialize_components import time cost about 0.1s from dbgpt.app.component_configs import initialize_components - -# fastapi import time cost about 0.05s -from fastapi.staticfiles import StaticFiles -from fastapi import FastAPI -from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware -from fastapi.openapi.docs import get_swagger_ui_html - from dbgpt.app.openapi.base import validation_exception_handler +from dbgpt.component import SystemApp +from dbgpt.configs.model_config import ( + EMBEDDING_MODEL_CONFIG, + LLM_MODEL_CONFIG, + LOGDIR, + ROOT_PATH, +) +from dbgpt.util.parameter_utils import _get_dict_from_obj +from dbgpt.util.system_utils import get_system_info +from dbgpt.util.tracer import SpanType, SpanTypeRunName, initialize_tracer, root_tracer from dbgpt.util.utils import ( - setup_logging, _get_logging_level, logging_str_to_uvicorn_level, setup_http_service_logging, + setup_logging, ) -from dbgpt.util.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName -from dbgpt.util.parameter_utils import _get_dict_from_obj -from dbgpt.util.system_utils import get_system_info static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static") @@ -89,9 +88,7 @@ async def custom_swagger_ui_html(): def mount_routers(app: FastAPI): """Lazy import to avoid high time cost""" from dbgpt.app.knowledge.api import router as knowledge_router - from dbgpt.app.llm_manage.api import router as llm_manage_api - from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1 from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import ( router as api_editor_route_v1, @@ -107,9 +104,7 @@ def mount_routers(app: FastAPI): def mount_static_files(app: FastAPI): - from dbgpt.agent.plugin.commands.built_in.disply_type import ( - static_message_img_path, - ) + from dbgpt.agent.plugin.commands.built_in.disply_type import static_message_img_path os.makedirs(static_message_img_path, exist_ok=True) app.mount( diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index 1d925201a..f2243dfc1 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -1,14 +1,13 @@ """Import all models to make sure they are registered with SQLAlchemy. """ -from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity -from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity - -from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity +from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity +from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity +from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.storage.chat_history.chat_history_db import ( ChatHistoryEntity, ChatHistoryMessageEntity, diff --git a/dbgpt/app/initialization/embedding_component.py b/dbgpt/app/initialization/embedding_component.py index a20bd5220..83552a681 100644 --- a/dbgpt/app/initialization/embedding_component.py +++ b/dbgpt/app/initialization/embedding_component.py @@ -1,12 +1,14 @@ from __future__ import annotations import logging -from typing import Any, Type, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Type + from dbgpt.component import ComponentType, SystemApp from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory if TYPE_CHECKING: from langchain.embeddings.base import Embeddings + from dbgpt.app.base import WebServerParameters logger = logging.getLogger(__name__) diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 35fe26844..7c5a3a465 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -1,5 +1,5 @@ -from dbgpt.component import SystemApp from dbgpt._private.config import Config +from dbgpt.component import SystemApp def register_serve_apps(system_app: SystemApp, cfg: Config): @@ -8,9 +8,9 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # ################################ Prompt Serve Register Begin ###################################### from dbgpt.serve.prompt.serve import ( - Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX as PROMPT_SERVE_CONFIG_KEY_PREFIX, ) + from dbgpt.serve.prompt.serve import Serve as PromptServe # Replace old prompt serve # Set config @@ -21,8 +21,15 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # ################################ Prompt Serve Register End ######################################## # ################################ Conversation Serve Register Begin ###################################### + from dbgpt.serve.conversation.serve import ( + SERVE_CONFIG_KEY_PREFIX as CONVERSATION_SERVE_CONFIG_KEY_PREFIX, + ) from dbgpt.serve.conversation.serve import Serve as ConversationServe + # Set config + system_app.config.set( + f"{CONVERSATION_SERVE_CONFIG_KEY_PREFIX}default_model", cfg.LLM_MODEL + ) # Register serve app - system_app.register(ConversationServe) + system_app.register(ConversationServe, api_prefix="/api/v1/chat/dialogue") # ################################ Conversation Serve Register End ######################################## diff --git a/dbgpt/app/knowledge/_cli/knowledge_cli.py b/dbgpt/app/knowledge/_cli/knowledge_cli.py index dd7c74b2c..38f24ede2 100644 --- a/dbgpt/app/knowledge/_cli/knowledge_cli.py +++ b/dbgpt/app/knowledge/_cli/knowledge_cli.py @@ -1,7 +1,8 @@ -import click +import functools import logging import os -import functools + +import click from dbgpt.configs.model_config import DATASETS_DIR diff --git a/dbgpt/app/knowledge/_cli/knowledge_client.py b/dbgpt/app/knowledge/_cli/knowledge_client.py index 5ab24d368..7f2b5ea1b 100644 --- a/dbgpt/app/knowledge/_cli/knowledge_client.py +++ b/dbgpt/app/knowledge/_cli/knowledge_client.py @@ -1,22 +1,20 @@ -import os -import requests import json import logging - -from urllib.parse import urljoin +import os from concurrent.futures import ThreadPoolExecutor, as_completed +from urllib.parse import urljoin + +import requests -from dbgpt.app.openapi.api_view_model import Result from dbgpt.app.knowledge.request.request import ( - KnowledgeQueryRequest, - KnowledgeDocumentRequest, ChunkQueryRequest, DocumentQueryRequest, + DocumentSyncRequest, + KnowledgeDocumentRequest, + KnowledgeQueryRequest, + KnowledgeSpaceRequest, ) - -from dbgpt.app.knowledge.request.request import DocumentSyncRequest - -from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest +from dbgpt.app.openapi.api_view_model import Result from dbgpt.rag.knowledge.base import KnowledgeType HTTP_HEADERS = {"Content-Type": "application/json"} diff --git a/dbgpt/app/knowledge/api.py b/dbgpt/app/knowledge/api.py index 8d14e0e51..ee8f1f767 100644 --- a/dbgpt/app/knowledge/api.py +++ b/dbgpt/app/knowledge/api.py @@ -1,42 +1,39 @@ +import logging import os import shutil import tempfile -import logging from typing import List -from fastapi import APIRouter, File, UploadFile, Form +from fastapi import APIRouter, File, Form, UploadFile from dbgpt._private.config import Config -from dbgpt.configs.model_config import ( - EMBEDDING_MODEL_CONFIG, - KNOWLEDGE_UPLOAD_ROOT_PATH, -) -from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator - -from dbgpt.app.openapi.api_view_model import Result -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory - -from dbgpt.app.knowledge.service import KnowledgeService -from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.app.knowledge.request.request import ( - KnowledgeQueryRequest, - KnowledgeQueryResponse, - KnowledgeDocumentRequest, - DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, - SpaceArgumentRequest, - EntityExtractRequest, DocumentSummaryRequest, + DocumentSyncRequest, + EntityExtractRequest, + KnowledgeDocumentRequest, + KnowledgeQueryRequest, + KnowledgeQueryResponse, + KnowledgeSpaceRequest, KnowledgeSyncRequest, + SpaceArgumentRequest, ) - -from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest +from dbgpt.app.knowledge.service import KnowledgeService +from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator +from dbgpt.app.openapi.api_view_model import Result +from dbgpt.configs.model_config import ( + EMBEDDING_MODEL_CONFIG, + KNOWLEDGE_UPLOAD_ROOT_PATH, +) +from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import ChunkStrategy +from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.retriever.embedding import EmbeddingRetriever from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector -from dbgpt.util.tracer import root_tracer, SpanType +from dbgpt.util.tracer import SpanType, root_tracer logger = logging.getLogger(__name__) @@ -312,9 +309,10 @@ async def document_summary(request: DocumentSummaryRequest): async def entity_extract(request: EntityExtractRequest): logger.info(f"Received params: {request}") try: + import uuid + from dbgpt.app.scene import ChatScene from dbgpt.util.chat_util import llm_chat_response_nostream - import uuid chat_param = { "chat_session_id": uuid.uuid1(), diff --git a/dbgpt/app/knowledge/chunk_db.py b/dbgpt/app/knowledge/chunk_db.py index e8b6137ef..e6c3018e2 100644 --- a/dbgpt/app/knowledge/chunk_db.py +++ b/dbgpt/app/knowledge/chunk_db.py @@ -1,10 +1,10 @@ from datetime import datetime from typing import List -from sqlalchemy import Column, String, DateTime, Integer, Text, func +from sqlalchemy import Column, DateTime, Integer, String, Text, func -from dbgpt.storage.metadata import BaseDao, Model from dbgpt._private.config import Config +from dbgpt.storage.metadata import BaseDao, Model CFG = Config() diff --git a/dbgpt/app/knowledge/document_db.py b/dbgpt/app/knowledge/document_db.py index d101d8ee3..7e08d0733 100644 --- a/dbgpt/app/knowledge/document_db.py +++ b/dbgpt/app/knowledge/document_db.py @@ -1,10 +1,10 @@ from datetime import datetime from typing import List -from sqlalchemy import Column, String, DateTime, Integer, Text, func +from sqlalchemy import Column, DateTime, Integer, String, Text, func -from dbgpt.storage.metadata import BaseDao, Model from dbgpt._private.config import Config +from dbgpt.storage.metadata import BaseDao, Model CFG = Config() diff --git a/dbgpt/app/knowledge/request/request.py b/dbgpt/app/knowledge/request/request.py index e87ca14ad..0b0e2795f 100644 --- a/dbgpt/app/knowledge/request/request.py +++ b/dbgpt/app/knowledge/request/request.py @@ -1,8 +1,8 @@ from typing import List, Optional -from dbgpt._private.pydantic import BaseModel from fastapi import UploadFile +from dbgpt._private.pydantic import BaseModel from dbgpt.rag.chunk_manager import ChunkParameters diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index ca348e07c..3db67c026 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -1,59 +1,48 @@ import json import logging from datetime import datetime +from enum import Enum from typing import List -from dbgpt.model import DefaultLLMClient -from dbgpt.rag.chunk import Chunk -from dbgpt.rag.chunk_manager import ChunkParameters -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.knowledge.base import KnowledgeType, ChunkStrategy -from dbgpt.rag.knowledge.factory import KnowledgeFactory -from dbgpt.rag.text_splitter.text_splitter import ( - RecursiveCharacterTextSplitter, - SpacyTextSplitter, -) -from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler -from dbgpt.serve.rag.assembler.summary import SummaryAssembler -from dbgpt.storage.vector_store.base import VectorStoreConfig -from dbgpt.storage.vector_store.connector import VectorStoreConnector - from dbgpt._private.config import Config -from dbgpt.configs.model_config import ( - EMBEDDING_MODEL_CONFIG, -) -from dbgpt.component import ComponentType -from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async - -from dbgpt.app.knowledge.chunk_db import ( - DocumentChunkEntity, - DocumentChunkDao, -) +from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity from dbgpt.app.knowledge.document_db import ( KnowledgeDocumentDao, KnowledgeDocumentEntity, ) -from dbgpt.app.knowledge.space_db import ( - KnowledgeSpaceDao, - KnowledgeSpaceEntity, -) from dbgpt.app.knowledge.request.request import ( - KnowledgeSpaceRequest, - KnowledgeDocumentRequest, - DocumentQueryRequest, ChunkQueryRequest, - SpaceArgumentRequest, - DocumentSyncRequest, + DocumentQueryRequest, DocumentSummaryRequest, + DocumentSyncRequest, + KnowledgeDocumentRequest, + KnowledgeSpaceRequest, KnowledgeSyncRequest, + SpaceArgumentRequest, ) -from enum import Enum - from dbgpt.app.knowledge.request.response import ( ChunkQueryResponse, DocumentQueryResponse, SpaceQueryResponse, ) +from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity +from dbgpt.component import ComponentType +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG +from dbgpt.model import DefaultLLMClient +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.chunk_manager import ChunkParameters +from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory +from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType +from dbgpt.rag.knowledge.factory import KnowledgeFactory +from dbgpt.rag.text_splitter.text_splitter import ( + RecursiveCharacterTextSplitter, + SpacyTextSplitter, +) +from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler +from dbgpt.serve.rag.assembler.summary import SummaryAssembler +from dbgpt.storage.vector_store.base import VectorStoreConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector +from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() @@ -572,8 +561,8 @@ def async_doc_embedding(self, assembler, chunk_docs, doc): def _build_default_context(self): from dbgpt.app.scene.chat_knowledge.v1.prompt import ( - PROMPT_SCENE_DEFINE, _DEFAULT_TEMPLATE, + PROMPT_SCENE_DEFINE, ) context_template = { diff --git a/dbgpt/app/knowledge/space_db.py b/dbgpt/app/knowledge/space_db.py index 4c958c613..bc283db1d 100644 --- a/dbgpt/app/knowledge/space_db.py +++ b/dbgpt/app/knowledge/space_db.py @@ -1,10 +1,10 @@ from datetime import datetime -from sqlalchemy import Column, Integer, Text, String, DateTime +from sqlalchemy import Column, DateTime, Integer, String, Text -from dbgpt.storage.metadata import BaseDao, Model from dbgpt._private.config import Config from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest +from dbgpt.storage.metadata import BaseDao, Model CFG = Config() diff --git a/dbgpt/app/llm_manage/api.py b/dbgpt/app/llm_manage/api.py index 66496c927..fdc20474b 100644 --- a/dbgpt/app/llm_manage/api.py +++ b/dbgpt/app/llm_manage/api.py @@ -1,12 +1,10 @@ from fastapi import APIRouter -from dbgpt.component import ComponentType from dbgpt._private.config import Config - -from dbgpt.model.cluster import WorkerStartupRequest, WorkerManagerFactory -from dbgpt.app.openapi.api_view_model import Result - from dbgpt.app.llm_manage.request.request import ModelResponse +from dbgpt.app.openapi.api_view_model import Result +from dbgpt.component import ComponentType +from dbgpt.model.cluster import WorkerManagerFactory, WorkerStartupRequest CFG = Config() router = APIRouter() diff --git a/dbgpt/app/llmserver.py b/dbgpt/app/llmserver.py index bf8719af8..be07c1a02 100644 --- a/dbgpt/app/llmserver.py +++ b/dbgpt/app/llmserver.py @@ -8,7 +8,7 @@ sys.path.append(ROOT_PATH) from dbgpt._private.config import Config -from dbgpt.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG from dbgpt.model.cluster import run_worker_manager CFG = Config() diff --git a/dbgpt/app/openapi/api_v1/api_v1.py b/dbgpt/app/openapi/api_v1/api_v1.py index 607642f8c..9c7137757 100644 --- a/dbgpt/app/openapi/api_v1/api_v1.py +++ b/dbgpt/app/openapi/api_v1/api_v1.py @@ -1,50 +1,39 @@ -import json -import uuid import asyncio -import os -import aiofiles import logging -from fastapi import ( - APIRouter, - File, - UploadFile, - Body, - Depends, -) +import os +import uuid +from concurrent.futures import Executor +from typing import List, Optional +import aiofiles +from fastapi import APIRouter, Body, Depends, File, UploadFile from fastapi.responses import StreamingResponse -from typing import List, Optional -from concurrent.futures import Executor -from dbgpt.component import ComponentType +from dbgpt._private.config import Config +from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest +from dbgpt.app.knowledge.service import KnowledgeService from dbgpt.app.openapi.api_view_model import ( - Result, - ConversationVo, - MessageVo, - ChatSceneVo, ChatCompletionResponseStreamChoice, - DeltaMessage, ChatCompletionStreamResponse, + ChatSceneVo, + ConversationVo, + DeltaMessage, + MessageVo, + Result, ) -from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo -from dbgpt._private.config import Config -from dbgpt.app.knowledge.service import KnowledgeService -from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest - -from dbgpt.app.scene import BaseChat, ChatScene, ChatFactory -from dbgpt.core.interface.message import OnceConversation +from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene +from dbgpt.component import ComponentType from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH -from dbgpt.rag.summary.db_summary_client import DBSummaryClient -from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory -from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory +from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo from dbgpt.model.base import FlatSupportedModel -from dbgpt.util.tracer import root_tracer, SpanType +from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory +from dbgpt.rag.summary.db_summary_client import DBSummaryClient from dbgpt.util.executor_utils import ( + DefaultExecutorFactory, ExecutorFactory, blocking_func_to_async, - DefaultExecutorFactory, ) - +from dbgpt.util.tracer import SpanType, root_tracer router = APIRouter() CFG = Config() @@ -201,47 +190,6 @@ async def db_support_types(): return Result[DbTypeInfo].succ(db_type_infos) -@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) -async def dialogue_list( - user_name: str = None, user_id: str = None, sys_code: str = None -): - dialogues: List = [] - chat_history_service = ChatHistory() - # TODO Change the synchronous call to the asynchronous call - user_name = user_name or user_id - datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code) - for item in datas: - conv_uid = item.get("conv_uid") - summary = item.get("summary") - chat_mode = item.get("chat_mode") - model_name = item.get("model_name", CFG.LLM_MODEL) - user_name = item.get("user_name") - sys_code = item.get("sys_code") - if not item.get("messages"): - # Skip the empty messages - # TODO support new conversation and message mode - continue - - messages = json.loads(item.get("messages")) - last_round = max(messages, key=lambda x: x["chat_order"]) - if "param_value" in last_round: - select_param = last_round["param_value"] - else: - select_param = "" - conv_vo: ConversationVo = ConversationVo( - conv_uid=conv_uid, - user_input=summary, - chat_mode=chat_mode, - model_name=model_name, - select_param=select_param, - user_name=user_name, - sys_code=sys_code, - ) - dialogues.append(conv_vo) - - return Result[ConversationVo].succ(dialogues[:10]) - - @router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]]) async def dialogue_scenes(): scene_vos: List[ChatSceneVo] = [] @@ -265,19 +213,6 @@ async def dialogue_scenes(): return Result.succ(scene_vos) -@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) -async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value(), - user_name: str = None, - # TODO remove user id - user_id: str = None, - sys_code: str = None, -): - user_name = user_name or user_id - conv_vo = __new_conversation(chat_mode, user_name, sys_code) - return Result.succ(conv_vo) - - @router.post("/v1/chat/mode/params/list", response_model=Result[dict]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): if ChatScene.ChatWithDbQA.value() == chat_mode: @@ -334,37 +269,11 @@ async def params_load( return Result.failed(code="E000X", msg=f"File Load Error {str(e)}") -@router.post("/v1/chat/dialogue/delete") -async def dialogue_delete(con_uid: str): - history_fac = ChatHistory() - history_mem = history_fac.get_store_instance(con_uid) - # TODO Change the synchronous call to the asynchronous call - history_mem.delete() - return Result.succ(None) - - def get_hist_messages(conv_uid: str): - message_vos: List[MessageVo] = [] - history_fac = ChatHistory() - history_mem = history_fac.get_store_instance(conv_uid) - - history_messages: List[OnceConversation] = history_mem.get_messages() - if history_messages: - for once in history_messages: - model_name = once.get("model_name", CFG.LLM_MODEL) - once_message_vos = [ - message2Vo(element, once["chat_order"], model_name) - for element in once["messages"] - ] - message_vos.extend(once_message_vos) - return message_vos - - -@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo]) -async def dialogue_history_messages(con_uid: str): - print(f"dialogue_history_messages:{con_uid}") - # TODO Change the synchronous call to the asynchronous call - return Result.succ(get_hist_messages(con_uid)) + from dbgpt.serve.conversation.serve import Service as ConversationService + + instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP) + return instance.get_history_messages({"conv_uid": conv_uid}) async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: @@ -378,9 +287,7 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: dialogue.conv_uid = conv_vo.conv_uid if not ChatScene.is_valid_mode(dialogue.chat_mode): - raise StopAsyncIteration( - Result.failed("Unsupported Chat Mode," + dialogue.chat_mode + "!") - ) + raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!") chat_param = { "chat_session_id": dialogue.conv_uid, @@ -405,7 +312,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()): logger.info(f"chat_prepare:{dialogue}") ## check conv_uid chat: BaseChat = await get_chat_instance(dialogue) - if len(chat.history_message) > 0: + if chat.has_history_messages(): return Result.succ(None) resp = await chat.prepare() return Result.succ(resp) diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index 740c0bf57..1e1ec9662 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -1,37 +1,30 @@ import json +import logging import time -from fastapi import ( - APIRouter, - Body, -) - from typing import List -import logging -from dbgpt._private.config import Config +from fastapi import APIRouter, Body, Depends -from dbgpt.app.scene import ChatFactory - -from dbgpt.app.openapi.api_view_model import ( - Result, +from dbgpt._private.config import Config +from dbgpt.app.openapi.api_v1.editor.service import EditorService +from dbgpt.app.openapi.api_v1.editor.sql_editor import ( + ChartRunData, + DataNode, + SqlRunData, ) +from dbgpt.app.openapi.api_view_model import Result from dbgpt.app.openapi.editor_view_model import ( - ChatDbRounds, - ChartList, ChartDetail, + ChartList, ChatChartEditContext, + ChatDbRounds, ChatSqlEditContext, DbTable, ) - -from dbgpt.app.openapi.api_v1.editor.sql_editor import ( - DataNode, - ChartRunData, - SqlRunData, -) -from dbgpt.core.interface.message import OnceConversation +from dbgpt.app.scene import ChatFactory from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader -from dbgpt.app.scene.chat_db.data_loader import DbDataLoader +from dbgpt.core.interface.message import OnceConversation +from dbgpt.serve.conversation.serve import Serve as ConversationServe from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory router = APIRouter() @@ -41,6 +34,14 @@ logger = logging.getLogger(__name__) +def get_conversation_serve() -> ConversationServe: + return ConversationServe.get_instance(CFG.SYSTEM_APP) + + +def get_edit_service() -> EditorService: + return EditorService.get_instance(CFG.SYSTEM_APP) + + @router.get("/v1/editor/db/tables", response_model=Result[DbTable]) async def get_editor_tables( db_name: str, page_index: int, page_size: int, search_str: str = "" @@ -69,48 +70,21 @@ async def get_editor_tables( @router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds]) -async def get_editor_sql_rounds(con_uid: str): +async def get_editor_sql_rounds( + con_uid: str, editor_service: EditorService = Depends(get_edit_service) +): logger.info("get_editor_sql_rounds:{con_uid}") - chat_history_fac = ChatHistory() - history_mem = chat_history_fac.get_store_instance(con_uid) - history_messages: List[OnceConversation] = history_mem.get_messages() - if history_messages: - result: List = [] - for once in history_messages: - round_name: str = "" - for element in once["messages"]: - if element["type"] == "human": - round_name = element["data"]["content"] - if once.get("param_value"): - round: ChatDbRounds = ChatDbRounds( - round=once["chat_order"], - db_name=once["param_value"], - round_name=round_name, - ) - result.append(round) - return Result.succ(result) + return Result.succ(editor_service.get_editor_sql_rounds(con_uid)) @router.get("/v1/editor/sql", response_model=Result[dict]) -async def get_editor_sql(con_uid: str, round: int): +async def get_editor_sql( + con_uid: str, round: int, editor_service: EditorService = Depends(get_edit_service) +): logger.info(f"get_editor_sql:{con_uid},{round}") - chat_history_fac = ChatHistory() - history_mem = chat_history_fac.get_store_instance(con_uid) - history_messages: List[OnceConversation] = history_mem.get_messages() - if history_messages: - for once in history_messages: - if int(once["chat_order"]) == round: - for element in once["messages"]: - if element["type"] == "ai": - logger.info( - f'history ai json resp:{element["data"]["content"]}' - ) - context = ( - element["data"]["content"] - .replace("\\n", " ") - .replace("\n", " ") - ) - return Result.succ(json.loads(context)) + context = editor_service.get_editor_sql_by_round(con_uid, round) + if context: + return Result.succ(context) return Result.failed(msg="not have sql!") @@ -120,7 +94,7 @@ async def editor_sql_run(run_param: dict = Body()): db_name = run_param["db_name"] sql = run_param["sql"] if not db_name and not sql: - return Result.failed("SQL run param error!") + return Result.failed(msg="SQL run param error!") conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) try: @@ -145,102 +119,43 @@ async def editor_sql_run(run_param: dict = Body()): @router.post("/v1/sql/editor/submit") -async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): +async def sql_editor_submit( + sql_edit_context: ChatSqlEditContext = Body(), + editor_service: EditorService = Depends(get_edit_service), +): logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}") - chat_history_fac = ChatHistory() - history_mem = chat_history_fac.get_store_instance(sql_edit_context.conv_uid) - history_messages: List[OnceConversation] = history_mem.get_messages() - if history_messages: - conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) - - edit_round = list( - filter( - lambda x: x["chat_order"] == sql_edit_context.conv_round, - history_messages, - ) - )[0] - if edit_round: - for element in edit_round["messages"]: - if element["type"] == "ai": - db_resp = json.loads(element["data"]["content"]) - db_resp["thoughts"] = sql_edit_context.new_speak - db_resp["sql"] = sql_edit_context.new_sql - element["data"]["content"] = json.dumps(db_resp) - if element["type"] == "view": - data_loader = DbDataLoader() - element["data"]["content"] = data_loader.get_table_view_by_conn( - conn.run_to_df(sql_edit_context.new_sql), - sql_edit_context.new_speak, - ) - history_mem.update(history_messages) - return Result.succ(None) - return Result.failed(msg="Edit Failed!") + conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) + try: + editor_service.sql_editor_submit_and_save(sql_edit_context, conn) + return Result.succ(None) + except Exception as e: + logger.error(f"edit sql exception!{str(e)}") + return Result.failed(msg=f"Edit sql exception!{str(e)}") @router.get("/v1/editor/chart/list", response_model=Result[ChartList]) -async def get_editor_chart_list(con_uid: str): +async def get_editor_chart_list( + con_uid: str, + editor_service: EditorService = Depends(get_edit_service), +): logger.info( f"get_editor_sql_rounds:{con_uid}", ) - chat_history_fac = ChatHistory() - history_mem = chat_history_fac.get_store_instance(con_uid) - history_messages: List[OnceConversation] = history_mem.get_messages() - if history_messages: - last_round = max(history_messages, key=lambda x: x["chat_order"]) - db_name = last_round["param_value"] - for element in last_round["messages"]: - if element["type"] == "ai": - chart_list: ChartList = ChartList( - round=last_round["chat_order"], - db_name=db_name, - charts=json.loads(element["data"]["content"]), - ) - return Result.succ(chart_list) + chart_list = editor_service.get_editor_chart_list(con_uid) + if chart_list: + return Result.succ(chart_list) return Result.failed(msg="Not have charts!") @router.post("/v1/editor/chart/info", response_model=Result[ChartDetail]) -async def get_editor_chart_info(param: dict = Body()): +async def get_editor_chart_info( + param: dict = Body(), editor_service: EditorService = Depends(get_edit_service) +): logger.info(f"get_editor_chart_info:{param}") conv_uid = param["con_uid"] chart_title = param["chart_title"] - - chat_history_fac = ChatHistory() - history_mem = chat_history_fac.get_store_instance(conv_uid) - history_messages: List[OnceConversation] = history_mem.get_messages() - if history_messages: - last_round = max(history_messages, key=lambda x: x["chat_order"]) - db_name = last_round["param_value"] - if not db_name: - logger.error( - "this dashboard dialogue version too old, can't support editor!" - ) - return Result.failed( - msg="this dashboard dialogue version too old, can't support editor!" - ) - for element in last_round["messages"]: - if element["type"] == "view": - view_data: dict = json.loads(element["data"]["content"]) - charts: List = view_data.get("charts") - find_chart = list( - filter(lambda x: x["chart_name"] == chart_title, charts) - )[0] - - conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) - detail: ChartDetail = ChartDetail( - chart_uid=find_chart["chart_uid"], - chart_type=find_chart["chart_type"], - chart_desc=find_chart["chart_desc"], - chart_sql=find_chart["chart_sql"], - db_name=db_name, - chart_name=find_chart["chart_name"], - chart_value=find_chart["values"], - table_value=conn.run(find_chart["chart_sql"]), - ) - - return Result.succ(detail) - return Result.failed(msg="Can't Find Chart Detail Info!") + return editor_service.get_editor_chart_info(conv_uid, chart_title, CFG) @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) diff --git a/dbgpt/app/openapi/api_v1/editor/service.py b/dbgpt/app/openapi/api_v1/editor/service.py new file mode 100644 index 000000000..ef22adebd --- /dev/null +++ b/dbgpt/app/openapi/api_v1/editor/service.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Dict, List, Optional + +from dbgpt._private.config import Config +from dbgpt.app.openapi.api_view_model import Result +from dbgpt.app.openapi.editor_view_model import ( + ChartDetail, + ChartList, + ChatDbRounds, + ChatSqlEditContext, +) +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core import BaseOutputParser +from dbgpt.core.interface.message import ( + MessageStorageItem, + StorageConversation, + _split_messages_by_round, +) +from dbgpt.serve.conversation.serve import Serve as ConversationServe + +if TYPE_CHECKING: + from dbgpt.datasource.base import BaseConnect + +logger = logging.getLogger(__name__) + + +class EditorService(BaseComponent): + name = "dbgpt_app_editor_service" + + def __init__(self, system_app: SystemApp): + self._system_app: SystemApp = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + self._system_app = system_app + + def conv_serve(self) -> ConversationServe: + return ConversationServe.get_instance(self._system_app) + + def get_storage_conv(self, conv_uid: str) -> StorageConversation: + conv_serve: ConversationServe = self.conv_serve() + return StorageConversation( + conv_uid, + conv_storage=conv_serve.conv_storage, + message_storage=conv_serve.message_storage, + ) + + def get_editor_sql_rounds(self, conv_uid: str) -> List[ChatDbRounds]: + storage_conv: StorageConversation = self.get_storage_conv(conv_uid) + messages_by_round = _split_messages_by_round(storage_conv.messages) + result: List[ChatDbRounds] = [] + for one_round_message in messages_by_round: + if not one_round_message: + continue + for message in one_round_message: + if message.type == "human": + round_name = message.content + if message.additional_kwargs.get("param_value"): + chat_db_round: ChatDbRounds = ChatDbRounds( + round=message.round_index, + db_name=message.additional_kwargs.get("param_value"), + round_name=round_name, + ) + result.append(chat_db_round) + + return result + + def get_editor_sql_by_round( + self, conv_uid: str, round_index: int + ) -> Optional[Dict]: + storage_conv: StorageConversation = self.get_storage_conv(conv_uid) + messages_by_round = _split_messages_by_round(storage_conv.messages) + for one_round_message in messages_by_round: + if not one_round_message: + continue + for message in one_round_message: + if message.type == "ai" and message.round_index == round_index: + content = message.content + logger.info(f"history ai json resp: {content}") + # context = content.replace("\\n", " ").replace("\n", " ") + context_dict = _parse_pure_dict(content) + return context_dict + return None + + def sql_editor_submit_and_save( + self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect + ): + storage_conv: StorageConversation = self.get_storage_conv( + sql_edit_context.conv_uid + ) + if not storage_conv.save_message_independent: + raise ValueError( + "Submit sql and save just support independent conversation mode(after v0.4.6)" + ) + conv_serve: ConversationServe = self.conv_serve() + messages_by_round = _split_messages_by_round(storage_conv.messages) + to_update_messages = [] + for one_round_message in messages_by_round: + if not one_round_message: + continue + if one_round_message[0].round_index == sql_edit_context.conv_round: + for message in one_round_message: + if message.type == "ai": + db_resp = _parse_pure_dict(message.content) + db_resp["thoughts"] = sql_edit_context.new_speak + db_resp["sql"] = sql_edit_context.new_sql + message.content = json.dumps(db_resp, ensure_ascii=False) + to_update_messages.append( + MessageStorageItem( + storage_conv.conv_uid, message.index, message.to_dict() + ) + ) + # TODO not support update view message now + # if message.type == "view": + # data_loader = DbDataLoader() + # message.content = data_loader.get_table_view_by_conn( + # connection.run_to_df(sql_edit_context.new_sql), + # sql_edit_context.new_speak, + # ) + # to_update_messages.append( + # MessageStorageItem( + # storage_conv.conv_uid, message.index, message.to_dict() + # ) + # ) + if to_update_messages: + conv_serve.message_storage.save_or_update_list(to_update_messages) + return + + def get_editor_chart_list(self, conv_uid: str) -> Optional[ChartList]: + storage_conv: StorageConversation = self.get_storage_conv(conv_uid) + messages_by_round = _split_messages_by_round(storage_conv.messages) + for one_round_message in messages_by_round: + if not one_round_message: + continue + for message in one_round_message: + if message.type == "ai": + context_dict = _parse_pure_dict(message.content) + chart_list: ChartList = ChartList( + round=message.round_index, + db_name=message.additional_kwargs.get("param_value"), + charts=context_dict, + ) + return chart_list + + def get_editor_chart_info( + self, conv_uid: str, chart_title: str, cfg: Config + ) -> Result[ChartDetail]: + storage_conv: StorageConversation = self.get_storage_conv(conv_uid) + messages_by_round = _split_messages_by_round(storage_conv.messages) + for one_round_message in messages_by_round: + if not one_round_message: + continue + for message in one_round_message: + db_name = message.additional_kwargs.get("param_value") + if not db_name: + logger.error( + "this dashboard dialogue version too old, can't support editor!" + ) + return Result.failed( + msg="this dashboard dialogue version too old, can't support editor!" + ) + if message.type == "view": + view_data: dict = _parse_pure_dict(message.content) + charts: List = view_data.get("charts") + find_chart = list( + filter(lambda x: x["chart_name"] == chart_title, charts) + )[0] + + conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name) + detail: ChartDetail = ChartDetail( + chart_uid=find_chart["chart_uid"], + chart_type=find_chart["chart_type"], + chart_desc=find_chart["chart_desc"], + chart_sql=find_chart["chart_sql"], + db_name=db_name, + chart_name=find_chart["chart_name"], + chart_value=find_chart["values"], + table_value=conn.run(find_chart["chart_sql"]), + ) + return Result.succ(detail) + return Result.failed(msg="Can't Find Chart Detail Info!") + + +def _parse_pure_dict(res_str: str) -> Dict: + output_parser = BaseOutputParser() + context = output_parser.parse_prompt_response(res_str) + return json.loads(context) diff --git a/dbgpt/app/openapi/api_v1/editor/sql_editor.py b/dbgpt/app/openapi/api_v1/editor/sql_editor.py index e31db2c41..5a05199e0 100644 --- a/dbgpt/app/openapi/api_v1/editor/sql_editor.py +++ b/dbgpt/app/openapi/api_v1/editor/sql_editor.py @@ -1,4 +1,5 @@ from typing import List + from dbgpt._private.pydantic import BaseModel from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem diff --git a/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py b/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py index 365b399b3..4dc4178d7 100644 --- a/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py +++ b/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py @@ -1,9 +1,7 @@ from fastapi import APIRouter, Body, Request +from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackDao from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody -from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ( - ChatFeedBackDao, -) from dbgpt.app.openapi.api_view_model import Result router = APIRouter() diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py index 434830625..2b358f8c5 100644 --- a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py +++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py @@ -1,10 +1,9 @@ from datetime import datetime -from sqlalchemy import Column, Integer, Text, String, DateTime - -from dbgpt.storage.metadata import BaseDao, Model +from sqlalchemy import Column, DateTime, Integer, String, Text from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody +from dbgpt.storage.metadata import BaseDao, Model class ChatFeedBackEntity(Model): diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py index 38f437b4a..7f4067418 100644 --- a/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py +++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py @@ -1,6 +1,7 @@ -from dbgpt._private.pydantic import BaseModel from typing import Optional +from dbgpt._private.pydantic import BaseModel + class FeedBackBody(BaseModel): """conv_uid: conversation id""" diff --git a/dbgpt/app/openapi/api_view_model.py b/dbgpt/app/openapi/api_view_model.py index 25f8cb5dc..cd4ccadfd 100644 --- a/dbgpt/app/openapi/api_view_model.py +++ b/dbgpt/app/openapi/api_view_model.py @@ -1,7 +1,8 @@ -from dbgpt._private.pydantic import BaseModel, Field -from typing import TypeVar, Generic, Any, Optional, Literal, List -import uuid import time +import uuid +from typing import Any, Generic, List, Literal, Optional, TypeVar + +from dbgpt._private.pydantic import BaseModel, Field T = TypeVar("T") @@ -17,11 +18,7 @@ def succ(cls, data: T): return Result(success=True, err_code=None, err_msg=None, data=data) @classmethod - def failed(cls, msg): - return Result(success=False, err_code="E000X", err_msg=msg, data=None) - - @classmethod - def failed(cls, code, msg): + def failed(cls, code: str = "E000X", msg=None): return Result(success=False, err_code=code, err_msg=msg, data=None) diff --git a/dbgpt/app/openapi/base.py b/dbgpt/app/openapi/base.py index 10cb1d231..d4cb7679b 100644 --- a/dbgpt/app/openapi/base.py +++ b/dbgpt/app/openapi/base.py @@ -1,5 +1,6 @@ from fastapi import Request from fastapi.exceptions import RequestValidationError + from dbgpt.app.openapi.api_view_model import Result diff --git a/dbgpt/app/openapi/editor_view_model.py b/dbgpt/app/openapi/editor_view_model.py index 82d97092c..a915805fd 100644 --- a/dbgpt/app/openapi/editor_view_model.py +++ b/dbgpt/app/openapi/editor_view_model.py @@ -1,5 +1,6 @@ +from typing import Any, List + from dbgpt._private.pydantic import BaseModel, Field -from typing import List, Any class DbField(BaseModel): diff --git a/dbgpt/app/scene/__init__.py b/dbgpt/app/scene/__init__.py index 494db2a31..d6b804f1a 100644 --- a/dbgpt/app/scene/__init__.py +++ b/dbgpt/app/scene/__init__.py @@ -1,3 +1,3 @@ +from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.base_chat import BaseChat from dbgpt.app.scene.chat_factory import ChatFactory -from dbgpt.app.scene.base import ChatScene diff --git a/dbgpt/app/scene/base.py b/dbgpt/app/scene/base.py index 893bccdb0..a2b5f03c4 100644 --- a/dbgpt/app/scene/base.py +++ b/dbgpt/app/scene/base.py @@ -1,5 +1,9 @@ from enum import Enum -from typing import List +from typing import List, Optional + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core import BaseOutputParser, ChatPromptTemplate +from dbgpt.core._private.example_base import ExampleSelector class Scene: @@ -135,3 +139,49 @@ def show_disable(self): def is_inner(self): return self._value_.is_inner + + +class AppScenePromptTemplateAdapter(BaseModel): + """The template of the scene. + + Include some fields that in :class:`dbgpt.core.PromptTemplate` + """ + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + prompt: ChatPromptTemplate = Field(..., description="The prompt of this scene") + template_scene: Optional[str] = Field( + default=None, description="The scene of this template" + ) + template_is_strict: Optional[bool] = Field( + default=True, description="Whether strict" + ) + + output_parser: Optional[BaseOutputParser] = Field( + default=None, description="The output parser of this scene" + ) + sep: Optional[str] = Field( + default="###", description="The default separator of this scene" + ) + + stream_out: Optional[bool] = Field( + default=True, description="Whether to stream out" + ) + example_selector: Optional[ExampleSelector] = Field( + default=None, description="Example selector" + ) + need_historical_messages: Optional[bool] = Field( + default=False, description="Whether to need historical messages" + ) + temperature: Optional[float] = Field( + default=0.6, description="The default temperature of this scene" + ) + max_new_tokens: Optional[int] = Field( + default=1024, description="The default max new tokens of this scene" + ) + str_history: Optional[bool] = Field( + default=False, description="Whether transform history to str" + ) diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index 776c460cc..b9a22f7e8 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -1,30 +1,56 @@ -import datetime -import traceback -import warnings import asyncio +import datetime import logging +import traceback from abc import ABC, abstractmethod -from typing import Any, List, Dict +from typing import Any, Dict from dbgpt._private.config import Config +from dbgpt._private.pydantic import Extra +from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene +from dbgpt.app.scene.operator.app_operator import ( + AppChatComposerOperator, + ChatComposerInput, +) from dbgpt.component import ComponentType -from dbgpt.core.interface.prompt import PromptTemplate -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType -from dbgpt.core.interface.message import OnceConversation +from dbgpt.core.awel import DAG, BaseOperator, InputOperator, SimpleCallDataInputSource +from dbgpt.core.interface.message import StorageConversation from dbgpt.model.cluster import WorkerManagerFactory +from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator +from dbgpt.serve.conversation.serve import Serve as ConversationServe from dbgpt.util import get_or_create_event_loop from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.tracer import root_tracer, trace -from dbgpt._private.pydantic import Extra -from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory -from dbgpt.core.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG -from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator logger = logging.getLogger(__name__) -headers = {"User-Agent": "dbgpt Client"} CFG = Config() +def _build_conversation( + chat_mode: ChatScene, + chat_param: Dict[str, Any], + model_name: str, + conv_serve: ConversationServe, +) -> StorageConversation: + param_type = "" + param_value = "" + if chat_param["select_param"]: + if len(chat_mode.param_types()) > 0: + param_type = chat_mode.param_types()[0] + param_value = chat_param["select_param"] + return StorageConversation( + chat_param["chat_session_id"], + chat_mode=chat_mode.value(), + user_name=chat_param.get("user_name"), + sys_code=chat_param.get("sys_code"), + model_name=model_name, + param_type=param_type, + param_value=param_value, + conv_storage=conv_serve.conv_storage, + message_storage=conv_serve.message_storage, + ) + + class BaseChat(ABC): """DB-GPT Chat Service Base Module Include: @@ -35,7 +61,8 @@ class BaseChat(ABC): chat_scene: str = None llm_model: Any = None # By default, keep the last two rounds of conversation records as the context - chat_retention_rounds: int = 0 + keep_start_rounds: int = 0 + keep_end_rounds: int = 0 class Config: """Configuration for this pydantic object.""" @@ -68,7 +95,7 @@ def __init__(self, chat_param: Dict): # self.prompt_template: PromptTemplate = CFG.prompt_templates[ # self.chat_mode.value() # ] - self.prompt_template: PromptTemplate = ( + self.prompt_template: AppScenePromptTemplateAdapter = ( CFG.prompt_template_registry.get_prompt_template( self.chat_mode.value(), language=CFG.LANGUAGE, @@ -76,21 +103,21 @@ def __init__(self, chat_param: Dict): proxyllm_backend=CFG.PROXYLLM_BACKEND, ) ) - chat_history_fac = ChatHistory() + self._conv_serve = ConversationServe.get_instance(CFG.SYSTEM_APP) + # chat_history_fac = ChatHistory() ### can configurable storage methods - self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"]) - - self.history_message: List[OnceConversation] = self.memory.messages() - self.current_message: OnceConversation = OnceConversation( - self.chat_mode.value(), - user_name=chat_param.get("user_name"), - sys_code=chat_param.get("sys_code"), + # self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"]) + + # self.history_message: List[OnceConversation] = self.memory.messages() + # self.current_message: OnceConversation = OnceConversation( + # self.chat_mode.value(), + # user_name=chat_param.get("user_name"), + # sys_code=chat_param.get("sys_code"), + # ) + self.current_message: StorageConversation = _build_conversation( + self.chat_mode, chat_param, self.llm_model, self._conv_serve ) - self.current_message.model_name = self.llm_model - if chat_param["select_param"]: - if len(self.chat_mode.param_types()) > 0: - self.current_message.param_type = self.chat_mode.param_types()[0] - self.current_message.param_value = chat_param["select_param"] + self.history_messages = self.current_message.get_history_message() self.current_tokens_used: int = 0 # The executor to submit blocking function self._executor = CFG.SYSTEM_APP.get_component( @@ -102,10 +129,9 @@ def __init__(self, chat_param: Dict): is_stream=True, dag_name="llm_stream_model_dag" ) - # Get the message version, default is v1 in app # In v1, we will transform the message to compatible format of specific model # In the future, we will upgrade the message version to v2, and the message will be compatible with all models - self._message_version = chat_param.get("message_version", "v1") + self._message_version = chat_param.get("message_version", "v2") class Config: """Configuration for this pydantic object.""" @@ -133,6 +159,14 @@ def do_action(self, prompt_response): def message_adjust(self): pass + def has_history_messages(self) -> bool: + """Whether there is a history messages + + Returns: + bool: True if there is a history message, False otherwise + """ + return len(self.history_messages) > 0 + def get_llm_speak(self, prompt_define_response): if hasattr(prompt_define_response, "thoughts"): if isinstance(prompt_define_response.thoughts, dict): @@ -153,31 +187,51 @@ def get_llm_speak(self, prompt_define_response): async def __call_base(self): input_values = await self.generate_input_values() - ### Chat sequence advance - self.current_message.chat_order = len(self.history_message) + 1 - self.current_message.add_user_message( - self.current_user_input, check_duplicate_type=True - ) + # Load history + self.history_messages = self.current_message.get_history_message() + self.current_message.start_new_round() + self.current_message.add_user_message(self.current_user_input) self.current_message.start_date = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) self.current_message.tokens = 0 - if self.prompt_template.template: - metadata = { - "template_scene": self.prompt_template.template_scene, - "input_values": input_values, - } - with root_tracer.start_span( - "BaseChat.__call_base.prompt_template.format", metadata=metadata - ): - current_prompt = self.prompt_template.format(**input_values) - ### prompt context token adapt according to llm max context length - current_prompt = await self.prompt_context_token_adapt( - prompt=current_prompt - ) - self.current_message.add_system_message(current_prompt) - - llm_messages = self.generate_llm_messages() + # TODO: support handle history message by tokens + # if self.prompt_template.template: + # metadata = { + # "template_scene": self.prompt_template.template_scene, + # "input_values": input_values, + # } + # with root_tracer.start_span( + # "BaseChat.__call_base.prompt_template.format", metadata=metadata + # ): + # current_prompt = self.prompt_template.format(**input_values) + # ### prompt context token adapt according to llm max context length + # current_prompt = await self.prompt_context_token_adapt( + # prompt=current_prompt + # ) + # self.current_message.add_system_message(current_prompt) + + keep_start_rounds = ( + self.keep_start_rounds + if self.prompt_template.need_historical_messages + else 0 + ) + keep_end_rounds = ( + self.keep_end_rounds if self.prompt_template.need_historical_messages else 0 + ) + node = AppChatComposerOperator( + prompt=self.prompt_template.prompt, + keep_start_rounds=keep_start_rounds, + keep_end_rounds=keep_end_rounds, + str_history=self.prompt_template.str_history, + ) + node_input = { + "data": ChatComposerInput( + messages=self.history_messages, prompt_dict=input_values + ) + } + # llm_messages = self.generate_llm_messages() + llm_messages = await node.call(call_data=node_input) if not CFG.NEW_SERVER_MODE: # Not new server mode, we convert the message format(List[ModelMessage]) to list of dict # fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post @@ -185,7 +239,7 @@ async def __call_base(self): payload = { "model": self.llm_model, - "prompt": self.generate_llm_text(), + "prompt": "", "messages": llm_messages, "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), @@ -238,7 +292,7 @@ async def stream_call(self): view_msg = self.stream_plugin_call(msg) view_msg = view_msg.replace("\n", "\\n") yield view_msg - self.current_message.add_ai_message(msg, update_if_exist=True) + self.current_message.add_ai_message(msg) view_msg = self.stream_call_reinforce_fn(view_msg) self.current_message.add_view_message(view_msg) span.end() @@ -250,7 +304,8 @@ async def stream_call(self): ) ### store current conversation span.end(metadata={"error": str(e)}) - self.memory.append(self.current_message) + # self.memory.append(self.current_message) + self.current_message.end_current_round() async def nostream_call(self): payload = await self.__call_base() @@ -274,7 +329,7 @@ async def nostream_call(self): ) ) ### model result deal - self.current_message.add_ai_message(ai_response_text, update_if_exist=True) + self.current_message.add_ai_message(ai_response_text) prompt_define_response = ( self.prompt_template.output_parser.parse_prompt_response( ai_response_text @@ -320,7 +375,8 @@ async def nostream_call(self): ) span.end(metadata={"error": str(e)}) ### store dialogue - self.memory.append(self.current_message) + # self.memory.append(self.current_message) + self.current_message.end_current_round() return self.current_ai_response() async def get_llm_response(self): @@ -328,6 +384,7 @@ async def get_llm_response(self): logger.info(f"Request: \n{payload}") ai_response_text = "" payload["model_cache_enable"] = self.model_cache_enable + prompt_define_response = None try: model_output = await self._model_operator.call(call_data={"data": payload}) ### output parse @@ -337,8 +394,7 @@ async def get_llm_response(self): ) ) ### model result deal - self.current_message.add_ai_message(ai_response_text, update_if_exist=True) - prompt_define_response = None + self.current_message.add_ai_message(ai_response_text) prompt_define_response = ( self.prompt_template.output_parser.parse_prompt_response( ai_response_text @@ -384,64 +440,8 @@ def call(self): async def prepare(self): pass - def generate_llm_text(self) -> str: - warnings.warn("This method is deprecated - please use `generate_llm_messages`.") - text = "" - ### Load scene setting or character definition - if self.prompt_template.template_define: - text += self.prompt_template.template_define + self.prompt_template.sep - ### Load prompt - text += _load_system_message(self.current_message, self.prompt_template) - - ### Load examples - text += _load_example_messages(self.prompt_template) - - ### Load History - text += _load_history_messages( - self.prompt_template, self.history_message, self.chat_retention_rounds - ) - - ### Load User Input - text += _load_user_message(self.current_message, self.prompt_template) - return text - - def generate_llm_messages(self) -> List[ModelMessage]: - """ - Structured prompt messages interaction between dbgpt-server and llm-server - See https://github.com/csunny/DB-GPT/issues/328 - """ - messages = [] - ### Load scene setting or character definition as system message - if self.prompt_template.template_define: - messages.append( - ModelMessage( - role=ModelMessageRoleType.SYSTEM, - content=self.prompt_template.template_define, - ) - ) - ### Load prompt - messages += _load_system_message( - self.current_message, self.prompt_template, str_message=False - ) - ### Load examples - messages += _load_example_messages(self.prompt_template, str_message=False) - - ### Load History - messages += _load_history_messages( - self.prompt_template, - self.history_message, - self.chat_retention_rounds, - str_message=False, - ) - - ### Load User Input - messages += _load_user_message( - self.current_message, self.prompt_template, str_message=False - ) - return messages - def current_ai_response(self) -> str: - for message in self.current_message.messages: + for message in self.current_message.messages[-1:]: if message.type == "view": return message.content return None @@ -539,16 +539,6 @@ def _generate_numbered_list(self) -> str: }, ] - # command_strings = [] - # if CFG.command_disply: - # for name, item in CFG.command_disply.commands.items(): - # if item.enabled: - # command_strings.append(f"{name}:{item.description}") - # command_strings += [ - # str(item) - # for item in CFG.command_disply.commands.values() - # if item.enabled - # ] return "\n".join( f"{key}:{value}" for dict_item in antv_charts @@ -566,6 +556,7 @@ def _build_model_operator( data using the model. It supports both streaming and non-streaming modes. .. code-block:: python + input_node >> cache_check_branch_node cache_check_branch_node >> model_node >> save_cached_node >> join_node cache_check_branch_node >> cached_node >> join_node @@ -585,12 +576,12 @@ def _build_model_operator( Returns: BaseOperator: The final operator in the constructed DAG, typically a join node. """ - from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.core.awel import JoinOperator + from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.model.operator.model_operator import ( - ModelCacheBranchOperator, - CachedModelStreamOperator, CachedModelOperator, + CachedModelStreamOperator, + ModelCacheBranchOperator, ModelSaveCacheOperator, ModelStreamSaveCacheOperator, ) @@ -639,127 +630,3 @@ def _build_model_operator( cache_check_branch_node >> cached_node >> join_node return join_node - - -def _load_system_message( - current_message: OnceConversation, - prompt_template: PromptTemplate, - str_message: bool = True, -): - system_convs = current_message.get_system_messages() - system_text = "" - system_messages = [] - for system_conv in system_convs: - system_text += ( - system_conv.type + ":" + system_conv.content + prompt_template.sep - ) - system_messages.append( - ModelMessage(role=system_conv.type, content=system_conv.content) - ) - return system_text if str_message else system_messages - - -def _load_user_message( - current_message: OnceConversation, - prompt_template: PromptTemplate, - str_message: bool = True, -): - user_conv = current_message.get_latest_user_message() - user_messages = [] - if user_conv: - user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep - user_messages.append( - ModelMessage(role=user_conv.type, content=user_conv.content) - ) - return user_text if str_message else user_messages - else: - raise ValueError("Hi! What do you want to talk about?") - - -def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True): - example_text = "" - example_messages = [] - if prompt_template.example_selector: - for round_conv in prompt_template.example_selector.examples(): - for round_message in round_conv["messages"]: - if not round_message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = round_message["type"] - message_content = round_message["data"]["content"] - example_text += ( - message_type + ":" + message_content + prompt_template.sep - ) - example_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - return example_text if str_message else example_messages - - -def _load_history_messages( - prompt_template: PromptTemplate, - history_message: List[OnceConversation], - chat_retention_rounds: int, - str_message: bool = True, -): - history_text = "" - history_messages = [] - if prompt_template.need_historical_messages: - if history_message: - logger.info( - f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!" - ) - if len(history_message) > chat_retention_rounds: - for first_message in history_message[0]["messages"]: - if not first_message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = first_message["type"] - message_content = first_message["data"]["content"] - history_text += ( - message_type + ":" + message_content + prompt_template.sep - ) - history_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - if chat_retention_rounds > 1: - index = chat_retention_rounds - 1 - for round_conv in history_message[-index:]: - for round_message in round_conv["messages"]: - if not round_message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = round_message["type"] - message_content = round_message["data"]["content"] - history_text += ( - message_type - + ":" - + message_content - + prompt_template.sep - ) - history_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - - else: - ### user all history - for conversation in history_message: - for message in conversation["messages"]: - ### histroy message not have promot and view info - if not message["type"] in [ - ModelMessageRoleType.VIEW, - ModelMessageRoleType.SYSTEM, - ]: - message_type = message["type"] - message_content = message["data"]["content"] - history_text += ( - message_type + ":" + message_content + prompt_template.sep - ) - history_messages.append( - ModelMessage(role=message_type, content=message_content) - ) - - return history_text if str_message else history_messages diff --git a/dbgpt/app/scene/chat_agent/chat.py b/dbgpt/app/scene/chat_agent/chat.py index a46d70f37..ded1db216 100644 --- a/dbgpt/app/scene/chat_agent/chat.py +++ b/dbgpt/app/scene/chat_agent/chat.py @@ -1,10 +1,10 @@ -from typing import List, Dict import logging +from typing import Dict, List -from dbgpt.app.scene import BaseChat, ChatScene from dbgpt._private.config import Config from dbgpt.agent.plugin.commands.command_mange import ApiCall from dbgpt.agent.plugin.generator import PluginPromptGenerator +from dbgpt.app.scene import BaseChat, ChatScene from dbgpt.component import ComponentType from dbgpt.serve.agent.hub.controller import ModuleAgent from dbgpt.util.tracer import root_tracer, trace @@ -18,7 +18,7 @@ class ChatAgent(BaseChat): """Chat With Agent through plugin""" chat_scene: str = ChatScene.ChatAgent.value() - chat_retention_rounds = 0 + keep_end_rounds = 0 def __init__(self, chat_param: Dict): """Chat Agent Module Initialization diff --git a/dbgpt/app/scene/chat_agent/out_parser.py b/dbgpt/app/scene/chat_agent/out_parser.py index 5cb348ae9..e1bc399e7 100644 --- a/dbgpt/app/scene/chat_agent/out_parser.py +++ b/dbgpt/app/scene/chat_agent/out_parser.py @@ -1,4 +1,5 @@ from typing import Dict, NamedTuple + from dbgpt.core.interface.output_parser import BaseOutputParser diff --git a/dbgpt/app/scene/chat_agent/prompt.py b/dbgpt/app/scene/chat_agent/prompt.py index f716606a0..7764b0c43 100644 --- a/dbgpt/app/scene/chat_agent/prompt.py +++ b/dbgpt/app/scene/chat_agent/prompt.py @@ -1,8 +1,7 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene - +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser +from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate, SystemPromptTemplate CFG = Config() @@ -65,16 +64,19 @@ ### Whether the model service is streaming output PROMPT_NEED_STREAM_OUT = True -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(_PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE), + HumanPromptTemplate.from_template("{user_goal}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatAgent.value(), - input_variables=["tool_list", "expand_constraints", "user_goal"], - response_format=None, - template_define=_PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - temperature=1 - # example_selector=plugin_example, + need_historical_messages=False, + temperature=1, ) - -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_dashboard/chat.py b/dbgpt/app/scene/chat_dashboard/chat.py index 7fe91b516..3748ce32e 100644 --- a/dbgpt/app/scene/chat_dashboard/chat.py +++ b/dbgpt/app/scene/chat_dashboard/chat.py @@ -1,15 +1,15 @@ import json import os import uuid -from typing import List, Dict +from typing import Dict, List -from dbgpt.app.scene import BaseChat, ChatScene from dbgpt._private.config import Config +from dbgpt.app.scene import BaseChat, ChatScene +from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, ReportData, ) -from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.tracer import trace diff --git a/dbgpt/app/scene/chat_dashboard/data_loader.py b/dbgpt/app/scene/chat_dashboard/data_loader.py index baf001faa..05cc55970 100644 --- a/dbgpt/app/scene/chat_dashboard/data_loader.py +++ b/dbgpt/app/scene/chat_dashboard/data_loader.py @@ -1,6 +1,6 @@ -from typing import List -from decimal import Decimal import logging +from decimal import Decimal +from typing import List from dbgpt._private.config import Config from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem diff --git a/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py b/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py index 19e4d4923..8a8c12749 100644 --- a/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py +++ b/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py @@ -1,5 +1,6 @@ +from typing import Any, List + from dbgpt._private.pydantic import BaseModel -from typing import List, Any class ValueItem(BaseModel): diff --git a/dbgpt/app/scene/chat_dashboard/out_parser.py b/dbgpt/app/scene/chat_dashboard/out_parser.py index 27fc27aef..71955a120 100644 --- a/dbgpt/app/scene/chat_dashboard/out_parser.py +++ b/dbgpt/app/scene/chat_dashboard/out_parser.py @@ -1,9 +1,9 @@ import json import logging +from typing import List, NamedTuple -from typing import NamedTuple, List -from dbgpt.core.interface.output_parser import BaseOutputParser from dbgpt.app.scene import ChatScene +from dbgpt.core.interface.output_parser import BaseOutputParser class ChartItem(NamedTuple): diff --git a/dbgpt/app/scene/chat_dashboard/prompt.py b/dbgpt/app/scene/chat_dashboard/prompt.py index 566646cbe..ad53f8d26 100644 --- a/dbgpt/app/scene/chat_dashboard/prompt.py +++ b/dbgpt/app/scene/chat_dashboard/prompt.py @@ -1,8 +1,9 @@ import json -from dbgpt.core.interface.prompt import PromptTemplate + from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_dashboard.out_parser import ChatDashboardOutputParser +from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate, SystemPromptTemplate CFG = Config() @@ -39,16 +40,23 @@ } ] - PROMPT_NEED_STREAM_OUT = False -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE, + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + ), + HumanPromptTemplate.from_template("{input}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatDashboard.value(), - input_variables=["input", "table_info", "dialect", "supported_chat_type"], - response_format=json.dumps(RESPONSE_FORMAT, indent=4), - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=ChatDashboardOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py index 4ace7e0b5..fa96b7a76 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -1,14 +1,14 @@ -import os import logging - +import os from typing import Dict -from dbgpt.app.scene import BaseChat, ChatScene + from dbgpt._private.config import Config from dbgpt.agent.plugin.commands.command_mange import ApiCall -from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader +from dbgpt.app.scene import BaseChat, ChatScene from dbgpt.app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning -from dbgpt.util.path_utils import has_path +from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH +from dbgpt.util.path_utils import has_path from dbgpt.util.tracer import root_tracer, trace CFG = Config() @@ -20,7 +20,8 @@ class ChatExcel(BaseChat): """a Excel analyzer to analyze Excel Data""" chat_scene: str = ChatScene.ChatExcel.value() - chat_retention_rounds = 2 + keep_start_rounds = 1 + keep_end_rounds = 2 def __init__(self, chat_param: Dict): """Chat Excel Module Initialization @@ -58,7 +59,7 @@ async def generate_input_values(self) -> Dict: async def prepare(self): logger.info(f"{self.chat_mode} prepare start!") - if len(self.history_message) > 0: + if self.has_history_messages(): return None chat_param = { "chat_session_id": self.chat_session_id, diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py index 9a2e7af1c..34a47f493 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py @@ -1,8 +1,9 @@ import json import logging from typing import NamedTuple -from dbgpt.core.interface.output_parser import BaseOutputParser + from dbgpt._private.config import Config +from dbgpt.core.interface.output_parser import BaseOutputParser CFG = Config() diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py index b295288a3..49eebd37e 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py @@ -1,9 +1,14 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.out_parser import ( ChatExcelOutputParser, ) +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) CFG = Config() @@ -59,15 +64,20 @@ # For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. PROMPT_TEMPERATURE = 0.3 -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(_PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{user_input}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatExcel.value(), - input_variables=["user_input", "table_name", "disply_type"], - template_define=_PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=ChatExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), need_historical_messages=True, - # example_selector=sql_data_example, temperature=PROMPT_TEMPERATURE, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py index db9927fd3..d28b3f79c 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py @@ -1,10 +1,10 @@ import json from typing import Any, Dict -from dbgpt.core.interface.message import ViewMessage, AIMessage from dbgpt.app.scene import BaseChat, ChatScene -from dbgpt.util.json_utils import EnhancedJSONEncoder +from dbgpt.core.interface.message import AIMessage, ViewMessage from dbgpt.util.executor_utils import blocking_func_to_async +from dbgpt.util.json_utils import EnhancedJSONEncoder from dbgpt.util.tracer import trace @@ -52,6 +52,7 @@ async def generate_input_values(self) -> Dict: def message_adjust(self): ### adjust learning result in messages + # TODO: Can't work in multi-rounds chat view_message = "" for message in self.current_message.messages: if message.type == ViewMessage.type: diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py index 9cd89c34b..311e349e7 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -1,6 +1,7 @@ import json import logging -from typing import NamedTuple, List +from typing import List, NamedTuple + from dbgpt.core.interface.output_parser import BaseOutputParser diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py index 603799e10..120901809 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py @@ -1,10 +1,17 @@ import json -from dbgpt.core.interface.prompt import PromptTemplate + from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_data.chat_excel.excel_learning.out_parser import ( LearningExcelOutputParser, ) +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) +from dbgpt.core.interface.prompt import PromptTemplate CFG = Config() @@ -72,15 +79,24 @@ # For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. PROMPT_TEMPERATURE = 0.8 -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE, + response_format=json.dumps( + RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4 + ), + ), + HumanPromptTemplate.from_template("{file_name}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ExcelLearning.value(), - input_variables=["data_example"], - response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=LearningExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - # example_selector=sql_data_example, + need_historical_messages=False, temperature=PROMPT_TEMPERATURE, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py index 854c2bfe7..f05393db8 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py @@ -1,4 +1,5 @@ import re + import sqlparse diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py b/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py index 0ad2c70d5..f234ab2f6 100644 --- a/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py +++ b/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py @@ -1,20 +1,20 @@ import logging - -import duckdb import os -import sqlparse -import pandas as pd + import chardet +import duckdb import numpy as np +import pandas as pd +import sqlparse from pyparsing import ( CaselessKeyword, - Word, - alphanums, - delimitedList, Forward, - Optional, Literal, + Optional, Regex, + Word, + alphanums, + delimitedList, ) from dbgpt.util.pd_utils import csv_colunm_foramt diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py index 3a3269d7e..8b59aeb42 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/chat.py +++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py @@ -1,8 +1,8 @@ from typing import Dict +from dbgpt._private.config import Config from dbgpt.agent.plugin.commands.command_mange import ApiCall from dbgpt.app.scene import BaseChat, ChatScene -from dbgpt._private.config import Config from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.tracer import root_tracer, trace diff --git a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py index dc1f50c66..d2e1eae96 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py +++ b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py @@ -1,11 +1,13 @@ import json -from typing import Dict, NamedTuple import logging -import sqlparse import xml.etree.ElementTree as ET -from dbgpt.util.json_utils import serialize -from dbgpt.core.interface.output_parser import BaseOutputParser +from typing import Dict, NamedTuple + +import sqlparse + from dbgpt._private.config import Config +from dbgpt.core.interface.output_parser import BaseOutputParser +from dbgpt.util.json_utils import serialize CFG = Config() diff --git a/dbgpt/app/scene/chat_db/auto_execute/prompt.py b/dbgpt/app/scene/chat_db/auto_execute/prompt.py index 6fb0d7de6..52b153215 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/prompt.py +++ b/dbgpt/app/scene/chat_db/auto_execute/prompt.py @@ -1,8 +1,14 @@ import json -from dbgpt.core.interface.prompt import PromptTemplate + from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_db.auto_execute.out_parser import DbChatOutputParser +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) CFG = Config() @@ -77,16 +83,25 @@ # For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. PROMPT_TEMPERATURE = 0.5 -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + _DEFAULT_TEMPLATE, + response_format=json.dumps( + RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4 + ), + ), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{user_input}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatWithDbExecute.value(), - input_variables=["table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=DbChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - # example_selector=sql_data_example, temperature=PROMPT_TEMPERATURE, - need_historical_messages=True, + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py b/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py index 25471da55..e994f51ec 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py +++ b/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py @@ -2,10 +2,11 @@ # -*- coding: utf-8 -*- import json -from dbgpt.core.interface.prompt import PromptTemplate + from dbgpt._private.config import Config from dbgpt.app.scene import ChatScene from dbgpt.app.scene.chat_db.auto_execute.out_parser import DbChatOutputParser +from dbgpt.core.interface.prompt import PromptTemplate CFG = Config() diff --git a/dbgpt/app/scene/chat_db/data_loader.py b/dbgpt/app/scene/chat_db/data_loader.py index 7abaf8b53..5004febd8 100644 --- a/dbgpt/app/scene/chat_db/data_loader.py +++ b/dbgpt/app/scene/chat_db/data_loader.py @@ -1,24 +1,12 @@ -import xml.etree.ElementTree as ET import json import logging +import xml.etree.ElementTree as ET from dbgpt.util.json_utils import serialize class DbDataLoader: def get_table_view_by_conn(self, data, speak, sql: str = None): - # import pandas as pd - # - # ### tool out data to table view - # if len(data) < 1: - # data.insert(0, ["result"]) - # df = pd.DataFrame(data[1:], columns=data[0]) - # html_table = df.to_html(index=False, escape=False, sparsify=False) - # table_str = "".join(html_table.split()) - # html = f"""
{table_str}
""" - # view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") - # return view_text - param = {} api_call_element = ET.Element("chart-view") err_msg = None diff --git a/dbgpt/app/scene/chat_db/professional_qa/chat.py b/dbgpt/app/scene/chat_db/professional_qa/chat.py index 01d154da9..fb616cf50 100644 --- a/dbgpt/app/scene/chat_db/professional_qa/chat.py +++ b/dbgpt/app/scene/chat_db/professional_qa/chat.py @@ -1,7 +1,7 @@ from typing import Dict -from dbgpt.app.scene import BaseChat, ChatScene from dbgpt._private.config import Config +from dbgpt.app.scene import BaseChat, ChatScene from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.tracer import trace @@ -11,6 +11,8 @@ class ChatWithDbQA(BaseChat): chat_scene: str = ChatScene.ChatWithDbQA.value() + keep_end_rounds = 5 + """As a DBA, Chat DB Module, chat with combine DB meta schema """ def __init__(self, chat_param: Dict): diff --git a/dbgpt/app/scene/chat_db/professional_qa/prompt.py b/dbgpt/app/scene/chat_db/professional_qa/prompt.py index 6c3efc252..d7a9846c7 100644 --- a/dbgpt/app/scene/chat_db/professional_qa/prompt.py +++ b/dbgpt/app/scene/chat_db/professional_qa/prompt.py @@ -1,29 +1,15 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser - -CFG = Config() - -PROMPT_SCENE_DEFINE = ( - """You are an assistant that answers user specialized database questions. """ +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, ) -# PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: -# {table_info} -# -# Question: {input} -# -# """ +CFG = Config() -# _DEFAULT_TEMPLATE = """ -# You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. -# Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. -# You can order the results by a relevant column to return the most interesting examples in the database. -# Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. -# Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. -# -# """ _DEFAULT_TEMPLATE_EN = """ Provide professional answers to requests and questions. If you can't get an answer from what you've provided, say: "Insufficient information in the knowledge base is available to answer this question." Feel free to fudge information. @@ -43,7 +29,7 @@ 问题: {input} -一步步思考 +一步步思考。 """ _DEFAULT_TEMPLATE = ( @@ -53,14 +39,24 @@ PROMPT_NEED_STREAM_OUT = True -prompt = PromptTemplate( + +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(_DEFAULT_TEMPLATE), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{input}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatWithDbQA.value(), - input_variables=["input", "table_info"], - response_format=None, - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), + need_historical_messages=True, ) -CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) + +CFG.prompt_template_registry.register( + prompt_adapter, language=CFG.LANGUAGE, is_default=True +) diff --git a/dbgpt/app/scene/chat_execution/chat.py b/dbgpt/app/scene/chat_execution/chat.py index b1c4f75c1..44b9e586e 100644 --- a/dbgpt/app/scene/chat_execution/chat.py +++ b/dbgpt/app/scene/chat_execution/chat.py @@ -1,9 +1,9 @@ -from typing import List, Dict +from typing import Dict, List -from dbgpt.app.scene import BaseChat, ChatScene from dbgpt._private.config import Config from dbgpt.agent.plugin.commands.command import execute_command from dbgpt.agent.plugin.generator import PluginPromptGenerator +from dbgpt.app.scene import BaseChat, ChatScene from dbgpt.util.tracer import trace CFG = Config() diff --git a/dbgpt/app/scene/chat_execution/out_parser.py b/dbgpt/app/scene/chat_execution/out_parser.py index 37e71766d..bd4da7bc9 100644 --- a/dbgpt/app/scene/chat_execution/out_parser.py +++ b/dbgpt/app/scene/chat_execution/out_parser.py @@ -1,8 +1,8 @@ import json import logging from typing import Dict, NamedTuple -from dbgpt.core.interface.output_parser import BaseOutputParser, T +from dbgpt.core.interface.output_parser import BaseOutputParser, T logger = logging.getLogger(__name__) diff --git a/dbgpt/app/scene/chat_execution/prompt.py b/dbgpt/app/scene/chat_execution/prompt.py index 2292292f3..9ada61f60 100644 --- a/dbgpt/app/scene/chat_execution/prompt.py +++ b/dbgpt/app/scene/chat_execution/prompt.py @@ -1,9 +1,14 @@ import json -from dbgpt.core.interface.prompt import PromptTemplate -from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt._private.config import Config +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) CFG = Config() @@ -34,15 +39,23 @@ ### Whether the model service is streaming output PROMPT_NEED_STREAM_OUT = False -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE, + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + ), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{input}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatExecution.value(), - input_variables=["input", "constraints", "commands_infos", "response"], - response_format=json.dumps(RESPONSE_FORMAT, indent=4), - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - # example_selector=plugin_example, + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_factory.py b/dbgpt/app/scene/chat_factory.py index 258020fac..73470f65f 100644 --- a/dbgpt/app/scene/chat_factory.py +++ b/dbgpt/app/scene/chat_factory.py @@ -1,4 +1,5 @@ from dbgpt.app.scene.base_chat import BaseChat +from dbgpt.core import PromptTemplate from dbgpt.util.singleton import Singleton from dbgpt.util.tracer import root_tracer @@ -7,31 +8,31 @@ class ChatFactory(metaclass=Singleton): @staticmethod def get_implementation(chat_mode, **kwargs): # Lazy loading - from dbgpt.app.scene.chat_execution.chat import ChatWithPlugin - from dbgpt.app.scene.chat_execution.prompt import prompt - from dbgpt.app.scene.chat_normal.chat import ChatNormal - from dbgpt.app.scene.chat_normal.prompt import prompt - from dbgpt.app.scene.chat_db.professional_qa.chat import ChatWithDbQA - from dbgpt.app.scene.chat_db.professional_qa.prompt import prompt - from dbgpt.app.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute - from dbgpt.app.scene.chat_db.auto_execute.prompt import prompt + from dbgpt.app.scene.chat_agent.chat import ChatAgent + from dbgpt.app.scene.chat_agent.prompt import prompt from dbgpt.app.scene.chat_dashboard.chat import ChatDashboard from dbgpt.app.scene.chat_dashboard.prompt import prompt - from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge - from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt - from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet - from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt + from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel + from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt + from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt + from dbgpt.app.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute + from dbgpt.app.scene.chat_db.auto_execute.prompt import prompt + from dbgpt.app.scene.chat_db.professional_qa.chat import ChatWithDbQA + from dbgpt.app.scene.chat_db.professional_qa.prompt import prompt + from dbgpt.app.scene.chat_execution.chat import ChatWithPlugin + from dbgpt.app.scene.chat_execution.prompt import prompt from dbgpt.app.scene.chat_knowledge.extract_entity.chat import ExtractEntity from dbgpt.app.scene.chat_knowledge.extract_entity.prompt import prompt + from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet + from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt from dbgpt.app.scene.chat_knowledge.refine_summary.chat import ( ExtractRefineSummary, ) from dbgpt.app.scene.chat_knowledge.refine_summary.prompt import prompt - from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel - from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt - from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt - from dbgpt.app.scene.chat_agent.chat import ChatAgent - from dbgpt.app.scene.chat_agent.prompt import prompt + from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge + from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt + from dbgpt.app.scene.chat_normal.chat import ChatNormal + from dbgpt.app.scene.chat_normal.prompt import prompt chat_classes = BaseChat.__subclasses__() implementation = None diff --git a/dbgpt/app/scene/chat_knowledge/__init__.py b/dbgpt/app/scene/chat_knowledge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py b/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py index 9551c1a4e..64e296dea 100644 --- a/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py +++ b/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py @@ -1,7 +1,7 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser +from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate CFG = Config() @@ -21,20 +21,21 @@ """ PROMPT_RESPONSE = """""" - -RESPONSE_FORMAT = """""" - - PROMPT_NEED_NEED_STREAM_OUT = False -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE), + HumanPromptTemplate.from_template(_DEFAULT_TEMPLATE + PROMPT_RESPONSE), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ExtractEntity.value(), - input_variables=["text"], - response_format="", - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, output_parser=ExtractEntityParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT), + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py b/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py index 4ee5e967a..49d0d4ed3 100644 --- a/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py +++ b/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py @@ -1,12 +1,9 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene - - +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_knowledge.extract_triplet.out_parser import ( ExtractTripleParser, ) - +from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate CFG = Config() @@ -33,19 +30,22 @@ PROMPT_RESPONSE = """""" -RESPONSE_FORMAT = """""" +PROMPT_NEED_NEED_STREAM_OUT = False -PROMPT_NEED_NEED_STREAM_OUT = False +prompt = ChatPromptTemplate( + messages=[ + # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE), + HumanPromptTemplate.from_template(_DEFAULT_TEMPLATE + PROMPT_RESPONSE), + ] +) -prompt = PromptTemplate( +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ExtractTriplet.value(), - input_variables=["text"], - response_format="", - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, output_parser=ExtractTripleParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT), + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py b/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py index 7e94168ce..ec2f1e9c1 100644 --- a/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py +++ b/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py @@ -1,10 +1,9 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene - +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_knowledge.refine_summary.out_parser import ( ExtractRefineSummaryParser, ) +from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate CFG = Config() @@ -27,17 +26,21 @@ PROMPT_RESPONSE = """""" - PROMPT_NEED_NEED_STREAM_OUT = True -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE), + HumanPromptTemplate.from_template(_DEFAULT_TEMPLATE + PROMPT_RESPONSE), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ExtractRefineSummary.value(), - input_variables=["existing_answer"], - response_format=None, - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, output_parser=ExtractRefineSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT), + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, is_default=True) +CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index 1d4854254..4090d9238 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -3,20 +3,22 @@ from functools import reduce from typing import Dict, List -from dbgpt.app.scene import BaseChat, ChatScene from dbgpt._private.config import Config -from dbgpt.component import ComponentType - -from dbgpt.configs.model_config import ( - EMBEDDING_MODEL_CONFIG, -) - from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity from dbgpt.app.knowledge.document_db import ( KnowledgeDocumentDao, KnowledgeDocumentEntity, ) from dbgpt.app.knowledge.service import KnowledgeService +from dbgpt.app.scene import BaseChat, ChatScene +from dbgpt.component import ComponentType +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.rag.retriever.rewrite import QueryRewrite @@ -130,8 +132,19 @@ def stream_call_reinforce_fn(self, text): @trace() async def generate_input_values(self) -> Dict: if self.space_context and self.space_context.get("prompt"): - self.prompt_template.template_define = self.space_context["prompt"]["scene"] - self.prompt_template.template = self.space_context["prompt"]["template"] + # Not use template_define + # self.prompt_template.template_define = self.space_context["prompt"]["scene"] + # self.prompt_template.template = self.space_context["prompt"]["template"] + # Replace the template with the prompt template + self.prompt_template.prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + self.space_context["prompt"]["template"] + ), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{question}"), + ] + ) from dbgpt.util.chat_util import run_async_tasks tasks = [self.execute_similar_search(self.current_user_input)] diff --git a/dbgpt/app/scene/chat_knowledge/v1/out_parser.py b/dbgpt/app/scene/chat_knowledge/v1/out_parser.py index 8d30e98fe..2660cabea 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/out_parser.py +++ b/dbgpt/app/scene/chat_knowledge/v1/out_parser.py @@ -1,6 +1,6 @@ import logging -from dbgpt.core.interface.output_parser import BaseOutputParser, T +from dbgpt.core.interface.output_parser import BaseOutputParser, T logger = logging.getLogger(__name__) diff --git a/dbgpt/app/scene/chat_knowledge/v1/prompt.py b/dbgpt/app/scene/chat_knowledge/v1/prompt.py index ee6a166a5..b32e8f5e4 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/prompt.py +++ b/dbgpt/app/scene/chat_knowledge/v1/prompt.py @@ -1,8 +1,12 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene - +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) CFG = Config() @@ -28,17 +32,23 @@ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH ) - PROMPT_NEED_STREAM_OUT = True +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(_DEFAULT_TEMPLATE), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{question}"), + ] +) -prompt = PromptTemplate( +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatKnowledge.value(), - input_variables=["context", "question"], - response_format=None, - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), + need_historical_messages=False, ) -CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) +CFG.prompt_template_registry.register( + prompt_adapter, language=CFG.LANGUAGE, is_default=True +) diff --git a/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py b/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py index bfbdaed1f..ed76f69c6 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py +++ b/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py @@ -1,9 +1,12 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene - +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser - +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) CFG = Config() @@ -32,18 +35,24 @@ PROMPT_NEED_STREAM_OUT = True -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(_DEFAULT_TEMPLATE), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{question}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatKnowledge.value(), - input_variables=["context", "question"], - response_format=None, - template_define=None, - template=_DEFAULT_TEMPLATE, - stream_out=PROMPT_NEED_STREAM_OUT, + stream_out=True, output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), + need_historical_messages=False, ) CFG.prompt_template_registry.register( - prompt, + prompt_adapter, language=CFG.LANGUAGE, is_default=False, model_names=["chatglm-6b-int4", "chatglm-6b", "chatglm2-6b", "chatglm2-6b-int4"], diff --git a/dbgpt/app/scene/chat_normal/chat.py b/dbgpt/app/scene/chat_normal/chat.py index 702a190b9..30b22ebd9 100644 --- a/dbgpt/app/scene/chat_normal/chat.py +++ b/dbgpt/app/scene/chat_normal/chat.py @@ -1,8 +1,7 @@ from typing import Dict -from dbgpt.app.scene import BaseChat, ChatScene from dbgpt._private.config import Config - +from dbgpt.app.scene import BaseChat, ChatScene from dbgpt.util.tracer import trace CFG = Config() @@ -11,6 +10,8 @@ class ChatNormal(BaseChat): chat_scene: str = ChatScene.ChatNormal.value() + keep_end_rounds: int = 10 + """Number of results to return from the query""" def __init__(self, chat_param: Dict): diff --git a/dbgpt/app/scene/chat_normal/out_parser.py b/dbgpt/app/scene/chat_normal/out_parser.py index 8d30e98fe..2660cabea 100644 --- a/dbgpt/app/scene/chat_normal/out_parser.py +++ b/dbgpt/app/scene/chat_normal/out_parser.py @@ -1,6 +1,6 @@ import logging -from dbgpt.core.interface.output_parser import BaseOutputParser, T +from dbgpt.core.interface.output_parser import BaseOutputParser, T logger = logging.getLogger(__name__) diff --git a/dbgpt/app/scene/chat_normal/prompt.py b/dbgpt/app/scene/chat_normal/prompt.py index 752193936..04dd28a51 100644 --- a/dbgpt/app/scene/chat_normal/prompt.py +++ b/dbgpt/app/scene/chat_normal/prompt.py @@ -1,25 +1,40 @@ -from dbgpt.core.interface.prompt import PromptTemplate from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene - +from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) -PROMPT_SCENE_DEFINE = None +PROMPT_SCENE_DEFINE_EN = "You are a helpful AI assistant." +PROMPT_SCENE_DEFINE_ZH = "你是一个有用的 AI 助手。" CFG = Config() +PROMPT_SCENE_DEFINE = ( + PROMPT_SCENE_DEFINE_ZH if CFG.LANGUAGE == "zh" else PROMPT_SCENE_DEFINE_EN +) PROMPT_NEED_STREAM_OUT = True -prompt = PromptTemplate( +prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{input}"), + ] +) + +prompt_adapter = AppScenePromptTemplateAdapter( + prompt=prompt, template_scene=ChatScene.ChatNormal.value(), - input_variables=["input"], - response_format=None, - template_define=PROMPT_SCENE_DEFINE, - template=None, stream_out=PROMPT_NEED_STREAM_OUT, output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), + need_historical_messages=True, ) -# CFG.prompt_templates.update({prompt.template_scene: prompt}) -CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True) +CFG.prompt_template_registry.register( + prompt_adapter, language=CFG.LANGUAGE, is_default=True +) diff --git a/dbgpt/app/scene/operator/_experimental.py b/dbgpt/app/scene/operator/_experimental.py deleted file mode 100644 index c422846c0..000000000 --- a/dbgpt/app/scene/operator/_experimental.py +++ /dev/null @@ -1,262 +0,0 @@ -from typing import Dict, Optional, List -from dataclasses import dataclass -import datetime -import os - -from dbgpt.configs.model_config import PILOT_PATH -from dbgpt.core.awel import MapOperator -from dbgpt.core.interface.prompt import PromptTemplate -from dbgpt._private.config import Config -from dbgpt.app.scene import ChatScene -from dbgpt.core.interface.message import OnceConversation -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType -from dbgpt.rag.retriever.embedding import EmbeddingRetriever - -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory -from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory -from dbgpt.storage.vector_store.base import VectorStoreConfig -from dbgpt.storage.vector_store.connector import VectorStoreConnector - -# TODO move global config -CFG = Config() - - -@dataclass -class ChatContext: - current_user_input: str - model_name: Optional[str] - chat_session_id: Optional[str] = None - select_param: Optional[str] = None - chat_scene: Optional[ChatScene] = ChatScene.ChatNormal - prompt_template: Optional[PromptTemplate] = None - chat_retention_rounds: Optional[int] = 0 - history_storage: Optional[BaseChatHistoryMemory] = None - history_manager: Optional["ChatHistoryManager"] = None - # The input values for prompt template - input_values: Optional[Dict] = None - echo: Optional[bool] = False - - def build_model_payload(self) -> Dict: - if not self.input_values: - raise ValueError("The input value can't be empty") - llm_messages = self.history_manager._new_chat(self.input_values) - return { - "model": self.model_name, - "prompt": "", - "messages": llm_messages, - "temperature": float(self.prompt_template.temperature), - "max_new_tokens": int(self.prompt_template.max_new_tokens), - "echo": self.echo, - } - - -class ChatHistoryManager: - def __init__( - self, - chat_ctx: ChatContext, - prompt_template: PromptTemplate, - history_storage: BaseChatHistoryMemory, - chat_retention_rounds: Optional[int] = 0, - ) -> None: - self._chat_ctx = chat_ctx - self.chat_retention_rounds = chat_retention_rounds - self.current_message: OnceConversation = OnceConversation( - chat_ctx.chat_scene.value() - ) - self.prompt_template = prompt_template - self.history_storage: BaseChatHistoryMemory = history_storage - self.history_message: List[OnceConversation] = history_storage.messages() - self.current_message.model_name = chat_ctx.model_name - if chat_ctx.select_param: - if len(chat_ctx.chat_scene.param_types()) > 0: - self.current_message.param_type = chat_ctx.chat_scene.param_types()[0] - self.current_message.param_value = chat_ctx.select_param - - def _new_chat(self, input_values: Dict) -> List[ModelMessage]: - self.current_message.chat_order = len(self.history_message) + 1 - self.current_message.add_user_message( - self._chat_ctx.current_user_input, check_duplicate_type=True - ) - self.current_message.start_date = datetime.datetime.now().strftime( - "%Y-%m-%d %H:%M:%S" - ) - self.current_message.tokens = 0 - if self.prompt_template.template: - current_prompt = self.prompt_template.format(**input_values) - self.current_message.add_system_message(current_prompt) - return self._generate_llm_messages() - - def _generate_llm_messages(self) -> List[ModelMessage]: - from dbgpt.app.scene.base_chat import ( - _load_system_message, - _load_example_messages, - _load_history_messages, - _load_user_message, - ) - - messages = [] - ### Load scene setting or character definition as system message - if self.prompt_template.template_define: - messages.append( - ModelMessage( - role=ModelMessageRoleType.SYSTEM, - content=self.prompt_template.template_define, - ) - ) - ### Load prompt - messages += _load_system_message( - self.current_message, self.prompt_template, str_message=False - ) - ### Load examples - messages += _load_example_messages(self.prompt_template, str_message=False) - - ### Load History - messages += _load_history_messages( - self.prompt_template, - self.history_message, - self.chat_retention_rounds, - str_message=False, - ) - - ### Load User Input - messages += _load_user_message( - self.current_message, self.prompt_template, str_message=False - ) - return messages - - -class PromptManagerOperator(MapOperator[ChatContext, ChatContext]): - def __init__(self, prompt_template: PromptTemplate = None, **kwargs): - super().__init__(**kwargs) - self._prompt_template = prompt_template - - async def map(self, input_value: ChatContext) -> ChatContext: - if not self._prompt_template: - self._prompt_template: PromptTemplate = ( - CFG.prompt_template_registry.get_prompt_template( - input_value.chat_scene.value(), - language=CFG.LANGUAGE, - model_name=input_value.model_name, - proxyllm_backend=CFG.PROXYLLM_BACKEND, - ) - ) - input_value.prompt_template = self._prompt_template - return input_value - - -class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]): - def __init__(self, history: BaseChatHistoryMemory = None, **kwargs): - super().__init__(**kwargs) - self._history = history - - async def map(self, input_value: ChatContext) -> ChatContext: - if self._history: - return self._history - chat_history_fac = ChatHistory() - input_value.history_storage = chat_history_fac.get_store_instance( - input_value.chat_session_id - ) - return input_value - - -class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]): - def __init__(self, history: BaseChatHistoryMemory = None, **kwargs): - super().__init__(**kwargs) - self._history = history - - async def map(self, input_value: ChatContext) -> ChatContext: - history_storage = self._history or input_value.history_storage - if not history_storage: - from dbgpt.storage.chat_history.store_type.mem_history import ( - MemHistoryMemory, - ) - - history_storage = MemHistoryMemory(input_value.chat_session_id) - input_value.history_storage = history_storage - input_value.history_manager = ChatHistoryManager( - input_value, - input_value.prompt_template, - history_storage, - input_value.chat_retention_rounds, - ) - return input_value - - -class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def map(self, input_value: ChatContext) -> ChatContext: - from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG - from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory - - # TODO, decompose the current operator into some atomic operators - knowledge_space = input_value.select_param - embedding_factory = self.system_app.get_component( - "embedding_factory", EmbeddingFactory - ) - - space_context = await self._get_space_context(knowledge_space) - top_k = ( - CFG.KNOWLEDGE_SEARCH_TOP_SIZE - if space_context is None - else int(space_context["embedding"]["topk"]) - ) - max_token = ( - CFG.KNOWLEDGE_SEARCH_MAX_TOKEN - if space_context is None or space_context.get("prompt") is None - else int(space_context["prompt"]["max_token"]) - ) - input_value.prompt_template.template_is_strict = False - if space_context and space_context.get("prompt"): - input_value.prompt_template.template_define = space_context["prompt"][ - "scene" - ] - input_value.prompt_template.template = space_context["prompt"]["template"] - - config = VectorStoreConfig( - name=knowledge_space, - embedding_fn=embedding_factory.create( - EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] - ), - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=CFG.VECTOR_STORE_TYPE, - vector_store_config=config, - ) - embedding_retriever = EmbeddingRetriever( - top_k=top_k, vector_store_connector=vector_store_connector - ) - docs = await self.blocking_func_to_async( - embedding_retriever.retrieve, - input_value.current_user_input, - ) - if not docs or len(docs) == 0: - print("no relevant docs to retrieve") - context = "no relevant docs to retrieve" - else: - context = [d.content for d in docs] - context = context[:max_token] - relations = list( - set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs]) - ) - input_value.input_values = { - "context": context, - "question": input_value.current_user_input, - "relations": relations, - } - return input_value - - async def _get_space_context(self, space_name): - from dbgpt.app.knowledge.service import KnowledgeService - - service = KnowledgeService() - return await self.blocking_func_to_async(service.get_space_context, space_name) - - -class BaseChatOperator(MapOperator[ChatContext, Dict]): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def map(self, input_value: ChatContext) -> Dict: - return input_value.build_model_payload() diff --git a/dbgpt/app/scene/operator/app_operator.py b/dbgpt/app/scene/operator/app_operator.py new file mode 100644 index 000000000..799b89505 --- /dev/null +++ b/dbgpt/app/scene/operator/app_operator.py @@ -0,0 +1,85 @@ +import dataclasses +from typing import Any, Dict, List, Optional + +from dbgpt.core import BaseMessage, ChatPromptTemplate, ModelMessage +from dbgpt.core.awel import ( + DAG, + BaseOperator, + InputOperator, + MapOperator, + SimpleCallDataInputSource, +) +from dbgpt.core.operator import ( + BufferedConversationMapperOperator, + HistoryPromptBuilderOperator, +) + + +@dataclasses.dataclass +class ChatComposerInput: + """The composer input.""" + + messages: List[BaseMessage] + prompt_dict: Dict[str, Any] + + +class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]): + """App chat composer operator. + + TODO: Support more history merge mode. + """ + + def __init__( + self, + prompt: ChatPromptTemplate, + history_key: str = "chat_history", + history_merge_mode: str = "window", + keep_start_rounds: Optional[int] = None, + keep_end_rounds: Optional[int] = None, + str_history: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self._prompt_template = prompt + self._history_key = history_key + self._history_merge_mode = history_merge_mode + self._keep_start_rounds = keep_start_rounds + self._keep_end_rounds = keep_end_rounds + self._str_history = str_history + self._sub_compose_dag = self._build_composer_dag() + + async def map(self, input_value: ChatComposerInput) -> List[ModelMessage]: + end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] + # Sub dag, use the same dag context in the parent dag + return await end_node.call( + call_data={"data": input_value}, dag_ctx=self.current_dag_context + ) + + def _build_composer_dag(self) -> DAG: + with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag: + input_task = InputOperator(input_source=SimpleCallDataInputSource()) + # History transform task + history_transform_task = BufferedConversationMapperOperator( + keep_start_rounds=self._keep_start_rounds, + keep_end_rounds=self._keep_end_rounds, + ) + history_prompt_build_task = HistoryPromptBuilderOperator( + prompt=self._prompt_template, + history_key=self._history_key, + check_storage=False, + str_history=self._str_history, + ) + # Build composer dag + ( + input_task + >> MapOperator(lambda x: x.messages) + >> history_transform_task + >> history_prompt_build_task + ) + ( + input_task + >> MapOperator(lambda x: x.prompt_dict) + >> history_prompt_build_task + ) + + return composer_dag diff --git a/dbgpt/app/tests/test_base.py b/dbgpt/app/tests/test_base.py index 4f5a21a9f..cefea1cb9 100644 --- a/dbgpt/app/tests/test_base.py +++ b/dbgpt/app/tests/test_base.py @@ -1,5 +1,6 @@ +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from sqlalchemy.exc import OperationalError, SQLAlchemyError from dbgpt.app.base import _create_mysql_database diff --git a/dbgpt/cli/cli_scripts.py b/dbgpt/cli/cli_scripts.py index fcd8edc51..e45a32819 100644 --- a/dbgpt/cli/cli_scripts.py +++ b/dbgpt/cli/cli_scripts.py @@ -1,7 +1,8 @@ -import click import copy import logging +import click + logging.basicConfig( level=logging.WARNING, encoding="utf-8", @@ -82,14 +83,14 @@ def stop_all(): try: from dbgpt.model.cli import ( + _stop_all_model_server, model_cli_group, + start_apiserver, start_model_controller, - stop_model_controller, start_model_worker, - stop_model_worker, - start_apiserver, stop_apiserver, - _stop_all_model_server, + stop_model_controller, + stop_model_worker, ) add_command_alias(model_cli_group, name="model", parent_group=cli) @@ -107,10 +108,10 @@ def stop_all(): try: from dbgpt.app._cli import ( - start_webserver, - stop_webserver, _stop_all_dbgpt_server, migration, + start_webserver, + stop_webserver, ) add_command_alias(start_webserver, name="webserver", parent_group=start) diff --git a/dbgpt/component.py b/dbgpt/component.py index 466d01c80..f094213a1 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -1,13 +1,14 @@ from __future__ import annotations -from abc import ABC, abstractmethod +import asyncio +import logging import sys -from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING +from abc import ABC, abstractmethod from enum import Enum -import logging -import asyncio -from dbgpt.util.annotations import PublicAPI +from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union + from dbgpt.util import AppConfig +from dbgpt.util.annotations import PublicAPI # Checking for type hints during runtime if TYPE_CHECKING: @@ -80,6 +81,9 @@ class ComponentType(str, Enum): UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory" +_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" + + @PublicAPI(stability="beta") class BaseComponent(LifeCycle, ABC): """Abstract Base Component class. All custom components should extend this.""" @@ -98,11 +102,37 @@ def init_app(self, system_app: SystemApp): with the main system app. """ + @classmethod + def get_instance( + cls, + system_app: SystemApp, + default_component=_EMPTY_DEFAULT_COMPONENT, + or_register_component: Type[BaseComponent] = None, + *args, + **kwargs, + ) -> BaseComponent: + """Get the current component instance. + + Args: + system_app (SystemApp): The system app + default_component : The default component instance if not retrieve by name + or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name + + Returns: + BaseComponent: The component instance + """ + return system_app.get_component( + cls.name, + cls, + default_component=default_component, + or_register_component=or_register_component, + *args, + **kwargs, + ) + T = TypeVar("T", bound=BaseComponent) -_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" - @PublicAPI(stability="beta") class SystemApp(LifeCycle): diff --git a/dbgpt/core/_private/prompt_registry.py b/dbgpt/core/_private/prompt_registry.py index e35282b6f..5684b1f6a 100644 --- a/dbgpt/core/_private/prompt_registry.py +++ b/dbgpt/core/_private/prompt_registry.py @@ -20,8 +20,9 @@ def register( self, prompt_template, language: str = "en", - is_default=False, + is_default: bool = False, model_names: List[str] = None, + scene_name: str = None, ) -> None: """Register prompt template with scene name, language registry dict format: @@ -37,7 +38,8 @@ def register( } } """ - scene_name = prompt_template.template_scene + if not scene_name: + scene_name = prompt_template.template_scene if not scene_name: raise ValueError("Prompt template scene name cannot be empty") if not model_names: diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index cded3f18a..25ba8b665 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -45,6 +45,18 @@ def to_dict(self) -> Dict: "round_index": self.round_index, } + @staticmethod + def messages_to_string(messages: List["BaseMessage"]) -> str: + """Convert messages to str + + Args: + messages (List[BaseMessage]): The messages + + Returns: + str: The str messages + """ + return _messages_to_str(messages) + class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" @@ -251,6 +263,41 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]: return [_message_to_dict(m) for m in messages] +def _messages_to_str( + messages: List[BaseMessage], + human_prefix: str = "Human", + ai_prefix: str = "AI", + system_prefix: str = "System", +) -> str: + """Convert messages to str + + Args: + messages (List[BaseMessage]): The messages + human_prefix (str): The human prefix + ai_prefix (str): The ai prefix + system_prefix (str): The system prefix + + Returns: + str: The str messages + """ + str_messages = [] + for message in messages: + role = None + if isinstance(message, HumanMessage): + role = human_prefix + elif isinstance(message, AIMessage): + role = ai_prefix + elif isinstance(message, SystemMessage): + role = system_prefix + elif isinstance(message, ViewMessage): + pass + else: + raise ValueError(f"Got unsupported message type: {message}") + if role: + str_messages.append(f"{role}: {message.content}") + return "\n".join(str_messages) + + def _message_from_dict(message: Dict) -> BaseMessage: _type = message["type"] if _type == "human": @@ -382,6 +429,9 @@ def _append_message(self, message: BaseMessage) -> None: self._message_index += 1 message.index = index message.round_index = self.chat_order + message.additional_kwargs["param_type"] = self.param_type + message.additional_kwargs["param_value"] = self.param_value + message.additional_kwargs["model_name"] = self.model_name self.messages.append(message) def start_new_round(self) -> None: @@ -504,9 +554,12 @@ def from_conversation(self, conversation: OnceConversation) -> None: self.messages = conversation.messages self.start_date = conversation.start_date self.chat_order = conversation.chat_order - self.model_name = conversation.model_name - self.param_type = conversation.param_type - self.param_value = conversation.param_value + if not self.model_name and conversation.model_name: + self.model_name = conversation.model_name + if not self.param_type and conversation.param_type: + self.param_type = conversation.param_type + if not self.param_value and conversation.param_value: + self.param_value = conversation.param_value self.cost = conversation.cost self.tokens = conversation.tokens self.user_name = conversation.user_name @@ -801,6 +854,7 @@ def __init__( save_message_independent: Optional[bool] = True, conv_storage: StorageInterface = None, message_storage: StorageInterface = None, + load_message: bool = True, **kwargs, ): super().__init__(chat_mode, user_name, sys_code, summary, **kwargs) @@ -811,6 +865,8 @@ def __init__( self._has_stored_message_index = ( len(kwargs["messages"]) - 1 if "messages" in kwargs else -1 ) + # Whether to load the message from the storage + self._load_message = load_message self.save_message_independent = save_message_independent self._id = ConversationIdentifier(conv_uid) if conv_storage is None: @@ -853,7 +909,9 @@ def save_to_storage(self) -> None: ] messages_to_save = message_list[self._has_stored_message_index + 1 :] self._has_stored_message_index = len(message_list) - 1 - self.message_storage.save_list(messages_to_save) + if self.save_message_independent: + # Save messages independently + self.message_storage.save_list(messages_to_save) # Save conversation self.conv_storage.save_or_update(self) @@ -876,23 +934,71 @@ def load_from_storage( return message_ids = conversation._message_ids or [] - # Load messages - message_list = message_storage.load_list( - [ - MessageIdentifier.from_str_identifier(message_id) - for message_id in message_ids - ], - MessageStorageItem, - ) - messages = [message.to_message() for message in message_list] - conversation.messages = messages + if self._load_message: + # Load messages + message_list = message_storage.load_list( + [ + MessageIdentifier.from_str_identifier(message_id) + for message_id in message_ids + ], + MessageStorageItem, + ) + messages = [message.to_message() for message in message_list] + else: + messages = [] + real_messages = messages or conversation.messages + conversation.messages = real_messages # This index is used to save the message to the storage(Has not been saved) # The new message append to the messages, so the index is len(messages) - conversation._message_index = len(messages) + conversation._message_index = len(real_messages) + conversation.chat_order = ( + max(m.round_index for m in real_messages) if real_messages else 0 + ) + self._append_additional_kwargs(conversation, real_messages) self._message_ids = message_ids - self._has_stored_message_index = len(messages) - 1 + self._has_stored_message_index = len(real_messages) - 1 + self.save_message_independent = conversation.save_message_independent self.from_conversation(conversation) + def _append_additional_kwargs( + self, conversation: StorageConversation, messages: List[BaseMessage] + ) -> None: + """Parse the additional kwargs and append to the conversation + + Args: + conversation (StorageConversation): The conversation + messages (List[BaseMessage]): The messages + """ + param_type = None + param_value = None + for message in messages[::-1]: + if message.additional_kwargs: + param_type = message.additional_kwargs.get("param_type") + param_value = message.additional_kwargs.get("param_value") + break + if not conversation.param_type: + conversation.param_type = param_type + if not conversation.param_value: + conversation.param_value = param_value + + def delete(self) -> None: + """Delete all the messages and conversation from the storage""" + # Delete messages first + message_list = self._get_message_items() + message_ids = [message.identifier for message in message_list] + self.message_storage.delete_list(message_ids) + # Delete conversation + self.conv_storage.delete(self.identifier) + # Overwrite the current conversation with empty conversation + self.from_conversation( + StorageConversation( + self.conv_uid, + save_message_independent=self.save_message_independent, + conv_storage=self.conv_storage, + message_storage=self.message_storage, + ) + ) + def _conversation_to_dict(once: OnceConversation) -> Dict: start_str: str = "" @@ -937,3 +1043,61 @@ def _conversation_from_dict(once: dict) -> OnceConversation: print(once.get("messages")) conversation.messages = _messages_from_dict(once.get("messages", [])) return conversation + + +def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessage]]: + """Split the messages by round index. + + Args: + messages (List[BaseMessage]): The messages. + + Returns: + List[List[BaseMessage]]: The messages split by round. + """ + messages_by_round: List[List[BaseMessage]] = [] + last_round_index = 0 + for message in messages: + if not message.round_index: + # Round index must bigger than 0 + raise ValueError("Message round_index is not set") + if message.round_index > last_round_index: + last_round_index = message.round_index + messages_by_round.append([]) + messages_by_round[-1].append(message) + return messages_by_round + + +def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]: + """Append the view message to the messages + Just for show in DB-GPT-Web. + If already have view message, do nothing. + + Args: + messages (List[BaseMessage]): The messages + + Returns: + List[BaseMessage]: The messages with view message + """ + messages_by_round = _split_messages_by_round(messages) + for current_round in messages_by_round: + ai_message = None + view_message = None + for message in current_round: + if message.type == "ai": + ai_message = message + elif message.type == "view": + view_message = message + if view_message: + # Already have view message, do nothing + continue + if ai_message: + view_message = ViewMessage( + content=ai_message.content, + index=ai_message.index, + round_index=ai_message.round_index, + additional_kwargs=ai_message.additional_kwargs.copy() + if ai_message.additional_kwargs + else {}, + ) + current_round.append(view_message) + return sum(messages_by_round, []) diff --git a/dbgpt/core/interface/operator/composer_operator.py b/dbgpt/core/interface/operator/composer_operator.py index 7c0093777..36eb97fbc 100644 --- a/dbgpt/core/interface/operator/composer_operator.py +++ b/dbgpt/core/interface/operator/composer_operator.py @@ -45,7 +45,8 @@ def __init__( self, prompt_template: ChatPromptTemplate, history_key: str = "chat_history", - last_k_round: int = 2, + keep_start_rounds: Optional[int] = None, + keep_end_rounds: Optional[int] = None, storage: Optional[StorageInterface[StorageConversation, Any]] = None, message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, **kwargs, @@ -53,7 +54,8 @@ def __init__( super().__init__(**kwargs) self._prompt_template = prompt_template self._history_key = history_key - self._last_k_round = last_k_round + self._keep_start_rounds = keep_start_rounds + self._keep_end_rounds = keep_end_rounds self._storage = storage self._message_storage = message_storage self._sub_compose_dag = self._build_composer_dag() @@ -74,7 +76,8 @@ def _build_composer_dag(self) -> DAG: ) # History transform task, here we keep last 5 round messages history_transform_task = BufferedConversationMapperOperator( - last_k_round=self._last_k_round + keep_start_rounds=self._keep_start_rounds, + keep_end_rounds=self._keep_end_rounds, ) history_prompt_build_task = HistoryPromptBuilderOperator( prompt=self._prompt_template, history_key=self._history_key diff --git a/dbgpt/core/interface/operator/message_operator.py b/dbgpt/core/interface/operator/message_operator.py index f6eb1b24b..fea21d0f9 100644 --- a/dbgpt/core/interface/operator/message_operator.py +++ b/dbgpt/core/interface/operator/message_operator.py @@ -1,8 +1,9 @@ import uuid from abc import ABC -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from dbgpt.core import ( + LLMClient, MessageStorageItem, ModelMessage, ModelMessageRoleType, @@ -11,7 +12,12 @@ StorageInterface, ) from dbgpt.core.awel import BaseOperator, MapOperator -from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper +from dbgpt.core.interface.message import ( + BaseMessage, + _messages_to_str, + _MultiRoundMessageMapper, + _split_messages_by_round, +) class BaseConversationOperator(BaseOperator, ABC): @@ -31,7 +37,6 @@ def __init__( **kwargs, ): self._check_storage = check_storage - super().__init__(**kwargs) self._storage = storage self._message_storage = message_storage @@ -167,12 +172,10 @@ def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs): self._message_mapper = message_mapper async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]: - return self.map_messages(input_value) + return await self.map_messages(input_value) - def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: - messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round( - messages - ) + async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages) message_mapper = self._message_mapper or self.map_multi_round_messages return message_mapper(messages_by_round) @@ -233,93 +236,66 @@ def map_multi_round_messages( Args: """ # Just merge and return - # e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6] - return sum(messages_by_round, []) - - def _split_messages_by_round( - self, messages: List[BaseMessage] - ) -> List[List[BaseMessage]]: - """Split the messages by round index. - - Args: - messages (List[BaseMessage]): The messages. - - Returns: - List[List[BaseMessage]]: The messages split by round. - """ - messages_by_round: List[List[BaseMessage]] = [] - last_round_index = 0 - for message in messages: - if not message.round_index: - # Round index must bigger than 0 - raise ValueError("Message round_index is not set") - if message.round_index > last_round_index: - last_round_index = message.round_index - messages_by_round.append([]) - messages_by_round[-1].append(message) - return messages_by_round + return _merge_multi_round_messages(messages_by_round) class BufferedConversationMapperOperator(ConversationMapperOperator): - """The buffered conversation mapper operator. + """ + The buffered conversation mapper operator which can be configured to keep + a certain number of starting and/or ending rounds of a conversation. - This Operator must be used after the PreChatHistoryLoadOperator, - and it will map the messages in the storage conversation. + Args: + keep_start_rounds (Optional[int]): Number of initial rounds to keep. + keep_end_rounds (Optional[int]): Number of final rounds to keep. Examples: - - Transform no history messages - - .. code-block:: python - - from dbgpt.core import ModelMessage - from dbgpt.core.operator import BufferedConversationMapperOperator - - # No history - messages = [ModelMessage(role="human", content="Hello", round_index=1)] - operator = BufferedConversationMapperOperator(last_k_round=1) - assert operator.map_messages(messages) == [ - ModelMessage(role="human", content="Hello", round_index=1) - ] - - Transform with history messages - - .. code-block:: python - - # With history - messages = [ - ModelMessage(role="human", content="Hi", round_index=1), - ModelMessage(role="ai", content="Hello!", round_index=1), - ModelMessage(role="system", content="Error 404", round_index=2), - ModelMessage(role="human", content="What's the error?", round_index=2), - ModelMessage(role="ai", content="Just a joke.", round_index=2), - ModelMessage(role="human", content="Funny!", round_index=3), - ] - operator = BufferedConversationMapperOperator(last_k_round=1) - # Just keep the last one round, so the first round messages will be removed - # Note: The round index 3 is not a complete round - assert operator.map_messages(messages) == [ - ModelMessage(role="system", content="Error 404", round_index=2), - ModelMessage(role="human", content="What's the error?", round_index=2), - ModelMessage(role="ai", content="Just a joke.", round_index=2), - ModelMessage(role="human", content="Funny!", round_index=3), - ] + # Keeping the first 2 and the last 1 rounds of a conversation + import asyncio + from dbgpt.core.interface.message import AIMessage, HumanMessage + from dbgpt.core.operator import BufferedConversationMapperOperator + + operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1) + messages = [ + # Assume each HumanMessage and AIMessage belongs to separate rounds + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] + # This will keep rounds 1, 2, and 3 + assert asyncio.run(operator.map_messages(messages)) == [ + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] """ def __init__( self, - last_k_round: Optional[int] = 2, + keep_start_rounds: Optional[int] = None, + keep_end_rounds: Optional[int] = None, message_mapper: _MultiRoundMessageMapper = None, **kwargs, ): - self._last_k_round = last_k_round + # Validate the input parameters + if keep_start_rounds is not None and keep_start_rounds < 0: + raise ValueError("keep_start_rounds must be non-negative") + if keep_end_rounds is not None and keep_end_rounds < 0: + raise ValueError("keep_end_rounds must be non-negative") + + self._keep_start_rounds = keep_start_rounds + self._keep_end_rounds = keep_end_rounds if message_mapper: def new_message_mapper( messages_by_round: List[List[BaseMessage]], ) -> List[BaseMessage]: - # Apply keep k round messages first, then apply the custom message mapper - messages_by_round = self._keep_last_round_messages(messages_by_round) + messages_by_round = self._filter_round_messages(messages_by_round) return message_mapper(messages_by_round) else: @@ -327,21 +303,189 @@ def new_message_mapper( def new_message_mapper( messages_by_round: List[List[BaseMessage]], ) -> List[BaseMessage]: - messages_by_round = self._keep_last_round_messages(messages_by_round) - return sum(messages_by_round, []) + messages_by_round = self._filter_round_messages(messages_by_round) + return _merge_multi_round_messages(messages_by_round) super().__init__(new_message_mapper, **kwargs) - def _keep_last_round_messages( + def _filter_round_messages( self, messages_by_round: List[List[BaseMessage]] ) -> List[List[BaseMessage]]: - """Keep the last k round messages. + """Filters the messages to keep only the specified starting and/or ending rounds. + + Examples: + + >>> from dbgpt.core import AIMessage, HumanMessage + >>> from dbgpt.core.operator import BufferedConversationMapperOperator + >>> messages = [ + ... [ + ... HumanMessage(content="Hi", round_index=1), + ... AIMessage(content="Hello!", round_index=1), + ... ], + ... [ + ... HumanMessage(content="How are you?", round_index=2), + ... AIMessage(content="I'm good, thanks!", round_index=2), + ... ], + ... [ + ... HumanMessage(content="What's new today?", round_index=3), + ... AIMessage(content="Lots of things!", round_index=3), + ... ], + ... ] + + # Test keeping only the first 2 rounds + >>> operator = BufferedConversationMapperOperator(keep_start_rounds=2) + >>> assert operator._filter_round_messages(messages) == [ + ... [ + ... HumanMessage(content="Hi", round_index=1), + ... AIMessage(content="Hello!", round_index=1), + ... ], + ... [ + ... HumanMessage(content="How are you?", round_index=2), + ... AIMessage(content="I'm good, thanks!", round_index=2), + ... ], + ... ] + + # Test keeping only the last 2 rounds + >>> operator = BufferedConversationMapperOperator(keep_end_rounds=2) + >>> assert operator._filter_round_messages(messages) == [ + ... [ + ... HumanMessage(content="How are you?", round_index=2), + ... AIMessage(content="I'm good, thanks!", round_index=2), + ... ], + ... [ + ... HumanMessage(content="What's new today?", round_index=3), + ... AIMessage(content="Lots of things!", round_index=3), + ... ], + ... ] + + # Test keeping the first 2 and last 1 rounds + >>> operator = BufferedConversationMapperOperator( + ... keep_start_rounds=2, keep_end_rounds=1 + ... ) + >>> assert operator._filter_round_messages(messages) == [ + ... [ + ... HumanMessage(content="Hi", round_index=1), + ... AIMessage(content="Hello!", round_index=1), + ... ], + ... [ + ... HumanMessage(content="How are you?", round_index=2), + ... AIMessage(content="I'm good, thanks!", round_index=2), + ... ], + ... [ + ... HumanMessage(content="What's new today?", round_index=3), + ... AIMessage(content="Lots of things!", round_index=3), + ... ], + ... ] + + # Test without specifying start or end rounds (keep all rounds) + >>> operator = BufferedConversationMapperOperator() + >>> assert operator._filter_round_messages(messages) == [ + ... [ + ... HumanMessage(content="Hi", round_index=1), + ... AIMessage(content="Hello!", round_index=1), + ... ], + ... [ + ... HumanMessage(content="How are you?", round_index=2), + ... AIMessage(content="I'm good, thanks!", round_index=2), + ... ], + ... [ + ... HumanMessage(content="What's new today?", round_index=3), + ... AIMessage(content="Lots of things!", round_index=3), + ... ], + ... ] + + Args: + messages_by_round (List[List[BaseMessage]]): The messages grouped by round. + + Returns: + List[List[BaseMessage]]: Filtered list of messages. + """ + total_rounds = len(messages_by_round) + if self._keep_start_rounds is not None and self._keep_end_rounds is not None: + if self._keep_start_rounds + self._keep_end_rounds > total_rounds: + # Avoid overlapping when the sum of start and end rounds exceeds total rounds + return messages_by_round + return ( + messages_by_round[: self._keep_start_rounds] + + messages_by_round[-self._keep_end_rounds :] + ) + elif self._keep_start_rounds is not None: + return messages_by_round[: self._keep_start_rounds] + elif self._keep_end_rounds is not None: + return messages_by_round[-self._keep_end_rounds :] + else: + return messages_by_round + + +EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]] + + +class TokenBufferedConversationMapperOperator(ConversationMapperOperator): + """The token buffered conversation mapper operator. + + If the token count of the messages is greater than the max token limit, we will evict the messages by round. + + Args: + model (str): The model name. + llm_client (LLMClient): The LLM client. + max_token_limit (int): The max token limit. + eviction_policy (EvictionPolicyType): The eviction policy. + message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled. + """ + + def __init__( + self, + model: str, + llm_client: LLMClient, + max_token_limit: int = 2000, + eviction_policy: EvictionPolicyType = None, + message_mapper: _MultiRoundMessageMapper = None, + **kwargs, + ): + if max_token_limit < 0: + raise ValueError("Max token limit can't be negative") + self._model = model + self._llm_client = llm_client + self._max_token_limit = max_token_limit + self._eviction_policy = eviction_policy + self._message_mapper = message_mapper + super().__init__(**kwargs) + + async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + eviction_policy = self._eviction_policy or self.eviction_policy + messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages) + messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round)) + # Fist time, we count the token of the messages + current_tokens = await self._llm_client.count_token(self._model, messages_str) + + while current_tokens > self._max_token_limit: + # Evict the messages by round after all tokens are not greater than the max token limit + # TODO: We should find a high performance way to do this + messages_by_round = eviction_policy(messages_by_round) + messages_str = _messages_to_str( + _merge_multi_round_messages(messages_by_round) + ) + current_tokens = await self._llm_client.count_token( + self._model, messages_str + ) + message_mapper = self._message_mapper or self.map_multi_round_messages + return message_mapper(messages_by_round) + + def eviction_policy( + self, messages_by_round: List[List[BaseMessage]] + ) -> List[List[BaseMessage]]: + """Evict the messages by round, default is FIFO. Args: messages_by_round (List[List[BaseMessage]]): The messages by round. Returns: - List[List[BaseMessage]]: The latest round messages. + List[List[BaseMessage]]: The evicted messages by round. """ - index = self._last_k_round + 1 - return messages_by_round[-index:] + messages_by_round.pop(0) + return messages_by_round + + +def _merge_multi_round_messages(messages: List[List[BaseMessage]]) -> List[BaseMessage]: + # e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6] + return sum(messages, []) diff --git a/dbgpt/core/interface/operator/prompt_operator.py b/dbgpt/core/interface/operator/prompt_operator.py index 18c727d14..7cdde4349 100644 --- a/dbgpt/core/interface/operator/prompt_operator.py +++ b/dbgpt/core/interface/operator/prompt_operator.py @@ -216,18 +216,27 @@ class HistoryPromptBuilderOperator( BasePromptBuilderOperator, JoinOperator[List[ModelMessage]] ): def __init__( - self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs + self, + prompt: ChatPromptTemplate, + history_key: Optional[str] = None, + check_storage: bool = True, + str_history: bool = False, + **kwargs, ): self._prompt = prompt self._history_key = history_key - + self._str_history = str_history + BasePromptBuilderOperator.__init__(self, check_storage=check_storage) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) @rearrange_args_by_type async def merge_history( self, history: List[BaseMessage], prompt_dict: Dict[str, Any] ) -> List[ModelMessage]: - prompt_dict[self._history_key] = history + if self._str_history: + prompt_dict[self._history_key] = BaseMessage.messages_to_string(history) + else: + prompt_dict[self._history_key] = history return await self.format_prompt(self._prompt, prompt_dict) @@ -239,9 +248,16 @@ class HistoryDynamicPromptBuilderOperator( The prompt template is dynamic, and it created by parent operator. """ - def __init__(self, history_key: Optional[str] = None, **kwargs): + def __init__( + self, + history_key: Optional[str] = None, + check_storage: bool = True, + str_history: bool = False, + **kwargs, + ): self._history_key = history_key - + self._str_history = str_history + BasePromptBuilderOperator.__init__(self, check_storage=check_storage) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) @rearrange_args_by_type @@ -251,5 +267,8 @@ async def merge_history( history: List[BaseMessage], prompt_dict: Dict[str, Any], ) -> List[ModelMessage]: - prompt_dict[self._history_key] = history + if self._str_history: + prompt_dict[self._history_key] = BaseMessage.messages_to_string(history) + else: + prompt_dict[self._history_key] = history return await self.format_prompt(prompt, prompt_dict) diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index 21fda87f6..1788357c4 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -49,13 +49,25 @@ class BasePromptTemplate(BaseModel): """The prompt template.""" template_format: Optional[str] = "f-string" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + + response_format: Optional[str] = None + + response_key: Optional[str] = "response" + + template_is_strict: Optional[bool] = True + """strict template will check template args""" def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs.""" if self.template: - return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)( - self.template, **kwargs - ) + if self.response_format: + kwargs[self.response_key] = json.dumps( + self.response_format, ensure_ascii=False, indent=4 + ) + return _DEFAULT_FORMATTER_MAPPING[self.template_format]( + self.template_is_strict + )(self.template, **kwargs) @classmethod def from_template( @@ -75,10 +87,6 @@ class PromptTemplate(BasePromptTemplate): template_scene: Optional[str] template_define: Optional[str] """this template define""" - """strict template will check template args""" - template_is_strict: bool = True - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - response_format: Optional[str] """default use stream out""" stream_out: bool = True """""" @@ -103,17 +111,6 @@ def _prompt_type(self) -> str: """Return the prompt type key.""" return "prompt" - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs.""" - if self.template: - if self.response_format: - kwargs["response"] = json.dumps( - self.response_format, ensure_ascii=False, indent=4 - ) - return _DEFAULT_FORMATTER_MAPPING[self.template_format]( - self.template_is_strict - )(self.template, **kwargs) - class BaseChatPromptTemplate(BaseModel, ABC): prompt: BasePromptTemplate @@ -129,10 +126,22 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: @classmethod def from_template( - cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any + cls, + template: str, + template_format: Optional[str] = "f-string", + response_format: Optional[str] = None, + response_key: Optional[str] = "response", + template_is_strict: bool = True, + **kwargs: Any, ) -> BaseChatPromptTemplate: """Create a prompt template from a template string.""" - prompt = BasePromptTemplate.from_template(template, template_format) + prompt = BasePromptTemplate.from_template( + template, + template_format, + response_format=response_format, + response_key=response_key, + template_is_strict=template_is_strict, + ) return cls(prompt=prompt, **kwargs) diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index d05258e9b..6a486cab7 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -284,6 +284,15 @@ def delete(self, resource_id: ResourceIdentifier) -> None: resource_id (ResourceIdentifier): The resource identifier of the data """ + def delete_list(self, resource_id: List[ResourceIdentifier]) -> None: + """Delete the data from the storage. + + Args: + resource_id (ResourceIdentifier): The resource identifier of the data + """ + for r in resource_id: + self.delete(r) + @abstractmethod def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: """Query data from the storage. diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py index 7221dadb2..72023c36f 100755 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -138,14 +138,14 @@ def test_clear_messages(basic_conversation, human_message): def test_get_latest_user_message(basic_conversation, human_message): basic_conversation.add_user_message(human_message.content) latest_message = basic_conversation.get_latest_user_message() - assert latest_message == human_message + assert latest_message.content == human_message.content def test_get_system_messages(basic_conversation, system_message): basic_conversation.add_system_message(system_message.content) system_messages = basic_conversation.get_system_messages() assert len(system_messages) == 1 - assert system_messages[0] == system_message + assert system_messages[0].content == system_message.content def test_from_conversation(basic_conversation): @@ -324,6 +324,35 @@ def test_load_from_storage(storage_conversation, in_memory_storage): assert isinstance(new_conversation.messages[1], AIMessage) +def test_delete(storage_conversation, in_memory_storage): + # Set storage + storage_conversation.conv_storage = in_memory_storage + storage_conversation.message_storage = in_memory_storage + + # Add messages and save to storage + storage_conversation.start_new_round() + storage_conversation.add_user_message("User message") + storage_conversation.add_ai_message("AI response") + storage_conversation.end_current_round() + + # Create a new StorageConversation instance to load the data + new_conversation = StorageConversation( + "conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage + ) + + # Delete the conversation + new_conversation.delete() + + # Check if the conversation is deleted + assert new_conversation.conv_uid == storage_conversation.conv_uid + assert len(new_conversation.messages) == 0 + + no_messages_conv = StorageConversation( + "conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage + ) + assert len(no_messages_conv.messages) == 0 + + def test_parse_model_messages_no_history_messages(): messages = [ ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"), diff --git a/dbgpt/core/operator/__init__.py b/dbgpt/core/operator/__init__.py index 952b89143..ffc60d2c7 100644 --- a/dbgpt/core/operator/__init__.py +++ b/dbgpt/core/operator/__init__.py @@ -14,6 +14,7 @@ BufferedConversationMapperOperator, ConversationMapperOperator, PreChatHistoryLoadOperator, + TokenBufferedConversationMapperOperator, ) from dbgpt.core.interface.operator.prompt_operator import ( DynamicPromptBuilderOperator, @@ -30,6 +31,7 @@ "BaseStreamingLLMOperator", "BaseConversationOperator", "BufferedConversationMapperOperator", + "TokenBufferedConversationMapperOperator", "ConversationMapperOperator", "PreChatHistoryLoadOperator", "PromptBuilderOperator", diff --git a/dbgpt/datasource/__init__.py b/dbgpt/datasource/__init__.py index 8cc9799db..5c619dc41 100644 --- a/dbgpt/datasource/__init__.py +++ b/dbgpt/datasource/__init__.py @@ -1 +1 @@ -from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao +from .manages.connect_config_db import ConnectConfigDao, ConnectConfigEntity diff --git a/dbgpt/datasource/base.py b/dbgpt/datasource/base.py index a5bf666b4..d95e173a0 100644 --- a/dbgpt/datasource/base.py +++ b/dbgpt/datasource/base.py @@ -3,7 +3,7 @@ """We need to design a base class. That other connector can Write with this""" from abc import ABC -from typing import Iterable, List, Optional, Any, Dict +from typing import Any, Dict, Iterable, List, Optional class BaseConnect(ABC): diff --git a/dbgpt/datasource/conn_spark.py b/dbgpt/datasource/conn_spark.py index fa7d9555e..cc8107f4b 100644 --- a/dbgpt/datasource/conn_spark.py +++ b/dbgpt/datasource/conn_spark.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Any, Optional from dbgpt.datasource.base import BaseConnect diff --git a/dbgpt/datasource/manages/connect_config_db.py b/dbgpt/datasource/manages/connect_config_db.py index 13f265211..7c2cc2ff6 100644 --- a/dbgpt/datasource/manages/connect_config_db.py +++ b/dbgpt/datasource/manages/connect_config_db.py @@ -1,5 +1,4 @@ -from sqlalchemy import Column, Integer, String, Index, Text, text -from sqlalchemy import UniqueConstraint +from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text from dbgpt.storage.metadata import BaseDao, Model diff --git a/dbgpt/datasource/manages/connect_storage_duckdb.py b/dbgpt/datasource/manages/connect_storage_duckdb.py index 59be27dc7..a08d33094 100644 --- a/dbgpt/datasource/manages/connect_storage_duckdb.py +++ b/dbgpt/datasource/manages/connect_storage_duckdb.py @@ -1,5 +1,7 @@ import os + import duckdb + from dbgpt.configs.model_config import PILOT_PATH default_db_path = os.path.join(PILOT_PATH, "message") diff --git a/dbgpt/datasource/manages/connection_manager.py b/dbgpt/datasource/manages/connection_manager.py index 5594b5207..01c4c5ec6 100644 --- a/dbgpt/datasource/manages/connection_manager.py +++ b/dbgpt/datasource/manages/connection_manager.py @@ -1,22 +1,22 @@ from typing import List, Type -from dbgpt.datasource import ConnectConfigDao -from dbgpt.storage.schema import DBType -from dbgpt.component import SystemApp, ComponentType -from dbgpt.util.executor_utils import ExecutorFactory -from dbgpt.datasource.db_conn_info import DBConfig -from dbgpt.rag.summary.db_summary_client import DBSummaryClient +from dbgpt.component import ComponentType, SystemApp +from dbgpt.datasource import ConnectConfigDao from dbgpt.datasource.base import BaseConnect -from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect -from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect -from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnect +from dbgpt.datasource.conn_spark import SparkConnect +from dbgpt.datasource.db_conn_info import DBConfig from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect +from dbgpt.datasource.rdbms.conn_doris import DorisConnect +from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect +from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnect +from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect from dbgpt.datasource.rdbms.conn_postgresql import PostgreSQLDatabase +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnect -from dbgpt.datasource.conn_spark import SparkConnect -from dbgpt.datasource.rdbms.conn_doris import DorisConnect +from dbgpt.rag.summary.db_summary_client import DBSummaryClient +from dbgpt.storage.schema import DBType +from dbgpt.util.executor_utils import ExecutorFactory class ConnectManager: diff --git a/dbgpt/datasource/operator/datasource_operator.py b/dbgpt/datasource/operator/datasource_operator.py index 752465d18..ef0a03e65 100644 --- a/dbgpt/datasource/operator/datasource_operator.py +++ b/dbgpt/datasource/operator/datasource_operator.py @@ -1,4 +1,5 @@ from typing import Any + from dbgpt.core.awel import MapOperator from dbgpt.core.awel.task.base import IN, OUT from dbgpt.datasource.base import BaseConnect diff --git a/dbgpt/datasource/rdbms/_base_dao.py b/dbgpt/datasource/rdbms/_base_dao.py index ec19c9634..9a4221bec 100644 --- a/dbgpt/datasource/rdbms/_base_dao.py +++ b/dbgpt/datasource/rdbms/_base_dao.py @@ -1,9 +1,11 @@ import logging + from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker + from dbgpt._private.config import Config -from dbgpt.storage.schema import DBType from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.storage.schema import DBType logger = logging.getLogger(__name__) CFG = Config() diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py index d38330988..5eb415e62 100644 --- a/dbgpt/datasource/rdbms/base.py +++ b/dbgpt/datasource/rdbms/base.py @@ -1,26 +1,21 @@ from __future__ import annotations -import sqlparse -import regex as re + +from typing import Any, Dict, Iterable, List, Optional from urllib.parse import quote from urllib.parse import quote_plus as urlquote -from typing import Any, Iterable, List, Optional, Dict + +import regex as re import sqlalchemy -from sqlalchemy import ( - MetaData, - Table, - create_engine, - inspect, - select, - text, -) +import sqlparse +from sqlalchemy import MetaData, Table, create_engine, inspect, select, text from sqlalchemy.engine import CursorResult from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.schema import CreateTable -from sqlalchemy.orm import sessionmaker, scoped_session -from dbgpt.storage.schema import DBType -from dbgpt.datasource.base import BaseConnect from dbgpt._private.config import Config +from dbgpt.datasource.base import BaseConnect +from dbgpt.storage.schema import DBType CFG = Config() diff --git a/dbgpt/datasource/rdbms/conn_clickhouse.py b/dbgpt/datasource/rdbms/conn_clickhouse.py index 27719e73c..06277050e 100644 --- a/dbgpt/datasource/rdbms/conn_clickhouse.py +++ b/dbgpt/datasource/rdbms/conn_clickhouse.py @@ -1,12 +1,11 @@ import re +from typing import Any, Dict, Iterable, List, Optional + import sqlparse -from typing import List, Optional, Any, Iterable, Dict -from sqlalchemy import text +from sqlalchemy import MetaData, text + from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.storage.schema import DBType -from sqlalchemy import ( - MetaData, -) class ClickhouseConnect(RDBMSDatabase): diff --git a/dbgpt/datasource/rdbms/conn_doris.py b/dbgpt/datasource/rdbms/conn_doris.py index 1e4d9d8aa..10b290af0 100644 --- a/dbgpt/datasource/rdbms/conn_doris.py +++ b/dbgpt/datasource/rdbms/conn_doris.py @@ -1,7 +1,9 @@ -from typing import Iterable, Optional, Any -from sqlalchemy import text +from typing import Any, Iterable, Optional from urllib.parse import quote from urllib.parse import quote_plus as urlquote + +from sqlalchemy import text + from dbgpt.datasource.rdbms.base import RDBMSDatabase diff --git a/dbgpt/datasource/rdbms/conn_duckdb.py b/dbgpt/datasource/rdbms/conn_duckdb.py index e9893d181..f2f53a2a8 100644 --- a/dbgpt/datasource/rdbms/conn_duckdb.py +++ b/dbgpt/datasource/rdbms/conn_duckdb.py @@ -1,8 +1,6 @@ -from typing import Optional, Any, Iterable -from sqlalchemy import ( - create_engine, - text, -) +from typing import Any, Iterable, Optional + +from sqlalchemy import create_engine, text from dbgpt.datasource.rdbms.base import RDBMSDatabase diff --git a/dbgpt/datasource/rdbms/conn_mssql.py b/dbgpt/datasource/rdbms/conn_mssql.py index 12e1d79fa..961ff20d1 100644 --- a/dbgpt/datasource/rdbms/conn_mssql.py +++ b/dbgpt/datasource/rdbms/conn_mssql.py @@ -1,15 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Optional, Any, Iterable +from typing import Any, Iterable, Optional + +from sqlalchemy import MetaData, Table, create_engine, inspect, select, text -from sqlalchemy import ( - MetaData, - Table, - create_engine, - inspect, - select, - text, -) from dbgpt.datasource.rdbms.base import RDBMSDatabase diff --git a/dbgpt/datasource/rdbms/conn_postgresql.py b/dbgpt/datasource/rdbms/conn_postgresql.py index ad169118d..7470d5d90 100644 --- a/dbgpt/datasource/rdbms/conn_postgresql.py +++ b/dbgpt/datasource/rdbms/conn_postgresql.py @@ -1,7 +1,9 @@ -from typing import Iterable, Optional, Any -from sqlalchemy import text +from typing import Any, Iterable, Optional from urllib.parse import quote from urllib.parse import quote_plus as urlquote + +from sqlalchemy import text + from dbgpt.datasource.rdbms.base import RDBMSDatabase diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index bed71bfd6..9eb9f735f 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import os -from typing import Optional, Any, Iterable -from sqlalchemy import create_engine, text import tempfile -import logging +from typing import Any, Iterable, Optional + +from sqlalchemy import create_engine, text + from dbgpt.datasource.rdbms.base import RDBMSDatabase logger = logging.getLogger(__name__) @@ -160,6 +162,7 @@ def create_temporary_db( Examples: .. code-block:: python + with SQLiteTempConnect.create_temporary_db() as db: db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run(db.session, "insert into test(id) values (1)") @@ -201,6 +204,7 @@ def create_temp_tables(self, tables_info): Examples: .. code-block:: python + tables_info = { "test": { "columns": { diff --git a/dbgpt/datasource/rdbms/conn_starrocks.py b/dbgpt/datasource/rdbms/conn_starrocks.py index 380d8eca8..0c79b5d2a 100644 --- a/dbgpt/datasource/rdbms/conn_starrocks.py +++ b/dbgpt/datasource/rdbms/conn_starrocks.py @@ -1,7 +1,9 @@ -from typing import Iterable, Optional, Any -from sqlalchemy import text +from typing import Any, Iterable, Optional from urllib.parse import quote from urllib.parse import quote_plus as urlquote + +from sqlalchemy import text + from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import * diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py index e8542f940..4f479e79b 100644 --- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py +++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py @@ -1,8 +1,8 @@ import logging import re -from typing import Optional, List, Any, Type, Dict +from typing import Any, Dict, List, Optional, Type -from sqlalchemy import Numeric, Integer, Float +from sqlalchemy import Float, Integer, Numeric from sqlalchemy.sql import sqltypes from sqlalchemy.sql.type_api import TypeEngine diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py index 049db13a5..d563b9603 100644 --- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py +++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py @@ -15,7 +15,7 @@ import logging from typing import Any, Dict, List -from sqlalchemy import log, exc, text +from sqlalchemy import exc, log, text from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql from sqlalchemy.engine import Connection diff --git a/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py b/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py index 015296b8b..1b08f5506 100644 --- a/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py +++ b/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py @@ -2,9 +2,10 @@ Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_duckdb.py """ -import pytest import tempfile +import pytest + from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect diff --git a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py index 267ee1575..f741158e4 100644 --- a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py +++ b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py @@ -1,9 +1,11 @@ """ Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_sqlite.py """ -import pytest -import tempfile import os +import tempfile + +import pytest + from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py index e3874393c..793f63c8a 100644 --- a/dbgpt/model/adapter/base.py +++ b/dbgpt/model/adapter/base.py @@ -1,19 +1,20 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Any, Tuple, Type, Callable import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.model.adapter.template import ( + ConversationAdapter, + ConversationAdapterFactory, + get_conv_template, +) from dbgpt.model.base import ModelType from dbgpt.model.parameter import ( BaseModelParameters, - ModelParameters, LlamaCppModelParameters, + ModelParameters, ProxyModelParameters, ) -from dbgpt.model.adapter.template import ( - get_conv_template, - ConversationAdapter, - ConversationAdapterFactory, -) logger = logging.getLogger(__name__) diff --git a/dbgpt/model/adapter/fschat_adapter.py b/dbgpt/model/adapter/fschat_adapter.py index ca5413645..956cccf66 100644 --- a/dbgpt/model/adapter/fschat_adapter.py +++ b/dbgpt/model/adapter/fschat_adapter.py @@ -2,17 +2,17 @@ You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it. """ +import logging import os import threading -import logging from functools import cache -from typing import TYPE_CHECKING, Callable, Tuple, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple try: from fastchat.conversation import ( Conversation, - register_conv_template, SeparatorStyle, + register_conv_template, ) except ImportError as exc: raise ValueError( @@ -20,8 +20,8 @@ "Please install fastchat by command `pip install fschat` " ) from exc -from dbgpt.model.adapter.template import ConversationAdapter, PromptType from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.template import ConversationAdapter, PromptType if TYPE_CHECKING: from fastchat.model.model_adapter import BaseModelAdapter diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py index 204133b68..7e3faabcb 100644 --- a/dbgpt/model/adapter/hf_adapter.py +++ b/dbgpt/model/adapter/hf_adapter.py @@ -1,10 +1,10 @@ -from abc import ABC, abstractmethod -from typing import Dict, Optional, List, Any import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional from dbgpt.core import ModelMessage -from dbgpt.model.base import ModelType from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter +from dbgpt.model.base import ModelType logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None): def load(self, model_path: str, from_pretrained_kwargs: dict): try: import transformers - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel + from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer except ImportError as exc: raise ValueError( "Could not import depend python package " diff --git a/dbgpt/model/adapter/model_adapter.py b/dbgpt/model/adapter/model_adapter.py index efaf9f1b9..b4d06a432 100644 --- a/dbgpt/model/adapter/model_adapter.py +++ b/dbgpt/model/adapter/model_adapter.py @@ -1,21 +1,15 @@ from __future__ import annotations -from typing import ( - List, - Type, - Optional, -) import logging -import threading import os +import threading from functools import cache +from typing import List, Optional, Type + +from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter +from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory from dbgpt.model.base import ModelType from dbgpt.model.parameter import BaseModelParameters -from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter -from dbgpt.model.adapter.template import ( - ConversationAdapter, - ConversationAdapterFactory, -) logger = logging.getLogger(__name__) @@ -64,9 +58,9 @@ def get_llm_model_adapter( if use_fastchat and not must_use_old: logger.info("Use fastcat adapter") from dbgpt.model.adapter.fschat_adapter import ( - _get_fastchat_model_adapter, - _fastchat_get_adapter_monkey_patch, FastChatLLMModelAdapterWrapper, + _fastchat_get_adapter_monkey_patch, + _get_fastchat_model_adapter, ) adapter = _get_fastchat_model_adapter( @@ -79,11 +73,11 @@ def get_llm_model_adapter( result_adapter = FastChatLLMModelAdapterWrapper(adapter) else: + from dbgpt.app.chat_adapter import get_llm_chat_adapter + from dbgpt.model.adapter.old_adapter import OldLLMModelAdapterWrapper from dbgpt.model.adapter.old_adapter import ( get_llm_model_adapter as _old_get_llm_model_adapter, - OldLLMModelAdapterWrapper, ) - from dbgpt.app.chat_adapter import get_llm_chat_adapter logger.info("Use DB-GPT old adapter") result_adapter = OldLLMModelAdapterWrapper( @@ -139,12 +133,12 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]: Returns: Optional[List[Type[BaseModelParameters]]]: The model parameters class list. """ - from dbgpt.util.parameter_utils import _SimpleArgParser from dbgpt.model.parameter import ( + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, EmbeddingModelParameters, WorkerType, - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, ) + from dbgpt.util.parameter_utils import _SimpleArgParser pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type") pre_args.parse() diff --git a/dbgpt/model/adapter/old_adapter.py b/dbgpt/model/adapter/old_adapter.py index a63054695..e9ed01af4 100644 --- a/dbgpt/model/adapter/old_adapter.py +++ b/dbgpt/model/adapter/old_adapter.py @@ -5,31 +5,26 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import os import re -import logging -from pathlib import Path -from typing import List, Tuple, TYPE_CHECKING, Optional from functools import cache -from transformers import ( - AutoModel, - AutoModelForCausalLM, - AutoTokenizer, - LlamaTokenizer, -) +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Tuple + +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer +from dbgpt._private.config import Config +from dbgpt.configs.model_config import get_device from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.template import ConversationAdapter, PromptType from dbgpt.model.base import ModelType - +from dbgpt.model.conversation import Conversation from dbgpt.model.parameter import ( - ModelParameters, LlamaCppModelParameters, + ModelParameters, ProxyModelParameters, ) -from dbgpt.model.conversation import Conversation -from dbgpt.configs.model_config import get_device -from dbgpt._private.config import Config if TYPE_CHECKING: from dbgpt.app.chat_adapter import BaseChatAdpter diff --git a/dbgpt/model/adapter/template.py b/dbgpt/model/adapter/template.py index 3fb9a6ec1..421e98a4e 100644 --- a/dbgpt/model/adapter/template.py +++ b/dbgpt/model/adapter/template.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Optional, Tuple, Union, List +from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: from fastchat.conversation import Conversation @@ -124,6 +124,7 @@ def get_conv_template(name: str) -> ConversationAdapter: Conversation: The conversation template. """ from fastchat.conversation import get_conv_template + from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter conv_template = get_conv_template(name) diff --git a/dbgpt/model/adapter/vllm_adapter.py b/dbgpt/model/adapter/vllm_adapter.py index 2ffe0c764..268ce82c6 100644 --- a/dbgpt/model/adapter/vllm_adapter.py +++ b/dbgpt/model/adapter/vllm_adapter.py @@ -1,12 +1,13 @@ import dataclasses import logging -from dbgpt.model.base import ModelType + from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory +from dbgpt.model.base import ModelType from dbgpt.model.parameter import BaseModelParameters from dbgpt.util.parameter_utils import ( - _extract_parameter_details, _build_parameter_class, + _extract_parameter_details, _get_dataclass_print_str, ) @@ -27,6 +28,7 @@ def model_type(self) -> str: def model_param_class(self, model_type: str = None) -> BaseModelParameters: import argparse + from vllm.engine.arg_utils import AsyncEngineArgs parser = argparse.ArgumentParser() @@ -56,9 +58,9 @@ def model_param_class(self, model_type: str = None) -> BaseModelParameters: return _build_parameter_class(descs) def load_from_params(self, params): + import torch from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs - import torch num_gpus = torch.cuda.device_count() if num_gpus > 1 and hasattr(params, "tensor_parallel_size"): diff --git a/dbgpt/model/base.py b/dbgpt/model/base.py index a662c8001..1bd5a3734 100644 --- a/dbgpt/model/base.py +++ b/dbgpt/model/base.py @@ -1,14 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from enum import Enum -from typing import TypedDict, Optional, Dict, List, Any - -from dataclasses import dataclass, asdict import time +from dataclasses import asdict, dataclass from datetime import datetime -from dbgpt.util.parameter_utils import ParameterDescription +from enum import Enum +from typing import Any, Dict, List, Optional, TypedDict + from dbgpt.util.model_utils import GPUInfo +from dbgpt.util.parameter_utils import ParameterDescription class ModelType: diff --git a/dbgpt/model/cli.py b/dbgpt/model/cli.py index 67e12d7eb..24a367fb1 100644 --- a/dbgpt/model/cli.py +++ b/dbgpt/model/cli.py @@ -1,30 +1,30 @@ -import click import functools import logging import os -from typing import Callable, List, Type, Optional +from typing import Callable, List, Optional, Type + +import click from dbgpt.configs.model_config import LOGDIR from dbgpt.model.base import WorkerApplyType from dbgpt.model.parameter import ( - ModelControllerParameters, + BaseParameters, ModelAPIServerParameters, - ModelWorkerParameters, + ModelControllerParameters, ModelParameters, - BaseParameters, + ModelWorkerParameters, ) from dbgpt.util import get_or_create_event_loop +from dbgpt.util.command_utils import ( + _detect_controller_address, + _run_current_with_daemon, + _stop_service, +) from dbgpt.util.parameter_utils import ( EnvArgumentParser, _build_parameter_class, build_lazy_click_command, ) -from dbgpt.util.command_utils import ( - _run_current_with_daemon, - _stop_service, - _detect_controller_address, -) - MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000" @@ -32,7 +32,7 @@ def _get_worker_manager(address: str): - from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient + from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager registry = ModelRegistryClient(address) worker_manager = RemoteWorkerManager(registry) @@ -70,6 +70,7 @@ def model_cli_group(address: str): def list(model_name: str, model_type: str): """List model instances""" from prettytable import PrettyTable + from dbgpt.model.cluster import ModelRegistryClient loop = get_or_create_event_loop() @@ -152,7 +153,7 @@ def wrapper(*args, **kwargs): ) def stop(model_name: str, model_type: str, host: str, port: int): """Stop model instances""" - from dbgpt.model.cluster import WorkerStartupRequest, RemoteWorkerManager + from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS) req = WorkerStartupRequest( @@ -168,10 +169,11 @@ def stop(model_name: str, model_type: str, host: str, port: int): def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]: - from dbgpt.util.parameter_utils import _SimpleArgParser + from dataclasses import dataclass, field + from dbgpt.model.cluster import RemoteWorkerManager from dbgpt.model.parameter import WorkerType - from dataclasses import dataclass, field + from dbgpt.util.parameter_utils import _SimpleArgParser pre_args = _SimpleArgParser("model_name", "address", "host", "port") pre_args.parse() @@ -270,7 +272,7 @@ class RemoteModelWorkerParameters(BaseParameters): ) def start(**kwargs): """Start model instances""" - from dbgpt.model.cluster import WorkerStartupRequest, RemoteWorkerManager + from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS) req = WorkerStartupRequest( @@ -339,8 +341,8 @@ def _cli_chat(address: str, model_name: str, system_prompt: str = None): async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None): - from dbgpt.model.cluster import PromptRequest from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType + from dbgpt.model.cluster import PromptRequest print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.") hist = [] diff --git a/dbgpt/model/cluster/__init__.py b/dbgpt/model/cluster/__init__.py index 8249ebd6f..aed894a6f 100644 --- a/dbgpt/model/cluster/__init__.py +++ b/dbgpt/model/cluster/__init__.py @@ -1,3 +1,4 @@ +from dbgpt.model.cluster.apiserver.api import run_apiserver from dbgpt.model.cluster.base import ( EmbeddingsRequest, PromptRequest, @@ -5,25 +6,21 @@ WorkerParameterRequest, WorkerStartupRequest, ) +from dbgpt.model.cluster.controller.controller import ( + BaseModelController, + ModelRegistryClient, + run_model_controller, +) from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory -from dbgpt.model.cluster.worker_base import ModelWorker +from dbgpt.model.cluster.registry import ModelRegistry from dbgpt.model.cluster.worker.default_worker import DefaultModelWorker - from dbgpt.model.cluster.worker.manager import ( initialize_worker_manager_in_client, run_worker_manager, worker_manager, ) - -from dbgpt.model.cluster.registry import ModelRegistry -from dbgpt.model.cluster.controller.controller import ( - ModelRegistryClient, - run_model_controller, - BaseModelController, -) -from dbgpt.model.cluster.apiserver.api import run_apiserver - from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager +from dbgpt.model.cluster.worker_base import ModelWorker __all__ = [ "EmbeddingsRequest", diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index 43fa62646..1a508fddb 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -3,45 +3,42 @@ Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py """ -from typing import Optional, List, Dict, Any, Generator - -import logging import asyncio -import shortuuid import json -from fastapi import APIRouter, FastAPI -from fastapi import Depends, HTTPException +import logging +from typing import Any, Dict, Generator, List, Optional + +import shortuuid +from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer - - +from fastchat.constants import ErrorCode +from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse from fastchat.protocol.openai_api_protocol import ( ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, - ChatCompletionResponseChoice, DeltaMessage, ModelCard, ModelList, ModelPermission, UsageInfo, ) -from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse -from fastchat.constants import ErrorCode +from dbgpt._private.pydantic import BaseModel from dbgpt.component import BaseComponent, ComponentType, SystemApp -from dbgpt.util.parameter_utils import EnvArgumentParser from dbgpt.core import ModelOutput from dbgpt.core.interface.message import ModelMessage from dbgpt.model.base import ModelInstance -from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType -from dbgpt.model.cluster import ModelRegistry from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory +from dbgpt.model.cluster.registry import ModelRegistry +from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType +from dbgpt.util.parameter_utils import EnvArgumentParser from dbgpt.util.utils import setup_logging -from dbgpt._private.pydantic import BaseModel logger = logging.getLogger(__name__) @@ -393,8 +390,9 @@ async def create_chat_completion( def _initialize_all(controller_addr: str, system_app: SystemApp): - from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient + from dbgpt.model.cluster.controller.controller import ModelRegistryClient from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory + from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager if not system_app.get_component( ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None diff --git a/dbgpt/model/cluster/apiserver/tests/test_api.py b/dbgpt/model/cluster/apiserver/tests/test_api.py index 0e3188d6c..00a374780 100644 --- a/dbgpt/model/cluster/apiserver/tests/test_api.py +++ b/dbgpt/model/cluster/apiserver/tests/test_api.py @@ -1,29 +1,28 @@ +import importlib.metadata as metadata + import pytest import pytest_asyncio from aioresponses import aioresponses from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from httpx import AsyncClient, HTTPError -import importlib.metadata as metadata from dbgpt.component import SystemApp -from dbgpt.util.openai_utils import chat_completion_stream, chat_completion - from dbgpt.model.cluster.apiserver.api import ( - api_settings, - initialize_apiserver, - ModelList, - UsageInfo, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, - ChatCompletionResponseChoice, DeltaMessage, + ModelList, + UsageInfo, + api_settings, + initialize_apiserver, ) from dbgpt.model.cluster.tests.conftest import _new_cluster - from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory +from dbgpt.util.openai_utils import chat_completion, chat_completion_stream app = FastAPI() app.add_middleware( diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py index 97b76fd30..f60fb0e54 100644 --- a/dbgpt/model/cluster/base.py +++ b/dbgpt/model/cluster/base.py @@ -1,9 +1,9 @@ from typing import Dict, List +from dbgpt._private.pydantic import BaseModel +from dbgpt.core.interface.message import ModelMessage from dbgpt.model.base import WorkerApplyType from dbgpt.model.parameter import WorkerType -from dbgpt.core.interface.message import ModelMessage -from dbgpt._private.pydantic import BaseModel WORKER_MANAGER_SERVICE_TYPE = "service" WORKER_MANAGER_SERVICE_NAME = "WorkerManager" diff --git a/dbgpt/model/cluster/controller/controller.py b/dbgpt/model/cluster/controller/controller.py index e52d6f264..1591c120b 100644 --- a/dbgpt/model/cluster/controller/controller.py +++ b/dbgpt/model/cluster/controller/controller.py @@ -1,19 +1,17 @@ -from abc import ABC, abstractmethod - import logging +from abc import ABC, abstractmethod from typing import List from fastapi import APIRouter, FastAPI + from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.model.base import ModelInstance -from dbgpt.model.parameter import ModelControllerParameters from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry +from dbgpt.model.parameter import ModelControllerParameters +from dbgpt.util.api_utils import _api_remote as api_remote +from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote from dbgpt.util.parameter_utils import EnvArgumentParser -from dbgpt.util.api_utils import ( - _api_remote as api_remote, - _sync_api_remote as sync_api_remote, -) -from dbgpt.util.utils import setup_logging, setup_http_service_logging +from dbgpt.util.utils import setup_http_service_logging, setup_logging logger = logging.getLogger(__name__) diff --git a/dbgpt/model/cluster/controller/tests/test_registry.py b/dbgpt/model/cluster/controller/tests/test_registry.py index 9596ef7a5..51494364d 100644 --- a/dbgpt/model/cluster/controller/tests/test_registry.py +++ b/dbgpt/model/cluster/controller/tests/test_registry.py @@ -1,7 +1,8 @@ -import pytest +import asyncio from datetime import datetime, timedelta -import asyncio +import pytest + from dbgpt.model.base import ModelInstance from dbgpt.model.cluster.registry import EmbeddedModelRegistry diff --git a/dbgpt/model/cluster/embedding/loader.py b/dbgpt/model/cluster/embedding/loader.py index aea9d0530..b4fba70b9 100644 --- a/dbgpt/model/cluster/embedding/loader.py +++ b/dbgpt/model/cluster/embedding/loader.py @@ -4,8 +4,8 @@ from dbgpt.model.parameter import BaseEmbeddingModelParameters from dbgpt.util.parameter_utils import _get_dict_from_obj -from dbgpt.util.tracer import root_tracer, SpanType, SpanTypeRunName from dbgpt.util.system_utils import get_system_info +from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer if TYPE_CHECKING: from langchain.embeddings.base import Embeddings diff --git a/dbgpt/model/cluster/embedding/remote_embedding.py b/dbgpt/model/cluster/embedding/remote_embedding.py index 277bc4d24..d45c9dd85 100644 --- a/dbgpt/model/cluster/embedding/remote_embedding.py +++ b/dbgpt/model/cluster/embedding/remote_embedding.py @@ -1,4 +1,5 @@ from typing import List + from langchain.embeddings.base import Embeddings from dbgpt.model.cluster.manager_base import WorkerManager diff --git a/dbgpt/model/cluster/manager_base.py b/dbgpt/model/cluster/manager_base.py index 636b2822f..37b09b0fd 100644 --- a/dbgpt/model/cluster/manager_base.py +++ b/dbgpt/model/cluster/manager_base.py @@ -1,15 +1,16 @@ import asyncio -from dataclasses import dataclass -from typing import List, Optional, Dict, Iterator, Callable from abc import ABC, abstractmethod -from datetime import datetime from concurrent.futures import Future +from dataclasses import dataclass +from datetime import datetime +from typing import Callable, Dict, Iterator, List, Optional + from dbgpt.component import BaseComponent, ComponentType, SystemApp -from dbgpt.core import ModelOutput, ModelMetadata -from dbgpt.model.base import WorkerSupportedModel, WorkerApplyOutput +from dbgpt.core import ModelMetadata, ModelOutput +from dbgpt.model.base import WorkerApplyOutput, WorkerSupportedModel +from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest from dbgpt.model.cluster.worker_base import ModelWorker -from dbgpt.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest -from dbgpt.model.parameter import ModelWorkerParameters, ModelParameters +from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters from dbgpt.util.parameter_utils import ParameterDescription diff --git a/dbgpt/model/cluster/registry.py b/dbgpt/model/cluster/registry.py index 6d2d7cf86..6a5b12bf5 100644 --- a/dbgpt/model/cluster/registry.py +++ b/dbgpt/model/cluster/registry.py @@ -1,17 +1,16 @@ +import itertools +import logging import random import threading import time -import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple -import itertools from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.model.base import ModelInstance - logger = logging.getLogger(__name__) diff --git a/dbgpt/model/cluster/tests/conftest.py b/dbgpt/model/cluster/tests/conftest.py index eeba4826f..dc2288327 100644 --- a/dbgpt/model/cluster/tests/conftest.py +++ b/dbgpt/model/cluster/tests/conftest.py @@ -1,21 +1,22 @@ +from contextlib import asynccontextmanager, contextmanager +from typing import Dict, Iterator, List, Tuple + import pytest import pytest_asyncio -from contextlib import contextmanager, asynccontextmanager -from typing import List, Iterator, Dict, Tuple -from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from dbgpt.core import ModelOutput, ModelMetadata -from dbgpt.model.cluster.worker_base import ModelWorker + +from dbgpt.core import ModelMetadata, ModelOutput +from dbgpt.model.base import ModelInstance +from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from dbgpt.model.cluster.worker.manager import ( - WorkerManager, + ApplyFunction, + DeregisterFunc, LocalWorkerManager, RegisterFunc, - DeregisterFunc, SendHeartbeatFunc, - ApplyFunction, + WorkerManager, ) - -from dbgpt.model.base import ModelInstance -from dbgpt.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry +from dbgpt.model.cluster.worker_base import ModelWorker +from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType @pytest.fixture diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index 453071004..c1fe109ea 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -1,26 +1,25 @@ -import os import logging - -from typing import Dict, Iterator, List, Optional +import os import time import traceback +from typing import Dict, Iterator, List, Optional from dbgpt.configs.model_config import get_device -from dbgpt.model.adapter.base import LLMModelAdapter -from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.core import ( - ModelOutput, + ModelExtraMedata, ModelInferenceMetrics, ModelMetadata, - ModelExtraMedata, + ModelOutput, ) +from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.model_adapter import get_llm_model_adapter +from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.loader import ModelLoader, _get_model_real_path from dbgpt.model.parameter import ModelParameters -from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj -from dbgpt.util.tracer import root_tracer, SpanType, SpanTypeRunName from dbgpt.util.system_utils import get_system_info +from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer logger = logging.getLogger(__name__) diff --git a/dbgpt/model/cluster/worker/embedding_worker.py b/dbgpt/model/cluster/worker/embedding_worker.py index c8df5d779..06b6ce3bb 100644 --- a/dbgpt/model/cluster/worker/embedding_worker.py +++ b/dbgpt/model/cluster/worker/embedding_worker.py @@ -1,17 +1,17 @@ import logging -from typing import Dict, List, Type, Optional +from typing import Dict, List, Optional, Type from dbgpt.configs.model_config import get_device from dbgpt.core import ModelMetadata +from dbgpt.model.cluster.embedding.loader import EmbeddingLoader +from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.loader import _get_model_real_path from dbgpt.model.parameter import ( - EmbeddingModelParameters, + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, BaseEmbeddingModelParameters, + EmbeddingModelParameters, WorkerType, - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, ) -from dbgpt.model.cluster.worker_base import ModelWorker -from dbgpt.model.cluster.embedding.loader import EmbeddingLoader from dbgpt.util.model_utils import _clear_model_cache from dbgpt.util.parameter_utils import EnvArgumentParser diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index 8a8af3b35..f775c9118 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -15,7 +15,7 @@ from dbgpt.component import SystemApp from dbgpt.configs.model_config import LOGDIR -from dbgpt.core import ModelOutput, ModelMetadata +from dbgpt.core import ModelMetadata, ModelOutput from dbgpt.model.base import ( ModelInstance, WorkerApplyOutput, @@ -38,9 +38,9 @@ _dict_to_command_args, _get_dict_from_obj, ) -from dbgpt.util.utils import setup_logging, setup_http_service_logging -from dbgpt.util.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName from dbgpt.util.system_utils import get_system_info +from dbgpt.util.tracer import SpanType, SpanTypeRunName, initialize_tracer, root_tracer +from dbgpt.util.utils import setup_http_service_logging, setup_logging logger = logging.getLogger(__name__) diff --git a/dbgpt/model/cluster/worker/remote_worker.py b/dbgpt/model/cluster/worker/remote_worker.py index 00e707768..c25a71ca7 100644 --- a/dbgpt/model/cluster/worker/remote_worker.py +++ b/dbgpt/model/cluster/worker/remote_worker.py @@ -1,10 +1,10 @@ import json -from typing import Dict, Iterator, List import logging -from dbgpt.core import ModelOutput, ModelMetadata -from dbgpt.model.parameter import ModelParameters -from dbgpt.model.cluster.worker_base import ModelWorker +from typing import Dict, Iterator, List +from dbgpt.core import ModelMetadata, ModelOutput +from dbgpt.model.cluster.worker_base import ModelWorker +from dbgpt.model.parameter import ModelParameters logger = logging.getLogger(__name__) diff --git a/dbgpt/model/cluster/worker/tests/test_manager.py b/dbgpt/model/cluster/worker/tests/test_manager.py index 0c02e71be..d8ee8bdc3 100644 --- a/dbgpt/model/cluster/worker/tests/test_manager.py +++ b/dbgpt/model/cluster/worker/tests/test_manager.py @@ -1,29 +1,30 @@ -from unittest.mock import patch, AsyncMock -import pytest -from typing import List, Iterator, Dict, Tuple from dataclasses import asdict -from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from dbgpt.model.base import WorkerApplyType, ModelInstance +from typing import Dict, Iterator, List, Tuple +from unittest.mock import AsyncMock, patch + +import pytest + +from dbgpt.model.base import ModelInstance, WorkerApplyType from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest -from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.cluster.manager_base import WorkerRunData -from dbgpt.model.cluster.worker.manager import ( - LocalWorkerManager, - RegisterFunc, - DeregisterFunc, - SendHeartbeatFunc, - ApplyFunction, -) from dbgpt.model.cluster.tests.conftest import ( MockModelWorker, - manager_2_workers, - manager_with_2_workers, - manager_2_embedding_workers, _create_workers, - _start_worker_manager, _new_worker_params, + _start_worker_manager, + manager_2_embedding_workers, + manager_2_workers, + manager_with_2_workers, ) - +from dbgpt.model.cluster.worker.manager import ( + ApplyFunction, + DeregisterFunc, + LocalWorkerManager, + RegisterFunc, + SendHeartbeatFunc, +) +from dbgpt.model.cluster.worker_base import ModelWorker +from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType _TEST_MODEL_NAME = "vicuna-13b-v1.5" _TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" diff --git a/dbgpt/model/cluster/worker_base.py b/dbgpt/model/cluster/worker_base.py index 1f0005a8f..36413745f 100644 --- a/dbgpt/model/cluster/worker_base.py +++ b/dbgpt/model/cluster/worker_base.py @@ -1,12 +1,9 @@ from abc import ABC, abstractmethod from typing import Dict, Iterator, List, Type -from dbgpt.core import ModelOutput, ModelMetadata +from dbgpt.core import ModelMetadata, ModelOutput from dbgpt.model.parameter import ModelParameters, WorkerType -from dbgpt.util.parameter_utils import ( - ParameterDescription, - _get_parameter_descriptions, -) +from dbgpt.util.parameter_utils import ParameterDescription, _get_parameter_descriptions class ModelWorker(ABC): diff --git a/dbgpt/model/compression.py b/dbgpt/model/compression.py index dc626ea23..83b681c6f 100644 --- a/dbgpt/model/compression.py +++ b/dbgpt/model/compression.py @@ -52,7 +52,7 @@ def compress_module(module, target_device): def compress(tensor, config): - """Simulate team-wise quantization.""" + """Simulate group-wise quantization.""" if not config.enabled: return tensor @@ -105,7 +105,7 @@ def compress(tensor, config): def decompress(packed_data, config): - """Simulate team-wise dequantization.""" + """Simulate group-wise dequantization.""" if not config.enabled: return packed_data diff --git a/dbgpt/model/conversation.py b/dbgpt/model/conversation.py index 875299ccd..e3e465531 100644 --- a/dbgpt/model/conversation.py +++ b/dbgpt/model/conversation.py @@ -9,8 +9,8 @@ """ import dataclasses -from enum import auto, IntEnum -from typing import List, Dict, Callable +from enum import IntEnum, auto +from typing import Callable, Dict, List class SeparatorStyle(IntEnum): diff --git a/dbgpt/model/inference.py b/dbgpt/model/inference.py index 400f57ab9..fcb0d08f3 100644 --- a/dbgpt/model/inference.py +++ b/dbgpt/model/inference.py @@ -7,10 +7,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- import gc -from typing import Iterable, Dict +from typing import Dict, Iterable import torch - from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -19,7 +18,7 @@ TopPLogitsWarper, ) -from dbgpt.model.llm_utils import is_sentence_complete, is_partial_stop +from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete def prepare_logits_processor( diff --git a/dbgpt/model/llm/llama_cpp/llama_cpp.py b/dbgpt/model/llm/llama_cpp/llama_cpp.py index dc4851451..c7ee69205 100644 --- a/dbgpt/model/llm/llama_cpp/llama_cpp.py +++ b/dbgpt/model/llm/llama_cpp/llama_cpp.py @@ -1,11 +1,12 @@ """ Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py """ +import logging import re from typing import Dict -import logging -import torch + import llama_cpp +import torch from dbgpt.model.parameter import LlamaCppModelParameters diff --git a/dbgpt/model/llm_out/chatglm_llm.py b/dbgpt/model/llm_out/chatglm_llm.py index 0bb734181..acd6b226a 100644 --- a/dbgpt/model/llm_out/chatglm_llm.py +++ b/dbgpt/model/llm_out/chatglm_llm.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -from typing import List import re +from typing import List import torch diff --git a/dbgpt/model/llm_out/falcon_llm.py b/dbgpt/model/llm_out/falcon_llm.py index 12aebc6e9..4b56a754b 100644 --- a/dbgpt/model/llm_out/falcon_llm.py +++ b/dbgpt/model/llm_out/falcon_llm.py @@ -1,6 +1,7 @@ -import torch from threading import Thread -from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria + +import torch +from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer def falcon_generate_output(model, tokenizer, params, device, context_len=2048): diff --git a/dbgpt/model/llm_out/guanaco_llm.py b/dbgpt/model/llm_out/guanaco_llm.py index dd727d19a..4fae428fb 100644 --- a/dbgpt/model/llm_out/guanaco_llm.py +++ b/dbgpt/model/llm_out/guanaco_llm.py @@ -1,6 +1,7 @@ -import torch from threading import Thread -from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria + +import torch +from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): diff --git a/dbgpt/model/llm_out/hf_chat_llm.py b/dbgpt/model/llm_out/hf_chat_llm.py index b7e6caaac..e8fb58054 100644 --- a/dbgpt/model/llm_out/hf_chat_llm.py +++ b/dbgpt/model/llm_out/hf_chat_llm.py @@ -1,6 +1,7 @@ import logging -import torch from threading import Thread + +import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer logger = logging.getLogger(__name__) diff --git a/dbgpt/model/llm_out/llama_cpp_llm.py b/dbgpt/model/llm_out/llama_cpp_llm.py index 921670065..b25ea6845 100644 --- a/dbgpt/model/llm_out/llama_cpp_llm.py +++ b/dbgpt/model/llm_out/llama_cpp_llm.py @@ -1,4 +1,5 @@ from typing import Dict + import torch diff --git a/dbgpt/model/llm_out/proxy_llm.py b/dbgpt/model/llm_out/proxy_llm.py index 390ef4886..cbe9c42fc 100644 --- a/dbgpt/model/llm_out/proxy_llm.py +++ b/dbgpt/model/llm_out/proxy_llm.py @@ -2,16 +2,16 @@ # -*- coding: utf-8 -*- import time -from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream +from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream from dbgpt.model.proxy.llms.bard import bard_generate_stream +from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream from dbgpt.model.proxy.llms.claude import claude_generate_stream -from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream -from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream -from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream from dbgpt.model.proxy.llms.gemini import gemini_generate_stream -from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream -from dbgpt.model.proxy.llms.spark import spark_generate_stream from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.spark import spark_generate_stream +from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream +from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream +from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream def proxyllm_generate_stream( diff --git a/dbgpt/model/llm_out/vicuna_llm.py b/dbgpt/model/llm_out/vicuna_llm.py index 447254efb..3713946df 100644 --- a/dbgpt/model/llm_out/vicuna_llm.py +++ b/dbgpt/model/llm_out/vicuna_llm.py @@ -8,9 +8,9 @@ import requests from langchain.embeddings.base import Embeddings from langchain.llms.base import LLM -from dbgpt._private.pydantic import BaseModel from dbgpt._private.config import Config +from dbgpt._private.pydantic import BaseModel CFG = Config() diff --git a/dbgpt/model/llm_out/vllm_llm.py b/dbgpt/model/llm_out/vllm_llm.py index de108c87c..838bcc35a 100644 --- a/dbgpt/model/llm_out/vllm_llm.py +++ b/dbgpt/model/llm_out/vllm_llm.py @@ -1,9 +1,9 @@ -from typing import Dict import os +from typing import Dict + from vllm import AsyncLLMEngine -from vllm.utils import random_uuid from vllm.sampling_params import SamplingParams - +from vllm.utils import random_uuid _IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" diff --git a/dbgpt/model/llm_utils.py b/dbgpt/model/llm_utils.py index 031896e86..fff11ba77 100644 --- a/dbgpt/model/llm_utils.py +++ b/dbgpt/model/llm_utils.py @@ -2,11 +2,11 @@ # -*- coding:utf-8 -*- from pathlib import Path +from typing import Dict, List -from typing import List, Dict import cachetools -from dbgpt.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG from dbgpt.model.base import SupportedModel from dbgpt.util.parameter_utils import _get_parameter_descriptions diff --git a/dbgpt/model/loader.py b/dbgpt/model/loader.py index cfd6e0c3c..43d4e27f8 100644 --- a/dbgpt/model/loader.py +++ b/dbgpt/model/loader.py @@ -1,16 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Optional, Dict, Any - import logging +from typing import Any, Dict, Optional + from dbgpt.configs.model_config import get_device -from dbgpt.model.base import ModelType from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.model_adapter import get_llm_model_adapter +from dbgpt.model.base import ModelType from dbgpt.model.parameter import ( - ModelParameters, LlamaCppModelParameters, + ModelParameters, ProxyModelParameters, ) from dbgpt.util import get_gpu_memory @@ -135,6 +135,7 @@ def loader_with_params( def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters): import torch + from dbgpt.model.compression import compress_module device = model_params.device @@ -178,7 +179,6 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete # NOTE: Recent transformers library seems to fix the mps issue, also # it has made some changes causing compatibility issues with our # original patch. So we only apply the patch for older versions. - # Avoid bugs in mps backend by not using in-place operations. replace_llama_attn_with_non_inplace_operations() @@ -274,18 +274,18 @@ def load_huggingface_quantization_model( import torch try: + import transformers from accelerate import init_empty_weights from accelerate.utils import infer_auto_device_map - import transformers from transformers import ( - BitsAndBytesConfig, AutoConfig, AutoModel, AutoModelForCausalLM, - LlamaForCausalLM, AutoModelForSeq2SeqLM, - LlamaTokenizer, AutoTokenizer, + BitsAndBytesConfig, + LlamaForCausalLM, + LlamaTokenizer, ) except ImportError as exc: raise ValueError( diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index afa77cb8b..bcaa46c4d 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Optional, Union, Tuple +from typing import Dict, Optional, Tuple, Union from dbgpt.model.conversation import conv_templates from dbgpt.util.parameter_utils import BaseParameters diff --git a/dbgpt/model/proxy/llms/baichuan.py b/dbgpt/model/proxy/llms/baichuan.py index ed641c72f..aed91e151 100644 --- a/dbgpt/model/proxy/llms/baichuan.py +++ b/dbgpt/model/proxy/llms/baichuan.py @@ -1,9 +1,11 @@ -import requests import json from typing import List -from dbgpt.model.proxy.llms.proxy_model import ProxyModel + +import requests + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.llms.proxy_model import ProxyModel BAICHUAN_DEFAULT_MODEL = "Baichuan2-Turbo-192k" diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py index 7e43661d8..6317cad6d 100755 --- a/dbgpt/model/proxy/llms/bard.py +++ b/dbgpt/model/proxy/llms/bard.py @@ -1,5 +1,7 @@ -import requests from typing import List + +import requests + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.model.proxy.llms.proxy_model import ProxyModel diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index 5d1882141..15932bfcc 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -1,15 +1,17 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import importlib.metadata as metadata +import logging import os from typing import List -import logging -import importlib.metadata as metadata -from dbgpt.model.proxy.llms.proxy_model import ProxyModel -from dbgpt.model.parameter import ProxyModelParameters -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType + import httpx +from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.llms.proxy_model import ProxyModel + logger = logging.getLogger(__name__) diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py index 04975817c..9a3b3b868 100644 --- a/dbgpt/model/proxy/llms/gemini.py +++ b/dbgpt/model/proxy/llms/gemini.py @@ -1,7 +1,7 @@ -from typing import List, Tuple, Dict, Any +from typing import Any, Dict, List, Tuple -from dbgpt.model.proxy.llms.proxy_model import ProxyModel from dbgpt.core.interface.message import ModelMessage, parse_model_messages +from dbgpt.model.proxy.llms.proxy_model import ProxyModel GEMINI_DEFAULT_MODEL = "gemini-pro" diff --git a/dbgpt/model/proxy/llms/proxy_model.py b/dbgpt/model/proxy/llms/proxy_model.py index 4e55ec3ea..b287ea88f 100644 --- a/dbgpt/model/proxy/llms/proxy_model.py +++ b/dbgpt/model/proxy/llms/proxy_model.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Union, List, Optional, TYPE_CHECKING import logging +from typing import TYPE_CHECKING, List, Optional, Union + from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper if TYPE_CHECKING: - from dbgpt.core.interface.message import ModelMessage, BaseMessage + from dbgpt.core.interface.message import BaseMessage, ModelMessage logger = logging.getLogger(__name__) diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index aedf8a951..81a7cc2b4 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -1,14 +1,15 @@ -import json import base64 -import hmac import hashlib -from websockets.sync.client import connect +import hmac +import json from datetime import datetime -from typing import List from time import mktime -from urllib.parse import urlencode -from urllib.parse import urlparse +from typing import List +from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time + +from websockets.sync.client import connect + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.model.proxy.llms.proxy_model import ProxyModel diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py index bbcec2f42..e5d008de1 100644 --- a/dbgpt/model/proxy/llms/tongyi.py +++ b/dbgpt/model/proxy/llms/tongyi.py @@ -1,7 +1,8 @@ import logging from typing import List -from dbgpt.model.proxy.llms.proxy_model import ProxyModel + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.model.proxy.llms.proxy_model import ProxyModel logger = logging.getLogger(__name__) diff --git a/dbgpt/model/proxy/llms/wenxin.py b/dbgpt/model/proxy/llms/wenxin.py index 28e5a65d0..74e44fa80 100644 --- a/dbgpt/model/proxy/llms/wenxin.py +++ b/dbgpt/model/proxy/llms/wenxin.py @@ -1,9 +1,11 @@ -import requests import json from typing import List -from dbgpt.model.proxy.llms.proxy_model import ProxyModel + +import requests +from cachetools import TTLCache, cached + from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType -from cachetools import cached, TTLCache +from dbgpt.model.proxy.llms.proxy_model import ProxyModel @cached(TTLCache(1, 1800)) diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 2f854e17a..8108ad5e0 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -1,7 +1,7 @@ from typing import List -from dbgpt.model.proxy.llms.proxy_model import ProxyModel from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.model.proxy.llms.proxy_model import ProxyModel CHATGLM_DEFAULT_MODEL = "chatglm_pro" diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 7d1344aec..f4d022a3a 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -1,42 +1,41 @@ from __future__ import annotations -import os +import importlib.metadata as metadata import logging -from dataclasses import dataclass +import os from abc import ABC -import importlib.metadata as metadata +from dataclasses import dataclass from typing import ( - List, - Dict, - Any, - Optional, TYPE_CHECKING, - Union, + Any, AsyncIterator, - Callable, Awaitable, + Callable, + Dict, + List, + Optional, + Union, ) +from dbgpt._private.pydantic import model_to_json from dbgpt.component import ComponentType -from dbgpt.core.operator import BaseLLM -from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator +from dbgpt.core.awel import BaseOperator, TransformStreamAbsOperator from dbgpt.core.interface.llm import ( - ModelOutput, - ModelRequest, - ModelMetadata, LLMClient, MessageConverter, + ModelMetadata, + ModelOutput, + ModelRequest, ) -from dbgpt.model.cluster.client import DefaultLLMClient +from dbgpt.core.operator import BaseLLM from dbgpt.model.cluster import WorkerManagerFactory -from dbgpt._private.pydantic import model_to_json +from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper if TYPE_CHECKING: import httpx from httpx._types import ProxiesTypes - from openai import AsyncAzureOpenAI - from openai import AsyncOpenAI + from openai import AsyncAzureOpenAI, AsyncOpenAI ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI] @@ -292,9 +291,10 @@ async def _to_openai_stream( model (Optional[str], optional): The model name. Defaults to None. model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None. """ + import asyncio import json + import shortuuid - import asyncio from fastchat.protocol.openai_api_protocol import ( ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, diff --git a/dbgpt/model/utils/token_utils.py b/dbgpt/model/utils/token_utils.py index 281ed5eed..1c84490f4 100644 --- a/dbgpt/model/utils/token_utils.py +++ b/dbgpt/model/utils/token_utils.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Union, List, Optional, TYPE_CHECKING import logging +from typing import TYPE_CHECKING, List, Optional, Union if TYPE_CHECKING: - from dbgpt.core.interface.message import ModelMessage, BaseMessage + from dbgpt.core.interface.message import BaseMessage, ModelMessage logger = logging.getLogger(__name__) diff --git a/dbgpt/serve/conversation/api/endpoints.py b/dbgpt/serve/conversation/api/endpoints.py index 8ca494b98..59b155aaf 100644 --- a/dbgpt/serve/conversation/api/endpoints.py +++ b/dbgpt/serve/conversation/api/endpoints.py @@ -1,3 +1,4 @@ +import uuid from functools import cache from typing import List, Optional @@ -10,7 +11,7 @@ from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import ServeRequest, ServerResponse +from .schemas import MessageVo, ServeRequest, ServerResponse router = APIRouter() @@ -95,12 +96,14 @@ async def test_auth(): @router.post( - "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] + "/query", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], ) -async def create( +async def query( request: ServeRequest, service: Service = Depends(get_service) ) -> Result[ServerResponse]: - """Create a new Conversation entity + """Query Conversation entities Args: request (ServeRequest): The request @@ -108,43 +111,46 @@ async def create( Returns: ServerResponse: The response """ - return Result.succ(service.create(request)) + return Result.succ(service.get(request)) -@router.put( - "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +@router.post( + "/new", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], ) -async def update( - request: ServeRequest, service: Service = Depends(get_service) -) -> Result[ServerResponse]: - """Update a Conversation entity - - Args: - request (ServeRequest): The request - service (Service): The service - Returns: - ServerResponse: The response - """ - return Result.succ(service.update(request)) +async def dialogue_new( + chat_mode: str = "chat_normal", + user_name: str = None, + # TODO remove user id + user_id: str = None, + sys_code: str = None, +): + user_name = user_name or user_id + unique_id = uuid.uuid1() + res = ServerResponse( + user_input="", + conv_uid=str(unique_id), + chat_mode=chat_mode, + user_name=user_name, + sys_code=sys_code, + ) + return Result.succ(res) @router.post( - "/query", - response_model=Result[ServerResponse], + "/delete", dependencies=[Depends(check_api_key)], ) -async def query( - request: ServeRequest, service: Service = Depends(get_service) -) -> Result[ServerResponse]: - """Query Conversation entities +async def delete(con_uid: str, service: Service = Depends(get_service)): + """Delete a Conversation entity Args: - request (ServeRequest): The request + con_uid (str): The conversation UID service (Service): The service - Returns: - ServerResponse: The response """ - return Result.succ(service.get(request)) + service.delete(ServeRequest(conv_uid=con_uid)) + return Result.succ(None) @router.post( @@ -155,7 +161,7 @@ async def query( async def query_page( request: ServeRequest, page: Optional[int] = Query(default=1, description="current page"), - page_size: Optional[int] = Query(default=20, description="page size"), + page_size: Optional[int] = Query(default=10, description="page size"), service: Service = Depends(get_service), ) -> Result[PaginationResult[ServerResponse]]: """Query Conversation entities @@ -171,6 +177,37 @@ async def query_page( return Result.succ(service.get_list_by_page(request, page, page_size)) +@router.get( + "/list", + response_model=Result[List[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def list_latest_conv( + user_name: str = None, + user_id: str = None, + sys_code: str = None, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=10, description="page size"), + service: Service = Depends(get_service), +) -> Result[List[ServerResponse]]: + """Return latest conversations""" + request = ServeRequest( + user_name=user_name or user_id, + sys_code=sys_code, + ) + return Result.succ(service.get_list_by_page(request, page, page_size).items) + + +@router.get( + "/messages/history", + response_model=Result[List[MessageVo]], + dependencies=[Depends(check_api_key)], +) +async def get_history_messages(con_uid: str, service: Service = Depends(get_service)): + """Get the history messages of a conversation""" + return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid))) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" global global_system_app diff --git a/dbgpt/serve/conversation/api/schemas.py b/dbgpt/serve/conversation/api/schemas.py index d83c79e15..2558d8ad9 100644 --- a/dbgpt/serve/conversation/api/schemas.py +++ b/dbgpt/serve/conversation/api/schemas.py @@ -1,4 +1,6 @@ # Define your Pydantic schemas here +from typing import Any, Optional + from dbgpt._private.pydantic import BaseModel, Field from ..config import SERVE_APP_NAME_HUMP @@ -7,15 +9,133 @@ class ServeRequest(BaseModel): """Conversation request model""" - # TODO define your own fields here - class Config: title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + # Just for query + chat_mode: str = Field( + default=None, + description="The chat mode.", + examples=[ + "chat_normal", + ], + ) + conv_uid: Optional[str] = Field( + default=None, + description="The conversation uid.", + examples=[ + "5e7100bc-9017-11ee-9876-8fe019728d79", + ], + ) + user_name: Optional[str] = Field( + default=None, + description="The user name.", + examples=[ + "zhangsan", + ], + ) + sys_code: Optional[str] = Field( + default=None, + description="The system code.", + examples=[ + "dbgpt", + ], + ) + class ServerResponse(BaseModel): """Conversation response model""" - # TODO define your own fields here class Config: title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + + conv_uid: str = Field( + ..., + description="The conversation uid.", + examples=[ + "5e7100bc-9017-11ee-9876-8fe019728d79", + ], + ) + user_input: str = Field( + ..., + description="The user input, we return it as the summary the conversation.", + examples=[ + "Hello world", + ], + ) + chat_mode: str = Field( + ..., + description="The chat mode.", + examples=[ + "chat_normal", + ], + ) + select_param: Optional[str] = Field( + default=None, + description="The select param.", + examples=[ + "my_knowledge_space_name", + ], + ) + model_name: Optional[str] = Field( + default=None, + description="The model name.", + examples=[ + "vicuna-13b-v1.5", + ], + ) + user_name: Optional[str] = Field( + default=None, + description="The user name.", + examples=[ + "zhangsan", + ], + ) + sys_code: Optional[str] = Field( + default=None, + description="The system code.", + examples=[ + "dbgpt", + ], + ) + + +class MessageVo(BaseModel): + role: str = Field( + ..., + description="The role that sends out the current message.", + examples=["human", "ai", "view"], + ) + context: str = Field( + ..., + description="The current message content.", + examples=[ + "Hello", + "Hi, how are you?", + ], + ) + + order: int = Field( + ..., + description="The current message order.", + examples=[ + 1, + 2, + ], + ) + + time_stamp: Optional[Any] = Field( + default=None, + description="The current message time stamp.", + examples=[ + "2023-01-07 09:00:00", + ], + ) + + model_name: Optional[str] = Field( + default=None, + description="The model name.", + examples=[ + "vicuna-13b-v1.5", + ], + ) diff --git a/dbgpt/serve/conversation/config.py b/dbgpt/serve/conversation/config.py index 60819389c..77574d2fc 100644 --- a/dbgpt/serve/conversation/config.py +++ b/dbgpt/serve/conversation/config.py @@ -20,3 +20,8 @@ class ServeConfig(BaseServeConfig): api_keys: Optional[str] = field( default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} ) + + default_model: Optional[str] = field( + default=None, + metadata={"help": "Default model name"}, + ) diff --git a/dbgpt/serve/conversation/models/models.py b/dbgpt/serve/conversation/models/models.py index fb333b051..860f456f4 100644 --- a/dbgpt/serve/conversation/models/models.py +++ b/dbgpt/serve/conversation/models/models.py @@ -1,30 +1,20 @@ """This is an auto-generated model file You can define your own models and DAOs here """ +import json from datetime import datetime -from typing import Any, Dict, Union - -from sqlalchemy import Column, DateTime, Index, Integer, String, Text +from typing import Any, Dict, List, Optional, Union +from dbgpt.core import MessageStorageItem +from dbgpt.storage.chat_history.chat_history_db import ChatHistoryEntity as ServeEntity +from dbgpt.storage.chat_history.chat_history_db import ChatHistoryMessageEntity from dbgpt.storage.metadata import BaseDao, Model, db +from dbgpt.util import PaginationResult from ..api.schemas import ServeRequest, ServerResponse from ..config import SERVER_APP_TABLE_NAME, ServeConfig -class ServeEntity(Model): - __tablename__ = SERVER_APP_TABLE_NAME - id = Column(Integer, primary_key=True, comment="Auto increment id") - - # TODO: define your own fields here - - gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") - gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") - - def __repr__(self): - return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" - - class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): """The DAO class for Conversation""" @@ -68,4 +58,95 @@ def to_response(self, entity: ServeEntity) -> ServerResponse: RES: The response """ # TODO implement your own logic here, transfer the entity to a response - return ServerResponse() + return ServerResponse( + conv_uid=entity.conv_uid, + user_input=entity.summary, + chat_mode=entity.chat_mode, + user_name=entity.user_name, + sys_code=entity.sys_code, + ) + + def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]: + """Get the latest message of a conversation + + Args: + conv_uid (str): The conversation UID + + Returns: + ChatHistoryMessageEntity: The latest message + """ + with self.session() as session: + entity: ChatHistoryMessageEntity = ( + session.query(ChatHistoryMessageEntity) + .filter(ChatHistoryMessageEntity.conv_uid == conv_uid) + .order_by(ChatHistoryMessageEntity.gmt_created.desc()) + .first() + ) + if not entity: + return None + message_detail = ( + json.loads(entity.message_detail) if entity.message_detail else {} + ) + return MessageStorageItem(entity.conv_uid, entity.index, message_detail) + + def _parse_old_messages(self, entity: ServeEntity) -> List[Dict[str, Any]]: + """Parse the old messages + + Args: + entity (ServeEntity): The entity + + Returns: + str: The old messages + """ + messages = json.loads(entity.messages) + return messages + + def get_conv_by_page( + self, req: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get conversation by page + + Args: + req (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ChatHistoryEntity]: The conversation list + """ + with self.session(commit=False) as session: + query = self._create_query_object(session, req) + query = query.order_by(ServeEntity.gmt_created.desc()) + total_count = query.count() + items = query.offset((page - 1) * page_size).limit(page_size) + total_pages = (total_count + page_size - 1) // page_size + result_items = [] + for item in items: + select_param, model_name = "", None + if item.messages: + messages = self._parse_old_messages(item) + last_round = max(messages, key=lambda x: x["chat_order"]) + if "param_value" in last_round: + select_param = last_round["param_value"] + else: + select_param = "" + else: + latest_message = self.get_latest_message(item.conv_uid) + if latest_message: + message = latest_message.to_message() + select_param = message.additional_kwargs.get("param_value") + model_name = message.additional_kwargs.get("model_name") + res_item = self.to_response(item) + res_item.select_param = select_param + res_item.model_name = model_name + result_items.append(res_item) + + result = PaginationResult( + items=result_items, + total_count=total_count, + total_pages=total_pages, + page=page, + page_size=page_size, + ) + + return result diff --git a/dbgpt/serve/conversation/serve.py b/dbgpt/serve/conversation/serve.py index 2dd63ca6a..643ef6432 100644 --- a/dbgpt/serve/conversation/serve.py +++ b/dbgpt/serve/conversation/serve.py @@ -8,6 +8,7 @@ from dbgpt.serve.core import BaseServe from dbgpt.storage.metadata import DatabaseManager +from .api.endpoints import init_endpoints, router from .config import ( APP_NAME, SERVE_APP_NAME, @@ -15,6 +16,7 @@ SERVE_CONFIG_KEY_PREFIX, ServeConfig, ) +from .service.service import Service logger = logging.getLogger(__name__) @@ -58,6 +60,10 @@ def init_app(self, system_app: SystemApp): if self._app_has_initiated: return self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) self._app_has_initiated = True def on_init(self): diff --git a/dbgpt/serve/conversation/service/service.py b/dbgpt/serve/conversation/service/service.py index 71592bd13..99fddfb8b 100644 --- a/dbgpt/serve/conversation/service/service.py +++ b/dbgpt/serve/conversation/service/service.py @@ -1,11 +1,20 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core import ( + InMemoryStorage, + MessageStorageItem, + QuerySpec, + StorageConversation, + StorageInterface, +) +from dbgpt.core.interface.message import _append_view_messages from dbgpt.serve.core import BaseService from dbgpt.storage.metadata import BaseDao +from dbgpt.storage.metadata._base_dao import REQ, RES from dbgpt.util.pagination_utils import PaginationResult -from ..api.schemas import ServeRequest, ServerResponse +from ..api.schemas import MessageVo, ServeRequest, ServerResponse from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -15,10 +24,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): name = SERVE_SERVICE_COMPONENT_NAME - def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + def __init__( + self, + system_app: SystemApp, + dao: Optional[ServeDao] = None, + storage: Optional[StorageInterface[StorageConversation, Any]] = None, + message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + ): self._system_app = None self._serve_config: ServeConfig = None self._dao: ServeDao = dao + self._storage = storage + self._message_storage = message_storage super().__init__(system_app) def init_app(self, system_app: SystemApp) -> None: @@ -34,7 +51,7 @@ def init_app(self, system_app: SystemApp) -> None: self._system_app = system_app @property - def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + def dao(self) -> ServeDao: """Returns the internal DAO.""" return self._dao @@ -43,6 +60,54 @@ def config(self) -> ServeConfig: """Returns the internal ServeConfig.""" return self._serve_config + def create(self, request: REQ) -> RES: + raise NotImplementedError() + + @property + def conv_storage(self) -> StorageInterface: + """The conversation storage, store the conversation items.""" + if self._storage: + return self._storage + from ..serve import Serve + + return Serve.call_on_current_serve( + self._system_app, lambda serve: serve.conv_storage + ) + + @property + def message_storage(self) -> StorageInterface: + """The message storage, store the messages of one conversation.""" + if self._message_storage: + return self._message_storage + from ..serve import Serve + + return Serve.call_on_current_serve( + self._system_app, + lambda serve: serve.message_storage, + ) + + def create_storage_conv( + self, request: Union[ServeRequest, Dict[str, Any]], load_message: bool = True + ) -> StorageConversation: + conv_storage = self.conv_storage + message_storage = self.message_storage + if not conv_storage or not message_storage: + raise RuntimeError( + "Can't get the conversation storage or message storage from current serve component." + ) + if isinstance(request, dict): + request = ServeRequest(**request) + storage_conv: StorageConversation = StorageConversation( + conv_uid=request.conv_uid, + chat_mode=request.chat_mode, + user_name=request.user_name, + sys_code=request.sys_code, + conv_storage=conv_storage, + message_storage=message_storage, + load_message=load_message, + ) + return storage_conv + def update(self, request: ServeRequest) -> ServerResponse: """Update a Conversation entity @@ -74,18 +139,13 @@ def get(self, request: ServeRequest) -> Optional[ServerResponse]: return self.dao.get_one(query_request) def delete(self, request: ServeRequest) -> None: - """Delete a Conversation entity + """Delete current conversation and its messages Args: request (ServeRequest): The request """ - - # TODO: implement your own logic here - # Build the query request from the request - query_request = { - # "id": request.id - } - self.dao.delete(query_request) + conv: StorageConversation = self.create_storage_conv(request) + conv.delete() def get_list(self, request: ServeRequest) -> List[ServerResponse]: """Get a list of Conversation entities @@ -114,5 +174,29 @@ def get_list_by_page( Returns: List[ServerResponse]: The response """ - query_request = request - return self.dao.get_list_page(query_request, page, page_size) + return self.dao.get_conv_by_page(request, page, page_size) + + def get_history_messages( + self, request: Union[ServeRequest, Dict[str, Any]] + ) -> List[MessageVo]: + """Get a list of Conversation entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + conv: StorageConversation = self.create_storage_conv(request) + result = [] + messages = _append_view_messages(conv.messages) + for msg in messages: + result.append( + MessageVo( + role=msg.type, + context=msg.content, + order=msg.round_index, + model_name=self.config.default_model, + ) + ) + return result diff --git a/dbgpt/serve/conversation/tests/test_models.py b/dbgpt/serve/conversation/tests/test_models.py index 1d111644d..c26660d1e 100644 --- a/dbgpt/serve/conversation/tests/test_models.py +++ b/dbgpt/serve/conversation/tests/test_models.py @@ -29,7 +29,7 @@ def dao(server_config): @pytest.fixture def default_entity_dict(): # TODO: build your default entity dict - return {} + return {"conv_uid": "test_conv_uid", "summary": "hello", "chat_mode": "chat_normal"} def test_table_exist(): @@ -67,19 +67,6 @@ def test_entity_all(): pass -def test_dao_create(dao, default_entity_dict): - # TODO: implement your test case - req = ServeRequest(**default_entity_dict) - res: ServerResponse = dao.create(req) - assert res is not None - - -def test_dao_get_one(dao, default_entity_dict): - # TODO: implement your test case - req = ServeRequest(**default_entity_dict) - res: ServerResponse = dao.create(req) - - def test_get_dao_get_list(dao): # TODO: implement your test case pass diff --git a/dbgpt/serve/core/serve.py b/dbgpt/serve/core/serve.py index e909ad4cb..ff0203b65 100644 --- a/dbgpt/serve/core/serve.py +++ b/dbgpt/serve/core/serve.py @@ -73,7 +73,7 @@ def get_current_serve(cls, system_app: SystemApp) -> Optional["BaseServe"]: Returns: Optional[BaseServe]: The current serve component. """ - return system_app.get_component(cls.name, cls, default_component=None) + return cls.get_instance(system_app, default_component=None) @classmethod def call_on_current_serve( diff --git a/dbgpt/storage/cache/__init__.py b/dbgpt/storage/cache/__init__.py index 906c18c32..80b23bf25 100644 --- a/dbgpt/storage/cache/__init__.py +++ b/dbgpt/storage/cache/__init__.py @@ -1,6 +1,6 @@ +from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue from dbgpt.storage.cache.manager import CacheManager, initialize_cache from dbgpt.storage.cache.storage.base import MemoryCacheStorage -from dbgpt.storage.cache.llm_cache import LLMCacheKey, LLMCacheValue, LLMCacheClient __all__ = [ "LLMCacheKey", diff --git a/dbgpt/storage/cache/llm_cache.py b/dbgpt/storage/cache/llm_cache.py index 682276349..441c67a6c 100644 --- a/dbgpt/storage/cache/llm_cache.py +++ b/dbgpt/storage/cache/llm_cache.py @@ -1,16 +1,11 @@ -from typing import Optional, Dict, Any, Union, List -from dataclasses import dataclass, asdict import hashlib +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Union -from dbgpt.core.interface.cache import ( - CacheKey, - CacheValue, - CacheClient, - CacheConfig, -) -from dbgpt.storage.cache.manager import CacheManager from dbgpt.core import ModelOutput, Serializer +from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue from dbgpt.model.base import ModelType +from dbgpt.storage.cache.manager import CacheManager @dataclass diff --git a/dbgpt/storage/cache/manager.py b/dbgpt/storage/cache/manager.py index 8eebd0883..97063c347 100644 --- a/dbgpt/storage/cache/manager.py +++ b/dbgpt/storage/cache/manager.py @@ -1,17 +1,12 @@ -from abc import ABC, abstractmethod -from typing import Optional, Type import logging +from abc import ABC, abstractmethod from concurrent.futures import Executor -from dbgpt.storage.cache.storage.base import CacheStorage -from dbgpt.core.interface.cache import K, V -from dbgpt.core import ( - CacheKey, - CacheValue, - CacheConfig, - Serializer, - Serializable, -) +from typing import Optional, Type + from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer +from dbgpt.core.interface.cache import K, V +from dbgpt.storage.cache.storage.base import CacheStorage from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async logger = logging.getLogger(__name__) @@ -102,8 +97,8 @@ def serializer(self) -> Serializer: def initialize_cache( system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str ): - from dbgpt.util.serialization.json_serialization import JsonSerializer from dbgpt.storage.cache.storage.base import MemoryCacheStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer cache_storage = None if storage_type == "disk": diff --git a/dbgpt/storage/cache/storage/base.py b/dbgpt/storage/cache/storage/base.py index a31453b56..779552b98 100644 --- a/dbgpt/storage/cache/storage/base.py +++ b/dbgpt/storage/cache/storage/base.py @@ -1,18 +1,19 @@ +import logging from abc import ABC, abstractmethod -from typing import Optional -from dataclasses import dataclass from collections import OrderedDict +from dataclasses import dataclass +from typing import Optional + import msgpack -import logging from dbgpt.core.interface.cache import ( - K, - V, + CacheConfig, CacheKey, + CachePolicy, CacheValue, - CacheConfig, + K, RetrievalPolicy, - CachePolicy, + V, ) from dbgpt.util.memory_utils import _get_object_bytes diff --git a/dbgpt/storage/cache/storage/disk/disk_storage.py b/dbgpt/storage/cache/storage/disk/disk_storage.py index 0f4f32523..fa1c01003 100644 --- a/dbgpt/storage/cache/storage/disk/disk_storage.py +++ b/dbgpt/storage/cache/storage/disk/disk_storage.py @@ -1,16 +1,17 @@ -from typing import Optional import logging -from rocksdict import Rdict, Options +from typing import Optional + +from rocksdict import Options, Rdict from dbgpt.core.interface.cache import ( - K, - V, + CacheConfig, CacheKey, CacheValue, - CacheConfig, + K, RetrievalPolicy, + V, ) -from dbgpt.storage.cache.storage.base import StorageItem, CacheStorage +from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/cache/storage/tests/test_storage.py b/dbgpt/storage/cache/storage/tests/test_storage.py index 1a175b2e4..1ba2545f0 100644 --- a/dbgpt/storage/cache/storage/tests/test_storage.py +++ b/dbgpt/storage/cache/storage/tests/test_storage.py @@ -1,7 +1,9 @@ import pytest -from ..base import StorageItem + from dbgpt.util.memory_utils import _get_object_bytes +from ..base import StorageItem + def test_build_from(): key_hash = b"key_hash" diff --git a/dbgpt/storage/chat_history/base.py b/dbgpt/storage/chat_history/base.py index abc750bba..2f83f0c5d 100644 --- a/dbgpt/storage/chat_history/base.py +++ b/dbgpt/storage/chat_history/base.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional, Dict from enum import Enum +from typing import Dict, List, Optional + from dbgpt.core.interface.message import OnceConversation diff --git a/dbgpt/storage/chat_history/chat_hisotry_factory.py b/dbgpt/storage/chat_history/chat_hisotry_factory.py index d07afb7d4..f202c20ff 100644 --- a/dbgpt/storage/chat_history/chat_hisotry_factory.py +++ b/dbgpt/storage/chat_history/chat_hisotry_factory.py @@ -1,9 +1,14 @@ +import logging from typing import Type -from .base import MemoryStoreType + from dbgpt._private.config import Config from dbgpt.storage.chat_history.base import BaseChatHistoryMemory +from .base import MemoryStoreType + +# TODO remove global variable CFG = Config() +logger = logging.getLogger(__name__) # Import first for auto create table from .store_type.meta_db_history import DbHistoryMemory @@ -13,14 +18,16 @@ class ChatHistory: def __init__(self): self.memory_type = MemoryStoreType.DB.value self.mem_store_class_map = {} - from .store_type.duckdb_history import DuckdbHistoryMemory - from .store_type.file_history import FileHistoryMemory - from .store_type.mem_history import MemHistoryMemory - self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory - self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory + # Just support db store type after v0.4.6 + # from .store_type.duckdb_history import DuckdbHistoryMemory + # from .store_type.file_history import FileHistoryMemory + # from .store_type.mem_history import MemHistoryMemory + # self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory + # self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory + # self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory + self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory - self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory: """New store instance for store chat histories @@ -31,9 +38,39 @@ def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory: Returns: BaseChatHistoryMemory: Store instance """ + self._check_store_type(CFG.CHAT_HISTORY_STORE_TYPE) return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)( chat_session_id ) def get_store_cls(self) -> Type[BaseChatHistoryMemory]: + self._check_store_type(CFG.CHAT_HISTORY_STORE_TYPE) return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE) + + def _check_store_type(self, store_type: str): + """Check store type + + Raises: + ValueError: Invalid store type + """ + from .store_type.duckdb_history import DuckdbHistoryMemory + from .store_type.file_history import FileHistoryMemory + from .store_type.mem_history import MemHistoryMemory + + if store_type == MemHistoryMemory.store_type: + logger.error( + "Not support memory store type, just support db store type now" + ) + raise ValueError(f"Invalid store type: {store_type}") + + if store_type == FileHistoryMemory.store_type: + logger.error("Not support file store type, just support db store type now") + raise ValueError(f"Invalid store type: {store_type}") + if store_type == DuckdbHistoryMemory.store_type: + link1 = "https://docs.dbgpt.site/docs/faq/install#q6-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-sqlitel" + link2 = "https://docs.dbgpt.site/docs/faq/install#q7-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-mysql" + logger.error( + "Not support duckdb store type after v0.4.6, just support db store type now, " + f"you can migrate your message according to {link1} or {link2}" + ) + raise ValueError(f"Invalid store type: {store_type}") diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py index 029abadbf..8faaca8b2 100644 --- a/dbgpt/storage/chat_history/chat_history_db.py +++ b/dbgpt/storage/chat_history/chat_history_db.py @@ -1,7 +1,7 @@ -from typing import Optional from datetime import datetime -from sqlalchemy import Column, Integer, String, Index, Text, DateTime -from sqlalchemy import UniqueConstraint +from typing import Optional + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint from dbgpt.storage.metadata import BaseDao, Model diff --git a/dbgpt/storage/chat_history/storage_adapter.py b/dbgpt/storage/chat_history/storage_adapter.py index 4a7a3b96b..93302ef89 100644 --- a/dbgpt/storage/chat_history/storage_adapter.py +++ b/dbgpt/storage/chat_history/storage_adapter.py @@ -1,16 +1,19 @@ -from typing import List, Dict, Type import json +from typing import Dict, List, Type + from sqlalchemy.orm import Session -from dbgpt.core.interface.storage import StorageItemAdapter + from dbgpt.core.interface.message import ( - StorageConversation, + BaseMessage, ConversationIdentifier, MessageIdentifier, MessageStorageItem, - _messages_from_dict, + StorageConversation, _conversation_to_dict, - BaseMessage, + _messages_from_dict, ) +from dbgpt.core.interface.storage import StorageItemAdapter + from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity @@ -21,7 +24,8 @@ def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity: message_ids = ",".join(item.message_ids) messages = None if not item.save_message_independent and item.messages: - messages = _conversation_to_dict(item) + message_dict_list = [_conversation_to_dict(item)] + messages = json.dumps(message_dict_list, ensure_ascii=False) return ChatHistoryEntity( conv_uid=item.conv_uid, chat_mode=item.chat_mode, @@ -42,15 +46,10 @@ def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation: save_message_independent = True if old_conversations: # Load old messages from old conversations, in old design, we save messages to chat_history table - old_messages_dict = [] - for old_conversation in old_conversations: - old_messages_dict.extend( - old_conversation["messages"] - if "messages" in old_conversation - else [] - ) save_message_independent = False - old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict) + old_messages: List[BaseMessage] = _parse_old_conversations( + old_conversations + ) return StorageConversation( conv_uid=model.conv_uid, chat_mode=model.chat_mode, @@ -114,3 +113,24 @@ def get_query_for_identifier( ChatHistoryMessageEntity.conv_uid == resource_id.conv_uid, ChatHistoryMessageEntity.index == resource_id.index, ) + + +def _parse_old_conversations(old_conversations: List[Dict]) -> List[BaseMessage]: + old_messages_dict = [] + for old_conversation in old_conversations: + messages = ( + old_conversation["messages"] if "messages" in old_conversation else [] + ) + for message in messages: + if "data" in message: + message_data = message["data"] + additional_kwargs = message_data.get("additional_kwargs", {}) + additional_kwargs["param_value"] = old_conversation.get("param_value") + additional_kwargs["param_type"] = old_conversation.get("param_type") + additional_kwargs["model_name"] = old_conversation.get("model_name") + message_data["additional_kwargs"] = additional_kwargs + + old_messages_dict.extend(messages) + + old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict) + return old_messages diff --git a/dbgpt/storage/chat_history/store_type/duckdb_history.py b/dbgpt/storage/chat_history/store_type/duckdb_history.py index 2e99fc185..81644f346 100644 --- a/dbgpt/storage/chat_history/store_type/duckdb_history.py +++ b/dbgpt/storage/chat_history/store_type/duckdb_history.py @@ -1,15 +1,14 @@ import json import os +from typing import Dict, List, Optional + import duckdb -from typing import List, Dict, Optional from dbgpt._private.config import Config from dbgpt.configs.model_config import PILOT_PATH +from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict from dbgpt.storage.chat_history.base import BaseChatHistoryMemory -from dbgpt.core.interface.message import ( - OnceConversation, - _conversation_to_dict, -) + from ..base import MemoryStoreType default_db_path = os.path.join(PILOT_PATH, "message") diff --git a/dbgpt/storage/chat_history/store_type/file_history.py b/dbgpt/storage/chat_history/store_type/file_history.py index 013a889b0..efa9f4e69 100644 --- a/dbgpt/storage/chat_history/store_type/file_history.py +++ b/dbgpt/storage/chat_history/store_type/file_history.py @@ -1,9 +1,8 @@ -from typing import List +import datetime import json import os -import datetime -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory from pathlib import Path +from typing import List from dbgpt._private.config import Config from dbgpt.core.interface.message import ( @@ -11,7 +10,7 @@ _conversation_from_dict, _conversations_to_dict, ) -from dbgpt.storage.chat_history.base import MemoryStoreType +from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType CFG = Config() diff --git a/dbgpt/storage/chat_history/store_type/mem_history.py b/dbgpt/storage/chat_history/store_type/mem_history.py index 4f35c2cd0..3c5264627 100644 --- a/dbgpt/storage/chat_history/store_type/mem_history.py +++ b/dbgpt/storage/chat_history/store_type/mem_history.py @@ -1,10 +1,9 @@ from typing import List -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory from dbgpt._private.config import Config from dbgpt.core.interface.message import OnceConversation +from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType from dbgpt.util.custom_data_structure import FixedSizeDict -from dbgpt.storage.chat_history.base import MemoryStoreType CFG = Config() diff --git a/dbgpt/storage/chat_history/store_type/meta_db_history.py b/dbgpt/storage/chat_history/store_type/meta_db_history.py index 9676259d0..ca08c69fd 100644 --- a/dbgpt/storage/chat_history/store_type/meta_db_history.py +++ b/dbgpt/storage/chat_history/store_type/meta_db_history.py @@ -1,12 +1,11 @@ import json import logging -from typing import List, Dict, Optional +from typing import Dict, List, Optional + from dbgpt._private.config import Config from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory -from dbgpt.storage.chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao - -from dbgpt.storage.chat_history.base import MemoryStoreType +from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType +from dbgpt.storage.chat_history.chat_history_db import ChatHistoryDao, ChatHistoryEntity CFG = Config() logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/chat_history/tests/test_storage_adapter.py b/dbgpt/storage/chat_history/tests/test_storage_adapter.py index 1802a8fd0..3596adc9a 100644 --- a/dbgpt/storage/chat_history/tests/test_storage_adapter.py +++ b/dbgpt/storage/chat_history/tests/test_storage_adapter.py @@ -1,20 +1,21 @@ -import pytest from typing import List -from dbgpt.util.pagination_utils import PaginationResult -from dbgpt.util.serialization.json_serialization import JsonSerializer -from dbgpt.core.interface.message import StorageConversation, HumanMessage, AIMessage +import pytest + +from dbgpt.core.interface.message import AIMessage, HumanMessage, StorageConversation from dbgpt.core.interface.storage import QuerySpec -from dbgpt.storage.metadata import db -from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage from dbgpt.storage.chat_history.chat_history_db import ( ChatHistoryEntity, ChatHistoryMessageEntity, ) from dbgpt.storage.chat_history.storage_adapter import ( - DBStorageConversationItemAdapter, DBMessageStorageItemAdapter, + DBStorageConversationItemAdapter, ) +from dbgpt.storage.metadata import db +from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.serialization.json_serialization import JsonSerializer @pytest.fixture diff --git a/dbgpt/storage/metadata/__init__.py b/dbgpt/storage/metadata/__init__.py index ee409f6cb..8660866d9 100644 --- a/dbgpt/storage/metadata/__init__.py +++ b/dbgpt/storage/metadata/__init__.py @@ -1,12 +1,12 @@ +from dbgpt.storage.metadata._base_dao import BaseDao +from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory from dbgpt.storage.metadata.db_manager import ( - db, - Model, + BaseModel, DatabaseManager, + Model, create_model, - BaseModel, + db, ) -from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory -from dbgpt.storage.metadata._base_dao import BaseDao __ALL__ = [ "db", diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 770d82c93..96294c00b 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -1,6 +1,8 @@ from contextlib import contextmanager -from typing import TypeVar, Generic, Any, Optional, Dict, Union, List +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union + from sqlalchemy.orm.session import Session + from dbgpt.util.pagination_utils import PaginationResult # The entity type @@ -10,8 +12,7 @@ # The response schema type RES = TypeVar("RES") -from .db_manager import db, DatabaseManager, BaseQuery - +from .db_manager import BaseQuery, DatabaseManager, db QUERY_SPEC = Union[REQ, Dict[str, Any]] @@ -21,6 +22,7 @@ class BaseDao(Generic[T, REQ, RES]): Examples: .. code-block:: python + class UserDao(BaseDao): def get_user_by_name(self, name: str) -> User: with self.session() as session: @@ -70,8 +72,9 @@ def session(self, commit: Optional[bool] = True) -> Session: Example: .. code-block:: python + with self.session() as session: - session.query(User).filter(User.name == 'Edward Snowden').first() + session.query(User).filter(User.name == "Edward Snowden").first() Args: commit (Optional[bool], optional): Whether to commit the session. Defaults to True. diff --git a/dbgpt/storage/metadata/db_factory.py b/dbgpt/storage/metadata/db_factory.py index 14cf0339a..c288a149d 100644 --- a/dbgpt/storage/metadata/db_factory.py +++ b/dbgpt/storage/metadata/db_factory.py @@ -1,4 +1,4 @@ -from dbgpt.component import SystemApp, BaseComponent, ComponentType +from dbgpt.component import BaseComponent, ComponentType, SystemApp from .db_manager import DatabaseManager diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index f222dd38a..67ec35954 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -1,29 +1,21 @@ from __future__ import annotations -from contextlib import contextmanager -from typing import ( - TypeVar, - Generic, - Union, - Dict, - Optional, - Type, - ClassVar, -) import logging -from sqlalchemy import create_engine, URL, Engine -from sqlalchemy import orm, inspect, MetaData +from contextlib import contextmanager +from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union + +from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm from sqlalchemy.orm import ( - scoped_session, - sessionmaker, + DeclarativeMeta, Session, declarative_base, - DeclarativeMeta, + scoped_session, + sessionmaker, ) - from sqlalchemy.pool import QueuePool -from dbgpt.util.string_utils import _to_str + from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.string_utils import _to_str logger = logging.getLogger(__name__) T = TypeVar("T", bound="BaseModel") @@ -314,7 +306,7 @@ def init_default_db( Examples: >>> db.init_default_db(sqlite_path) >>> with db.session() as session: - >>> session.query(...) + ... session.query(...) Args: sqlite_path (str): The sqlite path. diff --git a/dbgpt/storage/metadata/db_storage.py b/dbgpt/storage/metadata/db_storage.py index daa87ceba..6b6e9716a 100644 --- a/dbgpt/storage/metadata/db_storage.py +++ b/dbgpt/storage/metadata/db_storage.py @@ -1,25 +1,28 @@ from contextlib import contextmanager +from typing import Dict, List, Optional, Type, Union + +from sqlalchemy import URL +from sqlalchemy.orm import DeclarativeMeta, Session -from typing import Type, List, Optional, Union, Dict from dbgpt.core import Serializer from dbgpt.core.interface.storage import ( - StorageInterface, QuerySpec, ResourceIdentifier, + StorageInterface, StorageItemAdapter, T, ) -from sqlalchemy import URL -from sqlalchemy.orm import Session, DeclarativeMeta -from .db_manager import BaseModel, DatabaseManager, BaseQuery +from .db_manager import BaseModel, BaseQuery, DatabaseManager def _copy_public_properties(src: BaseModel, dest: BaseModel): """Simple copy public properties from src to dest""" for column in src.__table__.columns: if column.name != "id": - setattr(dest, column.name, getattr(src, column.name)) + value = getattr(src, column.name) + if value is not None: + setattr(dest, column.name, value) class SQLAlchemyStorage(StorageInterface[T, BaseModel]): @@ -51,8 +54,16 @@ def save(self, data: T) -> None: def update(self, data: T) -> None: with self.session() as session: - model_instance = self.adapter.to_storage_format(data) - session.merge(model_instance) + query = self.adapter.get_query_for_identifier( + self._model_class, data.identifier, session=session + ) + exist_model_instance = query.with_session(session).first() + if exist_model_instance: + _copy_public_properties( + self.adapter.to_storage_format(data), exist_model_instance + ) + session.merge(exist_model_instance) + return def save_or_update(self, data: T) -> None: with self.session() as session: diff --git a/dbgpt/storage/metadata/tests/test_base_dao.py b/dbgpt/storage/metadata/tests/test_base_dao.py index 9728fd482..cfa86ba3c 100644 --- a/dbgpt/storage/metadata/tests/test_base_dao.py +++ b/dbgpt/storage/metadata/tests/test_base_dao.py @@ -1,12 +1,15 @@ -from typing import Type, Optional, Union, Dict, Any +from typing import Any, Dict, Optional, Type, Union + import pytest from sqlalchemy import Column, Integer, String -from dbgpt._private.pydantic import BaseModel as PydanticBaseModel, Field + +from dbgpt._private.pydantic import BaseModel as PydanticBaseModel +from dbgpt._private.pydantic import Field from dbgpt.storage.metadata.db_manager import ( + BaseModel, DatabaseManager, PaginationResult, create_model, - BaseModel, ) from .._base_dao import BaseDao diff --git a/dbgpt/storage/metadata/tests/test_db_manager.py b/dbgpt/storage/metadata/tests/test_db_manager.py index aa67787fa..551f0d9d7 100644 --- a/dbgpt/storage/metadata/tests/test_db_manager.py +++ b/dbgpt/storage/metadata/tests/test_db_manager.py @@ -1,14 +1,17 @@ from __future__ import annotations -import pytest + import tempfile from typing import Type + +import pytest +from sqlalchemy import Column, Integer, String + from dbgpt.storage.metadata.db_manager import ( + BaseModel, DatabaseManager, PaginationResult, create_model, - BaseModel, ) -from sqlalchemy import Column, Integer, String @pytest.fixture diff --git a/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py b/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py index fcae83215..df4688c71 100644 --- a/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py +++ b/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py @@ -1,21 +1,19 @@ from typing import Dict, Type -from sqlalchemy.orm import declarative_base, Session -from sqlalchemy import Column, Integer, String import pytest +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import Session, declarative_base from dbgpt.core.interface.storage import ( - StorageItem, + QuerySpec, ResourceIdentifier, + StorageItem, StorageItemAdapter, - QuerySpec, ) -from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage - from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier +from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage from dbgpt.util.serialization.json_serialization import JsonSerializer - Base = declarative_base() diff --git a/dbgpt/storage/schema.py b/dbgpt/storage/schema.py index ab6dc7b40..d5874ba8e 100644 --- a/dbgpt/storage/schema.py +++ b/dbgpt/storage/schema.py @@ -1,5 +1,5 @@ -from enum import Enum import os +from enum import Enum class DbInfo: diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index 2f944fc71..2d2918f71 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -1,8 +1,8 @@ -from abc import ABC, abstractmethod import math -from typing import Optional, Callable, List, Any +from abc import ABC, abstractmethod +from typing import Any, Callable, List, Optional -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field from dbgpt.rag.chunk import Chunk diff --git a/dbgpt/storage/vector_store/chroma_store.py b/dbgpt/storage/vector_store/chroma_store.py index 9149e3665..cd5b76e71 100644 --- a/dbgpt/storage/vector_store/chroma_store.py +++ b/dbgpt/storage/vector_store/chroma_store.py @@ -1,14 +1,14 @@ -import os import logging +import os from typing import Any, List -from chromadb.config import Settings from chromadb import PersistentClient +from chromadb.config import Settings from pydantic import Field +from dbgpt.configs.model_config import PILOT_PATH from dbgpt.rag.chunk import Chunk from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig -from dbgpt.configs.model_config import PILOT_PATH logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index a256e6a87..27ee1b871 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -1,5 +1,5 @@ import os -from typing import Optional, List, Callable, Any +from typing import Any, Callable, List, Optional from dbgpt.rag.chunk import Chunk from dbgpt.storage import vector_store diff --git a/dbgpt/storage/vector_store/pgvector_store.py b/dbgpt/storage/vector_store/pgvector_store.py index 9917ca54b..2563a2893 100644 --- a/dbgpt/storage/vector_store/pgvector_store.py +++ b/dbgpt/storage/vector_store/pgvector_store.py @@ -1,11 +1,11 @@ -from typing import Any, List import logging +from typing import Any, List from pydantic import Field +from dbgpt._private.config import Config from dbgpt.rag.chunk import Chunk from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig -from dbgpt._private.config import Config logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/vector_store/weaviate_store.py b/dbgpt/storage/vector_store/weaviate_store.py index 92ded9980..77744489e 100644 --- a/dbgpt/storage/vector_store/weaviate_store.py +++ b/dbgpt/storage/vector_store/weaviate_store.py @@ -1,5 +1,5 @@ -import os import logging +import os from typing import List from langchain.schema import Document diff --git a/dbgpt/util/__init__.py b/dbgpt/util/__init__.py index 74945e5ae..d642b7dd4 100644 --- a/dbgpt/util/__init__.py +++ b/dbgpt/util/__init__.py @@ -1,10 +1,7 @@ -from .utils import ( - get_gpu_memory, - get_or_create_event_loop, -) -from .pagination_utils import PaginationResult -from .parameter_utils import BaseParameters, ParameterDescription, EnvArgumentParser from .config_utils import AppConfig +from .pagination_utils import PaginationResult +from .parameter_utils import BaseParameters, EnvArgumentParser, ParameterDescription +from .utils import get_gpu_memory, get_or_create_event_loop __ALL__ = [ "get_gpu_memory", diff --git a/dbgpt/util/_db_migration_utils.py b/dbgpt/util/_db_migration_utils.py index 13734960d..f8171f28a 100644 --- a/dbgpt/util/_db_migration_utils.py +++ b/dbgpt/util/_db_migration_utils.py @@ -1,12 +1,12 @@ -from typing import Optional -import os import logging -from sqlalchemy import Engine, text -from sqlalchemy.orm import Session, DeclarativeMeta +import os +from typing import Optional + from alembic import command -from alembic.util.exc import CommandError from alembic.config import Config as AlembicConfig - +from alembic.util.exc import CommandError +from sqlalchemy import Engine, text +from sqlalchemy.orm import DeclarativeMeta, Session logger = logging.getLogger(__name__) @@ -67,8 +67,8 @@ def create_migration_script( Returns: The path of the generated migration script. """ - from alembic.script import ScriptDirectory from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory # Check if the database is up-to-date script_dir = ScriptDirectory.from_config(alembic_cfg) @@ -244,8 +244,8 @@ def _check_database_migration_status(alembic_cfg: AlembicConfig, engine: Engine) Raises: Exception: If the database is not at the latest revision. """ - from alembic.script import ScriptDirectory from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory script = ScriptDirectory.from_config(alembic_cfg) diff --git a/dbgpt/util/api_utils.py b/dbgpt/util/api_utils.py index 1fd3499d2..175cd55de 100644 --- a/dbgpt/util/api_utils.py +++ b/dbgpt/util/api_utils.py @@ -1,7 +1,7 @@ -from inspect import signature import logging -from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple -from dataclasses import is_dataclass, asdict +from dataclasses import asdict, is_dataclass +from inspect import signature +from typing import List, Optional, Tuple, Type, TypeVar, Union, get_type_hints T = TypeVar("T") diff --git a/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py b/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py index 9c07f7a0a..fc0a847a1 100644 --- a/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py +++ b/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py @@ -4,9 +4,10 @@ """ import gc -from typing import Iterable, Dict +from typing import Dict, Iterable import torch +from fastchat.utils import get_context_length, is_partial_stop, is_sentence_complete from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -16,9 +17,6 @@ ) -from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length - - def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: diff --git a/dbgpt/util/benchmarks/llm/llm_benchmarks.py b/dbgpt/util/benchmarks/llm/llm_benchmarks.py index 81e839f5a..a5ded5eff 100644 --- a/dbgpt/util/benchmarks/llm/llm_benchmarks.py +++ b/dbgpt/util/benchmarks/llm/llm_benchmarks.py @@ -1,25 +1,23 @@ -from typing import Dict, List +import argparse import asyncio +import csv +import logging import os import sys import time -import csv -import argparse -import logging import traceback -from dbgpt.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG from datetime import datetime +from typing import Dict, List +from dbgpt.configs.model_config import LLM_MODEL_CONFIG, ROOT_PATH +from dbgpt.core import ModelInferenceMetrics, ModelOutput +from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.model.cluster.worker.manager import ( - run_worker_manager, - initialize_worker_manager_in_client, WorkerManager, + initialize_worker_manager_in_client, + run_worker_manager, ) -from dbgpt.core import ModelOutput, ModelInferenceMetrics -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType - - model_name = "vicuna-7b-v1.5" model_path = LLM_MODEL_CONFIG[model_name] # or vllm diff --git a/dbgpt/util/chat_util.py b/dbgpt/util/chat_util.py index 1c76a43f5..0755ef706 100644 --- a/dbgpt/util/chat_util.py +++ b/dbgpt/util/chat_util.py @@ -1,5 +1,5 @@ import asyncio -from typing import Coroutine, List, Any +from typing import Any, Coroutine, List from dbgpt.app.scene import BaseChat, ChatFactory diff --git a/dbgpt/util/command_utils.py b/dbgpt/util/command_utils.py index e7aa241bc..de554119d 100644 --- a/dbgpt/util/command_utils.py +++ b/dbgpt/util/command_utils.py @@ -1,10 +1,11 @@ -import sys import os -import subprocess -from typing import List, Dict -import psutil import platform +import subprocess +import sys from functools import lru_cache +from typing import Dict, List + +import psutil def _get_abspath_of_current_command(command_path: str): diff --git a/dbgpt/util/custom_data_structure.py b/dbgpt/util/custom_data_structure.py index a8502143a..d86736c81 100644 --- a/dbgpt/util/custom_data_structure.py +++ b/dbgpt/util/custom_data_structure.py @@ -1,5 +1,4 @@ -from collections import OrderedDict -from collections import deque +from collections import OrderedDict, deque class FixedSizeDict(OrderedDict): diff --git a/dbgpt/util/executor_utils.py b/dbgpt/util/executor_utils.py index 1e9aa4b3e..d530967b4 100644 --- a/dbgpt/util/executor_utils.py +++ b/dbgpt/util/executor_utils.py @@ -1,9 +1,9 @@ -from typing import Callable, Awaitable, Any import asyncio import contextvars from abc import ABC, abstractmethod from concurrent.futures import Executor, ThreadPoolExecutor from functools import partial +from typing import Any, Awaitable, Callable from dbgpt.component import BaseComponent, ComponentType, SystemApp diff --git a/dbgpt/util/function_utils.py b/dbgpt/util/function_utils.py index 5bfd578ea..ccce14ddd 100644 --- a/dbgpt/util/function_utils.py +++ b/dbgpt/util/function_utils.py @@ -1,7 +1,7 @@ -from typing import Any, get_type_hints, get_origin, get_args -from functools import wraps -import inspect import asyncio +import inspect +from functools import wraps +from typing import Any, get_args, get_origin, get_type_hints def _is_instance_of_generic_type(obj, generic_type): @@ -65,10 +65,12 @@ def rearrange_args_by_type(func): from dbgpt.util.function_utils import rearrange_args_by_type + @rearrange_args_by_type def sync_regular_function(a: int, b: str, c: float): return a, b, c + assert instance.sync_class_method(1, "b", 3.0) == (1, "b", 3.0) assert instance.sync_class_method("b", 3.0, 1) == (1, "b", 3.0) diff --git a/dbgpt/util/global_helper.py b/dbgpt/util/global_helper.py index 617995373..c3972bba2 100644 --- a/dbgpt/util/global_helper.py +++ b/dbgpt/util/global_helper.py @@ -196,7 +196,7 @@ def truncate_text(text: str, max_length: int) -> str: def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: """Iterate over an iterable in batches. - >>> list(iter_batch([1,2,3,4,5], 3)) + >>> list(iter_batch([1, 2, 3, 4, 5], 3)) [[1, 2, 3], [4, 5]] """ source_iter = iter(iterable) diff --git a/dbgpt/util/json_utils.py b/dbgpt/util/json_utils.py index e861e7284..9e12e880a 100644 --- a/dbgpt/util/json_utils.py +++ b/dbgpt/util/json_utils.py @@ -1,10 +1,10 @@ """Utilities for the json_fixes package.""" import json -from datetime import date, datetime -from dataclasses import dataclass, asdict, is_dataclass +import logging import os.path import re -import logging +from dataclasses import asdict, dataclass, is_dataclass +from datetime import date, datetime from jsonschema import Draft7Validator diff --git a/dbgpt/util/memory_utils.py b/dbgpt/util/memory_utils.py index cb0427c08..07d69339d 100644 --- a/dbgpt/util/memory_utils.py +++ b/dbgpt/util/memory_utils.py @@ -1,4 +1,5 @@ from typing import Any + from pympler import asizeof diff --git a/dbgpt/util/model_utils.py b/dbgpt/util/model_utils.py index 5df73dff0..0bf77bb94 100644 --- a/dbgpt/util/model_utils.py +++ b/dbgpt/util/model_utils.py @@ -1,6 +1,6 @@ -from typing import List, Tuple -from dataclasses import dataclass import logging +from dataclasses import dataclass +from typing import List, Tuple logger = logging.getLogger(__name__) @@ -17,9 +17,10 @@ def _clear_model_cache(device="cuda"): def _clear_torch_cache(device="cuda"): - import torch import gc + import torch + gc.collect() if device != "cpu": if torch.has_mps: diff --git a/dbgpt/util/module_utils.py b/dbgpt/util/module_utils.py index c2d857440..d26a387d2 100644 --- a/dbgpt/util/module_utils.py +++ b/dbgpt/util/module_utils.py @@ -1,5 +1,5 @@ -from typing import Type from importlib import import_module +from typing import Type def import_from_string(module_path: str, ignore_import_error: bool = False): diff --git a/dbgpt/util/net_utils.py b/dbgpt/util/net_utils.py index 8fc803e6f..fc9fb3f86 100644 --- a/dbgpt/util/net_utils.py +++ b/dbgpt/util/net_utils.py @@ -1,5 +1,5 @@ -import socket import errno +import socket def _get_ip_address(address: str = "10.254.254.254:1") -> str: diff --git a/dbgpt/util/openai_utils.py b/dbgpt/util/openai_utils.py index 5e1673f08..0e66c727a 100644 --- a/dbgpt/util/openai_utils.py +++ b/dbgpt/util/openai_utils.py @@ -1,8 +1,9 @@ -from typing import Dict, Any, Awaitable, Callable, Optional, Iterator -import httpx import asyncio -import logging import json +import logging +from typing import Any, Awaitable, Callable, Dict, Iterator, Optional + +import httpx logger = logging.getLogger(__name__) MessageCaller = Callable[[str], Awaitable[None]] diff --git a/dbgpt/util/pagination_utils.py b/dbgpt/util/pagination_utils.py index cbe21dda0..4b9288cb8 100644 --- a/dbgpt/util/pagination_utils.py +++ b/dbgpt/util/pagination_utils.py @@ -1,4 +1,5 @@ -from typing import TypeVar, Generic, List +from typing import Generic, List, TypeVar + from dbgpt._private.pydantic import BaseModel, Field T = TypeVar("T") diff --git a/dbgpt/util/parameter_utils.py b/dbgpt/util/parameter_utils.py index 8baddffa3..202c5f8a6 100644 --- a/dbgpt/util/parameter_utils.py +++ b/dbgpt/util/parameter_utils.py @@ -1,8 +1,8 @@ import argparse import os -from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass -from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING from collections import OrderedDict +from dataclasses import MISSING, asdict, dataclass, field, fields, is_dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union if TYPE_CHECKING: from dbgpt._private.pydantic import BaseModel diff --git a/dbgpt/util/prompt_util.py b/dbgpt/util/prompt_util.py index e0c0a3846..bad972a7a 100644 --- a/dbgpt/util/prompt_util.py +++ b/dbgpt/util/prompt_util.py @@ -12,12 +12,11 @@ from string import Formatter from typing import Callable, List, Optional, Sequence, Set -from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel - -from dbgpt.util.global_helper import globals_helper -from dbgpt.core.interface.prompt import get_template_vars from dbgpt._private.llm_metadata import LLMMetadata +from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr +from dbgpt.core.interface.prompt import get_template_vars from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter +from dbgpt.util.global_helper import globals_helper DEFAULT_PADDING = 5 DEFAULT_CHUNK_OVERLAP_RATIO = 0.1 diff --git a/dbgpt/util/serialization/json_serialization.py b/dbgpt/util/serialization/json_serialization.py index 58811cae2..b66530b76 100644 --- a/dbgpt/util/serialization/json_serialization.py +++ b/dbgpt/util/serialization/json_serialization.py @@ -1,6 +1,6 @@ +import json from abc import ABC, abstractmethod from typing import Dict, Type -import json from dbgpt.core.interface.serialization import Serializable, Serializer diff --git a/dbgpt/util/speech/eleven_labs.py b/dbgpt/util/speech/eleven_labs.py index 6bafd5597..5173b63d9 100644 --- a/dbgpt/util/speech/eleven_labs.py +++ b/dbgpt/util/speech/eleven_labs.py @@ -1,6 +1,7 @@ """ElevenLabs speech module""" -import os import logging +import os + import requests from dbgpt._private.config import Config diff --git a/dbgpt/util/splitter_utils.py b/dbgpt/util/splitter_utils.py index 9c57f2111..8905aa3b4 100644 --- a/dbgpt/util/splitter_utils.py +++ b/dbgpt/util/splitter_utils.py @@ -25,7 +25,6 @@ def split_by_sentence_tokenizer() -> Callable[[str], List[str]]: import os import nltk - from llama_index.utils import get_cache_dir cache_dir = get_cache_dir() diff --git a/dbgpt/util/system_utils.py b/dbgpt/util/system_utils.py index 50b54be0d..463fe354f 100644 --- a/dbgpt/util/system_utils.py +++ b/dbgpt/util/system_utils.py @@ -1,11 +1,11 @@ -from dataclasses import dataclass, asdict -from enum import Enum -from typing import Tuple, Dict import os import platform -import subprocess import re +import subprocess +from dataclasses import asdict, dataclass +from enum import Enum from functools import cache +from typing import Dict, Tuple @dataclass diff --git a/dbgpt/util/tests/test_function_utils.py b/dbgpt/util/tests/test_function_utils.py index f245b8c14..aad1dd600 100644 --- a/dbgpt/util/tests/test_function_utils.py +++ b/dbgpt/util/tests/test_function_utils.py @@ -1,6 +1,7 @@ -from typing import List, Dict, Any +from typing import Any, Dict, List import pytest + from dbgpt.util.function_utils import rearrange_args_by_type diff --git a/dbgpt/util/tests/test_parameter_utils.py b/dbgpt/util/tests/test_parameter_utils.py index 79b672ef6..4749076ed 100644 --- a/dbgpt/util/tests/test_parameter_utils.py +++ b/dbgpt/util/tests/test_parameter_utils.py @@ -1,5 +1,7 @@ import argparse + import pytest + from dbgpt.util.parameter_utils import _extract_parameter_details diff --git a/dbgpt/util/tracer/__init__.py b/dbgpt/util/tracer/__init__.py index fd6c7b4de..25cbe372e 100644 --- a/dbgpt/util/tracer/__init__.py +++ b/dbgpt/util/tracer/__init__.py @@ -1,23 +1,23 @@ from dbgpt.util.tracer.base import ( - SpanType, Span, - SpanTypeRunName, - Tracer, SpanStorage, SpanStorageType, + SpanType, + SpanTypeRunName, + Tracer, TracerContext, ) from dbgpt.util.tracer.span_storage import ( - MemorySpanStorage, FileSpanStorage, + MemorySpanStorage, SpanStorageContainer, ) from dbgpt.util.tracer.tracer_impl import ( - root_tracer, - trace, - initialize_tracer, DefaultTracer, TracerManager, + initialize_tracer, + root_tracer, + trace, ) __all__ = [ diff --git a/dbgpt/util/tracer/base.py b/dbgpt/util/tracer/base.py index dcf988bdf..77f8049f1 100644 --- a/dbgpt/util/tracer/base.py +++ b/dbgpt/util/tracer/base.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import Dict, Callable, Optional, List -from dataclasses import dataclass -from abc import ABC, abstractmethod -from enum import Enum import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import datetime +from enum import Enum +from typing import Callable, Dict, List, Optional -from dbgpt.component import BaseComponent, SystemApp, ComponentType +from dbgpt.component import BaseComponent, ComponentType, SystemApp class SpanType(str, Enum): diff --git a/dbgpt/util/tracer/span_storage.py b/dbgpt/util/tracer/span_storage.py index 4872336dc..296480d9d 100644 --- a/dbgpt/util/tracer/span_storage.py +++ b/dbgpt/util/tracer/span_storage.py @@ -1,12 +1,12 @@ -import os -import json -import time import datetime -import threading -import queue +import json import logging -from typing import Optional, List +import os +import queue +import threading +import time from concurrent.futures import Executor, ThreadPoolExecutor +from typing import List, Optional from dbgpt.component import SystemApp from dbgpt.util.tracer.base import Span, SpanStorage diff --git a/dbgpt/util/tracer/tests/test_base.py b/dbgpt/util/tracer/tests/test_base.py index 2a0449afa..933f55b97 100644 --- a/dbgpt/util/tracer/tests/test_base.py +++ b/dbgpt/util/tracer/tests/test_base.py @@ -1,8 +1,7 @@ from typing import Dict -from dbgpt.component import SystemApp - -from dbgpt.util.tracer import Span, SpanType, SpanStorage, Tracer +from dbgpt.component import SystemApp +from dbgpt.util.tracer import Span, SpanStorage, SpanType, Tracer # Mock implementations diff --git a/dbgpt/util/tracer/tests/test_span_storage.py b/dbgpt/util/tracer/tests/test_span_storage.py index 6bf9a48ee..984c8dd1c 100644 --- a/dbgpt/util/tracer/tests/test_span_storage.py +++ b/dbgpt/util/tracer/tests/test_span_storage.py @@ -1,18 +1,19 @@ -import os -import pytest import asyncio import json +import os import tempfile import time -from unittest.mock import patch from datetime import datetime, timedelta +from unittest.mock import patch + +import pytest from dbgpt.util.tracer import ( - SpanStorage, FileSpanStorage, Span, - SpanType, + SpanStorage, SpanStorageContainer, + SpanType, ) diff --git a/dbgpt/util/tracer/tests/test_tracer_impl.py b/dbgpt/util/tracer/tests/test_tracer_impl.py index 9e09ea623..2ea077402 100644 --- a/dbgpt/util/tracer/tests/test_tracer_impl.py +++ b/dbgpt/util/tracer/tests/test_tracer_impl.py @@ -1,14 +1,15 @@ import pytest + +from dbgpt.component import SystemApp from dbgpt.util.tracer import ( + DefaultTracer, + MemorySpanStorage, Span, - SpanStorageType, SpanStorage, - DefaultTracer, - TracerManager, + SpanStorageType, Tracer, - MemorySpanStorage, + TracerManager, ) -from dbgpt.component import SystemApp @pytest.fixture diff --git a/dbgpt/util/tracer/tracer_cli.py b/dbgpt/util/tracer/tracer_cli.py index f902910e0..40fdf87e3 100644 --- a/dbgpt/util/tracer/tracer_cli.py +++ b/dbgpt/util/tracer/tracer_cli.py @@ -1,10 +1,12 @@ -import os -import click -import logging import glob import json +import logging +import os from datetime import datetime -from typing import Iterable, Dict, Callable +from typing import Callable, Dict, Iterable + +import click + from dbgpt.configs.model_config import LOGDIR from dbgpt.util.tracer import SpanType, SpanTypeRunName diff --git a/dbgpt/util/tracer/tracer_impl.py b/dbgpt/util/tracer/tracer_impl.py index ac9485e6c..cd7f87737 100644 --- a/dbgpt/util/tracer/tracer_impl.py +++ b/dbgpt/util/tracer/tracer_impl.py @@ -1,22 +1,21 @@ -from typing import Dict, Optional -from contextvars import ContextVar -from functools import wraps import asyncio import inspect import logging +from contextvars import ContextVar +from functools import wraps +from typing import Dict, Optional - -from dbgpt.component import SystemApp, ComponentType +from dbgpt.component import ComponentType, SystemApp +from dbgpt.util.module_utils import import_from_checked_string from dbgpt.util.tracer.base import ( - SpanType, Span, - Tracer, SpanStorage, SpanStorageType, + SpanType, + Tracer, TracerContext, ) from dbgpt.util.tracer.span_storage import MemorySpanStorage -from dbgpt.util.module_utils import import_from_checked_string logger = logging.getLogger(__name__) diff --git a/dbgpt/util/tracer/tracer_middleware.py b/dbgpt/util/tracer/tracer_middleware.py index 55920278c..6e0d35222 100644 --- a/dbgpt/util/tracer/tracer_middleware.py +++ b/dbgpt/util/tracer/tracer_middleware.py @@ -4,8 +4,8 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.types import ASGIApp -from dbgpt.util.tracer import TracerContext, Tracer +from dbgpt.util.tracer import Tracer, TracerContext _DEFAULT_EXCLUDE_PATHS = ["/api/controller/heartbeat"] diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index 4f7c2678d..bed637cba 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import asyncio import logging import logging.handlers -from typing import Any, List - import os -import asyncio +from typing import Any, List from dbgpt.configs.model_config import LOGDIR diff --git a/dbgpt/vis/__init__.py b/dbgpt/vis/__init__.py index 2cc1c3aa8..dab2559a8 100644 --- a/dbgpt/vis/__init__.py +++ b/dbgpt/vis/__init__.py @@ -1,7 +1,7 @@ +from .client import vis_client +from .tags.vis_agent_message import VisAgentMessages +from .tags.vis_agent_plans import VisAgentPlans from .tags.vis_chart import VisChart from .tags.vis_code import VisCode from .tags.vis_dashboard import VisDashboard -from .tags.vis_agent_plans import VisAgentPlans -from .tags.vis_agent_message import VisAgentMessages from .tags.vis_plugin import VisPlugin -from .client import vis_client diff --git a/dbgpt/vis/base.py b/dbgpt/vis/base.py index c5f87c4c6..a01b6859a 100644 --- a/dbgpt/vis/base.py +++ b/dbgpt/vis/base.py @@ -1,6 +1,7 @@ import json from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union + from dbgpt.util.json_utils import serialize diff --git a/dbgpt/vis/client.py b/dbgpt/vis/client.py index b106d6ca3..4f62fd452 100644 --- a/dbgpt/vis/client.py +++ b/dbgpt/vis/client.py @@ -1,11 +1,12 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union -from .tags.vis_code import VisCode + +from .base import Vis +from .tags.vis_agent_message import VisAgentMessages +from .tags.vis_agent_plans import VisAgentPlans from .tags.vis_chart import VisChart +from .tags.vis_code import VisCode from .tags.vis_dashboard import VisDashboard -from .tags.vis_agent_plans import VisAgentPlans -from .tags.vis_agent_message import VisAgentMessages from .tags.vis_plugin import VisPlugin -from .base import Vis class VisClient: diff --git a/dbgpt/vis/tags/vis_agent_message.py b/dbgpt/vis/tags/vis_agent_message.py index c33f6e618..f15e8c210 100644 --- a/dbgpt/vis/tags/vis_agent_message.py +++ b/dbgpt/vis/tags/vis_agent_message.py @@ -1,4 +1,5 @@ from typing import Optional + from ..base import Vis diff --git a/dbgpt/vis/tags/vis_agent_plans.py b/dbgpt/vis/tags/vis_agent_plans.py index e4a24955c..45f8b6b36 100644 --- a/dbgpt/vis/tags/vis_agent_plans.py +++ b/dbgpt/vis/tags/vis_agent_plans.py @@ -1,4 +1,5 @@ from typing import Optional + from ..base import Vis diff --git a/dbgpt/vis/tags/vis_chart.py b/dbgpt/vis/tags/vis_chart.py index 946504793..40237786e 100644 --- a/dbgpt/vis/tags/vis_chart.py +++ b/dbgpt/vis/tags/vis_chart.py @@ -1,6 +1,8 @@ +import json from typing import Optional + import yaml -import json + from ..base import Vis diff --git a/dbgpt/vis/tags/vis_code.py b/dbgpt/vis/tags/vis_code.py index bd2647ba0..672eb9cf8 100644 --- a/dbgpt/vis/tags/vis_code.py +++ b/dbgpt/vis/tags/vis_code.py @@ -1,4 +1,5 @@ from typing import Optional + from ..base import Vis diff --git a/dbgpt/vis/tags/vis_dashboard.py b/dbgpt/vis/tags/vis_dashboard.py index 44a95abe0..7f6dbc627 100644 --- a/dbgpt/vis/tags/vis_dashboard.py +++ b/dbgpt/vis/tags/vis_dashboard.py @@ -1,5 +1,6 @@ import json from typing import Optional + from ..base import Vis diff --git a/dbgpt/vis/tags/vis_plugin.py b/dbgpt/vis/tags/vis_plugin.py index 29117a746..eed6db9f3 100644 --- a/dbgpt/vis/tags/vis_plugin.py +++ b/dbgpt/vis/tags/vis_plugin.py @@ -1,4 +1,5 @@ from typing import Optional + from ..base import Vis diff --git a/examples/awel/data_analyst_assistant.py b/examples/awel/data_analyst_assistant.py index 7997a60e0..c80d06b8b 100644 --- a/examples/awel/data_analyst_assistant.py +++ b/examples/awel/data_analyst_assistant.py @@ -322,10 +322,12 @@ async def build_model_request( # Load and store chat history chat_history_load_task = ServePreChatHistoryLoadOperator() - last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5)) - # History transform task, here we keep last k round messages + keep_start_rounds = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_KEEP_START_ROUNDS", 0)) + keep_end_rounds = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_KEEP_END_ROUNDS", 5)) + # History transform task, here we keep `keep_start_rounds` round messages of history, + # and keep `keep_end_rounds` round messages of history. history_transform_task = BufferedConversationMapperOperator( - last_k_round=last_k_round + keep_start_rounds=keep_start_rounds, keep_end_rounds=keep_end_rounds ) history_prompt_build_task = HistoryDynamicPromptBuilderOperator( history_key="chat_history" diff --git a/examples/awel/simple_chat_history_example.py b/examples/awel/simple_chat_history_example.py index c1977117c..138cf73d9 100644 --- a/examples/awel/simple_chat_history_example.py +++ b/examples/awel/simple_chat_history_example.py @@ -137,7 +137,7 @@ async def build_model_request( composer_operator = ChatHistoryPromptComposerOperator( prompt_template=prompt, - last_k_round=5, + keep_end_rounds=5, storage=InMemoryStorage(), message_storage=InMemoryStorage(), ) diff --git a/examples/awel/simple_rag_example.py b/examples/awel/simple_rag_example.py deleted file mode 100644 index 6ea44889d..000000000 --- a/examples/awel/simple_rag_example.py +++ /dev/null @@ -1,81 +0,0 @@ -"""AWEL: Simple rag example - - DB-GPT will automatically load and execute the current file after startup. - - Example: - - .. code-block:: shell - - curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \ - -H "Content-Type: application/json" -d '{ - "conv_uid": "36f0e992-8825-11ee-8638-0242ac150003", - "model_name": "proxyllm", - "chat_mode": "chat_knowledge", - "user_input": "What is DB-GPT?", - "select_param": "default" - }' - -""" - -from dbgpt.app.openapi.api_view_model import ConversationVo -from dbgpt.app.scene import ChatScene -from dbgpt.app.scene.operator._experimental import ( - BaseChatOperator, - ChatContext, - ChatHistoryOperator, - ChatHistoryStorageOperator, - EmbeddingEngingOperator, - PromptManagerOperator, -) -from dbgpt.core.awel import DAG, HttpTrigger, MapOperator -from dbgpt.model.operator.model_operator import ModelOperator - - -class RequestParseOperator(MapOperator[ConversationVo, ChatContext]): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def map(self, input_value: ConversationVo) -> ChatContext: - return ChatContext( - current_user_input=input_value.user_input, - model_name=input_value.model_name, - chat_session_id=input_value.conv_uid, - select_param=input_value.select_param, - chat_scene=ChatScene.ChatKnowledge, - ) - - -with DAG("simple_rag_example") as dag: - trigger_task = HttpTrigger( - "/examples/simple_rag", methods="POST", request_body=ConversationVo - ) - req_parse_task = RequestParseOperator() - # TODO should register prompt template first - prompt_task = PromptManagerOperator() - history_storage_task = ChatHistoryStorageOperator() - history_task = ChatHistoryOperator() - embedding_task = EmbeddingEngingOperator() - chat_task = BaseChatOperator() - model_task = ModelOperator() - output_parser_task = MapOperator(lambda out: out.to_dict()["text"]) - - ( - trigger_task - >> req_parse_task - >> prompt_task - >> history_storage_task - >> history_task - >> embedding_task - >> chat_task - >> model_task - >> output_parser_task - ) - - -if __name__ == "__main__": - if dag.leaf_nodes[0].dev_mode: - from dbgpt.core.awel import setup_dev_environment - - setup_dev_environment([dag]) - else: - pass