diff --git a/Makefile b/Makefile
index 2e5f4bff4..c1029f3e9 100644
--- a/Makefile
+++ b/Makefile
@@ -38,21 +38,41 @@ fmt: setup ## Format Python code
# https://github.com/PyCQA/isort
# $(VENV_BIN)/isort .
$(VENV_BIN)/isort dbgpt/agent/
+ $(VENV_BIN)/isort dbgpt/app/
+ $(VENV_BIN)/isort dbgpt/cli/
+ $(VENV_BIN)/isort dbgpt/configs/
$(VENV_BIN)/isort dbgpt/core/
+ $(VENV_BIN)/isort dbgpt/datasource/
+ $(VENV_BIN)/isort dbgpt/model/
+ # TODO: $(VENV_BIN)/isort dbgpt/serve
$(VENV_BIN)/isort dbgpt/serve/core/
$(VENV_BIN)/isort dbgpt/serve/agent/
$(VENV_BIN)/isort dbgpt/serve/conversation/
$(VENV_BIN)/isort dbgpt/serve/utils/_template_files
+ $(VENV_BIN)/isort dbgpt/storage/
+ $(VENV_BIN)/isort dbgpt/train/
+ $(VENV_BIN)/isort dbgpt/util/
+ $(VENV_BIN)/isort dbgpt/vis/
+ $(VENV_BIN)/isort dbgpt/__init__.py
+ $(VENV_BIN)/isort dbgpt/component.py
$(VENV_BIN)/isort --extend-skip="examples/notebook" examples
# https://github.com/psf/black
$(VENV_BIN)/black --extend-exclude="examples/notebook" .
# TODO: Use blackdoc to format Python doctests.
# https://blackdoc.readthedocs.io/en/latest/
# $(VENV_BIN)/blackdoc .
- $(VENV_BIN)/blackdoc dbgpt/core/
$(VENV_BIN)/blackdoc dbgpt/agent/
+ $(VENV_BIN)/blackdoc dbgpt/app/
+ $(VENV_BIN)/blackdoc dbgpt/cli/
+ $(VENV_BIN)/blackdoc dbgpt/configs/
+ $(VENV_BIN)/blackdoc dbgpt/core/
+ $(VENV_BIN)/blackdoc dbgpt/datasource/
$(VENV_BIN)/blackdoc dbgpt/model/
$(VENV_BIN)/blackdoc dbgpt/serve/
+ # TODO: $(VENV_BIN)/blackdoc dbgpt/storage/
+ $(VENV_BIN)/blackdoc dbgpt/train/
+ $(VENV_BIN)/blackdoc dbgpt/util/
+ $(VENV_BIN)/blackdoc dbgpt/vis/
$(VENV_BIN)/blackdoc examples
# TODO: Type checking of Python code.
# https://github.com/python/mypy
diff --git a/dbgpt/__init__.py b/dbgpt/__init__.py
index a23c1ec9e..1660c5ddc 100644
--- a/dbgpt/__init__.py
+++ b/dbgpt/__init__.py
@@ -1,5 +1,4 @@
-from dbgpt.component import SystemApp, BaseComponent
-
+from dbgpt.component import BaseComponent, SystemApp
__ALL__ = ["SystemApp", "BaseComponent"]
diff --git a/dbgpt/app/_cli.py b/dbgpt/app/_cli.py
index f48e57111..b0d0cb294 100644
--- a/dbgpt/app/_cli.py
+++ b/dbgpt/app/_cli.py
@@ -1,11 +1,13 @@
+import functools
+import os
from typing import Optional
+
import click
-import os
-import functools
+
from dbgpt.app.base import WebServerParameters
from dbgpt.configs.model_config import LOGDIR
-from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.command_utils import _run_current_with_daemon, _stop_service
+from dbgpt.util.parameter_utils import EnvArgumentParser
@click.command(name="webserver")
@@ -117,8 +119,8 @@ def migrate(alembic_ini_path: str, script_location: str, message: str):
def upgrade(alembic_ini_path: str, script_location: str, sql_output: str):
"""Upgrade database to target version"""
from dbgpt.util._db_migration_utils import (
- upgrade_database,
generate_sql_for_upgrade,
+ upgrade_database,
)
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
@@ -210,9 +212,8 @@ def clean(
@add_migration_options
def list(alembic_ini_path: str, script_location: str):
"""List all versions in the migration history, marking the current one"""
- from alembic.script import ScriptDirectory
-
from alembic.runtime.migration import MigrationContext
+ from alembic.script import ScriptDirectory
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
@@ -269,12 +270,12 @@ def show(alembic_ini_path: str, script_location: str, revision: str):
def _get_migration_config(
alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None
):
- from dbgpt.storage.metadata.db_manager import db as db_manager
- from dbgpt.util._db_migration_utils import create_alembic_config
+ from dbgpt.app.base import _initialize_db
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS
- from dbgpt.app.base import _initialize_db
+ from dbgpt.storage.metadata.db_manager import db as db_manager
+ from dbgpt.util._db_migration_utils import create_alembic_config
# initialize db
default_meta_data_path = _initialize_db()
diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py
index fde70098e..54c6f5122 100644
--- a/dbgpt/app/base.py
+++ b/dbgpt/app/base.py
@@ -1,16 +1,15 @@
-import signal
+import logging
import os
-import threading
+import signal
import sys
-import logging
-from typing import Optional
+import threading
from dataclasses import dataclass, field
+from typing import Optional
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.util.parameter_utils import BaseParameters
-
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -35,7 +34,6 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
from dbgpt.agent.plugin.commands.command_mange import CommandRegistry
# logger.info(f"args: {args}")
-
# init config
cfg = Config()
cfg.SYSTEM_APP = system_app
@@ -100,13 +98,12 @@ def _migration_db_storage(param: "WebServerParameters"):
"""Migration the db storage."""
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS
-
from dbgpt.configs.model_config import PILOT_PATH
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
if not param.disable_alembic_upgrade:
- from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
from dbgpt.storage.metadata.db_manager import db
+ from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
# try to create all tables
try:
@@ -123,9 +120,11 @@ def _initialize_db(
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
+ from urllib.parse import quote
+ from urllib.parse import quote_plus as urlquote
+
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.storage.metadata.db_manager import initialize_db
- from urllib.parse import quote_plus as urlquote, quote
CFG = Config()
db_name = CFG.LOCAL_DB_NAME
@@ -170,8 +169,8 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
Raises:
Exception: Raise exception if database operation failed
"""
- from sqlalchemy import create_engine, DDL
- from sqlalchemy.exc import SQLAlchemyError, OperationalError
+ from sqlalchemy import DDL, create_engine
+ from sqlalchemy.exc import OperationalError, SQLAlchemyError
if not try_to_create_db:
logger.info(f"Skipping creation of database {db_name}")
diff --git a/dbgpt/app/chat_adapter.py b/dbgpt/app/chat_adapter.py
index c1cb192b1..d0c5e84dc 100644
--- a/dbgpt/app/chat_adapter.py
+++ b/dbgpt/app/chat_adapter.py
@@ -6,9 +6,10 @@
# -*- coding: utf-8 -*-
from functools import cache
-from typing import List, Dict, Tuple
-from dbgpt.model.conversation import Conversation, get_conv_template
+from typing import Dict, List, Tuple
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+from dbgpt.model.conversation import Conversation, get_conv_template
class BaseChatAdpter:
diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py
index 8025a4e09..0d6eabde8 100644
--- a/dbgpt/app/component_configs.py
+++ b/dbgpt/app/component_configs.py
@@ -2,12 +2,11 @@
import logging
-from dbgpt.component import SystemApp
from dbgpt._private.config import Config
+from dbgpt.app.base import WebServerParameters
+from dbgpt.component import SystemApp
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR
from dbgpt.util.executor_utils import DefaultExecutorFactory
-from dbgpt.app.base import WebServerParameters
-
logger = logging.getLogger(__name__)
@@ -21,9 +20,9 @@ def initialize_components(
embedding_model_path: str,
):
# Lazy import to avoid high time cost
- from dbgpt.model.cluster.controller.controller import controller
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.serve_initialization import register_serve_apps
+ from dbgpt.model.cluster.controller.controller import controller
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
@@ -47,6 +46,7 @@ def initialize_components(
)
_initialize_model_cache(system_app)
_initialize_awel(system_app, param)
+ _initialize_openapi(system_app)
# Register serve apps
register_serve_apps(system_app, CFG)
@@ -65,8 +65,8 @@ def _initialize_model_cache(system_app: SystemApp):
def _initialize_awel(system_app: SystemApp, param: WebServerParameters):
- from dbgpt.core.awel import initialize_awel
from dbgpt.configs.model_config import _DAG_DEFINITION_DIR
+ from dbgpt.core.awel import initialize_awel
# Add default dag definition dir
dag_dirs = [_DAG_DEFINITION_DIR]
@@ -75,3 +75,9 @@ def _initialize_awel(system_app: SystemApp, param: WebServerParameters):
dag_dirs = [x.strip() for x in dag_dirs]
initialize_awel(system_app, dag_dirs)
+
+
+def _initialize_openapi(system_app: SystemApp):
+ from dbgpt.app.openapi.api_v1.editor.service import EditorService
+
+ system_app.register(EditorService)
diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py
index ff7598c01..b4328d257 100644
--- a/dbgpt/app/dbgpt_server.py
+++ b/dbgpt/app/dbgpt_server.py
@@ -1,46 +1,45 @@
-import os
import argparse
+import os
import sys
from typing import List
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
-from dbgpt.configs.model_config import (
- LLM_MODEL_CONFIG,
- EMBEDDING_MODEL_CONFIG,
- LOGDIR,
- ROOT_PATH,
-)
-from dbgpt._private.config import Config
-from dbgpt.component import SystemApp
+from fastapi import FastAPI
+from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.openapi.docs import get_swagger_ui_html
+
+# fastapi import time cost about 0.05s
+from fastapi.staticfiles import StaticFiles
+from dbgpt._private.config import Config
from dbgpt.app.base import (
- server_init,
- _migration_db_storage,
WebServerParameters,
_create_model_start_listener,
+ _migration_db_storage,
+ server_init,
)
# initialize_components import time cost about 0.1s
from dbgpt.app.component_configs import initialize_components
-
-# fastapi import time cost about 0.05s
-from fastapi.staticfiles import StaticFiles
-from fastapi import FastAPI
-from fastapi.exceptions import RequestValidationError
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.openapi.docs import get_swagger_ui_html
-
from dbgpt.app.openapi.base import validation_exception_handler
+from dbgpt.component import SystemApp
+from dbgpt.configs.model_config import (
+ EMBEDDING_MODEL_CONFIG,
+ LLM_MODEL_CONFIG,
+ LOGDIR,
+ ROOT_PATH,
+)
+from dbgpt.util.parameter_utils import _get_dict_from_obj
+from dbgpt.util.system_utils import get_system_info
+from dbgpt.util.tracer import SpanType, SpanTypeRunName, initialize_tracer, root_tracer
from dbgpt.util.utils import (
- setup_logging,
_get_logging_level,
logging_str_to_uvicorn_level,
setup_http_service_logging,
+ setup_logging,
)
-from dbgpt.util.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
-from dbgpt.util.parameter_utils import _get_dict_from_obj
-from dbgpt.util.system_utils import get_system_info
static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static")
@@ -89,9 +88,7 @@ async def custom_swagger_ui_html():
def mount_routers(app: FastAPI):
"""Lazy import to avoid high time cost"""
from dbgpt.app.knowledge.api import router as knowledge_router
-
from dbgpt.app.llm_manage.api import router as llm_manage_api
-
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import (
router as api_editor_route_v1,
@@ -107,9 +104,7 @@ def mount_routers(app: FastAPI):
def mount_static_files(app: FastAPI):
- from dbgpt.agent.plugin.commands.built_in.disply_type import (
- static_message_img_path,
- )
+ from dbgpt.agent.plugin.commands.built_in.disply_type import static_message_img_path
os.makedirs(static_message_img_path, exist_ok=True)
app.mount(
diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py
index 1d925201a..f2243dfc1 100644
--- a/dbgpt/app/initialization/db_model_initialization.py
+++ b/dbgpt/app/initialization/db_model_initialization.py
@@ -1,14 +1,13 @@
"""Import all models to make sure they are registered with SQLAlchemy.
"""
-from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity
-from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
-
-from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
+from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity
+from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
+from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
diff --git a/dbgpt/app/initialization/embedding_component.py b/dbgpt/app/initialization/embedding_component.py
index a20bd5220..83552a681 100644
--- a/dbgpt/app/initialization/embedding_component.py
+++ b/dbgpt/app/initialization/embedding_component.py
@@ -1,12 +1,14 @@
from __future__ import annotations
import logging
-from typing import Any, Type, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Type
+
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
+
from dbgpt.app.base import WebServerParameters
logger = logging.getLogger(__name__)
diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py
index 35fe26844..7c5a3a465 100644
--- a/dbgpt/app/initialization/serve_initialization.py
+++ b/dbgpt/app/initialization/serve_initialization.py
@@ -1,5 +1,5 @@
-from dbgpt.component import SystemApp
from dbgpt._private.config import Config
+from dbgpt.component import SystemApp
def register_serve_apps(system_app: SystemApp, cfg: Config):
@@ -8,9 +8,9 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# ################################ Prompt Serve Register Begin ######################################
from dbgpt.serve.prompt.serve import (
- Serve as PromptServe,
SERVE_CONFIG_KEY_PREFIX as PROMPT_SERVE_CONFIG_KEY_PREFIX,
)
+ from dbgpt.serve.prompt.serve import Serve as PromptServe
# Replace old prompt serve
# Set config
@@ -21,8 +21,15 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# ################################ Prompt Serve Register End ########################################
# ################################ Conversation Serve Register Begin ######################################
+ from dbgpt.serve.conversation.serve import (
+ SERVE_CONFIG_KEY_PREFIX as CONVERSATION_SERVE_CONFIG_KEY_PREFIX,
+ )
from dbgpt.serve.conversation.serve import Serve as ConversationServe
+ # Set config
+ system_app.config.set(
+ f"{CONVERSATION_SERVE_CONFIG_KEY_PREFIX}default_model", cfg.LLM_MODEL
+ )
# Register serve app
- system_app.register(ConversationServe)
+ system_app.register(ConversationServe, api_prefix="/api/v1/chat/dialogue")
# ################################ Conversation Serve Register End ########################################
diff --git a/dbgpt/app/knowledge/_cli/knowledge_cli.py b/dbgpt/app/knowledge/_cli/knowledge_cli.py
index dd7c74b2c..38f24ede2 100644
--- a/dbgpt/app/knowledge/_cli/knowledge_cli.py
+++ b/dbgpt/app/knowledge/_cli/knowledge_cli.py
@@ -1,7 +1,8 @@
-import click
+import functools
import logging
import os
-import functools
+
+import click
from dbgpt.configs.model_config import DATASETS_DIR
diff --git a/dbgpt/app/knowledge/_cli/knowledge_client.py b/dbgpt/app/knowledge/_cli/knowledge_client.py
index 5ab24d368..7f2b5ea1b 100644
--- a/dbgpt/app/knowledge/_cli/knowledge_client.py
+++ b/dbgpt/app/knowledge/_cli/knowledge_client.py
@@ -1,22 +1,20 @@
-import os
-import requests
import json
import logging
-
-from urllib.parse import urljoin
+import os
from concurrent.futures import ThreadPoolExecutor, as_completed
+from urllib.parse import urljoin
+
+import requests
-from dbgpt.app.openapi.api_view_model import Result
from dbgpt.app.knowledge.request.request import (
- KnowledgeQueryRequest,
- KnowledgeDocumentRequest,
ChunkQueryRequest,
DocumentQueryRequest,
+ DocumentSyncRequest,
+ KnowledgeDocumentRequest,
+ KnowledgeQueryRequest,
+ KnowledgeSpaceRequest,
)
-
-from dbgpt.app.knowledge.request.request import DocumentSyncRequest
-
-from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
+from dbgpt.app.openapi.api_view_model import Result
from dbgpt.rag.knowledge.base import KnowledgeType
HTTP_HEADERS = {"Content-Type": "application/json"}
diff --git a/dbgpt/app/knowledge/api.py b/dbgpt/app/knowledge/api.py
index 8d14e0e51..ee8f1f767 100644
--- a/dbgpt/app/knowledge/api.py
+++ b/dbgpt/app/knowledge/api.py
@@ -1,42 +1,39 @@
+import logging
import os
import shutil
import tempfile
-import logging
from typing import List
-from fastapi import APIRouter, File, UploadFile, Form
+from fastapi import APIRouter, File, Form, UploadFile
from dbgpt._private.config import Config
-from dbgpt.configs.model_config import (
- EMBEDDING_MODEL_CONFIG,
- KNOWLEDGE_UPLOAD_ROOT_PATH,
-)
-from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
-
-from dbgpt.app.openapi.api_view_model import Result
-from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
-
-from dbgpt.app.knowledge.service import KnowledgeService
-from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.app.knowledge.request.request import (
- KnowledgeQueryRequest,
- KnowledgeQueryResponse,
- KnowledgeDocumentRequest,
- DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
- SpaceArgumentRequest,
- EntityExtractRequest,
DocumentSummaryRequest,
+ DocumentSyncRequest,
+ EntityExtractRequest,
+ KnowledgeDocumentRequest,
+ KnowledgeQueryRequest,
+ KnowledgeQueryResponse,
+ KnowledgeSpaceRequest,
KnowledgeSyncRequest,
+ SpaceArgumentRequest,
)
-
-from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
+from dbgpt.app.knowledge.service import KnowledgeService
+from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
+from dbgpt.app.openapi.api_view_model import Result
+from dbgpt.configs.model_config import (
+ EMBEDDING_MODEL_CONFIG,
+ KNOWLEDGE_UPLOAD_ROOT_PATH,
+)
+from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy
+from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
-from dbgpt.util.tracer import root_tracer, SpanType
+from dbgpt.util.tracer import SpanType, root_tracer
logger = logging.getLogger(__name__)
@@ -312,9 +309,10 @@ async def document_summary(request: DocumentSummaryRequest):
async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")
try:
+ import uuid
+
from dbgpt.app.scene import ChatScene
from dbgpt.util.chat_util import llm_chat_response_nostream
- import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
diff --git a/dbgpt/app/knowledge/chunk_db.py b/dbgpt/app/knowledge/chunk_db.py
index e8b6137ef..e6c3018e2 100644
--- a/dbgpt/app/knowledge/chunk_db.py
+++ b/dbgpt/app/knowledge/chunk_db.py
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import List
-from sqlalchemy import Column, String, DateTime, Integer, Text, func
+from sqlalchemy import Column, DateTime, Integer, String, Text, func
-from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
+from dbgpt.storage.metadata import BaseDao, Model
CFG = Config()
diff --git a/dbgpt/app/knowledge/document_db.py b/dbgpt/app/knowledge/document_db.py
index d101d8ee3..7e08d0733 100644
--- a/dbgpt/app/knowledge/document_db.py
+++ b/dbgpt/app/knowledge/document_db.py
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import List
-from sqlalchemy import Column, String, DateTime, Integer, Text, func
+from sqlalchemy import Column, DateTime, Integer, String, Text, func
-from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
+from dbgpt.storage.metadata import BaseDao, Model
CFG = Config()
diff --git a/dbgpt/app/knowledge/request/request.py b/dbgpt/app/knowledge/request/request.py
index e87ca14ad..0b0e2795f 100644
--- a/dbgpt/app/knowledge/request/request.py
+++ b/dbgpt/app/knowledge/request/request.py
@@ -1,8 +1,8 @@
from typing import List, Optional
-from dbgpt._private.pydantic import BaseModel
from fastapi import UploadFile
+from dbgpt._private.pydantic import BaseModel
from dbgpt.rag.chunk_manager import ChunkParameters
diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py
index ca348e07c..3db67c026 100644
--- a/dbgpt/app/knowledge/service.py
+++ b/dbgpt/app/knowledge/service.py
@@ -1,59 +1,48 @@
import json
import logging
from datetime import datetime
+from enum import Enum
from typing import List
-from dbgpt.model import DefaultLLMClient
-from dbgpt.rag.chunk import Chunk
-from dbgpt.rag.chunk_manager import ChunkParameters
-from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
-from dbgpt.rag.knowledge.base import KnowledgeType, ChunkStrategy
-from dbgpt.rag.knowledge.factory import KnowledgeFactory
-from dbgpt.rag.text_splitter.text_splitter import (
- RecursiveCharacterTextSplitter,
- SpacyTextSplitter,
-)
-from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
-from dbgpt.serve.rag.assembler.summary import SummaryAssembler
-from dbgpt.storage.vector_store.base import VectorStoreConfig
-from dbgpt.storage.vector_store.connector import VectorStoreConnector
-
from dbgpt._private.config import Config
-from dbgpt.configs.model_config import (
- EMBEDDING_MODEL_CONFIG,
-)
-from dbgpt.component import ComponentType
-from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
-
-from dbgpt.app.knowledge.chunk_db import (
- DocumentChunkEntity,
- DocumentChunkDao,
-)
+from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
-from dbgpt.app.knowledge.space_db import (
- KnowledgeSpaceDao,
- KnowledgeSpaceEntity,
-)
from dbgpt.app.knowledge.request.request import (
- KnowledgeSpaceRequest,
- KnowledgeDocumentRequest,
- DocumentQueryRequest,
ChunkQueryRequest,
- SpaceArgumentRequest,
- DocumentSyncRequest,
+ DocumentQueryRequest,
DocumentSummaryRequest,
+ DocumentSyncRequest,
+ KnowledgeDocumentRequest,
+ KnowledgeSpaceRequest,
KnowledgeSyncRequest,
+ SpaceArgumentRequest,
)
-from enum import Enum
-
from dbgpt.app.knowledge.request.response import (
ChunkQueryResponse,
DocumentQueryResponse,
SpaceQueryResponse,
)
+from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity
+from dbgpt.component import ComponentType
+from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
+from dbgpt.model import DefaultLLMClient
+from dbgpt.rag.chunk import Chunk
+from dbgpt.rag.chunk_manager import ChunkParameters
+from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
+from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
+from dbgpt.rag.knowledge.factory import KnowledgeFactory
+from dbgpt.rag.text_splitter.text_splitter import (
+ RecursiveCharacterTextSplitter,
+ SpacyTextSplitter,
+)
+from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
+from dbgpt.serve.rag.assembler.summary import SummaryAssembler
+from dbgpt.storage.vector_store.base import VectorStoreConfig
+from dbgpt.storage.vector_store.connector import VectorStoreConnector
+from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
@@ -572,8 +561,8 @@ def async_doc_embedding(self, assembler, chunk_docs, doc):
def _build_default_context(self):
from dbgpt.app.scene.chat_knowledge.v1.prompt import (
- PROMPT_SCENE_DEFINE,
_DEFAULT_TEMPLATE,
+ PROMPT_SCENE_DEFINE,
)
context_template = {
diff --git a/dbgpt/app/knowledge/space_db.py b/dbgpt/app/knowledge/space_db.py
index 4c958c613..bc283db1d 100644
--- a/dbgpt/app/knowledge/space_db.py
+++ b/dbgpt/app/knowledge/space_db.py
@@ -1,10 +1,10 @@
from datetime import datetime
-from sqlalchemy import Column, Integer, Text, String, DateTime
+from sqlalchemy import Column, DateTime, Integer, String, Text
-from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
+from dbgpt.storage.metadata import BaseDao, Model
CFG = Config()
diff --git a/dbgpt/app/llm_manage/api.py b/dbgpt/app/llm_manage/api.py
index 66496c927..fdc20474b 100644
--- a/dbgpt/app/llm_manage/api.py
+++ b/dbgpt/app/llm_manage/api.py
@@ -1,12 +1,10 @@
from fastapi import APIRouter
-from dbgpt.component import ComponentType
from dbgpt._private.config import Config
-
-from dbgpt.model.cluster import WorkerStartupRequest, WorkerManagerFactory
-from dbgpt.app.openapi.api_view_model import Result
-
from dbgpt.app.llm_manage.request.request import ModelResponse
+from dbgpt.app.openapi.api_view_model import Result
+from dbgpt.component import ComponentType
+from dbgpt.model.cluster import WorkerManagerFactory, WorkerStartupRequest
CFG = Config()
router = APIRouter()
diff --git a/dbgpt/app/llmserver.py b/dbgpt/app/llmserver.py
index bf8719af8..be07c1a02 100644
--- a/dbgpt/app/llmserver.py
+++ b/dbgpt/app/llmserver.py
@@ -8,7 +8,7 @@
sys.path.append(ROOT_PATH)
from dbgpt._private.config import Config
-from dbgpt.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
+from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
from dbgpt.model.cluster import run_worker_manager
CFG = Config()
diff --git a/dbgpt/app/openapi/api_v1/api_v1.py b/dbgpt/app/openapi/api_v1/api_v1.py
index 607642f8c..9c7137757 100644
--- a/dbgpt/app/openapi/api_v1/api_v1.py
+++ b/dbgpt/app/openapi/api_v1/api_v1.py
@@ -1,50 +1,39 @@
-import json
-import uuid
import asyncio
-import os
-import aiofiles
import logging
-from fastapi import (
- APIRouter,
- File,
- UploadFile,
- Body,
- Depends,
-)
+import os
+import uuid
+from concurrent.futures import Executor
+from typing import List, Optional
+import aiofiles
+from fastapi import APIRouter, Body, Depends, File, UploadFile
from fastapi.responses import StreamingResponse
-from typing import List, Optional
-from concurrent.futures import Executor
-from dbgpt.component import ComponentType
+from dbgpt._private.config import Config
+from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
+from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.openapi.api_view_model import (
- Result,
- ConversationVo,
- MessageVo,
- ChatSceneVo,
ChatCompletionResponseStreamChoice,
- DeltaMessage,
ChatCompletionStreamResponse,
+ ChatSceneVo,
+ ConversationVo,
+ DeltaMessage,
+ MessageVo,
+ Result,
)
-from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
-from dbgpt._private.config import Config
-from dbgpt.app.knowledge.service import KnowledgeService
-from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
-
-from dbgpt.app.scene import BaseChat, ChatScene, ChatFactory
-from dbgpt.core.interface.message import OnceConversation
+from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
+from dbgpt.component import ComponentType
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
-from dbgpt.rag.summary.db_summary_client import DBSummaryClient
-from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
-from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
+from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
from dbgpt.model.base import FlatSupportedModel
-from dbgpt.util.tracer import root_tracer, SpanType
+from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
+from dbgpt.rag.summary.db_summary_client import DBSummaryClient
from dbgpt.util.executor_utils import (
+ DefaultExecutorFactory,
ExecutorFactory,
blocking_func_to_async,
- DefaultExecutorFactory,
)
-
+from dbgpt.util.tracer import SpanType, root_tracer
router = APIRouter()
CFG = Config()
@@ -201,47 +190,6 @@ async def db_support_types():
return Result[DbTypeInfo].succ(db_type_infos)
-@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
-async def dialogue_list(
- user_name: str = None, user_id: str = None, sys_code: str = None
-):
- dialogues: List = []
- chat_history_service = ChatHistory()
- # TODO Change the synchronous call to the asynchronous call
- user_name = user_name or user_id
- datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code)
- for item in datas:
- conv_uid = item.get("conv_uid")
- summary = item.get("summary")
- chat_mode = item.get("chat_mode")
- model_name = item.get("model_name", CFG.LLM_MODEL)
- user_name = item.get("user_name")
- sys_code = item.get("sys_code")
- if not item.get("messages"):
- # Skip the empty messages
- # TODO support new conversation and message mode
- continue
-
- messages = json.loads(item.get("messages"))
- last_round = max(messages, key=lambda x: x["chat_order"])
- if "param_value" in last_round:
- select_param = last_round["param_value"]
- else:
- select_param = ""
- conv_vo: ConversationVo = ConversationVo(
- conv_uid=conv_uid,
- user_input=summary,
- chat_mode=chat_mode,
- model_name=model_name,
- select_param=select_param,
- user_name=user_name,
- sys_code=sys_code,
- )
- dialogues.append(conv_vo)
-
- return Result[ConversationVo].succ(dialogues[:10])
-
-
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
@@ -265,19 +213,6 @@ async def dialogue_scenes():
return Result.succ(scene_vos)
-@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
-async def dialogue_new(
- chat_mode: str = ChatScene.ChatNormal.value(),
- user_name: str = None,
- # TODO remove user id
- user_id: str = None,
- sys_code: str = None,
-):
- user_name = user_name or user_id
- conv_vo = __new_conversation(chat_mode, user_name, sys_code)
- return Result.succ(conv_vo)
-
-
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
if ChatScene.ChatWithDbQA.value() == chat_mode:
@@ -334,37 +269,11 @@ async def params_load(
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
-@router.post("/v1/chat/dialogue/delete")
-async def dialogue_delete(con_uid: str):
- history_fac = ChatHistory()
- history_mem = history_fac.get_store_instance(con_uid)
- # TODO Change the synchronous call to the asynchronous call
- history_mem.delete()
- return Result.succ(None)
-
-
def get_hist_messages(conv_uid: str):
- message_vos: List[MessageVo] = []
- history_fac = ChatHistory()
- history_mem = history_fac.get_store_instance(conv_uid)
-
- history_messages: List[OnceConversation] = history_mem.get_messages()
- if history_messages:
- for once in history_messages:
- model_name = once.get("model_name", CFG.LLM_MODEL)
- once_message_vos = [
- message2Vo(element, once["chat_order"], model_name)
- for element in once["messages"]
- ]
- message_vos.extend(once_message_vos)
- return message_vos
-
-
-@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
-async def dialogue_history_messages(con_uid: str):
- print(f"dialogue_history_messages:{con_uid}")
- # TODO Change the synchronous call to the asynchronous call
- return Result.succ(get_hist_messages(con_uid))
+ from dbgpt.serve.conversation.serve import Service as ConversationService
+
+ instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP)
+ return instance.get_history_messages({"conv_uid": conv_uid})
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
@@ -378,9 +287,7 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
dialogue.conv_uid = conv_vo.conv_uid
if not ChatScene.is_valid_mode(dialogue.chat_mode):
- raise StopAsyncIteration(
- Result.failed("Unsupported Chat Mode," + dialogue.chat_mode + "!")
- )
+ raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!")
chat_param = {
"chat_session_id": dialogue.conv_uid,
@@ -405,7 +312,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
logger.info(f"chat_prepare:{dialogue}")
## check conv_uid
chat: BaseChat = await get_chat_instance(dialogue)
- if len(chat.history_message) > 0:
+ if chat.has_history_messages():
return Result.succ(None)
resp = await chat.prepare()
return Result.succ(resp)
diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
index 740c0bf57..1e1ec9662 100644
--- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
+++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
@@ -1,37 +1,30 @@
import json
+import logging
import time
-from fastapi import (
- APIRouter,
- Body,
-)
-
from typing import List
-import logging
-from dbgpt._private.config import Config
+from fastapi import APIRouter, Body, Depends
-from dbgpt.app.scene import ChatFactory
-
-from dbgpt.app.openapi.api_view_model import (
- Result,
+from dbgpt._private.config import Config
+from dbgpt.app.openapi.api_v1.editor.service import EditorService
+from dbgpt.app.openapi.api_v1.editor.sql_editor import (
+ ChartRunData,
+ DataNode,
+ SqlRunData,
)
+from dbgpt.app.openapi.api_view_model import Result
from dbgpt.app.openapi.editor_view_model import (
- ChatDbRounds,
- ChartList,
ChartDetail,
+ ChartList,
ChatChartEditContext,
+ ChatDbRounds,
ChatSqlEditContext,
DbTable,
)
-
-from dbgpt.app.openapi.api_v1.editor.sql_editor import (
- DataNode,
- ChartRunData,
- SqlRunData,
-)
-from dbgpt.core.interface.message import OnceConversation
+from dbgpt.app.scene import ChatFactory
from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
-from dbgpt.app.scene.chat_db.data_loader import DbDataLoader
+from dbgpt.core.interface.message import OnceConversation
+from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter()
@@ -41,6 +34,14 @@
logger = logging.getLogger(__name__)
+def get_conversation_serve() -> ConversationServe:
+ return ConversationServe.get_instance(CFG.SYSTEM_APP)
+
+
+def get_edit_service() -> EditorService:
+ return EditorService.get_instance(CFG.SYSTEM_APP)
+
+
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
@@ -69,48 +70,21 @@ async def get_editor_tables(
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
-async def get_editor_sql_rounds(con_uid: str):
+async def get_editor_sql_rounds(
+ con_uid: str, editor_service: EditorService = Depends(get_edit_service)
+):
logger.info("get_editor_sql_rounds:{con_uid}")
- chat_history_fac = ChatHistory()
- history_mem = chat_history_fac.get_store_instance(con_uid)
- history_messages: List[OnceConversation] = history_mem.get_messages()
- if history_messages:
- result: List = []
- for once in history_messages:
- round_name: str = ""
- for element in once["messages"]:
- if element["type"] == "human":
- round_name = element["data"]["content"]
- if once.get("param_value"):
- round: ChatDbRounds = ChatDbRounds(
- round=once["chat_order"],
- db_name=once["param_value"],
- round_name=round_name,
- )
- result.append(round)
- return Result.succ(result)
+ return Result.succ(editor_service.get_editor_sql_rounds(con_uid))
@router.get("/v1/editor/sql", response_model=Result[dict])
-async def get_editor_sql(con_uid: str, round: int):
+async def get_editor_sql(
+ con_uid: str, round: int, editor_service: EditorService = Depends(get_edit_service)
+):
logger.info(f"get_editor_sql:{con_uid},{round}")
- chat_history_fac = ChatHistory()
- history_mem = chat_history_fac.get_store_instance(con_uid)
- history_messages: List[OnceConversation] = history_mem.get_messages()
- if history_messages:
- for once in history_messages:
- if int(once["chat_order"]) == round:
- for element in once["messages"]:
- if element["type"] == "ai":
- logger.info(
- f'history ai json resp:{element["data"]["content"]}'
- )
- context = (
- element["data"]["content"]
- .replace("\\n", " ")
- .replace("\n", " ")
- )
- return Result.succ(json.loads(context))
+ context = editor_service.get_editor_sql_by_round(con_uid, round)
+ if context:
+ return Result.succ(context)
return Result.failed(msg="not have sql!")
@@ -120,7 +94,7 @@ async def editor_sql_run(run_param: dict = Body()):
db_name = run_param["db_name"]
sql = run_param["sql"]
if not db_name and not sql:
- return Result.failed("SQL run param error!")
+ return Result.failed(msg="SQL run param error!")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
try:
@@ -145,102 +119,43 @@ async def editor_sql_run(run_param: dict = Body()):
@router.post("/v1/sql/editor/submit")
-async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
+async def sql_editor_submit(
+ sql_edit_context: ChatSqlEditContext = Body(),
+ editor_service: EditorService = Depends(get_edit_service),
+):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
- chat_history_fac = ChatHistory()
- history_mem = chat_history_fac.get_store_instance(sql_edit_context.conv_uid)
- history_messages: List[OnceConversation] = history_mem.get_messages()
- if history_messages:
- conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
-
- edit_round = list(
- filter(
- lambda x: x["chat_order"] == sql_edit_context.conv_round,
- history_messages,
- )
- )[0]
- if edit_round:
- for element in edit_round["messages"]:
- if element["type"] == "ai":
- db_resp = json.loads(element["data"]["content"])
- db_resp["thoughts"] = sql_edit_context.new_speak
- db_resp["sql"] = sql_edit_context.new_sql
- element["data"]["content"] = json.dumps(db_resp)
- if element["type"] == "view":
- data_loader = DbDataLoader()
- element["data"]["content"] = data_loader.get_table_view_by_conn(
- conn.run_to_df(sql_edit_context.new_sql),
- sql_edit_context.new_speak,
- )
- history_mem.update(history_messages)
- return Result.succ(None)
- return Result.failed(msg="Edit Failed!")
+ conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
+ try:
+ editor_service.sql_editor_submit_and_save(sql_edit_context, conn)
+ return Result.succ(None)
+ except Exception as e:
+ logger.error(f"edit sql exception!{str(e)}")
+ return Result.failed(msg=f"Edit sql exception!{str(e)}")
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
-async def get_editor_chart_list(con_uid: str):
+async def get_editor_chart_list(
+ con_uid: str,
+ editor_service: EditorService = Depends(get_edit_service),
+):
logger.info(
f"get_editor_sql_rounds:{con_uid}",
)
- chat_history_fac = ChatHistory()
- history_mem = chat_history_fac.get_store_instance(con_uid)
- history_messages: List[OnceConversation] = history_mem.get_messages()
- if history_messages:
- last_round = max(history_messages, key=lambda x: x["chat_order"])
- db_name = last_round["param_value"]
- for element in last_round["messages"]:
- if element["type"] == "ai":
- chart_list: ChartList = ChartList(
- round=last_round["chat_order"],
- db_name=db_name,
- charts=json.loads(element["data"]["content"]),
- )
- return Result.succ(chart_list)
+ chart_list = editor_service.get_editor_chart_list(con_uid)
+ if chart_list:
+ return Result.succ(chart_list)
return Result.failed(msg="Not have charts!")
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
-async def get_editor_chart_info(param: dict = Body()):
+async def get_editor_chart_info(
+ param: dict = Body(), editor_service: EditorService = Depends(get_edit_service)
+):
logger.info(f"get_editor_chart_info:{param}")
conv_uid = param["con_uid"]
chart_title = param["chart_title"]
-
- chat_history_fac = ChatHistory()
- history_mem = chat_history_fac.get_store_instance(conv_uid)
- history_messages: List[OnceConversation] = history_mem.get_messages()
- if history_messages:
- last_round = max(history_messages, key=lambda x: x["chat_order"])
- db_name = last_round["param_value"]
- if not db_name:
- logger.error(
- "this dashboard dialogue version too old, can't support editor!"
- )
- return Result.failed(
- msg="this dashboard dialogue version too old, can't support editor!"
- )
- for element in last_round["messages"]:
- if element["type"] == "view":
- view_data: dict = json.loads(element["data"]["content"])
- charts: List = view_data.get("charts")
- find_chart = list(
- filter(lambda x: x["chart_name"] == chart_title, charts)
- )[0]
-
- conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
- detail: ChartDetail = ChartDetail(
- chart_uid=find_chart["chart_uid"],
- chart_type=find_chart["chart_type"],
- chart_desc=find_chart["chart_desc"],
- chart_sql=find_chart["chart_sql"],
- db_name=db_name,
- chart_name=find_chart["chart_name"],
- chart_value=find_chart["values"],
- table_value=conn.run(find_chart["chart_sql"]),
- )
-
- return Result.succ(detail)
- return Result.failed(msg="Can't Find Chart Detail Info!")
+ return editor_service.get_editor_chart_info(conv_uid, chart_title, CFG)
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
diff --git a/dbgpt/app/openapi/api_v1/editor/service.py b/dbgpt/app/openapi/api_v1/editor/service.py
new file mode 100644
index 000000000..ef22adebd
--- /dev/null
+++ b/dbgpt/app/openapi/api_v1/editor/service.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import json
+import logging
+from typing import TYPE_CHECKING, Dict, List, Optional
+
+from dbgpt._private.config import Config
+from dbgpt.app.openapi.api_view_model import Result
+from dbgpt.app.openapi.editor_view_model import (
+ ChartDetail,
+ ChartList,
+ ChatDbRounds,
+ ChatSqlEditContext,
+)
+from dbgpt.component import BaseComponent, SystemApp
+from dbgpt.core import BaseOutputParser
+from dbgpt.core.interface.message import (
+ MessageStorageItem,
+ StorageConversation,
+ _split_messages_by_round,
+)
+from dbgpt.serve.conversation.serve import Serve as ConversationServe
+
+if TYPE_CHECKING:
+ from dbgpt.datasource.base import BaseConnect
+
+logger = logging.getLogger(__name__)
+
+
+class EditorService(BaseComponent):
+ name = "dbgpt_app_editor_service"
+
+ def __init__(self, system_app: SystemApp):
+ self._system_app: SystemApp = system_app
+ super().__init__(system_app)
+
+ def init_app(self, system_app: SystemApp):
+ self._system_app = system_app
+
+ def conv_serve(self) -> ConversationServe:
+ return ConversationServe.get_instance(self._system_app)
+
+ def get_storage_conv(self, conv_uid: str) -> StorageConversation:
+ conv_serve: ConversationServe = self.conv_serve()
+ return StorageConversation(
+ conv_uid,
+ conv_storage=conv_serve.conv_storage,
+ message_storage=conv_serve.message_storage,
+ )
+
+ def get_editor_sql_rounds(self, conv_uid: str) -> List[ChatDbRounds]:
+ storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
+ messages_by_round = _split_messages_by_round(storage_conv.messages)
+ result: List[ChatDbRounds] = []
+ for one_round_message in messages_by_round:
+ if not one_round_message:
+ continue
+ for message in one_round_message:
+ if message.type == "human":
+ round_name = message.content
+ if message.additional_kwargs.get("param_value"):
+ chat_db_round: ChatDbRounds = ChatDbRounds(
+ round=message.round_index,
+ db_name=message.additional_kwargs.get("param_value"),
+ round_name=round_name,
+ )
+ result.append(chat_db_round)
+
+ return result
+
+ def get_editor_sql_by_round(
+ self, conv_uid: str, round_index: int
+ ) -> Optional[Dict]:
+ storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
+ messages_by_round = _split_messages_by_round(storage_conv.messages)
+ for one_round_message in messages_by_round:
+ if not one_round_message:
+ continue
+ for message in one_round_message:
+ if message.type == "ai" and message.round_index == round_index:
+ content = message.content
+ logger.info(f"history ai json resp: {content}")
+ # context = content.replace("\\n", " ").replace("\n", " ")
+ context_dict = _parse_pure_dict(content)
+ return context_dict
+ return None
+
+ def sql_editor_submit_and_save(
+ self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect
+ ):
+ storage_conv: StorageConversation = self.get_storage_conv(
+ sql_edit_context.conv_uid
+ )
+ if not storage_conv.save_message_independent:
+ raise ValueError(
+ "Submit sql and save just support independent conversation mode(after v0.4.6)"
+ )
+ conv_serve: ConversationServe = self.conv_serve()
+ messages_by_round = _split_messages_by_round(storage_conv.messages)
+ to_update_messages = []
+ for one_round_message in messages_by_round:
+ if not one_round_message:
+ continue
+ if one_round_message[0].round_index == sql_edit_context.conv_round:
+ for message in one_round_message:
+ if message.type == "ai":
+ db_resp = _parse_pure_dict(message.content)
+ db_resp["thoughts"] = sql_edit_context.new_speak
+ db_resp["sql"] = sql_edit_context.new_sql
+ message.content = json.dumps(db_resp, ensure_ascii=False)
+ to_update_messages.append(
+ MessageStorageItem(
+ storage_conv.conv_uid, message.index, message.to_dict()
+ )
+ )
+ # TODO not support update view message now
+ # if message.type == "view":
+ # data_loader = DbDataLoader()
+ # message.content = data_loader.get_table_view_by_conn(
+ # connection.run_to_df(sql_edit_context.new_sql),
+ # sql_edit_context.new_speak,
+ # )
+ # to_update_messages.append(
+ # MessageStorageItem(
+ # storage_conv.conv_uid, message.index, message.to_dict()
+ # )
+ # )
+ if to_update_messages:
+ conv_serve.message_storage.save_or_update_list(to_update_messages)
+ return
+
+ def get_editor_chart_list(self, conv_uid: str) -> Optional[ChartList]:
+ storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
+ messages_by_round = _split_messages_by_round(storage_conv.messages)
+ for one_round_message in messages_by_round:
+ if not one_round_message:
+ continue
+ for message in one_round_message:
+ if message.type == "ai":
+ context_dict = _parse_pure_dict(message.content)
+ chart_list: ChartList = ChartList(
+ round=message.round_index,
+ db_name=message.additional_kwargs.get("param_value"),
+ charts=context_dict,
+ )
+ return chart_list
+
+ def get_editor_chart_info(
+ self, conv_uid: str, chart_title: str, cfg: Config
+ ) -> Result[ChartDetail]:
+ storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
+ messages_by_round = _split_messages_by_round(storage_conv.messages)
+ for one_round_message in messages_by_round:
+ if not one_round_message:
+ continue
+ for message in one_round_message:
+ db_name = message.additional_kwargs.get("param_value")
+ if not db_name:
+ logger.error(
+ "this dashboard dialogue version too old, can't support editor!"
+ )
+ return Result.failed(
+ msg="this dashboard dialogue version too old, can't support editor!"
+ )
+ if message.type == "view":
+ view_data: dict = _parse_pure_dict(message.content)
+ charts: List = view_data.get("charts")
+ find_chart = list(
+ filter(lambda x: x["chart_name"] == chart_title, charts)
+ )[0]
+
+ conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name)
+ detail: ChartDetail = ChartDetail(
+ chart_uid=find_chart["chart_uid"],
+ chart_type=find_chart["chart_type"],
+ chart_desc=find_chart["chart_desc"],
+ chart_sql=find_chart["chart_sql"],
+ db_name=db_name,
+ chart_name=find_chart["chart_name"],
+ chart_value=find_chart["values"],
+ table_value=conn.run(find_chart["chart_sql"]),
+ )
+ return Result.succ(detail)
+ return Result.failed(msg="Can't Find Chart Detail Info!")
+
+
+def _parse_pure_dict(res_str: str) -> Dict:
+ output_parser = BaseOutputParser()
+ context = output_parser.parse_prompt_response(res_str)
+ return json.loads(context)
diff --git a/dbgpt/app/openapi/api_v1/editor/sql_editor.py b/dbgpt/app/openapi/api_v1/editor/sql_editor.py
index e31db2c41..5a05199e0 100644
--- a/dbgpt/app/openapi/api_v1/editor/sql_editor.py
+++ b/dbgpt/app/openapi/api_v1/editor/sql_editor.py
@@ -1,4 +1,5 @@
from typing import List
+
from dbgpt._private.pydantic import BaseModel
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
diff --git a/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py b/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py
index 365b399b3..4dc4178d7 100644
--- a/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py
+++ b/dbgpt/app/openapi/api_v1/feedback/api_fb_v1.py
@@ -1,9 +1,7 @@
from fastapi import APIRouter, Body, Request
+from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackDao
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
-from dbgpt.app.openapi.api_v1.feedback.feed_back_db import (
- ChatFeedBackDao,
-)
from dbgpt.app.openapi.api_view_model import Result
router = APIRouter()
diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py
index 434830625..2b358f8c5 100644
--- a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py
+++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py
@@ -1,10 +1,9 @@
from datetime import datetime
-from sqlalchemy import Column, Integer, Text, String, DateTime
-
-from dbgpt.storage.metadata import BaseDao, Model
+from sqlalchemy import Column, DateTime, Integer, String, Text
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
+from dbgpt.storage.metadata import BaseDao, Model
class ChatFeedBackEntity(Model):
diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py
index 38f437b4a..7f4067418 100644
--- a/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py
+++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_model.py
@@ -1,6 +1,7 @@
-from dbgpt._private.pydantic import BaseModel
from typing import Optional
+from dbgpt._private.pydantic import BaseModel
+
class FeedBackBody(BaseModel):
"""conv_uid: conversation id"""
diff --git a/dbgpt/app/openapi/api_view_model.py b/dbgpt/app/openapi/api_view_model.py
index 25f8cb5dc..cd4ccadfd 100644
--- a/dbgpt/app/openapi/api_view_model.py
+++ b/dbgpt/app/openapi/api_view_model.py
@@ -1,7 +1,8 @@
-from dbgpt._private.pydantic import BaseModel, Field
-from typing import TypeVar, Generic, Any, Optional, Literal, List
-import uuid
import time
+import uuid
+from typing import Any, Generic, List, Literal, Optional, TypeVar
+
+from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T")
@@ -17,11 +18,7 @@ def succ(cls, data: T):
return Result(success=True, err_code=None, err_msg=None, data=data)
@classmethod
- def failed(cls, msg):
- return Result(success=False, err_code="E000X", err_msg=msg, data=None)
-
- @classmethod
- def failed(cls, code, msg):
+ def failed(cls, code: str = "E000X", msg=None):
return Result(success=False, err_code=code, err_msg=msg, data=None)
diff --git a/dbgpt/app/openapi/base.py b/dbgpt/app/openapi/base.py
index 10cb1d231..d4cb7679b 100644
--- a/dbgpt/app/openapi/base.py
+++ b/dbgpt/app/openapi/base.py
@@ -1,5 +1,6 @@
from fastapi import Request
from fastapi.exceptions import RequestValidationError
+
from dbgpt.app.openapi.api_view_model import Result
diff --git a/dbgpt/app/openapi/editor_view_model.py b/dbgpt/app/openapi/editor_view_model.py
index 82d97092c..a915805fd 100644
--- a/dbgpt/app/openapi/editor_view_model.py
+++ b/dbgpt/app/openapi/editor_view_model.py
@@ -1,5 +1,6 @@
+from typing import Any, List
+
from dbgpt._private.pydantic import BaseModel, Field
-from typing import List, Any
class DbField(BaseModel):
diff --git a/dbgpt/app/scene/__init__.py b/dbgpt/app/scene/__init__.py
index 494db2a31..d6b804f1a 100644
--- a/dbgpt/app/scene/__init__.py
+++ b/dbgpt/app/scene/__init__.py
@@ -1,3 +1,3 @@
+from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.base_chat import BaseChat
from dbgpt.app.scene.chat_factory import ChatFactory
-from dbgpt.app.scene.base import ChatScene
diff --git a/dbgpt/app/scene/base.py b/dbgpt/app/scene/base.py
index 893bccdb0..a2b5f03c4 100644
--- a/dbgpt/app/scene/base.py
+++ b/dbgpt/app/scene/base.py
@@ -1,5 +1,9 @@
from enum import Enum
-from typing import List
+from typing import List, Optional
+
+from dbgpt._private.pydantic import BaseModel, Field
+from dbgpt.core import BaseOutputParser, ChatPromptTemplate
+from dbgpt.core._private.example_base import ExampleSelector
class Scene:
@@ -135,3 +139,49 @@ def show_disable(self):
def is_inner(self):
return self._value_.is_inner
+
+
+class AppScenePromptTemplateAdapter(BaseModel):
+ """The template of the scene.
+
+ Include some fields that in :class:`dbgpt.core.PromptTemplate`
+ """
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ arbitrary_types_allowed = True
+
+ prompt: ChatPromptTemplate = Field(..., description="The prompt of this scene")
+ template_scene: Optional[str] = Field(
+ default=None, description="The scene of this template"
+ )
+ template_is_strict: Optional[bool] = Field(
+ default=True, description="Whether strict"
+ )
+
+ output_parser: Optional[BaseOutputParser] = Field(
+ default=None, description="The output parser of this scene"
+ )
+ sep: Optional[str] = Field(
+ default="###", description="The default separator of this scene"
+ )
+
+ stream_out: Optional[bool] = Field(
+ default=True, description="Whether to stream out"
+ )
+ example_selector: Optional[ExampleSelector] = Field(
+ default=None, description="Example selector"
+ )
+ need_historical_messages: Optional[bool] = Field(
+ default=False, description="Whether to need historical messages"
+ )
+ temperature: Optional[float] = Field(
+ default=0.6, description="The default temperature of this scene"
+ )
+ max_new_tokens: Optional[int] = Field(
+ default=1024, description="The default max new tokens of this scene"
+ )
+ str_history: Optional[bool] = Field(
+ default=False, description="Whether transform history to str"
+ )
diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py
index 776c460cc..b9a22f7e8 100644
--- a/dbgpt/app/scene/base_chat.py
+++ b/dbgpt/app/scene/base_chat.py
@@ -1,30 +1,56 @@
-import datetime
-import traceback
-import warnings
import asyncio
+import datetime
import logging
+import traceback
from abc import ABC, abstractmethod
-from typing import Any, List, Dict
+from typing import Any, Dict
from dbgpt._private.config import Config
+from dbgpt._private.pydantic import Extra
+from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene
+from dbgpt.app.scene.operator.app_operator import (
+ AppChatComposerOperator,
+ ChatComposerInput,
+)
from dbgpt.component import ComponentType
-from dbgpt.core.interface.prompt import PromptTemplate
-from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
-from dbgpt.core.interface.message import OnceConversation
+from dbgpt.core.awel import DAG, BaseOperator, InputOperator, SimpleCallDataInputSource
+from dbgpt.core.interface.message import StorageConversation
from dbgpt.model.cluster import WorkerManagerFactory
+from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
+from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
-from dbgpt._private.pydantic import Extra
-from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
-from dbgpt.core.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
-from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
logger = logging.getLogger(__name__)
-headers = {"User-Agent": "dbgpt Client"}
CFG = Config()
+def _build_conversation(
+ chat_mode: ChatScene,
+ chat_param: Dict[str, Any],
+ model_name: str,
+ conv_serve: ConversationServe,
+) -> StorageConversation:
+ param_type = ""
+ param_value = ""
+ if chat_param["select_param"]:
+ if len(chat_mode.param_types()) > 0:
+ param_type = chat_mode.param_types()[0]
+ param_value = chat_param["select_param"]
+ return StorageConversation(
+ chat_param["chat_session_id"],
+ chat_mode=chat_mode.value(),
+ user_name=chat_param.get("user_name"),
+ sys_code=chat_param.get("sys_code"),
+ model_name=model_name,
+ param_type=param_type,
+ param_value=param_value,
+ conv_storage=conv_serve.conv_storage,
+ message_storage=conv_serve.message_storage,
+ )
+
+
class BaseChat(ABC):
"""DB-GPT Chat Service Base Module
Include:
@@ -35,7 +61,8 @@ class BaseChat(ABC):
chat_scene: str = None
llm_model: Any = None
# By default, keep the last two rounds of conversation records as the context
- chat_retention_rounds: int = 0
+ keep_start_rounds: int = 0
+ keep_end_rounds: int = 0
class Config:
"""Configuration for this pydantic object."""
@@ -68,7 +95,7 @@ def __init__(self, chat_param: Dict):
# self.prompt_template: PromptTemplate = CFG.prompt_templates[
# self.chat_mode.value()
# ]
- self.prompt_template: PromptTemplate = (
+ self.prompt_template: AppScenePromptTemplateAdapter = (
CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(),
language=CFG.LANGUAGE,
@@ -76,21 +103,21 @@ def __init__(self, chat_param: Dict):
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
- chat_history_fac = ChatHistory()
+ self._conv_serve = ConversationServe.get_instance(CFG.SYSTEM_APP)
+ # chat_history_fac = ChatHistory()
### can configurable storage methods
- self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
-
- self.history_message: List[OnceConversation] = self.memory.messages()
- self.current_message: OnceConversation = OnceConversation(
- self.chat_mode.value(),
- user_name=chat_param.get("user_name"),
- sys_code=chat_param.get("sys_code"),
+ # self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
+
+ # self.history_message: List[OnceConversation] = self.memory.messages()
+ # self.current_message: OnceConversation = OnceConversation(
+ # self.chat_mode.value(),
+ # user_name=chat_param.get("user_name"),
+ # sys_code=chat_param.get("sys_code"),
+ # )
+ self.current_message: StorageConversation = _build_conversation(
+ self.chat_mode, chat_param, self.llm_model, self._conv_serve
)
- self.current_message.model_name = self.llm_model
- if chat_param["select_param"]:
- if len(self.chat_mode.param_types()) > 0:
- self.current_message.param_type = self.chat_mode.param_types()[0]
- self.current_message.param_value = chat_param["select_param"]
+ self.history_messages = self.current_message.get_history_message()
self.current_tokens_used: int = 0
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
@@ -102,10 +129,9 @@ def __init__(self, chat_param: Dict):
is_stream=True, dag_name="llm_stream_model_dag"
)
- # Get the message version, default is v1 in app
# In v1, we will transform the message to compatible format of specific model
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
- self._message_version = chat_param.get("message_version", "v1")
+ self._message_version = chat_param.get("message_version", "v2")
class Config:
"""Configuration for this pydantic object."""
@@ -133,6 +159,14 @@ def do_action(self, prompt_response):
def message_adjust(self):
pass
+ def has_history_messages(self) -> bool:
+ """Whether there is a history messages
+
+ Returns:
+ bool: True if there is a history message, False otherwise
+ """
+ return len(self.history_messages) > 0
+
def get_llm_speak(self, prompt_define_response):
if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict):
@@ -153,31 +187,51 @@ def get_llm_speak(self, prompt_define_response):
async def __call_base(self):
input_values = await self.generate_input_values()
- ### Chat sequence advance
- self.current_message.chat_order = len(self.history_message) + 1
- self.current_message.add_user_message(
- self.current_user_input, check_duplicate_type=True
- )
+ # Load history
+ self.history_messages = self.current_message.get_history_message()
+ self.current_message.start_new_round()
+ self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
- if self.prompt_template.template:
- metadata = {
- "template_scene": self.prompt_template.template_scene,
- "input_values": input_values,
- }
- with root_tracer.start_span(
- "BaseChat.__call_base.prompt_template.format", metadata=metadata
- ):
- current_prompt = self.prompt_template.format(**input_values)
- ### prompt context token adapt according to llm max context length
- current_prompt = await self.prompt_context_token_adapt(
- prompt=current_prompt
- )
- self.current_message.add_system_message(current_prompt)
-
- llm_messages = self.generate_llm_messages()
+ # TODO: support handle history message by tokens
+ # if self.prompt_template.template:
+ # metadata = {
+ # "template_scene": self.prompt_template.template_scene,
+ # "input_values": input_values,
+ # }
+ # with root_tracer.start_span(
+ # "BaseChat.__call_base.prompt_template.format", metadata=metadata
+ # ):
+ # current_prompt = self.prompt_template.format(**input_values)
+ # ### prompt context token adapt according to llm max context length
+ # current_prompt = await self.prompt_context_token_adapt(
+ # prompt=current_prompt
+ # )
+ # self.current_message.add_system_message(current_prompt)
+
+ keep_start_rounds = (
+ self.keep_start_rounds
+ if self.prompt_template.need_historical_messages
+ else 0
+ )
+ keep_end_rounds = (
+ self.keep_end_rounds if self.prompt_template.need_historical_messages else 0
+ )
+ node = AppChatComposerOperator(
+ prompt=self.prompt_template.prompt,
+ keep_start_rounds=keep_start_rounds,
+ keep_end_rounds=keep_end_rounds,
+ str_history=self.prompt_template.str_history,
+ )
+ node_input = {
+ "data": ChatComposerInput(
+ messages=self.history_messages, prompt_dict=input_values
+ )
+ }
+ # llm_messages = self.generate_llm_messages()
+ llm_messages = await node.call(call_data=node_input)
if not CFG.NEW_SERVER_MODE:
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
@@ -185,7 +239,7 @@ async def __call_base(self):
payload = {
"model": self.llm_model,
- "prompt": self.generate_llm_text(),
+ "prompt": "",
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
@@ -238,7 +292,7 @@ async def stream_call(self):
view_msg = self.stream_plugin_call(msg)
view_msg = view_msg.replace("\n", "\\n")
yield view_msg
- self.current_message.add_ai_message(msg, update_if_exist=True)
+ self.current_message.add_ai_message(msg)
view_msg = self.stream_call_reinforce_fn(view_msg)
self.current_message.add_view_message(view_msg)
span.end()
@@ -250,7 +304,8 @@ async def stream_call(self):
)
### store current conversation
span.end(metadata={"error": str(e)})
- self.memory.append(self.current_message)
+ # self.memory.append(self.current_message)
+ self.current_message.end_current_round()
async def nostream_call(self):
payload = await self.__call_base()
@@ -274,7 +329,7 @@ async def nostream_call(self):
)
)
### model result deal
- self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
+ self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
@@ -320,7 +375,8 @@ async def nostream_call(self):
)
span.end(metadata={"error": str(e)})
### store dialogue
- self.memory.append(self.current_message)
+ # self.memory.append(self.current_message)
+ self.current_message.end_current_round()
return self.current_ai_response()
async def get_llm_response(self):
@@ -328,6 +384,7 @@ async def get_llm_response(self):
logger.info(f"Request: \n{payload}")
ai_response_text = ""
payload["model_cache_enable"] = self.model_cache_enable
+ prompt_define_response = None
try:
model_output = await self._model_operator.call(call_data={"data": payload})
### output parse
@@ -337,8 +394,7 @@ async def get_llm_response(self):
)
)
### model result deal
- self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
- prompt_define_response = None
+ self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
@@ -384,64 +440,8 @@ def call(self):
async def prepare(self):
pass
- def generate_llm_text(self) -> str:
- warnings.warn("This method is deprecated - please use `generate_llm_messages`.")
- text = ""
- ### Load scene setting or character definition
- if self.prompt_template.template_define:
- text += self.prompt_template.template_define + self.prompt_template.sep
- ### Load prompt
- text += _load_system_message(self.current_message, self.prompt_template)
-
- ### Load examples
- text += _load_example_messages(self.prompt_template)
-
- ### Load History
- text += _load_history_messages(
- self.prompt_template, self.history_message, self.chat_retention_rounds
- )
-
- ### Load User Input
- text += _load_user_message(self.current_message, self.prompt_template)
- return text
-
- def generate_llm_messages(self) -> List[ModelMessage]:
- """
- Structured prompt messages interaction between dbgpt-server and llm-server
- See https://github.com/csunny/DB-GPT/issues/328
- """
- messages = []
- ### Load scene setting or character definition as system message
- if self.prompt_template.template_define:
- messages.append(
- ModelMessage(
- role=ModelMessageRoleType.SYSTEM,
- content=self.prompt_template.template_define,
- )
- )
- ### Load prompt
- messages += _load_system_message(
- self.current_message, self.prompt_template, str_message=False
- )
- ### Load examples
- messages += _load_example_messages(self.prompt_template, str_message=False)
-
- ### Load History
- messages += _load_history_messages(
- self.prompt_template,
- self.history_message,
- self.chat_retention_rounds,
- str_message=False,
- )
-
- ### Load User Input
- messages += _load_user_message(
- self.current_message, self.prompt_template, str_message=False
- )
- return messages
-
def current_ai_response(self) -> str:
- for message in self.current_message.messages:
+ for message in self.current_message.messages[-1:]:
if message.type == "view":
return message.content
return None
@@ -539,16 +539,6 @@ def _generate_numbered_list(self) -> str:
},
]
- # command_strings = []
- # if CFG.command_disply:
- # for name, item in CFG.command_disply.commands.items():
- # if item.enabled:
- # command_strings.append(f"{name}:{item.description}")
- # command_strings += [
- # str(item)
- # for item in CFG.command_disply.commands.values()
- # if item.enabled
- # ]
return "\n".join(
f"{key}:{value}"
for dict_item in antv_charts
@@ -566,6 +556,7 @@ def _build_model_operator(
data using the model. It supports both streaming and non-streaming modes.
.. code-block:: python
+
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
@@ -585,12 +576,12 @@ def _build_model_operator(
Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node.
"""
- from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.core.awel import JoinOperator
+ from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.operator.model_operator import (
- ModelCacheBranchOperator,
- CachedModelStreamOperator,
CachedModelOperator,
+ CachedModelStreamOperator,
+ ModelCacheBranchOperator,
ModelSaveCacheOperator,
ModelStreamSaveCacheOperator,
)
@@ -639,127 +630,3 @@ def _build_model_operator(
cache_check_branch_node >> cached_node >> join_node
return join_node
-
-
-def _load_system_message(
- current_message: OnceConversation,
- prompt_template: PromptTemplate,
- str_message: bool = True,
-):
- system_convs = current_message.get_system_messages()
- system_text = ""
- system_messages = []
- for system_conv in system_convs:
- system_text += (
- system_conv.type + ":" + system_conv.content + prompt_template.sep
- )
- system_messages.append(
- ModelMessage(role=system_conv.type, content=system_conv.content)
- )
- return system_text if str_message else system_messages
-
-
-def _load_user_message(
- current_message: OnceConversation,
- prompt_template: PromptTemplate,
- str_message: bool = True,
-):
- user_conv = current_message.get_latest_user_message()
- user_messages = []
- if user_conv:
- user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
- user_messages.append(
- ModelMessage(role=user_conv.type, content=user_conv.content)
- )
- return user_text if str_message else user_messages
- else:
- raise ValueError("Hi! What do you want to talk about?")
-
-
-def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True):
- example_text = ""
- example_messages = []
- if prompt_template.example_selector:
- for round_conv in prompt_template.example_selector.examples():
- for round_message in round_conv["messages"]:
- if not round_message["type"] in [
- ModelMessageRoleType.VIEW,
- ModelMessageRoleType.SYSTEM,
- ]:
- message_type = round_message["type"]
- message_content = round_message["data"]["content"]
- example_text += (
- message_type + ":" + message_content + prompt_template.sep
- )
- example_messages.append(
- ModelMessage(role=message_type, content=message_content)
- )
- return example_text if str_message else example_messages
-
-
-def _load_history_messages(
- prompt_template: PromptTemplate,
- history_message: List[OnceConversation],
- chat_retention_rounds: int,
- str_message: bool = True,
-):
- history_text = ""
- history_messages = []
- if prompt_template.need_historical_messages:
- if history_message:
- logger.info(
- f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!"
- )
- if len(history_message) > chat_retention_rounds:
- for first_message in history_message[0]["messages"]:
- if not first_message["type"] in [
- ModelMessageRoleType.VIEW,
- ModelMessageRoleType.SYSTEM,
- ]:
- message_type = first_message["type"]
- message_content = first_message["data"]["content"]
- history_text += (
- message_type + ":" + message_content + prompt_template.sep
- )
- history_messages.append(
- ModelMessage(role=message_type, content=message_content)
- )
- if chat_retention_rounds > 1:
- index = chat_retention_rounds - 1
- for round_conv in history_message[-index:]:
- for round_message in round_conv["messages"]:
- if not round_message["type"] in [
- ModelMessageRoleType.VIEW,
- ModelMessageRoleType.SYSTEM,
- ]:
- message_type = round_message["type"]
- message_content = round_message["data"]["content"]
- history_text += (
- message_type
- + ":"
- + message_content
- + prompt_template.sep
- )
- history_messages.append(
- ModelMessage(role=message_type, content=message_content)
- )
-
- else:
- ### user all history
- for conversation in history_message:
- for message in conversation["messages"]:
- ### histroy message not have promot and view info
- if not message["type"] in [
- ModelMessageRoleType.VIEW,
- ModelMessageRoleType.SYSTEM,
- ]:
- message_type = message["type"]
- message_content = message["data"]["content"]
- history_text += (
- message_type + ":" + message_content + prompt_template.sep
- )
- history_messages.append(
- ModelMessage(role=message_type, content=message_content)
- )
-
- return history_text if str_message else history_messages
diff --git a/dbgpt/app/scene/chat_agent/chat.py b/dbgpt/app/scene/chat_agent/chat.py
index a46d70f37..ded1db216 100644
--- a/dbgpt/app/scene/chat_agent/chat.py
+++ b/dbgpt/app/scene/chat_agent/chat.py
@@ -1,10 +1,10 @@
-from typing import List, Dict
import logging
+from typing import Dict, List
-from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
from dbgpt.agent.plugin.commands.command_mange import ApiCall
from dbgpt.agent.plugin.generator import PluginPromptGenerator
+from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.component import ComponentType
from dbgpt.serve.agent.hub.controller import ModuleAgent
from dbgpt.util.tracer import root_tracer, trace
@@ -18,7 +18,7 @@ class ChatAgent(BaseChat):
"""Chat With Agent through plugin"""
chat_scene: str = ChatScene.ChatAgent.value()
- chat_retention_rounds = 0
+ keep_end_rounds = 0
def __init__(self, chat_param: Dict):
"""Chat Agent Module Initialization
diff --git a/dbgpt/app/scene/chat_agent/out_parser.py b/dbgpt/app/scene/chat_agent/out_parser.py
index 5cb348ae9..e1bc399e7 100644
--- a/dbgpt/app/scene/chat_agent/out_parser.py
+++ b/dbgpt/app/scene/chat_agent/out_parser.py
@@ -1,4 +1,5 @@
from typing import Dict, NamedTuple
+
from dbgpt.core.interface.output_parser import BaseOutputParser
diff --git a/dbgpt/app/scene/chat_agent/prompt.py b/dbgpt/app/scene/chat_agent/prompt.py
index f716606a0..7764b0c43 100644
--- a/dbgpt/app/scene/chat_agent/prompt.py
+++ b/dbgpt/app/scene/chat_agent/prompt.py
@@ -1,8 +1,7 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser
+from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate, SystemPromptTemplate
CFG = Config()
@@ -65,16 +64,19 @@
### Whether the model service is streaming output
PROMPT_NEED_STREAM_OUT = True
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(_PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE),
+ HumanPromptTemplate.from_template("{user_goal}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatAgent.value(),
- input_variables=["tool_list", "expand_constraints", "user_goal"],
- response_format=None,
- template_define=_PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
- temperature=1
- # example_selector=plugin_example,
+ need_historical_messages=False,
+ temperature=1,
)
-
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_dashboard/chat.py b/dbgpt/app/scene/chat_dashboard/chat.py
index 7fe91b516..3748ce32e 100644
--- a/dbgpt/app/scene/chat_dashboard/chat.py
+++ b/dbgpt/app/scene/chat_dashboard/chat.py
@@ -1,15 +1,15 @@
import json
import os
import uuid
-from typing import List, Dict
+from typing import Dict, List
-from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
+from dbgpt.app.scene import BaseChat, ChatScene
+from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
)
-from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import trace
diff --git a/dbgpt/app/scene/chat_dashboard/data_loader.py b/dbgpt/app/scene/chat_dashboard/data_loader.py
index baf001faa..05cc55970 100644
--- a/dbgpt/app/scene/chat_dashboard/data_loader.py
+++ b/dbgpt/app/scene/chat_dashboard/data_loader.py
@@ -1,6 +1,6 @@
-from typing import List
-from decimal import Decimal
import logging
+from decimal import Decimal
+from typing import List
from dbgpt._private.config import Config
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
diff --git a/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py b/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py
index 19e4d4923..8a8c12749 100644
--- a/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py
+++ b/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py
@@ -1,5 +1,6 @@
+from typing import Any, List
+
from dbgpt._private.pydantic import BaseModel
-from typing import List, Any
class ValueItem(BaseModel):
diff --git a/dbgpt/app/scene/chat_dashboard/out_parser.py b/dbgpt/app/scene/chat_dashboard/out_parser.py
index 27fc27aef..71955a120 100644
--- a/dbgpt/app/scene/chat_dashboard/out_parser.py
+++ b/dbgpt/app/scene/chat_dashboard/out_parser.py
@@ -1,9 +1,9 @@
import json
import logging
+from typing import List, NamedTuple
-from typing import NamedTuple, List
-from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.app.scene import ChatScene
+from dbgpt.core.interface.output_parser import BaseOutputParser
class ChartItem(NamedTuple):
diff --git a/dbgpt/app/scene/chat_dashboard/prompt.py b/dbgpt/app/scene/chat_dashboard/prompt.py
index 566646cbe..ad53f8d26 100644
--- a/dbgpt/app/scene/chat_dashboard/prompt.py
+++ b/dbgpt/app/scene/chat_dashboard/prompt.py
@@ -1,8 +1,9 @@
import json
-from dbgpt.core.interface.prompt import PromptTemplate
+
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_dashboard.out_parser import ChatDashboardOutputParser
+from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate, SystemPromptTemplate
CFG = Config()
@@ -39,16 +40,23 @@
}
]
-
PROMPT_NEED_STREAM_OUT = False
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(
+ PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE,
+ response_format=json.dumps(RESPONSE_FORMAT, indent=4),
+ ),
+ HumanPromptTemplate.from_template("{input}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatDashboard.value(),
- input_variables=["input", "table_info", "dialect", "supported_chat_type"],
- response_format=json.dumps(RESPONSE_FORMAT, indent=4),
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=ChatDashboardOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py
index 4ace7e0b5..fa96b7a76 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py
@@ -1,14 +1,14 @@
-import os
import logging
-
+import os
from typing import Dict
-from dbgpt.app.scene import BaseChat, ChatScene
+
from dbgpt._private.config import Config
from dbgpt.agent.plugin.commands.command_mange import ApiCall
-from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader
+from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
-from dbgpt.util.path_utils import has_path
+from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
+from dbgpt.util.path_utils import has_path
from dbgpt.util.tracer import root_tracer, trace
CFG = Config()
@@ -20,7 +20,8 @@ class ChatExcel(BaseChat):
"""a Excel analyzer to analyze Excel Data"""
chat_scene: str = ChatScene.ChatExcel.value()
- chat_retention_rounds = 2
+ keep_start_rounds = 1
+ keep_end_rounds = 2
def __init__(self, chat_param: Dict):
"""Chat Excel Module Initialization
@@ -58,7 +59,7 @@ async def generate_input_values(self) -> Dict:
async def prepare(self):
logger.info(f"{self.chat_mode} prepare start!")
- if len(self.history_message) > 0:
+ if self.has_history_messages():
return None
chat_param = {
"chat_session_id": self.chat_session_id,
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py
index 9a2e7af1c..34a47f493 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/out_parser.py
@@ -1,8 +1,9 @@
import json
import logging
from typing import NamedTuple
-from dbgpt.core.interface.output_parser import BaseOutputParser
+
from dbgpt._private.config import Config
+from dbgpt.core.interface.output_parser import BaseOutputParser
CFG = Config()
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py
index b295288a3..49eebd37e 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py
@@ -1,9 +1,14 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.out_parser import (
ChatExcelOutputParser,
)
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
CFG = Config()
@@ -59,15 +64,20 @@
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.3
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(_PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{user_input}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatExcel.value(),
- input_variables=["user_input", "table_name", "disply_type"],
- template_define=_PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=ChatExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
need_historical_messages=True,
- # example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py
index db9927fd3..d28b3f79c 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py
@@ -1,10 +1,10 @@
import json
from typing import Any, Dict
-from dbgpt.core.interface.message import ViewMessage, AIMessage
from dbgpt.app.scene import BaseChat, ChatScene
-from dbgpt.util.json_utils import EnhancedJSONEncoder
+from dbgpt.core.interface.message import AIMessage, ViewMessage
from dbgpt.util.executor_utils import blocking_func_to_async
+from dbgpt.util.json_utils import EnhancedJSONEncoder
from dbgpt.util.tracer import trace
@@ -52,6 +52,7 @@ async def generate_input_values(self) -> Dict:
def message_adjust(self):
### adjust learning result in messages
+ # TODO: Can't work in multi-rounds chat
view_message = ""
for message in self.current_message.messages:
if message.type == ViewMessage.type:
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py
index 9cd89c34b..311e349e7 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/out_parser.py
@@ -1,6 +1,7 @@
import json
import logging
-from typing import NamedTuple, List
+from typing import List, NamedTuple
+
from dbgpt.core.interface.output_parser import BaseOutputParser
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py
index 603799e10..120901809 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/prompt.py
@@ -1,10 +1,17 @@
import json
-from dbgpt.core.interface.prompt import PromptTemplate
+
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.out_parser import (
LearningExcelOutputParser,
)
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
+from dbgpt.core.interface.prompt import PromptTemplate
CFG = Config()
@@ -72,15 +79,24 @@
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.8
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(
+ PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE,
+ response_format=json.dumps(
+ RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4
+ ),
+ ),
+ HumanPromptTemplate.from_template("{file_name}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ExcelLearning.value(),
- input_variables=["data_example"],
- response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=LearningExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
- # example_selector=sql_data_example,
+ need_historical_messages=False,
temperature=PROMPT_TEMPERATURE,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py
index 854c2bfe7..f05393db8 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_learning/verify_sql.py
@@ -1,4 +1,5 @@
import re
+
import sqlparse
diff --git a/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py b/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py
index 0ad2c70d5..f234ab2f6 100644
--- a/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py
+++ b/dbgpt/app/scene/chat_data/chat_excel/excel_reader.py
@@ -1,20 +1,20 @@
import logging
-
-import duckdb
import os
-import sqlparse
-import pandas as pd
+
import chardet
+import duckdb
import numpy as np
+import pandas as pd
+import sqlparse
from pyparsing import (
CaselessKeyword,
- Word,
- alphanums,
- delimitedList,
Forward,
- Optional,
Literal,
+ Optional,
Regex,
+ Word,
+ alphanums,
+ delimitedList,
)
from dbgpt.util.pd_utils import csv_colunm_foramt
diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py
index 3a3269d7e..8b59aeb42 100644
--- a/dbgpt/app/scene/chat_db/auto_execute/chat.py
+++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py
@@ -1,8 +1,8 @@
from typing import Dict
+from dbgpt._private.config import Config
from dbgpt.agent.plugin.commands.command_mange import ApiCall
from dbgpt.app.scene import BaseChat, ChatScene
-from dbgpt._private.config import Config
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
diff --git a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py
index dc1f50c66..d2e1eae96 100644
--- a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py
+++ b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py
@@ -1,11 +1,13 @@
import json
-from typing import Dict, NamedTuple
import logging
-import sqlparse
import xml.etree.ElementTree as ET
-from dbgpt.util.json_utils import serialize
-from dbgpt.core.interface.output_parser import BaseOutputParser
+from typing import Dict, NamedTuple
+
+import sqlparse
+
from dbgpt._private.config import Config
+from dbgpt.core.interface.output_parser import BaseOutputParser
+from dbgpt.util.json_utils import serialize
CFG = Config()
diff --git a/dbgpt/app/scene/chat_db/auto_execute/prompt.py b/dbgpt/app/scene/chat_db/auto_execute/prompt.py
index 6fb0d7de6..52b153215 100644
--- a/dbgpt/app/scene/chat_db/auto_execute/prompt.py
+++ b/dbgpt/app/scene/chat_db/auto_execute/prompt.py
@@ -1,8 +1,14 @@
import json
-from dbgpt.core.interface.prompt import PromptTemplate
+
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
CFG = Config()
@@ -77,16 +83,25 @@
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(
+ _DEFAULT_TEMPLATE,
+ response_format=json.dumps(
+ RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4
+ ),
+ ),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{user_input}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatWithDbExecute.value(),
- input_variables=["table_info", "dialect", "top_k", "response"],
- response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
- # example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
- need_historical_messages=True,
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py b/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py
index 25471da55..e994f51ec 100644
--- a/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py
+++ b/dbgpt/app/scene/chat_db/auto_execute/prompt_baichuan.py
@@ -2,10 +2,11 @@
# -*- coding: utf-8 -*-
import json
-from dbgpt.core.interface.prompt import PromptTemplate
+
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
+from dbgpt.core.interface.prompt import PromptTemplate
CFG = Config()
diff --git a/dbgpt/app/scene/chat_db/data_loader.py b/dbgpt/app/scene/chat_db/data_loader.py
index 7abaf8b53..5004febd8 100644
--- a/dbgpt/app/scene/chat_db/data_loader.py
+++ b/dbgpt/app/scene/chat_db/data_loader.py
@@ -1,24 +1,12 @@
-import xml.etree.ElementTree as ET
import json
import logging
+import xml.etree.ElementTree as ET
from dbgpt.util.json_utils import serialize
class DbDataLoader:
def get_table_view_by_conn(self, data, speak, sql: str = None):
- # import pandas as pd
- #
- # ### tool out data to table view
- # if len(data) < 1:
- # data.insert(0, ["result"])
- # df = pd.DataFrame(data[1:], columns=data[0])
- # html_table = df.to_html(index=False, escape=False, sparsify=False)
- # table_str = "".join(html_table.split())
- # html = f"""
{table_str}
"""
- # view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
- # return view_text
-
param = {}
api_call_element = ET.Element("chart-view")
err_msg = None
diff --git a/dbgpt/app/scene/chat_db/professional_qa/chat.py b/dbgpt/app/scene/chat_db/professional_qa/chat.py
index 01d154da9..fb616cf50 100644
--- a/dbgpt/app/scene/chat_db/professional_qa/chat.py
+++ b/dbgpt/app/scene/chat_db/professional_qa/chat.py
@@ -1,7 +1,7 @@
from typing import Dict
-from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
+from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import trace
@@ -11,6 +11,8 @@
class ChatWithDbQA(BaseChat):
chat_scene: str = ChatScene.ChatWithDbQA.value()
+ keep_end_rounds = 5
+
"""As a DBA, Chat DB Module, chat with combine DB meta schema """
def __init__(self, chat_param: Dict):
diff --git a/dbgpt/app/scene/chat_db/professional_qa/prompt.py b/dbgpt/app/scene/chat_db/professional_qa/prompt.py
index 6c3efc252..d7a9846c7 100644
--- a/dbgpt/app/scene/chat_db/professional_qa/prompt.py
+++ b/dbgpt/app/scene/chat_db/professional_qa/prompt.py
@@ -1,29 +1,15 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser
-
-CFG = Config()
-
-PROMPT_SCENE_DEFINE = (
- """You are an assistant that answers user specialized database questions. """
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
)
-# PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
-# {table_info}
-#
-# Question: {input}
-#
-# """
+CFG = Config()
-# _DEFAULT_TEMPLATE = """
-# You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
-# Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
-# You can order the results by a relevant column to return the most interesting examples in the database.
-# Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
-# Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
-#
-# """
_DEFAULT_TEMPLATE_EN = """
Provide professional answers to requests and questions. If you can't get an answer from what you've provided, say: "Insufficient information in the knowledge base is available to answer this question." Feel free to fudge information.
@@ -43,7 +29,7 @@
问题:
{input}
-一步步思考
+一步步思考。
"""
_DEFAULT_TEMPLATE = (
@@ -53,14 +39,24 @@
PROMPT_NEED_STREAM_OUT = True
-prompt = PromptTemplate(
+
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(_DEFAULT_TEMPLATE),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{input}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatWithDbQA.value(),
- input_variables=["input", "table_info"],
- response_format=None,
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
+ need_historical_messages=True,
)
-CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
+
+CFG.prompt_template_registry.register(
+ prompt_adapter, language=CFG.LANGUAGE, is_default=True
+)
diff --git a/dbgpt/app/scene/chat_execution/chat.py b/dbgpt/app/scene/chat_execution/chat.py
index b1c4f75c1..44b9e586e 100644
--- a/dbgpt/app/scene/chat_execution/chat.py
+++ b/dbgpt/app/scene/chat_execution/chat.py
@@ -1,9 +1,9 @@
-from typing import List, Dict
+from typing import Dict, List
-from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
from dbgpt.agent.plugin.commands.command import execute_command
from dbgpt.agent.plugin.generator import PluginPromptGenerator
+from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.tracer import trace
CFG = Config()
diff --git a/dbgpt/app/scene/chat_execution/out_parser.py b/dbgpt/app/scene/chat_execution/out_parser.py
index 37e71766d..bd4da7bc9 100644
--- a/dbgpt/app/scene/chat_execution/out_parser.py
+++ b/dbgpt/app/scene/chat_execution/out_parser.py
@@ -1,8 +1,8 @@
import json
import logging
from typing import Dict, NamedTuple
-from dbgpt.core.interface.output_parser import BaseOutputParser, T
+from dbgpt.core.interface.output_parser import BaseOutputParser, T
logger = logging.getLogger(__name__)
diff --git a/dbgpt/app/scene/chat_execution/prompt.py b/dbgpt/app/scene/chat_execution/prompt.py
index 2292292f3..9ada61f60 100644
--- a/dbgpt/app/scene/chat_execution/prompt.py
+++ b/dbgpt/app/scene/chat_execution/prompt.py
@@ -1,9 +1,14 @@
import json
-from dbgpt.core.interface.prompt import PromptTemplate
-from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt._private.config import Config
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
CFG = Config()
@@ -34,15 +39,23 @@
### Whether the model service is streaming output
PROMPT_NEED_STREAM_OUT = False
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(
+ PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE,
+ response_format=json.dumps(RESPONSE_FORMAT, indent=4),
+ ),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{input}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatExecution.value(),
- input_variables=["input", "constraints", "commands_infos", "response"],
- response_format=json.dumps(RESPONSE_FORMAT, indent=4),
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
- # example_selector=plugin_example,
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_factory.py b/dbgpt/app/scene/chat_factory.py
index 258020fac..73470f65f 100644
--- a/dbgpt/app/scene/chat_factory.py
+++ b/dbgpt/app/scene/chat_factory.py
@@ -1,4 +1,5 @@
from dbgpt.app.scene.base_chat import BaseChat
+from dbgpt.core import PromptTemplate
from dbgpt.util.singleton import Singleton
from dbgpt.util.tracer import root_tracer
@@ -7,31 +8,31 @@ class ChatFactory(metaclass=Singleton):
@staticmethod
def get_implementation(chat_mode, **kwargs):
# Lazy loading
- from dbgpt.app.scene.chat_execution.chat import ChatWithPlugin
- from dbgpt.app.scene.chat_execution.prompt import prompt
- from dbgpt.app.scene.chat_normal.chat import ChatNormal
- from dbgpt.app.scene.chat_normal.prompt import prompt
- from dbgpt.app.scene.chat_db.professional_qa.chat import ChatWithDbQA
- from dbgpt.app.scene.chat_db.professional_qa.prompt import prompt
- from dbgpt.app.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
- from dbgpt.app.scene.chat_db.auto_execute.prompt import prompt
+ from dbgpt.app.scene.chat_agent.chat import ChatAgent
+ from dbgpt.app.scene.chat_agent.prompt import prompt
from dbgpt.app.scene.chat_dashboard.chat import ChatDashboard
from dbgpt.app.scene.chat_dashboard.prompt import prompt
- from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge
- from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt
- from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
- from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt
+ from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
+ from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
+ from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt
+ from dbgpt.app.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
+ from dbgpt.app.scene.chat_db.auto_execute.prompt import prompt
+ from dbgpt.app.scene.chat_db.professional_qa.chat import ChatWithDbQA
+ from dbgpt.app.scene.chat_db.professional_qa.prompt import prompt
+ from dbgpt.app.scene.chat_execution.chat import ChatWithPlugin
+ from dbgpt.app.scene.chat_execution.prompt import prompt
from dbgpt.app.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from dbgpt.app.scene.chat_knowledge.extract_entity.prompt import prompt
+ from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
+ from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt
from dbgpt.app.scene.chat_knowledge.refine_summary.chat import (
ExtractRefineSummary,
)
from dbgpt.app.scene.chat_knowledge.refine_summary.prompt import prompt
- from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
- from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
- from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt
- from dbgpt.app.scene.chat_agent.chat import ChatAgent
- from dbgpt.app.scene.chat_agent.prompt import prompt
+ from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge
+ from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt
+ from dbgpt.app.scene.chat_normal.chat import ChatNormal
+ from dbgpt.app.scene.chat_normal.prompt import prompt
chat_classes = BaseChat.__subclasses__()
implementation = None
diff --git a/dbgpt/app/scene/chat_knowledge/__init__.py b/dbgpt/app/scene/chat_knowledge/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py b/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py
index 9551c1a4e..64e296dea 100644
--- a/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py
+++ b/dbgpt/app/scene/chat_knowledge/extract_entity/prompt.py
@@ -1,7 +1,7 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_knowledge.extract_entity.out_parser import ExtractEntityParser
+from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate
CFG = Config()
@@ -21,20 +21,21 @@
"""
PROMPT_RESPONSE = """"""
-
-RESPONSE_FORMAT = """"""
-
-
PROMPT_NEED_NEED_STREAM_OUT = False
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE),
+ HumanPromptTemplate.from_template(_DEFAULT_TEMPLATE + PROMPT_RESPONSE),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ExtractEntity.value(),
- input_variables=["text"],
- response_format="",
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractEntityParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py b/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py
index 4ee5e967a..49d0d4ed3 100644
--- a/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py
+++ b/dbgpt/app/scene/chat_knowledge/extract_triplet/prompt.py
@@ -1,12 +1,9 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-
-
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_knowledge.extract_triplet.out_parser import (
ExtractTripleParser,
)
-
+from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate
CFG = Config()
@@ -33,19 +30,22 @@
PROMPT_RESPONSE = """"""
-RESPONSE_FORMAT = """"""
+PROMPT_NEED_NEED_STREAM_OUT = False
-PROMPT_NEED_NEED_STREAM_OUT = False
+prompt = ChatPromptTemplate(
+ messages=[
+ # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE),
+ HumanPromptTemplate.from_template(_DEFAULT_TEMPLATE + PROMPT_RESPONSE),
+ ]
+)
-prompt = PromptTemplate(
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ExtractTriplet.value(),
- input_variables=["text"],
- response_format="",
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractTripleParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py b/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py
index 7e94168ce..ec2f1e9c1 100644
--- a/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py
+++ b/dbgpt/app/scene/chat_knowledge/refine_summary/prompt.py
@@ -1,10 +1,9 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_knowledge.refine_summary.out_parser import (
ExtractRefineSummaryParser,
)
+from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate
CFG = Config()
@@ -27,17 +26,21 @@
PROMPT_RESPONSE = """"""
-
PROMPT_NEED_NEED_STREAM_OUT = True
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE),
+ HumanPromptTemplate.from_template(_DEFAULT_TEMPLATE + PROMPT_RESPONSE),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ExtractRefineSummary.value(),
- input_variables=["existing_answer"],
- response_format=None,
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractRefineSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, is_default=True)
+CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py
index 1d4854254..4090d9238 100644
--- a/dbgpt/app/scene/chat_knowledge/v1/chat.py
+++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py
@@ -3,20 +3,22 @@
from functools import reduce
from typing import Dict, List
-from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
-from dbgpt.component import ComponentType
-
-from dbgpt.configs.model_config import (
- EMBEDDING_MODEL_CONFIG,
-)
-
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.service import KnowledgeService
+from dbgpt.app.scene import BaseChat, ChatScene
+from dbgpt.component import ComponentType
+from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.retriever.rewrite import QueryRewrite
@@ -130,8 +132,19 @@ def stream_call_reinforce_fn(self, text):
@trace()
async def generate_input_values(self) -> Dict:
if self.space_context and self.space_context.get("prompt"):
- self.prompt_template.template_define = self.space_context["prompt"]["scene"]
- self.prompt_template.template = self.space_context["prompt"]["template"]
+ # Not use template_define
+ # self.prompt_template.template_define = self.space_context["prompt"]["scene"]
+ # self.prompt_template.template = self.space_context["prompt"]["template"]
+ # Replace the template with the prompt template
+ self.prompt_template.prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(
+ self.space_context["prompt"]["template"]
+ ),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{question}"),
+ ]
+ )
from dbgpt.util.chat_util import run_async_tasks
tasks = [self.execute_similar_search(self.current_user_input)]
diff --git a/dbgpt/app/scene/chat_knowledge/v1/out_parser.py b/dbgpt/app/scene/chat_knowledge/v1/out_parser.py
index 8d30e98fe..2660cabea 100644
--- a/dbgpt/app/scene/chat_knowledge/v1/out_parser.py
+++ b/dbgpt/app/scene/chat_knowledge/v1/out_parser.py
@@ -1,6 +1,6 @@
import logging
-from dbgpt.core.interface.output_parser import BaseOutputParser, T
+from dbgpt.core.interface.output_parser import BaseOutputParser, T
logger = logging.getLogger(__name__)
diff --git a/dbgpt/app/scene/chat_knowledge/v1/prompt.py b/dbgpt/app/scene/chat_knowledge/v1/prompt.py
index ee6a166a5..b32e8f5e4 100644
--- a/dbgpt/app/scene/chat_knowledge/v1/prompt.py
+++ b/dbgpt/app/scene/chat_knowledge/v1/prompt.py
@@ -1,8 +1,12 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
CFG = Config()
@@ -28,17 +32,23 @@
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
-
PROMPT_NEED_STREAM_OUT = True
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(_DEFAULT_TEMPLATE),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{question}"),
+ ]
+)
-prompt = PromptTemplate(
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatKnowledge.value(),
- input_variables=["context", "question"],
- response_format=None,
- template_define=PROMPT_SCENE_DEFINE,
- template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
+ need_historical_messages=False,
)
-CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
+CFG.prompt_template_registry.register(
+ prompt_adapter, language=CFG.LANGUAGE, is_default=True
+)
diff --git a/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py b/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py
index bfbdaed1f..ed76f69c6 100644
--- a/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py
+++ b/dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py
@@ -1,9 +1,12 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser
-
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
CFG = Config()
@@ -32,18 +35,24 @@
PROMPT_NEED_STREAM_OUT = True
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(_DEFAULT_TEMPLATE),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{question}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatKnowledge.value(),
- input_variables=["context", "question"],
- response_format=None,
- template_define=None,
- template=_DEFAULT_TEMPLATE,
- stream_out=PROMPT_NEED_STREAM_OUT,
+ stream_out=True,
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
+ need_historical_messages=False,
)
CFG.prompt_template_registry.register(
- prompt,
+ prompt_adapter,
language=CFG.LANGUAGE,
is_default=False,
model_names=["chatglm-6b-int4", "chatglm-6b", "chatglm2-6b", "chatglm2-6b-int4"],
diff --git a/dbgpt/app/scene/chat_normal/chat.py b/dbgpt/app/scene/chat_normal/chat.py
index 702a190b9..30b22ebd9 100644
--- a/dbgpt/app/scene/chat_normal/chat.py
+++ b/dbgpt/app/scene/chat_normal/chat.py
@@ -1,8 +1,7 @@
from typing import Dict
-from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
-
+from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.tracer import trace
CFG = Config()
@@ -11,6 +10,8 @@
class ChatNormal(BaseChat):
chat_scene: str = ChatScene.ChatNormal.value()
+ keep_end_rounds: int = 10
+
"""Number of results to return from the query"""
def __init__(self, chat_param: Dict):
diff --git a/dbgpt/app/scene/chat_normal/out_parser.py b/dbgpt/app/scene/chat_normal/out_parser.py
index 8d30e98fe..2660cabea 100644
--- a/dbgpt/app/scene/chat_normal/out_parser.py
+++ b/dbgpt/app/scene/chat_normal/out_parser.py
@@ -1,6 +1,6 @@
import logging
-from dbgpt.core.interface.output_parser import BaseOutputParser, T
+from dbgpt.core.interface.output_parser import BaseOutputParser, T
logger = logging.getLogger(__name__)
diff --git a/dbgpt/app/scene/chat_normal/prompt.py b/dbgpt/app/scene/chat_normal/prompt.py
index 752193936..04dd28a51 100644
--- a/dbgpt/app/scene/chat_normal/prompt.py
+++ b/dbgpt/app/scene/chat_normal/prompt.py
@@ -1,25 +1,40 @@
-from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-
+from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser
+from dbgpt.core import (
+ ChatPromptTemplate,
+ HumanPromptTemplate,
+ MessagesPlaceholder,
+ SystemPromptTemplate,
+)
-PROMPT_SCENE_DEFINE = None
+PROMPT_SCENE_DEFINE_EN = "You are a helpful AI assistant."
+PROMPT_SCENE_DEFINE_ZH = "你是一个有用的 AI 助手。"
CFG = Config()
+PROMPT_SCENE_DEFINE = (
+ PROMPT_SCENE_DEFINE_ZH if CFG.LANGUAGE == "zh" else PROMPT_SCENE_DEFINE_EN
+)
PROMPT_NEED_STREAM_OUT = True
-prompt = PromptTemplate(
+prompt = ChatPromptTemplate(
+ messages=[
+ SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE),
+ MessagesPlaceholder(variable_name="chat_history"),
+ HumanPromptTemplate.from_template("{input}"),
+ ]
+)
+
+prompt_adapter = AppScenePromptTemplateAdapter(
+ prompt=prompt,
template_scene=ChatScene.ChatNormal.value(),
- input_variables=["input"],
- response_format=None,
- template_define=PROMPT_SCENE_DEFINE,
- template=None,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
+ need_historical_messages=True,
)
-# CFG.prompt_templates.update({prompt.template_scene: prompt})
-CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
+CFG.prompt_template_registry.register(
+ prompt_adapter, language=CFG.LANGUAGE, is_default=True
+)
diff --git a/dbgpt/app/scene/operator/_experimental.py b/dbgpt/app/scene/operator/_experimental.py
deleted file mode 100644
index c422846c0..000000000
--- a/dbgpt/app/scene/operator/_experimental.py
+++ /dev/null
@@ -1,262 +0,0 @@
-from typing import Dict, Optional, List
-from dataclasses import dataclass
-import datetime
-import os
-
-from dbgpt.configs.model_config import PILOT_PATH
-from dbgpt.core.awel import MapOperator
-from dbgpt.core.interface.prompt import PromptTemplate
-from dbgpt._private.config import Config
-from dbgpt.app.scene import ChatScene
-from dbgpt.core.interface.message import OnceConversation
-from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
-from dbgpt.rag.retriever.embedding import EmbeddingRetriever
-
-from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
-from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
-from dbgpt.storage.vector_store.base import VectorStoreConfig
-from dbgpt.storage.vector_store.connector import VectorStoreConnector
-
-# TODO move global config
-CFG = Config()
-
-
-@dataclass
-class ChatContext:
- current_user_input: str
- model_name: Optional[str]
- chat_session_id: Optional[str] = None
- select_param: Optional[str] = None
- chat_scene: Optional[ChatScene] = ChatScene.ChatNormal
- prompt_template: Optional[PromptTemplate] = None
- chat_retention_rounds: Optional[int] = 0
- history_storage: Optional[BaseChatHistoryMemory] = None
- history_manager: Optional["ChatHistoryManager"] = None
- # The input values for prompt template
- input_values: Optional[Dict] = None
- echo: Optional[bool] = False
-
- def build_model_payload(self) -> Dict:
- if not self.input_values:
- raise ValueError("The input value can't be empty")
- llm_messages = self.history_manager._new_chat(self.input_values)
- return {
- "model": self.model_name,
- "prompt": "",
- "messages": llm_messages,
- "temperature": float(self.prompt_template.temperature),
- "max_new_tokens": int(self.prompt_template.max_new_tokens),
- "echo": self.echo,
- }
-
-
-class ChatHistoryManager:
- def __init__(
- self,
- chat_ctx: ChatContext,
- prompt_template: PromptTemplate,
- history_storage: BaseChatHistoryMemory,
- chat_retention_rounds: Optional[int] = 0,
- ) -> None:
- self._chat_ctx = chat_ctx
- self.chat_retention_rounds = chat_retention_rounds
- self.current_message: OnceConversation = OnceConversation(
- chat_ctx.chat_scene.value()
- )
- self.prompt_template = prompt_template
- self.history_storage: BaseChatHistoryMemory = history_storage
- self.history_message: List[OnceConversation] = history_storage.messages()
- self.current_message.model_name = chat_ctx.model_name
- if chat_ctx.select_param:
- if len(chat_ctx.chat_scene.param_types()) > 0:
- self.current_message.param_type = chat_ctx.chat_scene.param_types()[0]
- self.current_message.param_value = chat_ctx.select_param
-
- def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
- self.current_message.chat_order = len(self.history_message) + 1
- self.current_message.add_user_message(
- self._chat_ctx.current_user_input, check_duplicate_type=True
- )
- self.current_message.start_date = datetime.datetime.now().strftime(
- "%Y-%m-%d %H:%M:%S"
- )
- self.current_message.tokens = 0
- if self.prompt_template.template:
- current_prompt = self.prompt_template.format(**input_values)
- self.current_message.add_system_message(current_prompt)
- return self._generate_llm_messages()
-
- def _generate_llm_messages(self) -> List[ModelMessage]:
- from dbgpt.app.scene.base_chat import (
- _load_system_message,
- _load_example_messages,
- _load_history_messages,
- _load_user_message,
- )
-
- messages = []
- ### Load scene setting or character definition as system message
- if self.prompt_template.template_define:
- messages.append(
- ModelMessage(
- role=ModelMessageRoleType.SYSTEM,
- content=self.prompt_template.template_define,
- )
- )
- ### Load prompt
- messages += _load_system_message(
- self.current_message, self.prompt_template, str_message=False
- )
- ### Load examples
- messages += _load_example_messages(self.prompt_template, str_message=False)
-
- ### Load History
- messages += _load_history_messages(
- self.prompt_template,
- self.history_message,
- self.chat_retention_rounds,
- str_message=False,
- )
-
- ### Load User Input
- messages += _load_user_message(
- self.current_message, self.prompt_template, str_message=False
- )
- return messages
-
-
-class PromptManagerOperator(MapOperator[ChatContext, ChatContext]):
- def __init__(self, prompt_template: PromptTemplate = None, **kwargs):
- super().__init__(**kwargs)
- self._prompt_template = prompt_template
-
- async def map(self, input_value: ChatContext) -> ChatContext:
- if not self._prompt_template:
- self._prompt_template: PromptTemplate = (
- CFG.prompt_template_registry.get_prompt_template(
- input_value.chat_scene.value(),
- language=CFG.LANGUAGE,
- model_name=input_value.model_name,
- proxyllm_backend=CFG.PROXYLLM_BACKEND,
- )
- )
- input_value.prompt_template = self._prompt_template
- return input_value
-
-
-class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]):
- def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
- super().__init__(**kwargs)
- self._history = history
-
- async def map(self, input_value: ChatContext) -> ChatContext:
- if self._history:
- return self._history
- chat_history_fac = ChatHistory()
- input_value.history_storage = chat_history_fac.get_store_instance(
- input_value.chat_session_id
- )
- return input_value
-
-
-class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]):
- def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
- super().__init__(**kwargs)
- self._history = history
-
- async def map(self, input_value: ChatContext) -> ChatContext:
- history_storage = self._history or input_value.history_storage
- if not history_storage:
- from dbgpt.storage.chat_history.store_type.mem_history import (
- MemHistoryMemory,
- )
-
- history_storage = MemHistoryMemory(input_value.chat_session_id)
- input_value.history_storage = history_storage
- input_value.history_manager = ChatHistoryManager(
- input_value,
- input_value.prompt_template,
- history_storage,
- input_value.chat_retention_rounds,
- )
- return input_value
-
-
-class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- async def map(self, input_value: ChatContext) -> ChatContext:
- from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
- from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
-
- # TODO, decompose the current operator into some atomic operators
- knowledge_space = input_value.select_param
- embedding_factory = self.system_app.get_component(
- "embedding_factory", EmbeddingFactory
- )
-
- space_context = await self._get_space_context(knowledge_space)
- top_k = (
- CFG.KNOWLEDGE_SEARCH_TOP_SIZE
- if space_context is None
- else int(space_context["embedding"]["topk"])
- )
- max_token = (
- CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
- if space_context is None or space_context.get("prompt") is None
- else int(space_context["prompt"]["max_token"])
- )
- input_value.prompt_template.template_is_strict = False
- if space_context and space_context.get("prompt"):
- input_value.prompt_template.template_define = space_context["prompt"][
- "scene"
- ]
- input_value.prompt_template.template = space_context["prompt"]["template"]
-
- config = VectorStoreConfig(
- name=knowledge_space,
- embedding_fn=embedding_factory.create(
- EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
- ),
- )
- vector_store_connector = VectorStoreConnector(
- vector_store_type=CFG.VECTOR_STORE_TYPE,
- vector_store_config=config,
- )
- embedding_retriever = EmbeddingRetriever(
- top_k=top_k, vector_store_connector=vector_store_connector
- )
- docs = await self.blocking_func_to_async(
- embedding_retriever.retrieve,
- input_value.current_user_input,
- )
- if not docs or len(docs) == 0:
- print("no relevant docs to retrieve")
- context = "no relevant docs to retrieve"
- else:
- context = [d.content for d in docs]
- context = context[:max_token]
- relations = list(
- set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
- )
- input_value.input_values = {
- "context": context,
- "question": input_value.current_user_input,
- "relations": relations,
- }
- return input_value
-
- async def _get_space_context(self, space_name):
- from dbgpt.app.knowledge.service import KnowledgeService
-
- service = KnowledgeService()
- return await self.blocking_func_to_async(service.get_space_context, space_name)
-
-
-class BaseChatOperator(MapOperator[ChatContext, Dict]):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- async def map(self, input_value: ChatContext) -> Dict:
- return input_value.build_model_payload()
diff --git a/dbgpt/app/scene/operator/app_operator.py b/dbgpt/app/scene/operator/app_operator.py
new file mode 100644
index 000000000..799b89505
--- /dev/null
+++ b/dbgpt/app/scene/operator/app_operator.py
@@ -0,0 +1,85 @@
+import dataclasses
+from typing import Any, Dict, List, Optional
+
+from dbgpt.core import BaseMessage, ChatPromptTemplate, ModelMessage
+from dbgpt.core.awel import (
+ DAG,
+ BaseOperator,
+ InputOperator,
+ MapOperator,
+ SimpleCallDataInputSource,
+)
+from dbgpt.core.operator import (
+ BufferedConversationMapperOperator,
+ HistoryPromptBuilderOperator,
+)
+
+
+@dataclasses.dataclass
+class ChatComposerInput:
+ """The composer input."""
+
+ messages: List[BaseMessage]
+ prompt_dict: Dict[str, Any]
+
+
+class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]):
+ """App chat composer operator.
+
+ TODO: Support more history merge mode.
+ """
+
+ def __init__(
+ self,
+ prompt: ChatPromptTemplate,
+ history_key: str = "chat_history",
+ history_merge_mode: str = "window",
+ keep_start_rounds: Optional[int] = None,
+ keep_end_rounds: Optional[int] = None,
+ str_history: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self._prompt_template = prompt
+ self._history_key = history_key
+ self._history_merge_mode = history_merge_mode
+ self._keep_start_rounds = keep_start_rounds
+ self._keep_end_rounds = keep_end_rounds
+ self._str_history = str_history
+ self._sub_compose_dag = self._build_composer_dag()
+
+ async def map(self, input_value: ChatComposerInput) -> List[ModelMessage]:
+ end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
+ # Sub dag, use the same dag context in the parent dag
+ return await end_node.call(
+ call_data={"data": input_value}, dag_ctx=self.current_dag_context
+ )
+
+ def _build_composer_dag(self) -> DAG:
+ with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
+ input_task = InputOperator(input_source=SimpleCallDataInputSource())
+ # History transform task
+ history_transform_task = BufferedConversationMapperOperator(
+ keep_start_rounds=self._keep_start_rounds,
+ keep_end_rounds=self._keep_end_rounds,
+ )
+ history_prompt_build_task = HistoryPromptBuilderOperator(
+ prompt=self._prompt_template,
+ history_key=self._history_key,
+ check_storage=False,
+ str_history=self._str_history,
+ )
+ # Build composer dag
+ (
+ input_task
+ >> MapOperator(lambda x: x.messages)
+ >> history_transform_task
+ >> history_prompt_build_task
+ )
+ (
+ input_task
+ >> MapOperator(lambda x: x.prompt_dict)
+ >> history_prompt_build_task
+ )
+
+ return composer_dag
diff --git a/dbgpt/app/tests/test_base.py b/dbgpt/app/tests/test_base.py
index 4f5a21a9f..cefea1cb9 100644
--- a/dbgpt/app/tests/test_base.py
+++ b/dbgpt/app/tests/test_base.py
@@ -1,5 +1,6 @@
+from unittest.mock import MagicMock, patch
+
import pytest
-from unittest.mock import patch, MagicMock
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from dbgpt.app.base import _create_mysql_database
diff --git a/dbgpt/cli/cli_scripts.py b/dbgpt/cli/cli_scripts.py
index fcd8edc51..e45a32819 100644
--- a/dbgpt/cli/cli_scripts.py
+++ b/dbgpt/cli/cli_scripts.py
@@ -1,7 +1,8 @@
-import click
import copy
import logging
+import click
+
logging.basicConfig(
level=logging.WARNING,
encoding="utf-8",
@@ -82,14 +83,14 @@ def stop_all():
try:
from dbgpt.model.cli import (
+ _stop_all_model_server,
model_cli_group,
+ start_apiserver,
start_model_controller,
- stop_model_controller,
start_model_worker,
- stop_model_worker,
- start_apiserver,
stop_apiserver,
- _stop_all_model_server,
+ stop_model_controller,
+ stop_model_worker,
)
add_command_alias(model_cli_group, name="model", parent_group=cli)
@@ -107,10 +108,10 @@ def stop_all():
try:
from dbgpt.app._cli import (
- start_webserver,
- stop_webserver,
_stop_all_dbgpt_server,
migration,
+ start_webserver,
+ stop_webserver,
)
add_command_alias(start_webserver, name="webserver", parent_group=start)
diff --git a/dbgpt/component.py b/dbgpt/component.py
index 466d01c80..f094213a1 100644
--- a/dbgpt/component.py
+++ b/dbgpt/component.py
@@ -1,13 +1,14 @@
from __future__ import annotations
-from abc import ABC, abstractmethod
+import asyncio
+import logging
import sys
-from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
+from abc import ABC, abstractmethod
from enum import Enum
-import logging
-import asyncio
-from dbgpt.util.annotations import PublicAPI
+from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union
+
from dbgpt.util import AppConfig
+from dbgpt.util.annotations import PublicAPI
# Checking for type hints during runtime
if TYPE_CHECKING:
@@ -80,6 +81,9 @@ class ComponentType(str, Enum):
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"
+_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
+
+
@PublicAPI(stability="beta")
class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
@@ -98,11 +102,37 @@ def init_app(self, system_app: SystemApp):
with the main system app.
"""
+ @classmethod
+ def get_instance(
+ cls,
+ system_app: SystemApp,
+ default_component=_EMPTY_DEFAULT_COMPONENT,
+ or_register_component: Type[BaseComponent] = None,
+ *args,
+ **kwargs,
+ ) -> BaseComponent:
+ """Get the current component instance.
+
+ Args:
+ system_app (SystemApp): The system app
+ default_component : The default component instance if not retrieve by name
+ or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
+
+ Returns:
+ BaseComponent: The component instance
+ """
+ return system_app.get_component(
+ cls.name,
+ cls,
+ default_component=default_component,
+ or_register_component=or_register_component,
+ *args,
+ **kwargs,
+ )
+
T = TypeVar("T", bound=BaseComponent)
-_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
-
@PublicAPI(stability="beta")
class SystemApp(LifeCycle):
diff --git a/dbgpt/core/_private/prompt_registry.py b/dbgpt/core/_private/prompt_registry.py
index e35282b6f..5684b1f6a 100644
--- a/dbgpt/core/_private/prompt_registry.py
+++ b/dbgpt/core/_private/prompt_registry.py
@@ -20,8 +20,9 @@ def register(
self,
prompt_template,
language: str = "en",
- is_default=False,
+ is_default: bool = False,
model_names: List[str] = None,
+ scene_name: str = None,
) -> None:
"""Register prompt template with scene name, language
registry dict format:
@@ -37,7 +38,8 @@ def register(
}
}
"""
- scene_name = prompt_template.template_scene
+ if not scene_name:
+ scene_name = prompt_template.template_scene
if not scene_name:
raise ValueError("Prompt template scene name cannot be empty")
if not model_names:
diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py
index cded3f18a..25ba8b665 100755
--- a/dbgpt/core/interface/message.py
+++ b/dbgpt/core/interface/message.py
@@ -45,6 +45,18 @@ def to_dict(self) -> Dict:
"round_index": self.round_index,
}
+ @staticmethod
+ def messages_to_string(messages: List["BaseMessage"]) -> str:
+ """Convert messages to str
+
+ Args:
+ messages (List[BaseMessage]): The messages
+
+ Returns:
+ str: The str messages
+ """
+ return _messages_to_str(messages)
+
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@@ -251,6 +263,41 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
return [_message_to_dict(m) for m in messages]
+def _messages_to_str(
+ messages: List[BaseMessage],
+ human_prefix: str = "Human",
+ ai_prefix: str = "AI",
+ system_prefix: str = "System",
+) -> str:
+ """Convert messages to str
+
+ Args:
+ messages (List[BaseMessage]): The messages
+ human_prefix (str): The human prefix
+ ai_prefix (str): The ai prefix
+ system_prefix (str): The system prefix
+
+ Returns:
+ str: The str messages
+ """
+ str_messages = []
+ for message in messages:
+ role = None
+ if isinstance(message, HumanMessage):
+ role = human_prefix
+ elif isinstance(message, AIMessage):
+ role = ai_prefix
+ elif isinstance(message, SystemMessage):
+ role = system_prefix
+ elif isinstance(message, ViewMessage):
+ pass
+ else:
+ raise ValueError(f"Got unsupported message type: {message}")
+ if role:
+ str_messages.append(f"{role}: {message.content}")
+ return "\n".join(str_messages)
+
+
def _message_from_dict(message: Dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
@@ -382,6 +429,9 @@ def _append_message(self, message: BaseMessage) -> None:
self._message_index += 1
message.index = index
message.round_index = self.chat_order
+ message.additional_kwargs["param_type"] = self.param_type
+ message.additional_kwargs["param_value"] = self.param_value
+ message.additional_kwargs["model_name"] = self.model_name
self.messages.append(message)
def start_new_round(self) -> None:
@@ -504,9 +554,12 @@ def from_conversation(self, conversation: OnceConversation) -> None:
self.messages = conversation.messages
self.start_date = conversation.start_date
self.chat_order = conversation.chat_order
- self.model_name = conversation.model_name
- self.param_type = conversation.param_type
- self.param_value = conversation.param_value
+ if not self.model_name and conversation.model_name:
+ self.model_name = conversation.model_name
+ if not self.param_type and conversation.param_type:
+ self.param_type = conversation.param_type
+ if not self.param_value and conversation.param_value:
+ self.param_value = conversation.param_value
self.cost = conversation.cost
self.tokens = conversation.tokens
self.user_name = conversation.user_name
@@ -801,6 +854,7 @@ def __init__(
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
+ load_message: bool = True,
**kwargs,
):
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
@@ -811,6 +865,8 @@ def __init__(
self._has_stored_message_index = (
len(kwargs["messages"]) - 1 if "messages" in kwargs else -1
)
+ # Whether to load the message from the storage
+ self._load_message = load_message
self.save_message_independent = save_message_independent
self._id = ConversationIdentifier(conv_uid)
if conv_storage is None:
@@ -853,7 +909,9 @@ def save_to_storage(self) -> None:
]
messages_to_save = message_list[self._has_stored_message_index + 1 :]
self._has_stored_message_index = len(message_list) - 1
- self.message_storage.save_list(messages_to_save)
+ if self.save_message_independent:
+ # Save messages independently
+ self.message_storage.save_list(messages_to_save)
# Save conversation
self.conv_storage.save_or_update(self)
@@ -876,23 +934,71 @@ def load_from_storage(
return
message_ids = conversation._message_ids or []
- # Load messages
- message_list = message_storage.load_list(
- [
- MessageIdentifier.from_str_identifier(message_id)
- for message_id in message_ids
- ],
- MessageStorageItem,
- )
- messages = [message.to_message() for message in message_list]
- conversation.messages = messages
+ if self._load_message:
+ # Load messages
+ message_list = message_storage.load_list(
+ [
+ MessageIdentifier.from_str_identifier(message_id)
+ for message_id in message_ids
+ ],
+ MessageStorageItem,
+ )
+ messages = [message.to_message() for message in message_list]
+ else:
+ messages = []
+ real_messages = messages or conversation.messages
+ conversation.messages = real_messages
# This index is used to save the message to the storage(Has not been saved)
# The new message append to the messages, so the index is len(messages)
- conversation._message_index = len(messages)
+ conversation._message_index = len(real_messages)
+ conversation.chat_order = (
+ max(m.round_index for m in real_messages) if real_messages else 0
+ )
+ self._append_additional_kwargs(conversation, real_messages)
self._message_ids = message_ids
- self._has_stored_message_index = len(messages) - 1
+ self._has_stored_message_index = len(real_messages) - 1
+ self.save_message_independent = conversation.save_message_independent
self.from_conversation(conversation)
+ def _append_additional_kwargs(
+ self, conversation: StorageConversation, messages: List[BaseMessage]
+ ) -> None:
+ """Parse the additional kwargs and append to the conversation
+
+ Args:
+ conversation (StorageConversation): The conversation
+ messages (List[BaseMessage]): The messages
+ """
+ param_type = None
+ param_value = None
+ for message in messages[::-1]:
+ if message.additional_kwargs:
+ param_type = message.additional_kwargs.get("param_type")
+ param_value = message.additional_kwargs.get("param_value")
+ break
+ if not conversation.param_type:
+ conversation.param_type = param_type
+ if not conversation.param_value:
+ conversation.param_value = param_value
+
+ def delete(self) -> None:
+ """Delete all the messages and conversation from the storage"""
+ # Delete messages first
+ message_list = self._get_message_items()
+ message_ids = [message.identifier for message in message_list]
+ self.message_storage.delete_list(message_ids)
+ # Delete conversation
+ self.conv_storage.delete(self.identifier)
+ # Overwrite the current conversation with empty conversation
+ self.from_conversation(
+ StorageConversation(
+ self.conv_uid,
+ save_message_independent=self.save_message_independent,
+ conv_storage=self.conv_storage,
+ message_storage=self.message_storage,
+ )
+ )
+
def _conversation_to_dict(once: OnceConversation) -> Dict:
start_str: str = ""
@@ -937,3 +1043,61 @@ def _conversation_from_dict(once: dict) -> OnceConversation:
print(once.get("messages"))
conversation.messages = _messages_from_dict(once.get("messages", []))
return conversation
+
+
+def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessage]]:
+ """Split the messages by round index.
+
+ Args:
+ messages (List[BaseMessage]): The messages.
+
+ Returns:
+ List[List[BaseMessage]]: The messages split by round.
+ """
+ messages_by_round: List[List[BaseMessage]] = []
+ last_round_index = 0
+ for message in messages:
+ if not message.round_index:
+ # Round index must bigger than 0
+ raise ValueError("Message round_index is not set")
+ if message.round_index > last_round_index:
+ last_round_index = message.round_index
+ messages_by_round.append([])
+ messages_by_round[-1].append(message)
+ return messages_by_round
+
+
+def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
+ """Append the view message to the messages
+ Just for show in DB-GPT-Web.
+ If already have view message, do nothing.
+
+ Args:
+ messages (List[BaseMessage]): The messages
+
+ Returns:
+ List[BaseMessage]: The messages with view message
+ """
+ messages_by_round = _split_messages_by_round(messages)
+ for current_round in messages_by_round:
+ ai_message = None
+ view_message = None
+ for message in current_round:
+ if message.type == "ai":
+ ai_message = message
+ elif message.type == "view":
+ view_message = message
+ if view_message:
+ # Already have view message, do nothing
+ continue
+ if ai_message:
+ view_message = ViewMessage(
+ content=ai_message.content,
+ index=ai_message.index,
+ round_index=ai_message.round_index,
+ additional_kwargs=ai_message.additional_kwargs.copy()
+ if ai_message.additional_kwargs
+ else {},
+ )
+ current_round.append(view_message)
+ return sum(messages_by_round, [])
diff --git a/dbgpt/core/interface/operator/composer_operator.py b/dbgpt/core/interface/operator/composer_operator.py
index 7c0093777..36eb97fbc 100644
--- a/dbgpt/core/interface/operator/composer_operator.py
+++ b/dbgpt/core/interface/operator/composer_operator.py
@@ -45,7 +45,8 @@ def __init__(
self,
prompt_template: ChatPromptTemplate,
history_key: str = "chat_history",
- last_k_round: int = 2,
+ keep_start_rounds: Optional[int] = None,
+ keep_end_rounds: Optional[int] = None,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
@@ -53,7 +54,8 @@ def __init__(
super().__init__(**kwargs)
self._prompt_template = prompt_template
self._history_key = history_key
- self._last_k_round = last_k_round
+ self._keep_start_rounds = keep_start_rounds
+ self._keep_end_rounds = keep_end_rounds
self._storage = storage
self._message_storage = message_storage
self._sub_compose_dag = self._build_composer_dag()
@@ -74,7 +76,8 @@ def _build_composer_dag(self) -> DAG:
)
# History transform task, here we keep last 5 round messages
history_transform_task = BufferedConversationMapperOperator(
- last_k_round=self._last_k_round
+ keep_start_rounds=self._keep_start_rounds,
+ keep_end_rounds=self._keep_end_rounds,
)
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template, history_key=self._history_key
diff --git a/dbgpt/core/interface/operator/message_operator.py b/dbgpt/core/interface/operator/message_operator.py
index f6eb1b24b..fea21d0f9 100644
--- a/dbgpt/core/interface/operator/message_operator.py
+++ b/dbgpt/core/interface/operator/message_operator.py
@@ -1,8 +1,9 @@
import uuid
from abc import ABC
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
from dbgpt.core import (
+ LLMClient,
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
@@ -11,7 +12,12 @@
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator
-from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
+from dbgpt.core.interface.message import (
+ BaseMessage,
+ _messages_to_str,
+ _MultiRoundMessageMapper,
+ _split_messages_by_round,
+)
class BaseConversationOperator(BaseOperator, ABC):
@@ -31,7 +37,6 @@ def __init__(
**kwargs,
):
self._check_storage = check_storage
- super().__init__(**kwargs)
self._storage = storage
self._message_storage = message_storage
@@ -167,12 +172,10 @@ def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
self._message_mapper = message_mapper
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
- return self.map_messages(input_value)
+ return await self.map_messages(input_value)
- def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
- messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
- messages
- )
+ async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
+ messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
@@ -233,93 +236,66 @@ def map_multi_round_messages(
Args:
"""
# Just merge and return
- # e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
- return sum(messages_by_round, [])
-
- def _split_messages_by_round(
- self, messages: List[BaseMessage]
- ) -> List[List[BaseMessage]]:
- """Split the messages by round index.
-
- Args:
- messages (List[BaseMessage]): The messages.
-
- Returns:
- List[List[BaseMessage]]: The messages split by round.
- """
- messages_by_round: List[List[BaseMessage]] = []
- last_round_index = 0
- for message in messages:
- if not message.round_index:
- # Round index must bigger than 0
- raise ValueError("Message round_index is not set")
- if message.round_index > last_round_index:
- last_round_index = message.round_index
- messages_by_round.append([])
- messages_by_round[-1].append(message)
- return messages_by_round
+ return _merge_multi_round_messages(messages_by_round)
class BufferedConversationMapperOperator(ConversationMapperOperator):
- """The buffered conversation mapper operator.
+ """
+ The buffered conversation mapper operator which can be configured to keep
+ a certain number of starting and/or ending rounds of a conversation.
- This Operator must be used after the PreChatHistoryLoadOperator,
- and it will map the messages in the storage conversation.
+ Args:
+ keep_start_rounds (Optional[int]): Number of initial rounds to keep.
+ keep_end_rounds (Optional[int]): Number of final rounds to keep.
Examples:
-
- Transform no history messages
-
- .. code-block:: python
-
- from dbgpt.core import ModelMessage
- from dbgpt.core.operator import BufferedConversationMapperOperator
-
- # No history
- messages = [ModelMessage(role="human", content="Hello", round_index=1)]
- operator = BufferedConversationMapperOperator(last_k_round=1)
- assert operator.map_messages(messages) == [
- ModelMessage(role="human", content="Hello", round_index=1)
- ]
-
- Transform with history messages
-
- .. code-block:: python
-
- # With history
- messages = [
- ModelMessage(role="human", content="Hi", round_index=1),
- ModelMessage(role="ai", content="Hello!", round_index=1),
- ModelMessage(role="system", content="Error 404", round_index=2),
- ModelMessage(role="human", content="What's the error?", round_index=2),
- ModelMessage(role="ai", content="Just a joke.", round_index=2),
- ModelMessage(role="human", content="Funny!", round_index=3),
- ]
- operator = BufferedConversationMapperOperator(last_k_round=1)
- # Just keep the last one round, so the first round messages will be removed
- # Note: The round index 3 is not a complete round
- assert operator.map_messages(messages) == [
- ModelMessage(role="system", content="Error 404", round_index=2),
- ModelMessage(role="human", content="What's the error?", round_index=2),
- ModelMessage(role="ai", content="Just a joke.", round_index=2),
- ModelMessage(role="human", content="Funny!", round_index=3),
- ]
+ # Keeping the first 2 and the last 1 rounds of a conversation
+ import asyncio
+ from dbgpt.core.interface.message import AIMessage, HumanMessage
+ from dbgpt.core.operator import BufferedConversationMapperOperator
+
+ operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1)
+ messages = [
+ # Assume each HumanMessage and AIMessage belongs to separate rounds
+ HumanMessage(content="Hi", round_index=1),
+ AIMessage(content="Hello!", round_index=1),
+ HumanMessage(content="How are you?", round_index=2),
+ AIMessage(content="I'm good, thanks!", round_index=2),
+ HumanMessage(content="What's new today?", round_index=3),
+ AIMessage(content="Lots of things!", round_index=3),
+ ]
+ # This will keep rounds 1, 2, and 3
+ assert asyncio.run(operator.map_messages(messages)) == [
+ HumanMessage(content="Hi", round_index=1),
+ AIMessage(content="Hello!", round_index=1),
+ HumanMessage(content="How are you?", round_index=2),
+ AIMessage(content="I'm good, thanks!", round_index=2),
+ HumanMessage(content="What's new today?", round_index=3),
+ AIMessage(content="Lots of things!", round_index=3),
+ ]
"""
def __init__(
self,
- last_k_round: Optional[int] = 2,
+ keep_start_rounds: Optional[int] = None,
+ keep_end_rounds: Optional[int] = None,
message_mapper: _MultiRoundMessageMapper = None,
**kwargs,
):
- self._last_k_round = last_k_round
+ # Validate the input parameters
+ if keep_start_rounds is not None and keep_start_rounds < 0:
+ raise ValueError("keep_start_rounds must be non-negative")
+ if keep_end_rounds is not None and keep_end_rounds < 0:
+ raise ValueError("keep_end_rounds must be non-negative")
+
+ self._keep_start_rounds = keep_start_rounds
+ self._keep_end_rounds = keep_end_rounds
if message_mapper:
def new_message_mapper(
messages_by_round: List[List[BaseMessage]],
) -> List[BaseMessage]:
- # Apply keep k round messages first, then apply the custom message mapper
- messages_by_round = self._keep_last_round_messages(messages_by_round)
+ messages_by_round = self._filter_round_messages(messages_by_round)
return message_mapper(messages_by_round)
else:
@@ -327,21 +303,189 @@ def new_message_mapper(
def new_message_mapper(
messages_by_round: List[List[BaseMessage]],
) -> List[BaseMessage]:
- messages_by_round = self._keep_last_round_messages(messages_by_round)
- return sum(messages_by_round, [])
+ messages_by_round = self._filter_round_messages(messages_by_round)
+ return _merge_multi_round_messages(messages_by_round)
super().__init__(new_message_mapper, **kwargs)
- def _keep_last_round_messages(
+ def _filter_round_messages(
self, messages_by_round: List[List[BaseMessage]]
) -> List[List[BaseMessage]]:
- """Keep the last k round messages.
+ """Filters the messages to keep only the specified starting and/or ending rounds.
+
+ Examples:
+
+ >>> from dbgpt.core import AIMessage, HumanMessage
+ >>> from dbgpt.core.operator import BufferedConversationMapperOperator
+ >>> messages = [
+ ... [
+ ... HumanMessage(content="Hi", round_index=1),
+ ... AIMessage(content="Hello!", round_index=1),
+ ... ],
+ ... [
+ ... HumanMessage(content="How are you?", round_index=2),
+ ... AIMessage(content="I'm good, thanks!", round_index=2),
+ ... ],
+ ... [
+ ... HumanMessage(content="What's new today?", round_index=3),
+ ... AIMessage(content="Lots of things!", round_index=3),
+ ... ],
+ ... ]
+
+ # Test keeping only the first 2 rounds
+ >>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
+ >>> assert operator._filter_round_messages(messages) == [
+ ... [
+ ... HumanMessage(content="Hi", round_index=1),
+ ... AIMessage(content="Hello!", round_index=1),
+ ... ],
+ ... [
+ ... HumanMessage(content="How are you?", round_index=2),
+ ... AIMessage(content="I'm good, thanks!", round_index=2),
+ ... ],
+ ... ]
+
+ # Test keeping only the last 2 rounds
+ >>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
+ >>> assert operator._filter_round_messages(messages) == [
+ ... [
+ ... HumanMessage(content="How are you?", round_index=2),
+ ... AIMessage(content="I'm good, thanks!", round_index=2),
+ ... ],
+ ... [
+ ... HumanMessage(content="What's new today?", round_index=3),
+ ... AIMessage(content="Lots of things!", round_index=3),
+ ... ],
+ ... ]
+
+ # Test keeping the first 2 and last 1 rounds
+ >>> operator = BufferedConversationMapperOperator(
+ ... keep_start_rounds=2, keep_end_rounds=1
+ ... )
+ >>> assert operator._filter_round_messages(messages) == [
+ ... [
+ ... HumanMessage(content="Hi", round_index=1),
+ ... AIMessage(content="Hello!", round_index=1),
+ ... ],
+ ... [
+ ... HumanMessage(content="How are you?", round_index=2),
+ ... AIMessage(content="I'm good, thanks!", round_index=2),
+ ... ],
+ ... [
+ ... HumanMessage(content="What's new today?", round_index=3),
+ ... AIMessage(content="Lots of things!", round_index=3),
+ ... ],
+ ... ]
+
+ # Test without specifying start or end rounds (keep all rounds)
+ >>> operator = BufferedConversationMapperOperator()
+ >>> assert operator._filter_round_messages(messages) == [
+ ... [
+ ... HumanMessage(content="Hi", round_index=1),
+ ... AIMessage(content="Hello!", round_index=1),
+ ... ],
+ ... [
+ ... HumanMessage(content="How are you?", round_index=2),
+ ... AIMessage(content="I'm good, thanks!", round_index=2),
+ ... ],
+ ... [
+ ... HumanMessage(content="What's new today?", round_index=3),
+ ... AIMessage(content="Lots of things!", round_index=3),
+ ... ],
+ ... ]
+
+ Args:
+ messages_by_round (List[List[BaseMessage]]): The messages grouped by round.
+
+ Returns:
+ List[List[BaseMessage]]: Filtered list of messages.
+ """
+ total_rounds = len(messages_by_round)
+ if self._keep_start_rounds is not None and self._keep_end_rounds is not None:
+ if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
+ # Avoid overlapping when the sum of start and end rounds exceeds total rounds
+ return messages_by_round
+ return (
+ messages_by_round[: self._keep_start_rounds]
+ + messages_by_round[-self._keep_end_rounds :]
+ )
+ elif self._keep_start_rounds is not None:
+ return messages_by_round[: self._keep_start_rounds]
+ elif self._keep_end_rounds is not None:
+ return messages_by_round[-self._keep_end_rounds :]
+ else:
+ return messages_by_round
+
+
+EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]
+
+
+class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
+ """The token buffered conversation mapper operator.
+
+ If the token count of the messages is greater than the max token limit, we will evict the messages by round.
+
+ Args:
+ model (str): The model name.
+ llm_client (LLMClient): The LLM client.
+ max_token_limit (int): The max token limit.
+ eviction_policy (EvictionPolicyType): The eviction policy.
+ message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled.
+ """
+
+ def __init__(
+ self,
+ model: str,
+ llm_client: LLMClient,
+ max_token_limit: int = 2000,
+ eviction_policy: EvictionPolicyType = None,
+ message_mapper: _MultiRoundMessageMapper = None,
+ **kwargs,
+ ):
+ if max_token_limit < 0:
+ raise ValueError("Max token limit can't be negative")
+ self._model = model
+ self._llm_client = llm_client
+ self._max_token_limit = max_token_limit
+ self._eviction_policy = eviction_policy
+ self._message_mapper = message_mapper
+ super().__init__(**kwargs)
+
+ async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
+ eviction_policy = self._eviction_policy or self.eviction_policy
+ messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
+ messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round))
+ # Fist time, we count the token of the messages
+ current_tokens = await self._llm_client.count_token(self._model, messages_str)
+
+ while current_tokens > self._max_token_limit:
+ # Evict the messages by round after all tokens are not greater than the max token limit
+ # TODO: We should find a high performance way to do this
+ messages_by_round = eviction_policy(messages_by_round)
+ messages_str = _messages_to_str(
+ _merge_multi_round_messages(messages_by_round)
+ )
+ current_tokens = await self._llm_client.count_token(
+ self._model, messages_str
+ )
+ message_mapper = self._message_mapper or self.map_multi_round_messages
+ return message_mapper(messages_by_round)
+
+ def eviction_policy(
+ self, messages_by_round: List[List[BaseMessage]]
+ ) -> List[List[BaseMessage]]:
+ """Evict the messages by round, default is FIFO.
Args:
messages_by_round (List[List[BaseMessage]]): The messages by round.
Returns:
- List[List[BaseMessage]]: The latest round messages.
+ List[List[BaseMessage]]: The evicted messages by round.
"""
- index = self._last_k_round + 1
- return messages_by_round[-index:]
+ messages_by_round.pop(0)
+ return messages_by_round
+
+
+def _merge_multi_round_messages(messages: List[List[BaseMessage]]) -> List[BaseMessage]:
+ # e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
+ return sum(messages, [])
diff --git a/dbgpt/core/interface/operator/prompt_operator.py b/dbgpt/core/interface/operator/prompt_operator.py
index 18c727d14..7cdde4349 100644
--- a/dbgpt/core/interface/operator/prompt_operator.py
+++ b/dbgpt/core/interface/operator/prompt_operator.py
@@ -216,18 +216,27 @@ class HistoryPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
def __init__(
- self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
+ self,
+ prompt: ChatPromptTemplate,
+ history_key: Optional[str] = None,
+ check_storage: bool = True,
+ str_history: bool = False,
+ **kwargs,
):
self._prompt = prompt
self._history_key = history_key
-
+ self._str_history = str_history
+ BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
async def merge_history(
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
- prompt_dict[self._history_key] = history
+ if self._str_history:
+ prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
+ else:
+ prompt_dict[self._history_key] = history
return await self.format_prompt(self._prompt, prompt_dict)
@@ -239,9 +248,16 @@ class HistoryDynamicPromptBuilderOperator(
The prompt template is dynamic, and it created by parent operator.
"""
- def __init__(self, history_key: Optional[str] = None, **kwargs):
+ def __init__(
+ self,
+ history_key: Optional[str] = None,
+ check_storage: bool = True,
+ str_history: bool = False,
+ **kwargs,
+ ):
self._history_key = history_key
-
+ self._str_history = str_history
+ BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
@@ -251,5 +267,8 @@ async def merge_history(
history: List[BaseMessage],
prompt_dict: Dict[str, Any],
) -> List[ModelMessage]:
- prompt_dict[self._history_key] = history
+ if self._str_history:
+ prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
+ else:
+ prompt_dict[self._history_key] = history
return await self.format_prompt(prompt, prompt_dict)
diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py
index 21fda87f6..1788357c4 100644
--- a/dbgpt/core/interface/prompt.py
+++ b/dbgpt/core/interface/prompt.py
@@ -49,13 +49,25 @@ class BasePromptTemplate(BaseModel):
"""The prompt template."""
template_format: Optional[str] = "f-string"
+ """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
+
+ response_format: Optional[str] = None
+
+ response_key: Optional[str] = "response"
+
+ template_is_strict: Optional[bool] = True
+ """strict template will check template args"""
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
- return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
- self.template, **kwargs
- )
+ if self.response_format:
+ kwargs[self.response_key] = json.dumps(
+ self.response_format, ensure_ascii=False, indent=4
+ )
+ return _DEFAULT_FORMATTER_MAPPING[self.template_format](
+ self.template_is_strict
+ )(self.template, **kwargs)
@classmethod
def from_template(
@@ -75,10 +87,6 @@ class PromptTemplate(BasePromptTemplate):
template_scene: Optional[str]
template_define: Optional[str]
"""this template define"""
- """strict template will check template args"""
- template_is_strict: bool = True
- """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
- response_format: Optional[str]
"""default use stream out"""
stream_out: bool = True
""""""
@@ -103,17 +111,6 @@ def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "prompt"
- def format(self, **kwargs: Any) -> str:
- """Format the prompt with the inputs."""
- if self.template:
- if self.response_format:
- kwargs["response"] = json.dumps(
- self.response_format, ensure_ascii=False, indent=4
- )
- return _DEFAULT_FORMATTER_MAPPING[self.template_format](
- self.template_is_strict
- )(self.template, **kwargs)
-
class BaseChatPromptTemplate(BaseModel, ABC):
prompt: BasePromptTemplate
@@ -129,10 +126,22 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
@classmethod
def from_template(
- cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
+ cls,
+ template: str,
+ template_format: Optional[str] = "f-string",
+ response_format: Optional[str] = None,
+ response_key: Optional[str] = "response",
+ template_is_strict: bool = True,
+ **kwargs: Any,
) -> BaseChatPromptTemplate:
"""Create a prompt template from a template string."""
- prompt = BasePromptTemplate.from_template(template, template_format)
+ prompt = BasePromptTemplate.from_template(
+ template,
+ template_format,
+ response_format=response_format,
+ response_key=response_key,
+ template_is_strict=template_is_strict,
+ )
return cls(prompt=prompt, **kwargs)
diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py
index d05258e9b..6a486cab7 100644
--- a/dbgpt/core/interface/storage.py
+++ b/dbgpt/core/interface/storage.py
@@ -284,6 +284,15 @@ def delete(self, resource_id: ResourceIdentifier) -> None:
resource_id (ResourceIdentifier): The resource identifier of the data
"""
+ def delete_list(self, resource_id: List[ResourceIdentifier]) -> None:
+ """Delete the data from the storage.
+
+ Args:
+ resource_id (ResourceIdentifier): The resource identifier of the data
+ """
+ for r in resource_id:
+ self.delete(r)
+
@abstractmethod
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py
index 7221dadb2..72023c36f 100755
--- a/dbgpt/core/interface/tests/test_message.py
+++ b/dbgpt/core/interface/tests/test_message.py
@@ -138,14 +138,14 @@ def test_clear_messages(basic_conversation, human_message):
def test_get_latest_user_message(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
latest_message = basic_conversation.get_latest_user_message()
- assert latest_message == human_message
+ assert latest_message.content == human_message.content
def test_get_system_messages(basic_conversation, system_message):
basic_conversation.add_system_message(system_message.content)
system_messages = basic_conversation.get_system_messages()
assert len(system_messages) == 1
- assert system_messages[0] == system_message
+ assert system_messages[0].content == system_message.content
def test_from_conversation(basic_conversation):
@@ -324,6 +324,35 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
assert isinstance(new_conversation.messages[1], AIMessage)
+def test_delete(storage_conversation, in_memory_storage):
+ # Set storage
+ storage_conversation.conv_storage = in_memory_storage
+ storage_conversation.message_storage = in_memory_storage
+
+ # Add messages and save to storage
+ storage_conversation.start_new_round()
+ storage_conversation.add_user_message("User message")
+ storage_conversation.add_ai_message("AI response")
+ storage_conversation.end_current_round()
+
+ # Create a new StorageConversation instance to load the data
+ new_conversation = StorageConversation(
+ "conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
+ )
+
+ # Delete the conversation
+ new_conversation.delete()
+
+ # Check if the conversation is deleted
+ assert new_conversation.conv_uid == storage_conversation.conv_uid
+ assert len(new_conversation.messages) == 0
+
+ no_messages_conv = StorageConversation(
+ "conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
+ )
+ assert len(no_messages_conv.messages) == 0
+
+
def test_parse_model_messages_no_history_messages():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
diff --git a/dbgpt/core/operator/__init__.py b/dbgpt/core/operator/__init__.py
index 952b89143..ffc60d2c7 100644
--- a/dbgpt/core/operator/__init__.py
+++ b/dbgpt/core/operator/__init__.py
@@ -14,6 +14,7 @@
BufferedConversationMapperOperator,
ConversationMapperOperator,
PreChatHistoryLoadOperator,
+ TokenBufferedConversationMapperOperator,
)
from dbgpt.core.interface.operator.prompt_operator import (
DynamicPromptBuilderOperator,
@@ -30,6 +31,7 @@
"BaseStreamingLLMOperator",
"BaseConversationOperator",
"BufferedConversationMapperOperator",
+ "TokenBufferedConversationMapperOperator",
"ConversationMapperOperator",
"PreChatHistoryLoadOperator",
"PromptBuilderOperator",
diff --git a/dbgpt/datasource/__init__.py b/dbgpt/datasource/__init__.py
index 8cc9799db..5c619dc41 100644
--- a/dbgpt/datasource/__init__.py
+++ b/dbgpt/datasource/__init__.py
@@ -1 +1 @@
-from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
+from .manages.connect_config_db import ConnectConfigDao, ConnectConfigEntity
diff --git a/dbgpt/datasource/base.py b/dbgpt/datasource/base.py
index a5bf666b4..d95e173a0 100644
--- a/dbgpt/datasource/base.py
+++ b/dbgpt/datasource/base.py
@@ -3,7 +3,7 @@
"""We need to design a base class. That other connector can Write with this"""
from abc import ABC
-from typing import Iterable, List, Optional, Any, Dict
+from typing import Any, Dict, Iterable, List, Optional
class BaseConnect(ABC):
diff --git a/dbgpt/datasource/conn_spark.py b/dbgpt/datasource/conn_spark.py
index fa7d9555e..cc8107f4b 100644
--- a/dbgpt/datasource/conn_spark.py
+++ b/dbgpt/datasource/conn_spark.py
@@ -1,4 +1,4 @@
-from typing import Optional, Any
+from typing import Any, Optional
from dbgpt.datasource.base import BaseConnect
diff --git a/dbgpt/datasource/manages/connect_config_db.py b/dbgpt/datasource/manages/connect_config_db.py
index 13f265211..7c2cc2ff6 100644
--- a/dbgpt/datasource/manages/connect_config_db.py
+++ b/dbgpt/datasource/manages/connect_config_db.py
@@ -1,5 +1,4 @@
-from sqlalchemy import Column, Integer, String, Index, Text, text
-from sqlalchemy import UniqueConstraint
+from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text
from dbgpt.storage.metadata import BaseDao, Model
diff --git a/dbgpt/datasource/manages/connect_storage_duckdb.py b/dbgpt/datasource/manages/connect_storage_duckdb.py
index 59be27dc7..a08d33094 100644
--- a/dbgpt/datasource/manages/connect_storage_duckdb.py
+++ b/dbgpt/datasource/manages/connect_storage_duckdb.py
@@ -1,5 +1,7 @@
import os
+
import duckdb
+
from dbgpt.configs.model_config import PILOT_PATH
default_db_path = os.path.join(PILOT_PATH, "message")
diff --git a/dbgpt/datasource/manages/connection_manager.py b/dbgpt/datasource/manages/connection_manager.py
index 5594b5207..01c4c5ec6 100644
--- a/dbgpt/datasource/manages/connection_manager.py
+++ b/dbgpt/datasource/manages/connection_manager.py
@@ -1,22 +1,22 @@
from typing import List, Type
-from dbgpt.datasource import ConnectConfigDao
-from dbgpt.storage.schema import DBType
-from dbgpt.component import SystemApp, ComponentType
-from dbgpt.util.executor_utils import ExecutorFactory
-from dbgpt.datasource.db_conn_info import DBConfig
-from dbgpt.rag.summary.db_summary_client import DBSummaryClient
+from dbgpt.component import ComponentType, SystemApp
+from dbgpt.datasource import ConnectConfigDao
from dbgpt.datasource.base import BaseConnect
-from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect
-from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
-from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
-from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnect
+from dbgpt.datasource.conn_spark import SparkConnect
+from dbgpt.datasource.db_conn_info import DBConfig
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect
+from dbgpt.datasource.rdbms.conn_doris import DorisConnect
+from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
+from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnect
+from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect
from dbgpt.datasource.rdbms.conn_postgresql import PostgreSQLDatabase
+from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnect
-from dbgpt.datasource.conn_spark import SparkConnect
-from dbgpt.datasource.rdbms.conn_doris import DorisConnect
+from dbgpt.rag.summary.db_summary_client import DBSummaryClient
+from dbgpt.storage.schema import DBType
+from dbgpt.util.executor_utils import ExecutorFactory
class ConnectManager:
diff --git a/dbgpt/datasource/operator/datasource_operator.py b/dbgpt/datasource/operator/datasource_operator.py
index 752465d18..ef0a03e65 100644
--- a/dbgpt/datasource/operator/datasource_operator.py
+++ b/dbgpt/datasource/operator/datasource_operator.py
@@ -1,4 +1,5 @@
from typing import Any
+
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
from dbgpt.datasource.base import BaseConnect
diff --git a/dbgpt/datasource/rdbms/_base_dao.py b/dbgpt/datasource/rdbms/_base_dao.py
index ec19c9634..9a4221bec 100644
--- a/dbgpt/datasource/rdbms/_base_dao.py
+++ b/dbgpt/datasource/rdbms/_base_dao.py
@@ -1,9 +1,11 @@
import logging
+
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
+
from dbgpt._private.config import Config
-from dbgpt.storage.schema import DBType
from dbgpt.datasource.rdbms.base import RDBMSDatabase
+from dbgpt.storage.schema import DBType
logger = logging.getLogger(__name__)
CFG = Config()
diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py
index d38330988..5eb415e62 100644
--- a/dbgpt/datasource/rdbms/base.py
+++ b/dbgpt/datasource/rdbms/base.py
@@ -1,26 +1,21 @@
from __future__ import annotations
-import sqlparse
-import regex as re
+
+from typing import Any, Dict, Iterable, List, Optional
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
-from typing import Any, Iterable, List, Optional, Dict
+
+import regex as re
import sqlalchemy
-from sqlalchemy import (
- MetaData,
- Table,
- create_engine,
- inspect,
- select,
- text,
-)
+import sqlparse
+from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
+from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.schema import CreateTable
-from sqlalchemy.orm import sessionmaker, scoped_session
-from dbgpt.storage.schema import DBType
-from dbgpt.datasource.base import BaseConnect
from dbgpt._private.config import Config
+from dbgpt.datasource.base import BaseConnect
+from dbgpt.storage.schema import DBType
CFG = Config()
diff --git a/dbgpt/datasource/rdbms/conn_clickhouse.py b/dbgpt/datasource/rdbms/conn_clickhouse.py
index 27719e73c..06277050e 100644
--- a/dbgpt/datasource/rdbms/conn_clickhouse.py
+++ b/dbgpt/datasource/rdbms/conn_clickhouse.py
@@ -1,12 +1,11 @@
import re
+from typing import Any, Dict, Iterable, List, Optional
+
import sqlparse
-from typing import List, Optional, Any, Iterable, Dict
-from sqlalchemy import text
+from sqlalchemy import MetaData, text
+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.storage.schema import DBType
-from sqlalchemy import (
- MetaData,
-)
class ClickhouseConnect(RDBMSDatabase):
diff --git a/dbgpt/datasource/rdbms/conn_doris.py b/dbgpt/datasource/rdbms/conn_doris.py
index 1e4d9d8aa..10b290af0 100644
--- a/dbgpt/datasource/rdbms/conn_doris.py
+++ b/dbgpt/datasource/rdbms/conn_doris.py
@@ -1,7 +1,9 @@
-from typing import Iterable, Optional, Any
-from sqlalchemy import text
+from typing import Any, Iterable, Optional
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
+
+from sqlalchemy import text
+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
diff --git a/dbgpt/datasource/rdbms/conn_duckdb.py b/dbgpt/datasource/rdbms/conn_duckdb.py
index e9893d181..f2f53a2a8 100644
--- a/dbgpt/datasource/rdbms/conn_duckdb.py
+++ b/dbgpt/datasource/rdbms/conn_duckdb.py
@@ -1,8 +1,6 @@
-from typing import Optional, Any, Iterable
-from sqlalchemy import (
- create_engine,
- text,
-)
+from typing import Any, Iterable, Optional
+
+from sqlalchemy import create_engine, text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
diff --git a/dbgpt/datasource/rdbms/conn_mssql.py b/dbgpt/datasource/rdbms/conn_mssql.py
index 12e1d79fa..961ff20d1 100644
--- a/dbgpt/datasource/rdbms/conn_mssql.py
+++ b/dbgpt/datasource/rdbms/conn_mssql.py
@@ -1,15 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-from typing import Optional, Any, Iterable
+from typing import Any, Iterable, Optional
+
+from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
-from sqlalchemy import (
- MetaData,
- Table,
- create_engine,
- inspect,
- select,
- text,
-)
from dbgpt.datasource.rdbms.base import RDBMSDatabase
diff --git a/dbgpt/datasource/rdbms/conn_postgresql.py b/dbgpt/datasource/rdbms/conn_postgresql.py
index ad169118d..7470d5d90 100644
--- a/dbgpt/datasource/rdbms/conn_postgresql.py
+++ b/dbgpt/datasource/rdbms/conn_postgresql.py
@@ -1,7 +1,9 @@
-from typing import Iterable, Optional, Any
-from sqlalchemy import text
+from typing import Any, Iterable, Optional
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
+
+from sqlalchemy import text
+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py
index bed71bfd6..9eb9f735f 100644
--- a/dbgpt/datasource/rdbms/conn_sqlite.py
+++ b/dbgpt/datasource/rdbms/conn_sqlite.py
@@ -1,11 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+import logging
import os
-from typing import Optional, Any, Iterable
-from sqlalchemy import create_engine, text
import tempfile
-import logging
+from typing import Any, Iterable, Optional
+
+from sqlalchemy import create_engine, text
+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
logger = logging.getLogger(__name__)
@@ -160,6 +162,7 @@ def create_temporary_db(
Examples:
.. code-block:: python
+
with SQLiteTempConnect.create_temporary_db() as db:
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)")
@@ -201,6 +204,7 @@ def create_temp_tables(self, tables_info):
Examples:
.. code-block:: python
+
tables_info = {
"test": {
"columns": {
diff --git a/dbgpt/datasource/rdbms/conn_starrocks.py b/dbgpt/datasource/rdbms/conn_starrocks.py
index 380d8eca8..0c79b5d2a 100644
--- a/dbgpt/datasource/rdbms/conn_starrocks.py
+++ b/dbgpt/datasource/rdbms/conn_starrocks.py
@@ -1,7 +1,9 @@
-from typing import Iterable, Optional, Any
-from sqlalchemy import text
+from typing import Any, Iterable, Optional
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
+
+from sqlalchemy import text
+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import *
diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py
index e8542f940..4f479e79b 100644
--- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py
+++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py
@@ -1,8 +1,8 @@
import logging
import re
-from typing import Optional, List, Any, Type, Dict
+from typing import Any, Dict, List, Optional, Type
-from sqlalchemy import Numeric, Integer, Float
+from sqlalchemy import Float, Integer, Numeric
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine
diff --git a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py
index 049db13a5..d563b9603 100644
--- a/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py
+++ b/dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py
@@ -15,7 +15,7 @@
import logging
from typing import Any, Dict, List
-from sqlalchemy import log, exc, text
+from sqlalchemy import exc, log, text
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
from sqlalchemy.engine import Connection
diff --git a/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py b/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
index 015296b8b..1b08f5506 100644
--- a/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
+++ b/dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
@@ -2,9 +2,10 @@
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_duckdb.py
"""
-import pytest
import tempfile
+import pytest
+
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
diff --git a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
index 267ee1575..f741158e4 100644
--- a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
+++ b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
@@ -1,9 +1,11 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
"""
-import pytest
-import tempfile
import os
+import tempfile
+
+import pytest
+
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py
index e3874393c..793f63c8a 100644
--- a/dbgpt/model/adapter/base.py
+++ b/dbgpt/model/adapter/base.py
@@ -1,19 +1,20 @@
-from abc import ABC, abstractmethod
-from typing import Dict, List, Optional, Any, Tuple, Type, Callable
import logging
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+from dbgpt.model.adapter.template import (
+ ConversationAdapter,
+ ConversationAdapterFactory,
+ get_conv_template,
+)
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import (
BaseModelParameters,
- ModelParameters,
LlamaCppModelParameters,
+ ModelParameters,
ProxyModelParameters,
)
-from dbgpt.model.adapter.template import (
- get_conv_template,
- ConversationAdapter,
- ConversationAdapterFactory,
-)
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/adapter/fschat_adapter.py b/dbgpt/model/adapter/fschat_adapter.py
index ca5413645..956cccf66 100644
--- a/dbgpt/model/adapter/fschat_adapter.py
+++ b/dbgpt/model/adapter/fschat_adapter.py
@@ -2,17 +2,17 @@
You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it.
"""
+import logging
import os
import threading
-import logging
from functools import cache
-from typing import TYPE_CHECKING, Callable, Tuple, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
try:
from fastchat.conversation import (
Conversation,
- register_conv_template,
SeparatorStyle,
+ register_conv_template,
)
except ImportError as exc:
raise ValueError(
@@ -20,8 +20,8 @@
"Please install fastchat by command `pip install fschat` "
) from exc
-from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.adapter.base import LLMModelAdapter
+from dbgpt.model.adapter.template import ConversationAdapter, PromptType
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter
diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py
index 204133b68..7e3faabcb 100644
--- a/dbgpt/model/adapter/hf_adapter.py
+++ b/dbgpt/model/adapter/hf_adapter.py
@@ -1,10 +1,10 @@
-from abc import ABC, abstractmethod
-from typing import Dict, Optional, List, Any
import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
from dbgpt.core import ModelMessage
-from dbgpt.model.base import ModelType
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
+from dbgpt.model.base import ModelType
logger = logging.getLogger(__name__)
@@ -42,7 +42,7 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None):
def load(self, model_path: str, from_pretrained_kwargs: dict):
try:
import transformers
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
diff --git a/dbgpt/model/adapter/model_adapter.py b/dbgpt/model/adapter/model_adapter.py
index efaf9f1b9..b4d06a432 100644
--- a/dbgpt/model/adapter/model_adapter.py
+++ b/dbgpt/model/adapter/model_adapter.py
@@ -1,21 +1,15 @@
from __future__ import annotations
-from typing import (
- List,
- Type,
- Optional,
-)
import logging
-import threading
import os
+import threading
from functools import cache
+from typing import List, Optional, Type
+
+from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
+from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import BaseModelParameters
-from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
-from dbgpt.model.adapter.template import (
- ConversationAdapter,
- ConversationAdapterFactory,
-)
logger = logging.getLogger(__name__)
@@ -64,9 +58,9 @@ def get_llm_model_adapter(
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
from dbgpt.model.adapter.fschat_adapter import (
- _get_fastchat_model_adapter,
- _fastchat_get_adapter_monkey_patch,
FastChatLLMModelAdapterWrapper,
+ _fastchat_get_adapter_monkey_patch,
+ _get_fastchat_model_adapter,
)
adapter = _get_fastchat_model_adapter(
@@ -79,11 +73,11 @@ def get_llm_model_adapter(
result_adapter = FastChatLLMModelAdapterWrapper(adapter)
else:
+ from dbgpt.app.chat_adapter import get_llm_chat_adapter
+ from dbgpt.model.adapter.old_adapter import OldLLMModelAdapterWrapper
from dbgpt.model.adapter.old_adapter import (
get_llm_model_adapter as _old_get_llm_model_adapter,
- OldLLMModelAdapterWrapper,
)
- from dbgpt.app.chat_adapter import get_llm_chat_adapter
logger.info("Use DB-GPT old adapter")
result_adapter = OldLLMModelAdapterWrapper(
@@ -139,12 +133,12 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
Returns:
Optional[List[Type[BaseModelParameters]]]: The model parameters class list.
"""
- from dbgpt.util.parameter_utils import _SimpleArgParser
from dbgpt.model.parameter import (
+ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
EmbeddingModelParameters,
WorkerType,
- EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
+ from dbgpt.util.parameter_utils import _SimpleArgParser
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
pre_args.parse()
diff --git a/dbgpt/model/adapter/old_adapter.py b/dbgpt/model/adapter/old_adapter.py
index a63054695..e9ed01af4 100644
--- a/dbgpt/model/adapter/old_adapter.py
+++ b/dbgpt/model/adapter/old_adapter.py
@@ -5,31 +5,26 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+import logging
import os
import re
-import logging
-from pathlib import Path
-from typing import List, Tuple, TYPE_CHECKING, Optional
from functools import cache
-from transformers import (
- AutoModel,
- AutoModelForCausalLM,
- AutoTokenizer,
- LlamaTokenizer,
-)
+from pathlib import Path
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
+from dbgpt._private.config import Config
+from dbgpt.configs.model_config import get_device
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.base import ModelType
-
+from dbgpt.model.conversation import Conversation
from dbgpt.model.parameter import (
- ModelParameters,
LlamaCppModelParameters,
+ ModelParameters,
ProxyModelParameters,
)
-from dbgpt.model.conversation import Conversation
-from dbgpt.configs.model_config import get_device
-from dbgpt._private.config import Config
if TYPE_CHECKING:
from dbgpt.app.chat_adapter import BaseChatAdpter
diff --git a/dbgpt/model/adapter/template.py b/dbgpt/model/adapter/template.py
index 3fb9a6ec1..421e98a4e 100644
--- a/dbgpt/model/adapter/template.py
+++ b/dbgpt/model/adapter/template.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Optional, Tuple, Union, List
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
if TYPE_CHECKING:
from fastchat.conversation import Conversation
@@ -124,6 +124,7 @@ def get_conv_template(name: str) -> ConversationAdapter:
Conversation: The conversation template.
"""
from fastchat.conversation import get_conv_template
+
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
conv_template = get_conv_template(name)
diff --git a/dbgpt/model/adapter/vllm_adapter.py b/dbgpt/model/adapter/vllm_adapter.py
index 2ffe0c764..268ce82c6 100644
--- a/dbgpt/model/adapter/vllm_adapter.py
+++ b/dbgpt/model/adapter/vllm_adapter.py
@@ -1,12 +1,13 @@
import dataclasses
import logging
-from dbgpt.model.base import ModelType
+
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
+from dbgpt.model.base import ModelType
from dbgpt.model.parameter import BaseModelParameters
from dbgpt.util.parameter_utils import (
- _extract_parameter_details,
_build_parameter_class,
+ _extract_parameter_details,
_get_dataclass_print_str,
)
@@ -27,6 +28,7 @@ def model_type(self) -> str:
def model_param_class(self, model_type: str = None) -> BaseModelParameters:
import argparse
+
from vllm.engine.arg_utils import AsyncEngineArgs
parser = argparse.ArgumentParser()
@@ -56,9 +58,9 @@ def model_param_class(self, model_type: str = None) -> BaseModelParameters:
return _build_parameter_class(descs)
def load_from_params(self, params):
+ import torch
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
- import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
diff --git a/dbgpt/model/base.py b/dbgpt/model/base.py
index a662c8001..1bd5a3734 100644
--- a/dbgpt/model/base.py
+++ b/dbgpt/model/base.py
@@ -1,14 +1,14 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-from enum import Enum
-from typing import TypedDict, Optional, Dict, List, Any
-
-from dataclasses import dataclass, asdict
import time
+from dataclasses import asdict, dataclass
from datetime import datetime
-from dbgpt.util.parameter_utils import ParameterDescription
+from enum import Enum
+from typing import Any, Dict, List, Optional, TypedDict
+
from dbgpt.util.model_utils import GPUInfo
+from dbgpt.util.parameter_utils import ParameterDescription
class ModelType:
diff --git a/dbgpt/model/cli.py b/dbgpt/model/cli.py
index 67e12d7eb..24a367fb1 100644
--- a/dbgpt/model/cli.py
+++ b/dbgpt/model/cli.py
@@ -1,30 +1,30 @@
-import click
import functools
import logging
import os
-from typing import Callable, List, Type, Optional
+from typing import Callable, List, Optional, Type
+
+import click
from dbgpt.configs.model_config import LOGDIR
from dbgpt.model.base import WorkerApplyType
from dbgpt.model.parameter import (
- ModelControllerParameters,
+ BaseParameters,
ModelAPIServerParameters,
- ModelWorkerParameters,
+ ModelControllerParameters,
ModelParameters,
- BaseParameters,
+ ModelWorkerParameters,
)
from dbgpt.util import get_or_create_event_loop
+from dbgpt.util.command_utils import (
+ _detect_controller_address,
+ _run_current_with_daemon,
+ _stop_service,
+)
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
_build_parameter_class,
build_lazy_click_command,
)
-from dbgpt.util.command_utils import (
- _run_current_with_daemon,
- _stop_service,
- _detect_controller_address,
-)
-
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
@@ -32,7 +32,7 @@
def _get_worker_manager(address: str):
- from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient
+ from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
@@ -70,6 +70,7 @@ def model_cli_group(address: str):
def list(model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
+
from dbgpt.model.cluster import ModelRegistryClient
loop = get_or_create_event_loop()
@@ -152,7 +153,7 @@ def wrapper(*args, **kwargs):
)
def stop(model_name: str, model_type: str, host: str, port: int):
"""Stop model instances"""
- from dbgpt.model.cluster import WorkerStartupRequest, RemoteWorkerManager
+ from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
req = WorkerStartupRequest(
@@ -168,10 +169,11 @@ def stop(model_name: str, model_type: str, host: str, port: int):
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
- from dbgpt.util.parameter_utils import _SimpleArgParser
+ from dataclasses import dataclass, field
+
from dbgpt.model.cluster import RemoteWorkerManager
from dbgpt.model.parameter import WorkerType
- from dataclasses import dataclass, field
+ from dbgpt.util.parameter_utils import _SimpleArgParser
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
pre_args.parse()
@@ -270,7 +272,7 @@ class RemoteModelWorkerParameters(BaseParameters):
)
def start(**kwargs):
"""Start model instances"""
- from dbgpt.model.cluster import WorkerStartupRequest, RemoteWorkerManager
+ from dbgpt.model.cluster import RemoteWorkerManager, WorkerStartupRequest
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
req = WorkerStartupRequest(
@@ -339,8 +341,8 @@ def _cli_chat(address: str, model_name: str, system_prompt: str = None):
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
- from dbgpt.model.cluster import PromptRequest
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+ from dbgpt.model.cluster import PromptRequest
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
hist = []
diff --git a/dbgpt/model/cluster/__init__.py b/dbgpt/model/cluster/__init__.py
index 8249ebd6f..aed894a6f 100644
--- a/dbgpt/model/cluster/__init__.py
+++ b/dbgpt/model/cluster/__init__.py
@@ -1,3 +1,4 @@
+from dbgpt.model.cluster.apiserver.api import run_apiserver
from dbgpt.model.cluster.base import (
EmbeddingsRequest,
PromptRequest,
@@ -5,25 +6,21 @@
WorkerParameterRequest,
WorkerStartupRequest,
)
+from dbgpt.model.cluster.controller.controller import (
+ BaseModelController,
+ ModelRegistryClient,
+ run_model_controller,
+)
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
-from dbgpt.model.cluster.worker_base import ModelWorker
+from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker.default_worker import DefaultModelWorker
-
from dbgpt.model.cluster.worker.manager import (
initialize_worker_manager_in_client,
run_worker_manager,
worker_manager,
)
-
-from dbgpt.model.cluster.registry import ModelRegistry
-from dbgpt.model.cluster.controller.controller import (
- ModelRegistryClient,
- run_model_controller,
- BaseModelController,
-)
-from dbgpt.model.cluster.apiserver.api import run_apiserver
-
from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager
+from dbgpt.model.cluster.worker_base import ModelWorker
__all__ = [
"EmbeddingsRequest",
diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py
index 43fa62646..1a508fddb 100644
--- a/dbgpt/model/cluster/apiserver/api.py
+++ b/dbgpt/model/cluster/apiserver/api.py
@@ -3,45 +3,42 @@
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
"""
-from typing import Optional, List, Dict, Any, Generator
-
-import logging
import asyncio
-import shortuuid
import json
-from fastapi import APIRouter, FastAPI
-from fastapi import Depends, HTTPException
+import logging
+from typing import Any, Dict, Generator, List, Optional
+
+import shortuuid
+from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse, JSONResponse
+from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
-
-
+from fastchat.constants import ErrorCode
+from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponse,
+ ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
- ChatCompletionResponseChoice,
DeltaMessage,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
-from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
-from fastchat.constants import ErrorCode
+from dbgpt._private.pydantic import BaseModel
from dbgpt.component import BaseComponent, ComponentType, SystemApp
-from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.core import ModelOutput
from dbgpt.core.interface.message import ModelMessage
from dbgpt.model.base import ModelInstance
-from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
-from dbgpt.model.cluster import ModelRegistry
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
+from dbgpt.model.cluster.registry import ModelRegistry
+from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
+from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.utils import setup_logging
-from dbgpt._private.pydantic import BaseModel
logger = logging.getLogger(__name__)
@@ -393,8 +390,9 @@ async def create_chat_completion(
def _initialize_all(controller_addr: str, system_app: SystemApp):
- from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient
+ from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
+ from dbgpt.model.cluster.worker.remote_manager import RemoteWorkerManager
if not system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
diff --git a/dbgpt/model/cluster/apiserver/tests/test_api.py b/dbgpt/model/cluster/apiserver/tests/test_api.py
index 0e3188d6c..00a374780 100644
--- a/dbgpt/model/cluster/apiserver/tests/test_api.py
+++ b/dbgpt/model/cluster/apiserver/tests/test_api.py
@@ -1,29 +1,28 @@
+import importlib.metadata as metadata
+
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
-import importlib.metadata as metadata
from dbgpt.component import SystemApp
-from dbgpt.util.openai_utils import chat_completion_stream, chat_completion
-
from dbgpt.model.cluster.apiserver.api import (
- api_settings,
- initialize_apiserver,
- ModelList,
- UsageInfo,
ChatCompletionResponse,
+ ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
- ChatCompletionResponseChoice,
DeltaMessage,
+ ModelList,
+ UsageInfo,
+ api_settings,
+ initialize_apiserver,
)
from dbgpt.model.cluster.tests.conftest import _new_cluster
-
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
+from dbgpt.util.openai_utils import chat_completion, chat_completion_stream
app = FastAPI()
app.add_middleware(
diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py
index 97b76fd30..f60fb0e54 100644
--- a/dbgpt/model/cluster/base.py
+++ b/dbgpt/model/cluster/base.py
@@ -1,9 +1,9 @@
from typing import Dict, List
+from dbgpt._private.pydantic import BaseModel
+from dbgpt.core.interface.message import ModelMessage
from dbgpt.model.base import WorkerApplyType
from dbgpt.model.parameter import WorkerType
-from dbgpt.core.interface.message import ModelMessage
-from dbgpt._private.pydantic import BaseModel
WORKER_MANAGER_SERVICE_TYPE = "service"
WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
diff --git a/dbgpt/model/cluster/controller/controller.py b/dbgpt/model/cluster/controller/controller.py
index e52d6f264..1591c120b 100644
--- a/dbgpt/model/cluster/controller/controller.py
+++ b/dbgpt/model/cluster/controller/controller.py
@@ -1,19 +1,17 @@
-from abc import ABC, abstractmethod
-
import logging
+from abc import ABC, abstractmethod
from typing import List
from fastapi import APIRouter, FastAPI
+
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
-from dbgpt.model.parameter import ModelControllerParameters
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
+from dbgpt.model.parameter import ModelControllerParameters
+from dbgpt.util.api_utils import _api_remote as api_remote
+from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
from dbgpt.util.parameter_utils import EnvArgumentParser
-from dbgpt.util.api_utils import (
- _api_remote as api_remote,
- _sync_api_remote as sync_api_remote,
-)
-from dbgpt.util.utils import setup_logging, setup_http_service_logging
+from dbgpt.util.utils import setup_http_service_logging, setup_logging
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/cluster/controller/tests/test_registry.py b/dbgpt/model/cluster/controller/tests/test_registry.py
index 9596ef7a5..51494364d 100644
--- a/dbgpt/model/cluster/controller/tests/test_registry.py
+++ b/dbgpt/model/cluster/controller/tests/test_registry.py
@@ -1,7 +1,8 @@
-import pytest
+import asyncio
from datetime import datetime, timedelta
-import asyncio
+import pytest
+
from dbgpt.model.base import ModelInstance
from dbgpt.model.cluster.registry import EmbeddedModelRegistry
diff --git a/dbgpt/model/cluster/embedding/loader.py b/dbgpt/model/cluster/embedding/loader.py
index aea9d0530..b4fba70b9 100644
--- a/dbgpt/model/cluster/embedding/loader.py
+++ b/dbgpt/model/cluster/embedding/loader.py
@@ -4,8 +4,8 @@
from dbgpt.model.parameter import BaseEmbeddingModelParameters
from dbgpt.util.parameter_utils import _get_dict_from_obj
-from dbgpt.util.tracer import root_tracer, SpanType, SpanTypeRunName
from dbgpt.util.system_utils import get_system_info
+from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
diff --git a/dbgpt/model/cluster/embedding/remote_embedding.py b/dbgpt/model/cluster/embedding/remote_embedding.py
index 277bc4d24..d45c9dd85 100644
--- a/dbgpt/model/cluster/embedding/remote_embedding.py
+++ b/dbgpt/model/cluster/embedding/remote_embedding.py
@@ -1,4 +1,5 @@
from typing import List
+
from langchain.embeddings.base import Embeddings
from dbgpt.model.cluster.manager_base import WorkerManager
diff --git a/dbgpt/model/cluster/manager_base.py b/dbgpt/model/cluster/manager_base.py
index 636b2822f..37b09b0fd 100644
--- a/dbgpt/model/cluster/manager_base.py
+++ b/dbgpt/model/cluster/manager_base.py
@@ -1,15 +1,16 @@
import asyncio
-from dataclasses import dataclass
-from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod
-from datetime import datetime
from concurrent.futures import Future
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Callable, Dict, Iterator, List, Optional
+
from dbgpt.component import BaseComponent, ComponentType, SystemApp
-from dbgpt.core import ModelOutput, ModelMetadata
-from dbgpt.model.base import WorkerSupportedModel, WorkerApplyOutput
+from dbgpt.core import ModelMetadata, ModelOutput
+from dbgpt.model.base import WorkerApplyOutput, WorkerSupportedModel
+from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from dbgpt.model.cluster.worker_base import ModelWorker
-from dbgpt.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
-from dbgpt.model.parameter import ModelWorkerParameters, ModelParameters
+from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters
from dbgpt.util.parameter_utils import ParameterDescription
diff --git a/dbgpt/model/cluster/registry.py b/dbgpt/model/cluster/registry.py
index 6d2d7cf86..6a5b12bf5 100644
--- a/dbgpt/model/cluster/registry.py
+++ b/dbgpt/model/cluster/registry.py
@@ -1,17 +1,16 @@
+import itertools
+import logging
import random
import threading
import time
-import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
-import itertools
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
-
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/cluster/tests/conftest.py b/dbgpt/model/cluster/tests/conftest.py
index eeba4826f..dc2288327 100644
--- a/dbgpt/model/cluster/tests/conftest.py
+++ b/dbgpt/model/cluster/tests/conftest.py
@@ -1,21 +1,22 @@
+from contextlib import asynccontextmanager, contextmanager
+from typing import Dict, Iterator, List, Tuple
+
import pytest
import pytest_asyncio
-from contextlib import contextmanager, asynccontextmanager
-from typing import List, Iterator, Dict, Tuple
-from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
-from dbgpt.core import ModelOutput, ModelMetadata
-from dbgpt.model.cluster.worker_base import ModelWorker
+
+from dbgpt.core import ModelMetadata, ModelOutput
+from dbgpt.model.base import ModelInstance
+from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from dbgpt.model.cluster.worker.manager import (
- WorkerManager,
+ ApplyFunction,
+ DeregisterFunc,
LocalWorkerManager,
RegisterFunc,
- DeregisterFunc,
SendHeartbeatFunc,
- ApplyFunction,
+ WorkerManager,
)
-
-from dbgpt.model.base import ModelInstance
-from dbgpt.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry
+from dbgpt.model.cluster.worker_base import ModelWorker
+from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
@pytest.fixture
diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py
index 453071004..c1fe109ea 100644
--- a/dbgpt/model/cluster/worker/default_worker.py
+++ b/dbgpt/model/cluster/worker/default_worker.py
@@ -1,26 +1,25 @@
-import os
import logging
-
-from typing import Dict, Iterator, List, Optional
+import os
import time
import traceback
+from typing import Dict, Iterator, List, Optional
from dbgpt.configs.model_config import get_device
-from dbgpt.model.adapter.base import LLMModelAdapter
-from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.core import (
- ModelOutput,
+ ModelExtraMedata,
ModelInferenceMetrics,
ModelMetadata,
- ModelExtraMedata,
+ ModelOutput,
)
+from dbgpt.model.adapter.base import LLMModelAdapter
+from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
+from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
-from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
-from dbgpt.util.tracer import root_tracer, SpanType, SpanTypeRunName
from dbgpt.util.system_utils import get_system_info
+from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/cluster/worker/embedding_worker.py b/dbgpt/model/cluster/worker/embedding_worker.py
index c8df5d779..06b6ce3bb 100644
--- a/dbgpt/model/cluster/worker/embedding_worker.py
+++ b/dbgpt/model/cluster/worker/embedding_worker.py
@@ -1,17 +1,17 @@
import logging
-from typing import Dict, List, Type, Optional
+from typing import Dict, List, Optional, Type
from dbgpt.configs.model_config import get_device
from dbgpt.core import ModelMetadata
+from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
+from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import _get_model_real_path
from dbgpt.model.parameter import (
- EmbeddingModelParameters,
+ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
+ EmbeddingModelParameters,
WorkerType,
- EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
-from dbgpt.model.cluster.worker_base import ModelWorker
-from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.util.model_utils import _clear_model_cache
from dbgpt.util.parameter_utils import EnvArgumentParser
diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py
index 8a8af3b35..f775c9118 100644
--- a/dbgpt/model/cluster/worker/manager.py
+++ b/dbgpt/model/cluster/worker/manager.py
@@ -15,7 +15,7 @@
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
-from dbgpt.core import ModelOutput, ModelMetadata
+from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
@@ -38,9 +38,9 @@
_dict_to_command_args,
_get_dict_from_obj,
)
-from dbgpt.util.utils import setup_logging, setup_http_service_logging
-from dbgpt.util.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
from dbgpt.util.system_utils import get_system_info
+from dbgpt.util.tracer import SpanType, SpanTypeRunName, initialize_tracer, root_tracer
+from dbgpt.util.utils import setup_http_service_logging, setup_logging
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/cluster/worker/remote_worker.py b/dbgpt/model/cluster/worker/remote_worker.py
index 00e707768..c25a71ca7 100644
--- a/dbgpt/model/cluster/worker/remote_worker.py
+++ b/dbgpt/model/cluster/worker/remote_worker.py
@@ -1,10 +1,10 @@
import json
-from typing import Dict, Iterator, List
import logging
-from dbgpt.core import ModelOutput, ModelMetadata
-from dbgpt.model.parameter import ModelParameters
-from dbgpt.model.cluster.worker_base import ModelWorker
+from typing import Dict, Iterator, List
+from dbgpt.core import ModelMetadata, ModelOutput
+from dbgpt.model.cluster.worker_base import ModelWorker
+from dbgpt.model.parameter import ModelParameters
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/cluster/worker/tests/test_manager.py b/dbgpt/model/cluster/worker/tests/test_manager.py
index 0c02e71be..d8ee8bdc3 100644
--- a/dbgpt/model/cluster/worker/tests/test_manager.py
+++ b/dbgpt/model/cluster/worker/tests/test_manager.py
@@ -1,29 +1,30 @@
-from unittest.mock import patch, AsyncMock
-import pytest
-from typing import List, Iterator, Dict, Tuple
from dataclasses import asdict
-from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
-from dbgpt.model.base import WorkerApplyType, ModelInstance
+from typing import Dict, Iterator, List, Tuple
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from dbgpt.model.base import ModelInstance, WorkerApplyType
from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
-from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.cluster.manager_base import WorkerRunData
-from dbgpt.model.cluster.worker.manager import (
- LocalWorkerManager,
- RegisterFunc,
- DeregisterFunc,
- SendHeartbeatFunc,
- ApplyFunction,
-)
from dbgpt.model.cluster.tests.conftest import (
MockModelWorker,
- manager_2_workers,
- manager_with_2_workers,
- manager_2_embedding_workers,
_create_workers,
- _start_worker_manager,
_new_worker_params,
+ _start_worker_manager,
+ manager_2_embedding_workers,
+ manager_2_workers,
+ manager_with_2_workers,
)
-
+from dbgpt.model.cluster.worker.manager import (
+ ApplyFunction,
+ DeregisterFunc,
+ LocalWorkerManager,
+ RegisterFunc,
+ SendHeartbeatFunc,
+)
+from dbgpt.model.cluster.worker_base import ModelWorker
+from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
diff --git a/dbgpt/model/cluster/worker_base.py b/dbgpt/model/cluster/worker_base.py
index 1f0005a8f..36413745f 100644
--- a/dbgpt/model/cluster/worker_base.py
+++ b/dbgpt/model/cluster/worker_base.py
@@ -1,12 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Type
-from dbgpt.core import ModelOutput, ModelMetadata
+from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.parameter import ModelParameters, WorkerType
-from dbgpt.util.parameter_utils import (
- ParameterDescription,
- _get_parameter_descriptions,
-)
+from dbgpt.util.parameter_utils import ParameterDescription, _get_parameter_descriptions
class ModelWorker(ABC):
diff --git a/dbgpt/model/compression.py b/dbgpt/model/compression.py
index dc626ea23..83b681c6f 100644
--- a/dbgpt/model/compression.py
+++ b/dbgpt/model/compression.py
@@ -52,7 +52,7 @@ def compress_module(module, target_device):
def compress(tensor, config):
- """Simulate team-wise quantization."""
+ """Simulate group-wise quantization."""
if not config.enabled:
return tensor
@@ -105,7 +105,7 @@ def compress(tensor, config):
def decompress(packed_data, config):
- """Simulate team-wise dequantization."""
+ """Simulate group-wise dequantization."""
if not config.enabled:
return packed_data
diff --git a/dbgpt/model/conversation.py b/dbgpt/model/conversation.py
index 875299ccd..e3e465531 100644
--- a/dbgpt/model/conversation.py
+++ b/dbgpt/model/conversation.py
@@ -9,8 +9,8 @@
"""
import dataclasses
-from enum import auto, IntEnum
-from typing import List, Dict, Callable
+from enum import IntEnum, auto
+from typing import Callable, Dict, List
class SeparatorStyle(IntEnum):
diff --git a/dbgpt/model/inference.py b/dbgpt/model/inference.py
index 400f57ab9..fcb0d08f3 100644
--- a/dbgpt/model/inference.py
+++ b/dbgpt/model/inference.py
@@ -7,10 +7,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import gc
-from typing import Iterable, Dict
+from typing import Dict, Iterable
import torch
-
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
@@ -19,7 +18,7 @@
TopPLogitsWarper,
)
-from dbgpt.model.llm_utils import is_sentence_complete, is_partial_stop
+from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
def prepare_logits_processor(
diff --git a/dbgpt/model/llm/llama_cpp/llama_cpp.py b/dbgpt/model/llm/llama_cpp/llama_cpp.py
index dc4851451..c7ee69205 100644
--- a/dbgpt/model/llm/llama_cpp/llama_cpp.py
+++ b/dbgpt/model/llm/llama_cpp/llama_cpp.py
@@ -1,11 +1,12 @@
"""
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
"""
+import logging
import re
from typing import Dict
-import logging
-import torch
+
import llama_cpp
+import torch
from dbgpt.model.parameter import LlamaCppModelParameters
diff --git a/dbgpt/model/llm_out/chatglm_llm.py b/dbgpt/model/llm_out/chatglm_llm.py
index 0bb734181..acd6b226a 100644
--- a/dbgpt/model/llm_out/chatglm_llm.py
+++ b/dbgpt/model/llm_out/chatglm_llm.py
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
-from typing import List
import re
+from typing import List
import torch
diff --git a/dbgpt/model/llm_out/falcon_llm.py b/dbgpt/model/llm_out/falcon_llm.py
index 12aebc6e9..4b56a754b 100644
--- a/dbgpt/model/llm_out/falcon_llm.py
+++ b/dbgpt/model/llm_out/falcon_llm.py
@@ -1,6 +1,7 @@
-import torch
from threading import Thread
-from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
+
+import torch
+from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
diff --git a/dbgpt/model/llm_out/guanaco_llm.py b/dbgpt/model/llm_out/guanaco_llm.py
index dd727d19a..4fae428fb 100644
--- a/dbgpt/model/llm_out/guanaco_llm.py
+++ b/dbgpt/model/llm_out/guanaco_llm.py
@@ -1,6 +1,7 @@
-import torch
from threading import Thread
-from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
+
+import torch
+from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
diff --git a/dbgpt/model/llm_out/hf_chat_llm.py b/dbgpt/model/llm_out/hf_chat_llm.py
index b7e6caaac..e8fb58054 100644
--- a/dbgpt/model/llm_out/hf_chat_llm.py
+++ b/dbgpt/model/llm_out/hf_chat_llm.py
@@ -1,6 +1,7 @@
import logging
-import torch
from threading import Thread
+
+import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/llm_out/llama_cpp_llm.py b/dbgpt/model/llm_out/llama_cpp_llm.py
index 921670065..b25ea6845 100644
--- a/dbgpt/model/llm_out/llama_cpp_llm.py
+++ b/dbgpt/model/llm_out/llama_cpp_llm.py
@@ -1,4 +1,5 @@
from typing import Dict
+
import torch
diff --git a/dbgpt/model/llm_out/proxy_llm.py b/dbgpt/model/llm_out/proxy_llm.py
index 390ef4886..cbe9c42fc 100644
--- a/dbgpt/model/llm_out/proxy_llm.py
+++ b/dbgpt/model/llm_out/proxy_llm.py
@@ -2,16 +2,16 @@
# -*- coding: utf-8 -*-
import time
-from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream
+from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
from dbgpt.model.proxy.llms.bard import bard_generate_stream
+from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream
from dbgpt.model.proxy.llms.claude import claude_generate_stream
-from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
-from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
-from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
from dbgpt.model.proxy.llms.gemini import gemini_generate_stream
-from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
-from dbgpt.model.proxy.llms.spark import spark_generate_stream
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
+from dbgpt.model.proxy.llms.spark import spark_generate_stream
+from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
+from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
+from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
def proxyllm_generate_stream(
diff --git a/dbgpt/model/llm_out/vicuna_llm.py b/dbgpt/model/llm_out/vicuna_llm.py
index 447254efb..3713946df 100644
--- a/dbgpt/model/llm_out/vicuna_llm.py
+++ b/dbgpt/model/llm_out/vicuna_llm.py
@@ -8,9 +8,9 @@
import requests
from langchain.embeddings.base import Embeddings
from langchain.llms.base import LLM
-from dbgpt._private.pydantic import BaseModel
from dbgpt._private.config import Config
+from dbgpt._private.pydantic import BaseModel
CFG = Config()
diff --git a/dbgpt/model/llm_out/vllm_llm.py b/dbgpt/model/llm_out/vllm_llm.py
index de108c87c..838bcc35a 100644
--- a/dbgpt/model/llm_out/vllm_llm.py
+++ b/dbgpt/model/llm_out/vllm_llm.py
@@ -1,9 +1,9 @@
-from typing import Dict
import os
+from typing import Dict
+
from vllm import AsyncLLMEngine
-from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
-
+from vllm.utils import random_uuid
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
diff --git a/dbgpt/model/llm_utils.py b/dbgpt/model/llm_utils.py
index 031896e86..fff11ba77 100644
--- a/dbgpt/model/llm_utils.py
+++ b/dbgpt/model/llm_utils.py
@@ -2,11 +2,11 @@
# -*- coding:utf-8 -*-
from pathlib import Path
+from typing import Dict, List
-from typing import List, Dict
import cachetools
-from dbgpt.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
+from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
from dbgpt.model.base import SupportedModel
from dbgpt.util.parameter_utils import _get_parameter_descriptions
diff --git a/dbgpt/model/loader.py b/dbgpt/model/loader.py
index cfd6e0c3c..43d4e27f8 100644
--- a/dbgpt/model/loader.py
+++ b/dbgpt/model/loader.py
@@ -1,16 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-from typing import Optional, Dict, Any
-
import logging
+from typing import Any, Dict, Optional
+
from dbgpt.configs.model_config import get_device
-from dbgpt.model.base import ModelType
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
+from dbgpt.model.base import ModelType
from dbgpt.model.parameter import (
- ModelParameters,
LlamaCppModelParameters,
+ ModelParameters,
ProxyModelParameters,
)
from dbgpt.util import get_gpu_memory
@@ -135,6 +135,7 @@ def loader_with_params(
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
import torch
+
from dbgpt.model.compression import compress_module
device = model_params.device
@@ -178,7 +179,6 @@ def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParamete
# NOTE: Recent transformers library seems to fix the mps issue, also
# it has made some changes causing compatibility issues with our
# original patch. So we only apply the patch for older versions.
-
# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
@@ -274,18 +274,18 @@ def load_huggingface_quantization_model(
import torch
try:
+ import transformers
from accelerate import init_empty_weights
from accelerate.utils import infer_auto_device_map
- import transformers
from transformers import (
- BitsAndBytesConfig,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
- LlamaForCausalLM,
AutoModelForSeq2SeqLM,
- LlamaTokenizer,
AutoTokenizer,
+ BitsAndBytesConfig,
+ LlamaForCausalLM,
+ LlamaTokenizer,
)
except ImportError as exc:
raise ValueError(
diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py
index afa77cb8b..bcaa46c4d 100644
--- a/dbgpt/model/parameter.py
+++ b/dbgpt/model/parameter.py
@@ -4,7 +4,7 @@
import os
from dataclasses import dataclass, field
from enum import Enum
-from typing import Dict, Optional, Union, Tuple
+from typing import Dict, Optional, Tuple, Union
from dbgpt.model.conversation import conv_templates
from dbgpt.util.parameter_utils import BaseParameters
diff --git a/dbgpt/model/proxy/llms/baichuan.py b/dbgpt/model/proxy/llms/baichuan.py
index ed641c72f..aed91e151 100644
--- a/dbgpt/model/proxy/llms/baichuan.py
+++ b/dbgpt/model/proxy/llms/baichuan.py
@@ -1,9 +1,11 @@
-import requests
import json
from typing import List
-from dbgpt.model.proxy.llms.proxy_model import ProxyModel
+
+import requests
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.parameter import ProxyModelParameters
+from dbgpt.model.proxy.llms.proxy_model import ProxyModel
BAICHUAN_DEFAULT_MODEL = "Baichuan2-Turbo-192k"
diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py
index 7e43661d8..6317cad6d 100755
--- a/dbgpt/model/proxy/llms/bard.py
+++ b/dbgpt/model/proxy/llms/bard.py
@@ -1,5 +1,7 @@
-import requests
from typing import List
+
+import requests
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py
index 5d1882141..15932bfcc 100755
--- a/dbgpt/model/proxy/llms/chatgpt.py
+++ b/dbgpt/model/proxy/llms/chatgpt.py
@@ -1,15 +1,17 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+import importlib.metadata as metadata
+import logging
import os
from typing import List
-import logging
-import importlib.metadata as metadata
-from dbgpt.model.proxy.llms.proxy_model import ProxyModel
-from dbgpt.model.parameter import ProxyModelParameters
-from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+
import httpx
+from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+from dbgpt.model.parameter import ProxyModelParameters
+from dbgpt.model.proxy.llms.proxy_model import ProxyModel
+
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py
index 04975817c..9a3b3b868 100644
--- a/dbgpt/model/proxy/llms/gemini.py
+++ b/dbgpt/model/proxy/llms/gemini.py
@@ -1,7 +1,7 @@
-from typing import List, Tuple, Dict, Any
+from typing import Any, Dict, List, Tuple
-from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core.interface.message import ModelMessage, parse_model_messages
+from dbgpt.model.proxy.llms.proxy_model import ProxyModel
GEMINI_DEFAULT_MODEL = "gemini-pro"
diff --git a/dbgpt/model/proxy/llms/proxy_model.py b/dbgpt/model/proxy/llms/proxy_model.py
index 4e55ec3ea..b287ea88f 100644
--- a/dbgpt/model/proxy/llms/proxy_model.py
+++ b/dbgpt/model/proxy/llms/proxy_model.py
@@ -1,12 +1,13 @@
from __future__ import annotations
-from typing import Union, List, Optional, TYPE_CHECKING
import logging
+from typing import TYPE_CHECKING, List, Optional, Union
+
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
- from dbgpt.core.interface.message import ModelMessage, BaseMessage
+ from dbgpt.core.interface.message import BaseMessage, ModelMessage
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py
index aedf8a951..81a7cc2b4 100644
--- a/dbgpt/model/proxy/llms/spark.py
+++ b/dbgpt/model/proxy/llms/spark.py
@@ -1,14 +1,15 @@
-import json
import base64
-import hmac
import hashlib
-from websockets.sync.client import connect
+import hmac
+import json
from datetime import datetime
-from typing import List
from time import mktime
-from urllib.parse import urlencode
-from urllib.parse import urlparse
+from typing import List
+from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
+
+from websockets.sync.client import connect
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py
index bbcec2f42..e5d008de1 100644
--- a/dbgpt/model/proxy/llms/tongyi.py
+++ b/dbgpt/model/proxy/llms/tongyi.py
@@ -1,7 +1,8 @@
import logging
from typing import List
-from dbgpt.model.proxy.llms.proxy_model import ProxyModel
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
diff --git a/dbgpt/model/proxy/llms/wenxin.py b/dbgpt/model/proxy/llms/wenxin.py
index 28e5a65d0..74e44fa80 100644
--- a/dbgpt/model/proxy/llms/wenxin.py
+++ b/dbgpt/model/proxy/llms/wenxin.py
@@ -1,9 +1,11 @@
-import requests
import json
from typing import List
-from dbgpt.model.proxy.llms.proxy_model import ProxyModel
+
+import requests
+from cachetools import TTLCache, cached
+
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
-from cachetools import cached, TTLCache
+from dbgpt.model.proxy.llms.proxy_model import ProxyModel
@cached(TTLCache(1, 1800))
diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py
index 2f854e17a..8108ad5e0 100644
--- a/dbgpt/model/proxy/llms/zhipu.py
+++ b/dbgpt/model/proxy/llms/zhipu.py
@@ -1,7 +1,7 @@
from typing import List
-from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+from dbgpt.model.proxy.llms.proxy_model import ProxyModel
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py
index 7d1344aec..f4d022a3a 100644
--- a/dbgpt/model/utils/chatgpt_utils.py
+++ b/dbgpt/model/utils/chatgpt_utils.py
@@ -1,42 +1,41 @@
from __future__ import annotations
-import os
+import importlib.metadata as metadata
import logging
-from dataclasses import dataclass
+import os
from abc import ABC
-import importlib.metadata as metadata
+from dataclasses import dataclass
from typing import (
- List,
- Dict,
- Any,
- Optional,
TYPE_CHECKING,
- Union,
+ Any,
AsyncIterator,
- Callable,
Awaitable,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Union,
)
+from dbgpt._private.pydantic import model_to_json
from dbgpt.component import ComponentType
-from dbgpt.core.operator import BaseLLM
-from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
+from dbgpt.core.awel import BaseOperator, TransformStreamAbsOperator
from dbgpt.core.interface.llm import (
- ModelOutput,
- ModelRequest,
- ModelMetadata,
LLMClient,
MessageConverter,
+ ModelMetadata,
+ ModelOutput,
+ ModelRequest,
)
-from dbgpt.model.cluster.client import DefaultLLMClient
+from dbgpt.core.operator import BaseLLM
from dbgpt.model.cluster import WorkerManagerFactory
-from dbgpt._private.pydantic import model_to_json
+from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
import httpx
from httpx._types import ProxiesTypes
- from openai import AsyncAzureOpenAI
- from openai import AsyncOpenAI
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
@@ -292,9 +291,10 @@ async def _to_openai_stream(
model (Optional[str], optional): The model name. Defaults to None.
model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None.
"""
+ import asyncio
import json
+
import shortuuid
- import asyncio
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
diff --git a/dbgpt/model/utils/token_utils.py b/dbgpt/model/utils/token_utils.py
index 281ed5eed..1c84490f4 100644
--- a/dbgpt/model/utils/token_utils.py
+++ b/dbgpt/model/utils/token_utils.py
@@ -1,10 +1,10 @@
from __future__ import annotations
-from typing import Union, List, Optional, TYPE_CHECKING
import logging
+from typing import TYPE_CHECKING, List, Optional, Union
if TYPE_CHECKING:
- from dbgpt.core.interface.message import ModelMessage, BaseMessage
+ from dbgpt.core.interface.message import BaseMessage, ModelMessage
logger = logging.getLogger(__name__)
diff --git a/dbgpt/serve/conversation/api/endpoints.py b/dbgpt/serve/conversation/api/endpoints.py
index 8ca494b98..59b155aaf 100644
--- a/dbgpt/serve/conversation/api/endpoints.py
+++ b/dbgpt/serve/conversation/api/endpoints.py
@@ -1,3 +1,4 @@
+import uuid
from functools import cache
from typing import List, Optional
@@ -10,7 +11,7 @@
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
-from .schemas import ServeRequest, ServerResponse
+from .schemas import MessageVo, ServeRequest, ServerResponse
router = APIRouter()
@@ -95,12 +96,14 @@ async def test_auth():
@router.post(
- "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
+ "/query",
+ response_model=Result[ServerResponse],
+ dependencies=[Depends(check_api_key)],
)
-async def create(
+async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
- """Create a new Conversation entity
+ """Query Conversation entities
Args:
request (ServeRequest): The request
@@ -108,43 +111,46 @@ async def create(
Returns:
ServerResponse: The response
"""
- return Result.succ(service.create(request))
+ return Result.succ(service.get(request))
-@router.put(
- "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
+@router.post(
+ "/new",
+ response_model=Result[ServerResponse],
+ dependencies=[Depends(check_api_key)],
)
-async def update(
- request: ServeRequest, service: Service = Depends(get_service)
-) -> Result[ServerResponse]:
- """Update a Conversation entity
-
- Args:
- request (ServeRequest): The request
- service (Service): The service
- Returns:
- ServerResponse: The response
- """
- return Result.succ(service.update(request))
+async def dialogue_new(
+ chat_mode: str = "chat_normal",
+ user_name: str = None,
+ # TODO remove user id
+ user_id: str = None,
+ sys_code: str = None,
+):
+ user_name = user_name or user_id
+ unique_id = uuid.uuid1()
+ res = ServerResponse(
+ user_input="",
+ conv_uid=str(unique_id),
+ chat_mode=chat_mode,
+ user_name=user_name,
+ sys_code=sys_code,
+ )
+ return Result.succ(res)
@router.post(
- "/query",
- response_model=Result[ServerResponse],
+ "/delete",
dependencies=[Depends(check_api_key)],
)
-async def query(
- request: ServeRequest, service: Service = Depends(get_service)
-) -> Result[ServerResponse]:
- """Query Conversation entities
+async def delete(con_uid: str, service: Service = Depends(get_service)):
+ """Delete a Conversation entity
Args:
- request (ServeRequest): The request
+ con_uid (str): The conversation UID
service (Service): The service
- Returns:
- ServerResponse: The response
"""
- return Result.succ(service.get(request))
+ service.delete(ServeRequest(conv_uid=con_uid))
+ return Result.succ(None)
@router.post(
@@ -155,7 +161,7 @@ async def query(
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
- page_size: Optional[int] = Query(default=20, description="page size"),
+ page_size: Optional[int] = Query(default=10, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Conversation entities
@@ -171,6 +177,37 @@ async def query_page(
return Result.succ(service.get_list_by_page(request, page, page_size))
+@router.get(
+ "/list",
+ response_model=Result[List[ServerResponse]],
+ dependencies=[Depends(check_api_key)],
+)
+async def list_latest_conv(
+ user_name: str = None,
+ user_id: str = None,
+ sys_code: str = None,
+ page: Optional[int] = Query(default=1, description="current page"),
+ page_size: Optional[int] = Query(default=10, description="page size"),
+ service: Service = Depends(get_service),
+) -> Result[List[ServerResponse]]:
+ """Return latest conversations"""
+ request = ServeRequest(
+ user_name=user_name or user_id,
+ sys_code=sys_code,
+ )
+ return Result.succ(service.get_list_by_page(request, page, page_size).items)
+
+
+@router.get(
+ "/messages/history",
+ response_model=Result[List[MessageVo]],
+ dependencies=[Depends(check_api_key)],
+)
+async def get_history_messages(con_uid: str, service: Service = Depends(get_service)):
+ """Get the history messages of a conversation"""
+ return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
+
+
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
diff --git a/dbgpt/serve/conversation/api/schemas.py b/dbgpt/serve/conversation/api/schemas.py
index d83c79e15..2558d8ad9 100644
--- a/dbgpt/serve/conversation/api/schemas.py
+++ b/dbgpt/serve/conversation/api/schemas.py
@@ -1,4 +1,6 @@
# Define your Pydantic schemas here
+from typing import Any, Optional
+
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
@@ -7,15 +9,133 @@
class ServeRequest(BaseModel):
"""Conversation request model"""
- # TODO define your own fields here
-
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
+ # Just for query
+ chat_mode: str = Field(
+ default=None,
+ description="The chat mode.",
+ examples=[
+ "chat_normal",
+ ],
+ )
+ conv_uid: Optional[str] = Field(
+ default=None,
+ description="The conversation uid.",
+ examples=[
+ "5e7100bc-9017-11ee-9876-8fe019728d79",
+ ],
+ )
+ user_name: Optional[str] = Field(
+ default=None,
+ description="The user name.",
+ examples=[
+ "zhangsan",
+ ],
+ )
+ sys_code: Optional[str] = Field(
+ default=None,
+ description="The system code.",
+ examples=[
+ "dbgpt",
+ ],
+ )
+
class ServerResponse(BaseModel):
"""Conversation response model"""
- # TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
+
+ conv_uid: str = Field(
+ ...,
+ description="The conversation uid.",
+ examples=[
+ "5e7100bc-9017-11ee-9876-8fe019728d79",
+ ],
+ )
+ user_input: str = Field(
+ ...,
+ description="The user input, we return it as the summary the conversation.",
+ examples=[
+ "Hello world",
+ ],
+ )
+ chat_mode: str = Field(
+ ...,
+ description="The chat mode.",
+ examples=[
+ "chat_normal",
+ ],
+ )
+ select_param: Optional[str] = Field(
+ default=None,
+ description="The select param.",
+ examples=[
+ "my_knowledge_space_name",
+ ],
+ )
+ model_name: Optional[str] = Field(
+ default=None,
+ description="The model name.",
+ examples=[
+ "vicuna-13b-v1.5",
+ ],
+ )
+ user_name: Optional[str] = Field(
+ default=None,
+ description="The user name.",
+ examples=[
+ "zhangsan",
+ ],
+ )
+ sys_code: Optional[str] = Field(
+ default=None,
+ description="The system code.",
+ examples=[
+ "dbgpt",
+ ],
+ )
+
+
+class MessageVo(BaseModel):
+ role: str = Field(
+ ...,
+ description="The role that sends out the current message.",
+ examples=["human", "ai", "view"],
+ )
+ context: str = Field(
+ ...,
+ description="The current message content.",
+ examples=[
+ "Hello",
+ "Hi, how are you?",
+ ],
+ )
+
+ order: int = Field(
+ ...,
+ description="The current message order.",
+ examples=[
+ 1,
+ 2,
+ ],
+ )
+
+ time_stamp: Optional[Any] = Field(
+ default=None,
+ description="The current message time stamp.",
+ examples=[
+ "2023-01-07 09:00:00",
+ ],
+ )
+
+ model_name: Optional[str] = Field(
+ default=None,
+ description="The model name.",
+ examples=[
+ "vicuna-13b-v1.5",
+ ],
+ )
diff --git a/dbgpt/serve/conversation/config.py b/dbgpt/serve/conversation/config.py
index 60819389c..77574d2fc 100644
--- a/dbgpt/serve/conversation/config.py
+++ b/dbgpt/serve/conversation/config.py
@@ -20,3 +20,8 @@ class ServeConfig(BaseServeConfig):
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
+
+ default_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "Default model name"},
+ )
diff --git a/dbgpt/serve/conversation/models/models.py b/dbgpt/serve/conversation/models/models.py
index fb333b051..860f456f4 100644
--- a/dbgpt/serve/conversation/models/models.py
+++ b/dbgpt/serve/conversation/models/models.py
@@ -1,30 +1,20 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
+import json
from datetime import datetime
-from typing import Any, Dict, Union
-
-from sqlalchemy import Column, DateTime, Index, Integer, String, Text
+from typing import Any, Dict, List, Optional, Union
+from dbgpt.core import MessageStorageItem
+from dbgpt.storage.chat_history.chat_history_db import ChatHistoryEntity as ServeEntity
+from dbgpt.storage.chat_history.chat_history_db import ChatHistoryMessageEntity
from dbgpt.storage.metadata import BaseDao, Model, db
+from dbgpt.util import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
-class ServeEntity(Model):
- __tablename__ = SERVER_APP_TABLE_NAME
- id = Column(Integer, primary_key=True, comment="Auto increment id")
-
- # TODO: define your own fields here
-
- gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
- gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
-
- def __repr__(self):
- return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
-
-
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Conversation"""
@@ -68,4 +58,95 @@ def to_response(self, entity: ServeEntity) -> ServerResponse:
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
- return ServerResponse()
+ return ServerResponse(
+ conv_uid=entity.conv_uid,
+ user_input=entity.summary,
+ chat_mode=entity.chat_mode,
+ user_name=entity.user_name,
+ sys_code=entity.sys_code,
+ )
+
+ def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
+ """Get the latest message of a conversation
+
+ Args:
+ conv_uid (str): The conversation UID
+
+ Returns:
+ ChatHistoryMessageEntity: The latest message
+ """
+ with self.session() as session:
+ entity: ChatHistoryMessageEntity = (
+ session.query(ChatHistoryMessageEntity)
+ .filter(ChatHistoryMessageEntity.conv_uid == conv_uid)
+ .order_by(ChatHistoryMessageEntity.gmt_created.desc())
+ .first()
+ )
+ if not entity:
+ return None
+ message_detail = (
+ json.loads(entity.message_detail) if entity.message_detail else {}
+ )
+ return MessageStorageItem(entity.conv_uid, entity.index, message_detail)
+
+ def _parse_old_messages(self, entity: ServeEntity) -> List[Dict[str, Any]]:
+ """Parse the old messages
+
+ Args:
+ entity (ServeEntity): The entity
+
+ Returns:
+ str: The old messages
+ """
+ messages = json.loads(entity.messages)
+ return messages
+
+ def get_conv_by_page(
+ self, req: ServeRequest, page: int, page_size: int
+ ) -> PaginationResult[ServerResponse]:
+ """Get conversation by page
+
+ Args:
+ req (ServeRequest): The request
+ page (int): The page number
+ page_size (int): The page size
+
+ Returns:
+ List[ChatHistoryEntity]: The conversation list
+ """
+ with self.session(commit=False) as session:
+ query = self._create_query_object(session, req)
+ query = query.order_by(ServeEntity.gmt_created.desc())
+ total_count = query.count()
+ items = query.offset((page - 1) * page_size).limit(page_size)
+ total_pages = (total_count + page_size - 1) // page_size
+ result_items = []
+ for item in items:
+ select_param, model_name = "", None
+ if item.messages:
+ messages = self._parse_old_messages(item)
+ last_round = max(messages, key=lambda x: x["chat_order"])
+ if "param_value" in last_round:
+ select_param = last_round["param_value"]
+ else:
+ select_param = ""
+ else:
+ latest_message = self.get_latest_message(item.conv_uid)
+ if latest_message:
+ message = latest_message.to_message()
+ select_param = message.additional_kwargs.get("param_value")
+ model_name = message.additional_kwargs.get("model_name")
+ res_item = self.to_response(item)
+ res_item.select_param = select_param
+ res_item.model_name = model_name
+ result_items.append(res_item)
+
+ result = PaginationResult(
+ items=result_items,
+ total_count=total_count,
+ total_pages=total_pages,
+ page=page,
+ page_size=page_size,
+ )
+
+ return result
diff --git a/dbgpt/serve/conversation/serve.py b/dbgpt/serve/conversation/serve.py
index 2dd63ca6a..643ef6432 100644
--- a/dbgpt/serve/conversation/serve.py
+++ b/dbgpt/serve/conversation/serve.py
@@ -8,6 +8,7 @@
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
+from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
@@ -15,6 +16,7 @@
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
+from .service.service import Service
logger = logging.getLogger(__name__)
@@ -58,6 +60,10 @@ def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
+ self._system_app.app.include_router(
+ router, prefix=self._api_prefix, tags=self._api_tags
+ )
+ init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
diff --git a/dbgpt/serve/conversation/service/service.py b/dbgpt/serve/conversation/service/service.py
index 71592bd13..99fddfb8b 100644
--- a/dbgpt/serve/conversation/service/service.py
+++ b/dbgpt/serve/conversation/service/service.py
@@ -1,11 +1,20 @@
-from typing import List, Optional
+from typing import Any, Dict, List, Optional, Union
from dbgpt.component import BaseComponent, SystemApp
+from dbgpt.core import (
+ InMemoryStorage,
+ MessageStorageItem,
+ QuerySpec,
+ StorageConversation,
+ StorageInterface,
+)
+from dbgpt.core.interface.message import _append_view_messages
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
+from dbgpt.storage.metadata._base_dao import REQ, RES
from dbgpt.util.pagination_utils import PaginationResult
-from ..api.schemas import ServeRequest, ServerResponse
+from ..api.schemas import MessageVo, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
@@ -15,10 +24,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
- def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
+ def __init__(
+ self,
+ system_app: SystemApp,
+ dao: Optional[ServeDao] = None,
+ storage: Optional[StorageInterface[StorageConversation, Any]] = None,
+ message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
+ ):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
+ self._storage = storage
+ self._message_storage = message_storage
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@@ -34,7 +51,7 @@ def init_app(self, system_app: SystemApp) -> None:
self._system_app = system_app
@property
- def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
+ def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao
@@ -43,6 +60,54 @@ def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
+ def create(self, request: REQ) -> RES:
+ raise NotImplementedError()
+
+ @property
+ def conv_storage(self) -> StorageInterface:
+ """The conversation storage, store the conversation items."""
+ if self._storage:
+ return self._storage
+ from ..serve import Serve
+
+ return Serve.call_on_current_serve(
+ self._system_app, lambda serve: serve.conv_storage
+ )
+
+ @property
+ def message_storage(self) -> StorageInterface:
+ """The message storage, store the messages of one conversation."""
+ if self._message_storage:
+ return self._message_storage
+ from ..serve import Serve
+
+ return Serve.call_on_current_serve(
+ self._system_app,
+ lambda serve: serve.message_storage,
+ )
+
+ def create_storage_conv(
+ self, request: Union[ServeRequest, Dict[str, Any]], load_message: bool = True
+ ) -> StorageConversation:
+ conv_storage = self.conv_storage
+ message_storage = self.message_storage
+ if not conv_storage or not message_storage:
+ raise RuntimeError(
+ "Can't get the conversation storage or message storage from current serve component."
+ )
+ if isinstance(request, dict):
+ request = ServeRequest(**request)
+ storage_conv: StorageConversation = StorageConversation(
+ conv_uid=request.conv_uid,
+ chat_mode=request.chat_mode,
+ user_name=request.user_name,
+ sys_code=request.sys_code,
+ conv_storage=conv_storage,
+ message_storage=message_storage,
+ load_message=load_message,
+ )
+ return storage_conv
+
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a Conversation entity
@@ -74,18 +139,13 @@ def get(self, request: ServeRequest) -> Optional[ServerResponse]:
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
- """Delete a Conversation entity
+ """Delete current conversation and its messages
Args:
request (ServeRequest): The request
"""
-
- # TODO: implement your own logic here
- # Build the query request from the request
- query_request = {
- # "id": request.id
- }
- self.dao.delete(query_request)
+ conv: StorageConversation = self.create_storage_conv(request)
+ conv.delete()
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Conversation entities
@@ -114,5 +174,29 @@ def get_list_by_page(
Returns:
List[ServerResponse]: The response
"""
- query_request = request
- return self.dao.get_list_page(query_request, page, page_size)
+ return self.dao.get_conv_by_page(request, page, page_size)
+
+ def get_history_messages(
+ self, request: Union[ServeRequest, Dict[str, Any]]
+ ) -> List[MessageVo]:
+ """Get a list of Conversation entities
+
+ Args:
+ request (ServeRequest): The request
+
+ Returns:
+ List[ServerResponse]: The response
+ """
+ conv: StorageConversation = self.create_storage_conv(request)
+ result = []
+ messages = _append_view_messages(conv.messages)
+ for msg in messages:
+ result.append(
+ MessageVo(
+ role=msg.type,
+ context=msg.content,
+ order=msg.round_index,
+ model_name=self.config.default_model,
+ )
+ )
+ return result
diff --git a/dbgpt/serve/conversation/tests/test_models.py b/dbgpt/serve/conversation/tests/test_models.py
index 1d111644d..c26660d1e 100644
--- a/dbgpt/serve/conversation/tests/test_models.py
+++ b/dbgpt/serve/conversation/tests/test_models.py
@@ -29,7 +29,7 @@ def dao(server_config):
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
- return {}
+ return {"conv_uid": "test_conv_uid", "summary": "hello", "chat_mode": "chat_normal"}
def test_table_exist():
@@ -67,19 +67,6 @@ def test_entity_all():
pass
-def test_dao_create(dao, default_entity_dict):
- # TODO: implement your test case
- req = ServeRequest(**default_entity_dict)
- res: ServerResponse = dao.create(req)
- assert res is not None
-
-
-def test_dao_get_one(dao, default_entity_dict):
- # TODO: implement your test case
- req = ServeRequest(**default_entity_dict)
- res: ServerResponse = dao.create(req)
-
-
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
diff --git a/dbgpt/serve/core/serve.py b/dbgpt/serve/core/serve.py
index e909ad4cb..ff0203b65 100644
--- a/dbgpt/serve/core/serve.py
+++ b/dbgpt/serve/core/serve.py
@@ -73,7 +73,7 @@ def get_current_serve(cls, system_app: SystemApp) -> Optional["BaseServe"]:
Returns:
Optional[BaseServe]: The current serve component.
"""
- return system_app.get_component(cls.name, cls, default_component=None)
+ return cls.get_instance(system_app, default_component=None)
@classmethod
def call_on_current_serve(
diff --git a/dbgpt/storage/cache/__init__.py b/dbgpt/storage/cache/__init__.py
index 906c18c32..80b23bf25 100644
--- a/dbgpt/storage/cache/__init__.py
+++ b/dbgpt/storage/cache/__init__.py
@@ -1,6 +1,6 @@
+from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
-from dbgpt.storage.cache.llm_cache import LLMCacheKey, LLMCacheValue, LLMCacheClient
__all__ = [
"LLMCacheKey",
diff --git a/dbgpt/storage/cache/llm_cache.py b/dbgpt/storage/cache/llm_cache.py
index 682276349..441c67a6c 100644
--- a/dbgpt/storage/cache/llm_cache.py
+++ b/dbgpt/storage/cache/llm_cache.py
@@ -1,16 +1,11 @@
-from typing import Optional, Dict, Any, Union, List
-from dataclasses import dataclass, asdict
import hashlib
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, List, Optional, Union
-from dbgpt.core.interface.cache import (
- CacheKey,
- CacheValue,
- CacheClient,
- CacheConfig,
-)
-from dbgpt.storage.cache.manager import CacheManager
from dbgpt.core import ModelOutput, Serializer
+from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
from dbgpt.model.base import ModelType
+from dbgpt.storage.cache.manager import CacheManager
@dataclass
diff --git a/dbgpt/storage/cache/manager.py b/dbgpt/storage/cache/manager.py
index 8eebd0883..97063c347 100644
--- a/dbgpt/storage/cache/manager.py
+++ b/dbgpt/storage/cache/manager.py
@@ -1,17 +1,12 @@
-from abc import ABC, abstractmethod
-from typing import Optional, Type
import logging
+from abc import ABC, abstractmethod
from concurrent.futures import Executor
-from dbgpt.storage.cache.storage.base import CacheStorage
-from dbgpt.core.interface.cache import K, V
-from dbgpt.core import (
- CacheKey,
- CacheValue,
- CacheConfig,
- Serializer,
- Serializable,
-)
+from typing import Optional, Type
+
from dbgpt.component import BaseComponent, ComponentType, SystemApp
+from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer
+from dbgpt.core.interface.cache import K, V
+from dbgpt.storage.cache.storage.base import CacheStorage
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
logger = logging.getLogger(__name__)
@@ -102,8 +97,8 @@ def serializer(self) -> Serializer:
def initialize_cache(
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
):
- from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
+ from dbgpt.util.serialization.json_serialization import JsonSerializer
cache_storage = None
if storage_type == "disk":
diff --git a/dbgpt/storage/cache/storage/base.py b/dbgpt/storage/cache/storage/base.py
index a31453b56..779552b98 100644
--- a/dbgpt/storage/cache/storage/base.py
+++ b/dbgpt/storage/cache/storage/base.py
@@ -1,18 +1,19 @@
+import logging
from abc import ABC, abstractmethod
-from typing import Optional
-from dataclasses import dataclass
from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Optional
+
import msgpack
-import logging
from dbgpt.core.interface.cache import (
- K,
- V,
+ CacheConfig,
CacheKey,
+ CachePolicy,
CacheValue,
- CacheConfig,
+ K,
RetrievalPolicy,
- CachePolicy,
+ V,
)
from dbgpt.util.memory_utils import _get_object_bytes
diff --git a/dbgpt/storage/cache/storage/disk/disk_storage.py b/dbgpt/storage/cache/storage/disk/disk_storage.py
index 0f4f32523..fa1c01003 100644
--- a/dbgpt/storage/cache/storage/disk/disk_storage.py
+++ b/dbgpt/storage/cache/storage/disk/disk_storage.py
@@ -1,16 +1,17 @@
-from typing import Optional
import logging
-from rocksdict import Rdict, Options
+from typing import Optional
+
+from rocksdict import Options, Rdict
from dbgpt.core.interface.cache import (
- K,
- V,
+ CacheConfig,
CacheKey,
CacheValue,
- CacheConfig,
+ K,
RetrievalPolicy,
+ V,
)
-from dbgpt.storage.cache.storage.base import StorageItem, CacheStorage
+from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem
logger = logging.getLogger(__name__)
diff --git a/dbgpt/storage/cache/storage/tests/test_storage.py b/dbgpt/storage/cache/storage/tests/test_storage.py
index 1a175b2e4..1ba2545f0 100644
--- a/dbgpt/storage/cache/storage/tests/test_storage.py
+++ b/dbgpt/storage/cache/storage/tests/test_storage.py
@@ -1,7 +1,9 @@
import pytest
-from ..base import StorageItem
+
from dbgpt.util.memory_utils import _get_object_bytes
+from ..base import StorageItem
+
def test_build_from():
key_hash = b"key_hash"
diff --git a/dbgpt/storage/chat_history/base.py b/dbgpt/storage/chat_history/base.py
index abc750bba..2f83f0c5d 100644
--- a/dbgpt/storage/chat_history/base.py
+++ b/dbgpt/storage/chat_history/base.py
@@ -1,8 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import List, Optional, Dict
from enum import Enum
+from typing import Dict, List, Optional
+
from dbgpt.core.interface.message import OnceConversation
diff --git a/dbgpt/storage/chat_history/chat_hisotry_factory.py b/dbgpt/storage/chat_history/chat_hisotry_factory.py
index d07afb7d4..f202c20ff 100644
--- a/dbgpt/storage/chat_history/chat_hisotry_factory.py
+++ b/dbgpt/storage/chat_history/chat_hisotry_factory.py
@@ -1,9 +1,14 @@
+import logging
from typing import Type
-from .base import MemoryStoreType
+
from dbgpt._private.config import Config
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
+from .base import MemoryStoreType
+
+# TODO remove global variable
CFG = Config()
+logger = logging.getLogger(__name__)
# Import first for auto create table
from .store_type.meta_db_history import DbHistoryMemory
@@ -13,14 +18,16 @@ class ChatHistory:
def __init__(self):
self.memory_type = MemoryStoreType.DB.value
self.mem_store_class_map = {}
- from .store_type.duckdb_history import DuckdbHistoryMemory
- from .store_type.file_history import FileHistoryMemory
- from .store_type.mem_history import MemHistoryMemory
- self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
- self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
+ # Just support db store type after v0.4.6
+ # from .store_type.duckdb_history import DuckdbHistoryMemory
+ # from .store_type.file_history import FileHistoryMemory
+ # from .store_type.mem_history import MemHistoryMemory
+ # self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
+ # self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
+ # self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
+
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
- self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
"""New store instance for store chat histories
@@ -31,9 +38,39 @@ def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
Returns:
BaseChatHistoryMemory: Store instance
"""
+ self._check_store_type(CFG.CHAT_HISTORY_STORE_TYPE)
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
chat_session_id
)
def get_store_cls(self) -> Type[BaseChatHistoryMemory]:
+ self._check_store_type(CFG.CHAT_HISTORY_STORE_TYPE)
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
+
+ def _check_store_type(self, store_type: str):
+ """Check store type
+
+ Raises:
+ ValueError: Invalid store type
+ """
+ from .store_type.duckdb_history import DuckdbHistoryMemory
+ from .store_type.file_history import FileHistoryMemory
+ from .store_type.mem_history import MemHistoryMemory
+
+ if store_type == MemHistoryMemory.store_type:
+ logger.error(
+ "Not support memory store type, just support db store type now"
+ )
+ raise ValueError(f"Invalid store type: {store_type}")
+
+ if store_type == FileHistoryMemory.store_type:
+ logger.error("Not support file store type, just support db store type now")
+ raise ValueError(f"Invalid store type: {store_type}")
+ if store_type == DuckdbHistoryMemory.store_type:
+ link1 = "https://docs.dbgpt.site/docs/faq/install#q6-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-sqlitel"
+ link2 = "https://docs.dbgpt.site/docs/faq/install#q7-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-mysql"
+ logger.error(
+ "Not support duckdb store type after v0.4.6, just support db store type now, "
+ f"you can migrate your message according to {link1} or {link2}"
+ )
+ raise ValueError(f"Invalid store type: {store_type}")
diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py
index 029abadbf..8faaca8b2 100644
--- a/dbgpt/storage/chat_history/chat_history_db.py
+++ b/dbgpt/storage/chat_history/chat_history_db.py
@@ -1,7 +1,7 @@
-from typing import Optional
from datetime import datetime
-from sqlalchemy import Column, Integer, String, Index, Text, DateTime
-from sqlalchemy import UniqueConstraint
+from typing import Optional
+
+from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
diff --git a/dbgpt/storage/chat_history/storage_adapter.py b/dbgpt/storage/chat_history/storage_adapter.py
index 4a7a3b96b..93302ef89 100644
--- a/dbgpt/storage/chat_history/storage_adapter.py
+++ b/dbgpt/storage/chat_history/storage_adapter.py
@@ -1,16 +1,19 @@
-from typing import List, Dict, Type
import json
+from typing import Dict, List, Type
+
from sqlalchemy.orm import Session
-from dbgpt.core.interface.storage import StorageItemAdapter
+
from dbgpt.core.interface.message import (
- StorageConversation,
+ BaseMessage,
ConversationIdentifier,
MessageIdentifier,
MessageStorageItem,
- _messages_from_dict,
+ StorageConversation,
_conversation_to_dict,
- BaseMessage,
+ _messages_from_dict,
)
+from dbgpt.core.interface.storage import StorageItemAdapter
+
from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity
@@ -21,7 +24,8 @@ def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity:
message_ids = ",".join(item.message_ids)
messages = None
if not item.save_message_independent and item.messages:
- messages = _conversation_to_dict(item)
+ message_dict_list = [_conversation_to_dict(item)]
+ messages = json.dumps(message_dict_list, ensure_ascii=False)
return ChatHistoryEntity(
conv_uid=item.conv_uid,
chat_mode=item.chat_mode,
@@ -42,15 +46,10 @@ def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation:
save_message_independent = True
if old_conversations:
# Load old messages from old conversations, in old design, we save messages to chat_history table
- old_messages_dict = []
- for old_conversation in old_conversations:
- old_messages_dict.extend(
- old_conversation["messages"]
- if "messages" in old_conversation
- else []
- )
save_message_independent = False
- old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict)
+ old_messages: List[BaseMessage] = _parse_old_conversations(
+ old_conversations
+ )
return StorageConversation(
conv_uid=model.conv_uid,
chat_mode=model.chat_mode,
@@ -114,3 +113,24 @@ def get_query_for_identifier(
ChatHistoryMessageEntity.conv_uid == resource_id.conv_uid,
ChatHistoryMessageEntity.index == resource_id.index,
)
+
+
+def _parse_old_conversations(old_conversations: List[Dict]) -> List[BaseMessage]:
+ old_messages_dict = []
+ for old_conversation in old_conversations:
+ messages = (
+ old_conversation["messages"] if "messages" in old_conversation else []
+ )
+ for message in messages:
+ if "data" in message:
+ message_data = message["data"]
+ additional_kwargs = message_data.get("additional_kwargs", {})
+ additional_kwargs["param_value"] = old_conversation.get("param_value")
+ additional_kwargs["param_type"] = old_conversation.get("param_type")
+ additional_kwargs["model_name"] = old_conversation.get("model_name")
+ message_data["additional_kwargs"] = additional_kwargs
+
+ old_messages_dict.extend(messages)
+
+ old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict)
+ return old_messages
diff --git a/dbgpt/storage/chat_history/store_type/duckdb_history.py b/dbgpt/storage/chat_history/store_type/duckdb_history.py
index 2e99fc185..81644f346 100644
--- a/dbgpt/storage/chat_history/store_type/duckdb_history.py
+++ b/dbgpt/storage/chat_history/store_type/duckdb_history.py
@@ -1,15 +1,14 @@
import json
import os
+from typing import Dict, List, Optional
+
import duckdb
-from typing import List, Dict, Optional
from dbgpt._private.config import Config
from dbgpt.configs.model_config import PILOT_PATH
+from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
-from dbgpt.core.interface.message import (
- OnceConversation,
- _conversation_to_dict,
-)
+
from ..base import MemoryStoreType
default_db_path = os.path.join(PILOT_PATH, "message")
diff --git a/dbgpt/storage/chat_history/store_type/file_history.py b/dbgpt/storage/chat_history/store_type/file_history.py
index 013a889b0..efa9f4e69 100644
--- a/dbgpt/storage/chat_history/store_type/file_history.py
+++ b/dbgpt/storage/chat_history/store_type/file_history.py
@@ -1,9 +1,8 @@
-from typing import List
+import datetime
import json
import os
-import datetime
-from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
from pathlib import Path
+from typing import List
from dbgpt._private.config import Config
from dbgpt.core.interface.message import (
@@ -11,7 +10,7 @@
_conversation_from_dict,
_conversations_to_dict,
)
-from dbgpt.storage.chat_history.base import MemoryStoreType
+from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
CFG = Config()
diff --git a/dbgpt/storage/chat_history/store_type/mem_history.py b/dbgpt/storage/chat_history/store_type/mem_history.py
index 4f35c2cd0..3c5264627 100644
--- a/dbgpt/storage/chat_history/store_type/mem_history.py
+++ b/dbgpt/storage/chat_history/store_type/mem_history.py
@@ -1,10 +1,9 @@
from typing import List
-from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
from dbgpt._private.config import Config
from dbgpt.core.interface.message import OnceConversation
+from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
from dbgpt.util.custom_data_structure import FixedSizeDict
-from dbgpt.storage.chat_history.base import MemoryStoreType
CFG = Config()
diff --git a/dbgpt/storage/chat_history/store_type/meta_db_history.py b/dbgpt/storage/chat_history/store_type/meta_db_history.py
index 9676259d0..ca08c69fd 100644
--- a/dbgpt/storage/chat_history/store_type/meta_db_history.py
+++ b/dbgpt/storage/chat_history/store_type/meta_db_history.py
@@ -1,12 +1,11 @@
import json
import logging
-from typing import List, Dict, Optional
+from typing import Dict, List, Optional
+
from dbgpt._private.config import Config
from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict
-from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
-from dbgpt.storage.chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao
-
-from dbgpt.storage.chat_history.base import MemoryStoreType
+from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
+from dbgpt.storage.chat_history.chat_history_db import ChatHistoryDao, ChatHistoryEntity
CFG = Config()
logger = logging.getLogger(__name__)
diff --git a/dbgpt/storage/chat_history/tests/test_storage_adapter.py b/dbgpt/storage/chat_history/tests/test_storage_adapter.py
index 1802a8fd0..3596adc9a 100644
--- a/dbgpt/storage/chat_history/tests/test_storage_adapter.py
+++ b/dbgpt/storage/chat_history/tests/test_storage_adapter.py
@@ -1,20 +1,21 @@
-import pytest
from typing import List
-from dbgpt.util.pagination_utils import PaginationResult
-from dbgpt.util.serialization.json_serialization import JsonSerializer
-from dbgpt.core.interface.message import StorageConversation, HumanMessage, AIMessage
+import pytest
+
+from dbgpt.core.interface.message import AIMessage, HumanMessage, StorageConversation
from dbgpt.core.interface.storage import QuerySpec
-from dbgpt.storage.metadata import db
-from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
from dbgpt.storage.chat_history.storage_adapter import (
- DBStorageConversationItemAdapter,
DBMessageStorageItemAdapter,
+ DBStorageConversationItemAdapter,
)
+from dbgpt.storage.metadata import db
+from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
+from dbgpt.util.pagination_utils import PaginationResult
+from dbgpt.util.serialization.json_serialization import JsonSerializer
@pytest.fixture
diff --git a/dbgpt/storage/metadata/__init__.py b/dbgpt/storage/metadata/__init__.py
index ee409f6cb..8660866d9 100644
--- a/dbgpt/storage/metadata/__init__.py
+++ b/dbgpt/storage/metadata/__init__.py
@@ -1,12 +1,12 @@
+from dbgpt.storage.metadata._base_dao import BaseDao
+from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory
from dbgpt.storage.metadata.db_manager import (
- db,
- Model,
+ BaseModel,
DatabaseManager,
+ Model,
create_model,
- BaseModel,
+ db,
)
-from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory
-from dbgpt.storage.metadata._base_dao import BaseDao
__ALL__ = [
"db",
diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py
index 770d82c93..96294c00b 100644
--- a/dbgpt/storage/metadata/_base_dao.py
+++ b/dbgpt/storage/metadata/_base_dao.py
@@ -1,6 +1,8 @@
from contextlib import contextmanager
-from typing import TypeVar, Generic, Any, Optional, Dict, Union, List
+from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
+
from sqlalchemy.orm.session import Session
+
from dbgpt.util.pagination_utils import PaginationResult
# The entity type
@@ -10,8 +12,7 @@
# The response schema type
RES = TypeVar("RES")
-from .db_manager import db, DatabaseManager, BaseQuery
-
+from .db_manager import BaseQuery, DatabaseManager, db
QUERY_SPEC = Union[REQ, Dict[str, Any]]
@@ -21,6 +22,7 @@ class BaseDao(Generic[T, REQ, RES]):
Examples:
.. code-block:: python
+
class UserDao(BaseDao):
def get_user_by_name(self, name: str) -> User:
with self.session() as session:
@@ -70,8 +72,9 @@ def session(self, commit: Optional[bool] = True) -> Session:
Example:
.. code-block:: python
+
with self.session() as session:
- session.query(User).filter(User.name == 'Edward Snowden').first()
+ session.query(User).filter(User.name == "Edward Snowden").first()
Args:
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
diff --git a/dbgpt/storage/metadata/db_factory.py b/dbgpt/storage/metadata/db_factory.py
index 14cf0339a..c288a149d 100644
--- a/dbgpt/storage/metadata/db_factory.py
+++ b/dbgpt/storage/metadata/db_factory.py
@@ -1,4 +1,4 @@
-from dbgpt.component import SystemApp, BaseComponent, ComponentType
+from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .db_manager import DatabaseManager
diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py
index f222dd38a..67ec35954 100644
--- a/dbgpt/storage/metadata/db_manager.py
+++ b/dbgpt/storage/metadata/db_manager.py
@@ -1,29 +1,21 @@
from __future__ import annotations
-from contextlib import contextmanager
-from typing import (
- TypeVar,
- Generic,
- Union,
- Dict,
- Optional,
- Type,
- ClassVar,
-)
import logging
-from sqlalchemy import create_engine, URL, Engine
-from sqlalchemy import orm, inspect, MetaData
+from contextlib import contextmanager
+from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union
+
+from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm
from sqlalchemy.orm import (
- scoped_session,
- sessionmaker,
+ DeclarativeMeta,
Session,
declarative_base,
- DeclarativeMeta,
+ scoped_session,
+ sessionmaker,
)
-
from sqlalchemy.pool import QueuePool
-from dbgpt.util.string_utils import _to_str
+
from dbgpt.util.pagination_utils import PaginationResult
+from dbgpt.util.string_utils import _to_str
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="BaseModel")
@@ -314,7 +306,7 @@ def init_default_db(
Examples:
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
- >>> session.query(...)
+ ... session.query(...)
Args:
sqlite_path (str): The sqlite path.
diff --git a/dbgpt/storage/metadata/db_storage.py b/dbgpt/storage/metadata/db_storage.py
index daa87ceba..6b6e9716a 100644
--- a/dbgpt/storage/metadata/db_storage.py
+++ b/dbgpt/storage/metadata/db_storage.py
@@ -1,25 +1,28 @@
from contextlib import contextmanager
+from typing import Dict, List, Optional, Type, Union
+
+from sqlalchemy import URL
+from sqlalchemy.orm import DeclarativeMeta, Session
-from typing import Type, List, Optional, Union, Dict
from dbgpt.core import Serializer
from dbgpt.core.interface.storage import (
- StorageInterface,
QuerySpec,
ResourceIdentifier,
+ StorageInterface,
StorageItemAdapter,
T,
)
-from sqlalchemy import URL
-from sqlalchemy.orm import Session, DeclarativeMeta
-from .db_manager import BaseModel, DatabaseManager, BaseQuery
+from .db_manager import BaseModel, BaseQuery, DatabaseManager
def _copy_public_properties(src: BaseModel, dest: BaseModel):
"""Simple copy public properties from src to dest"""
for column in src.__table__.columns:
if column.name != "id":
- setattr(dest, column.name, getattr(src, column.name))
+ value = getattr(src, column.name)
+ if value is not None:
+ setattr(dest, column.name, value)
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
@@ -51,8 +54,16 @@ def save(self, data: T) -> None:
def update(self, data: T) -> None:
with self.session() as session:
- model_instance = self.adapter.to_storage_format(data)
- session.merge(model_instance)
+ query = self.adapter.get_query_for_identifier(
+ self._model_class, data.identifier, session=session
+ )
+ exist_model_instance = query.with_session(session).first()
+ if exist_model_instance:
+ _copy_public_properties(
+ self.adapter.to_storage_format(data), exist_model_instance
+ )
+ session.merge(exist_model_instance)
+ return
def save_or_update(self, data: T) -> None:
with self.session() as session:
diff --git a/dbgpt/storage/metadata/tests/test_base_dao.py b/dbgpt/storage/metadata/tests/test_base_dao.py
index 9728fd482..cfa86ba3c 100644
--- a/dbgpt/storage/metadata/tests/test_base_dao.py
+++ b/dbgpt/storage/metadata/tests/test_base_dao.py
@@ -1,12 +1,15 @@
-from typing import Type, Optional, Union, Dict, Any
+from typing import Any, Dict, Optional, Type, Union
+
import pytest
from sqlalchemy import Column, Integer, String
-from dbgpt._private.pydantic import BaseModel as PydanticBaseModel, Field
+
+from dbgpt._private.pydantic import BaseModel as PydanticBaseModel
+from dbgpt._private.pydantic import Field
from dbgpt.storage.metadata.db_manager import (
+ BaseModel,
DatabaseManager,
PaginationResult,
create_model,
- BaseModel,
)
from .._base_dao import BaseDao
diff --git a/dbgpt/storage/metadata/tests/test_db_manager.py b/dbgpt/storage/metadata/tests/test_db_manager.py
index aa67787fa..551f0d9d7 100644
--- a/dbgpt/storage/metadata/tests/test_db_manager.py
+++ b/dbgpt/storage/metadata/tests/test_db_manager.py
@@ -1,14 +1,17 @@
from __future__ import annotations
-import pytest
+
import tempfile
from typing import Type
+
+import pytest
+from sqlalchemy import Column, Integer, String
+
from dbgpt.storage.metadata.db_manager import (
+ BaseModel,
DatabaseManager,
PaginationResult,
create_model,
- BaseModel,
)
-from sqlalchemy import Column, Integer, String
@pytest.fixture
diff --git a/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py b/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
index fcae83215..df4688c71 100644
--- a/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
+++ b/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
@@ -1,21 +1,19 @@
from typing import Dict, Type
-from sqlalchemy.orm import declarative_base, Session
-from sqlalchemy import Column, Integer, String
import pytest
+from sqlalchemy import Column, Integer, String
+from sqlalchemy.orm import Session, declarative_base
from dbgpt.core.interface.storage import (
- StorageItem,
+ QuerySpec,
ResourceIdentifier,
+ StorageItem,
StorageItemAdapter,
- QuerySpec,
)
-from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
-
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
+from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
-
Base = declarative_base()
diff --git a/dbgpt/storage/schema.py b/dbgpt/storage/schema.py
index ab6dc7b40..d5874ba8e 100644
--- a/dbgpt/storage/schema.py
+++ b/dbgpt/storage/schema.py
@@ -1,5 +1,5 @@
-from enum import Enum
import os
+from enum import Enum
class DbInfo:
diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py
index 2f944fc71..2d2918f71 100644
--- a/dbgpt/storage/vector_store/base.py
+++ b/dbgpt/storage/vector_store/base.py
@@ -1,8 +1,8 @@
-from abc import ABC, abstractmethod
import math
-from typing import Optional, Callable, List, Any
+from abc import ABC, abstractmethod
+from typing import Any, Callable, List, Optional
-from pydantic import Field, BaseModel
+from pydantic import BaseModel, Field
from dbgpt.rag.chunk import Chunk
diff --git a/dbgpt/storage/vector_store/chroma_store.py b/dbgpt/storage/vector_store/chroma_store.py
index 9149e3665..cd5b76e71 100644
--- a/dbgpt/storage/vector_store/chroma_store.py
+++ b/dbgpt/storage/vector_store/chroma_store.py
@@ -1,14 +1,14 @@
-import os
import logging
+import os
from typing import Any, List
-from chromadb.config import Settings
from chromadb import PersistentClient
+from chromadb.config import Settings
from pydantic import Field
+from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
-from dbgpt.configs.model_config import PILOT_PATH
logger = logging.getLogger(__name__)
diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py
index a256e6a87..27ee1b871 100644
--- a/dbgpt/storage/vector_store/connector.py
+++ b/dbgpt/storage/vector_store/connector.py
@@ -1,5 +1,5 @@
import os
-from typing import Optional, List, Callable, Any
+from typing import Any, Callable, List, Optional
from dbgpt.rag.chunk import Chunk
from dbgpt.storage import vector_store
diff --git a/dbgpt/storage/vector_store/pgvector_store.py b/dbgpt/storage/vector_store/pgvector_store.py
index 9917ca54b..2563a2893 100644
--- a/dbgpt/storage/vector_store/pgvector_store.py
+++ b/dbgpt/storage/vector_store/pgvector_store.py
@@ -1,11 +1,11 @@
-from typing import Any, List
import logging
+from typing import Any, List
from pydantic import Field
+from dbgpt._private.config import Config
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
-from dbgpt._private.config import Config
logger = logging.getLogger(__name__)
diff --git a/dbgpt/storage/vector_store/weaviate_store.py b/dbgpt/storage/vector_store/weaviate_store.py
index 92ded9980..77744489e 100644
--- a/dbgpt/storage/vector_store/weaviate_store.py
+++ b/dbgpt/storage/vector_store/weaviate_store.py
@@ -1,5 +1,5 @@
-import os
import logging
+import os
from typing import List
from langchain.schema import Document
diff --git a/dbgpt/util/__init__.py b/dbgpt/util/__init__.py
index 74945e5ae..d642b7dd4 100644
--- a/dbgpt/util/__init__.py
+++ b/dbgpt/util/__init__.py
@@ -1,10 +1,7 @@
-from .utils import (
- get_gpu_memory,
- get_or_create_event_loop,
-)
-from .pagination_utils import PaginationResult
-from .parameter_utils import BaseParameters, ParameterDescription, EnvArgumentParser
from .config_utils import AppConfig
+from .pagination_utils import PaginationResult
+from .parameter_utils import BaseParameters, EnvArgumentParser, ParameterDescription
+from .utils import get_gpu_memory, get_or_create_event_loop
__ALL__ = [
"get_gpu_memory",
diff --git a/dbgpt/util/_db_migration_utils.py b/dbgpt/util/_db_migration_utils.py
index 13734960d..f8171f28a 100644
--- a/dbgpt/util/_db_migration_utils.py
+++ b/dbgpt/util/_db_migration_utils.py
@@ -1,12 +1,12 @@
-from typing import Optional
-import os
import logging
-from sqlalchemy import Engine, text
-from sqlalchemy.orm import Session, DeclarativeMeta
+import os
+from typing import Optional
+
from alembic import command
-from alembic.util.exc import CommandError
from alembic.config import Config as AlembicConfig
-
+from alembic.util.exc import CommandError
+from sqlalchemy import Engine, text
+from sqlalchemy.orm import DeclarativeMeta, Session
logger = logging.getLogger(__name__)
@@ -67,8 +67,8 @@ def create_migration_script(
Returns:
The path of the generated migration script.
"""
- from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
+ from alembic.script import ScriptDirectory
# Check if the database is up-to-date
script_dir = ScriptDirectory.from_config(alembic_cfg)
@@ -244,8 +244,8 @@ def _check_database_migration_status(alembic_cfg: AlembicConfig, engine: Engine)
Raises:
Exception: If the database is not at the latest revision.
"""
- from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
+ from alembic.script import ScriptDirectory
script = ScriptDirectory.from_config(alembic_cfg)
diff --git a/dbgpt/util/api_utils.py b/dbgpt/util/api_utils.py
index 1fd3499d2..175cd55de 100644
--- a/dbgpt/util/api_utils.py
+++ b/dbgpt/util/api_utils.py
@@ -1,7 +1,7 @@
-from inspect import signature
import logging
-from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
-from dataclasses import is_dataclass, asdict
+from dataclasses import asdict, is_dataclass
+from inspect import signature
+from typing import List, Optional, Tuple, Type, TypeVar, Union, get_type_hints
T = TypeVar("T")
diff --git a/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py b/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py
index 9c07f7a0a..fc0a847a1 100644
--- a/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py
+++ b/dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py
@@ -4,9 +4,10 @@
"""
import gc
-from typing import Iterable, Dict
+from typing import Dict, Iterable
import torch
+from fastchat.utils import get_context_length, is_partial_stop, is_sentence_complete
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
@@ -16,9 +17,6 @@
)
-from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
-
-
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
diff --git a/dbgpt/util/benchmarks/llm/llm_benchmarks.py b/dbgpt/util/benchmarks/llm/llm_benchmarks.py
index 81e839f5a..a5ded5eff 100644
--- a/dbgpt/util/benchmarks/llm/llm_benchmarks.py
+++ b/dbgpt/util/benchmarks/llm/llm_benchmarks.py
@@ -1,25 +1,23 @@
-from typing import Dict, List
+import argparse
import asyncio
+import csv
+import logging
import os
import sys
import time
-import csv
-import argparse
-import logging
import traceback
-from dbgpt.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
from datetime import datetime
+from typing import Dict, List
+from dbgpt.configs.model_config import LLM_MODEL_CONFIG, ROOT_PATH
+from dbgpt.core import ModelInferenceMetrics, ModelOutput
+from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.cluster.worker.manager import (
- run_worker_manager,
- initialize_worker_manager_in_client,
WorkerManager,
+ initialize_worker_manager_in_client,
+ run_worker_manager,
)
-from dbgpt.core import ModelOutput, ModelInferenceMetrics
-from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
-
-
model_name = "vicuna-7b-v1.5"
model_path = LLM_MODEL_CONFIG[model_name]
# or vllm
diff --git a/dbgpt/util/chat_util.py b/dbgpt/util/chat_util.py
index 1c76a43f5..0755ef706 100644
--- a/dbgpt/util/chat_util.py
+++ b/dbgpt/util/chat_util.py
@@ -1,5 +1,5 @@
import asyncio
-from typing import Coroutine, List, Any
+from typing import Any, Coroutine, List
from dbgpt.app.scene import BaseChat, ChatFactory
diff --git a/dbgpt/util/command_utils.py b/dbgpt/util/command_utils.py
index e7aa241bc..de554119d 100644
--- a/dbgpt/util/command_utils.py
+++ b/dbgpt/util/command_utils.py
@@ -1,10 +1,11 @@
-import sys
import os
-import subprocess
-from typing import List, Dict
-import psutil
import platform
+import subprocess
+import sys
from functools import lru_cache
+from typing import Dict, List
+
+import psutil
def _get_abspath_of_current_command(command_path: str):
diff --git a/dbgpt/util/custom_data_structure.py b/dbgpt/util/custom_data_structure.py
index a8502143a..d86736c81 100644
--- a/dbgpt/util/custom_data_structure.py
+++ b/dbgpt/util/custom_data_structure.py
@@ -1,5 +1,4 @@
-from collections import OrderedDict
-from collections import deque
+from collections import OrderedDict, deque
class FixedSizeDict(OrderedDict):
diff --git a/dbgpt/util/executor_utils.py b/dbgpt/util/executor_utils.py
index 1e9aa4b3e..d530967b4 100644
--- a/dbgpt/util/executor_utils.py
+++ b/dbgpt/util/executor_utils.py
@@ -1,9 +1,9 @@
-from typing import Callable, Awaitable, Any
import asyncio
import contextvars
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
+from typing import Any, Awaitable, Callable
from dbgpt.component import BaseComponent, ComponentType, SystemApp
diff --git a/dbgpt/util/function_utils.py b/dbgpt/util/function_utils.py
index 5bfd578ea..ccce14ddd 100644
--- a/dbgpt/util/function_utils.py
+++ b/dbgpt/util/function_utils.py
@@ -1,7 +1,7 @@
-from typing import Any, get_type_hints, get_origin, get_args
-from functools import wraps
-import inspect
import asyncio
+import inspect
+from functools import wraps
+from typing import Any, get_args, get_origin, get_type_hints
def _is_instance_of_generic_type(obj, generic_type):
@@ -65,10 +65,12 @@ def rearrange_args_by_type(func):
from dbgpt.util.function_utils import rearrange_args_by_type
+
@rearrange_args_by_type
def sync_regular_function(a: int, b: str, c: float):
return a, b, c
+
assert instance.sync_class_method(1, "b", 3.0) == (1, "b", 3.0)
assert instance.sync_class_method("b", 3.0, 1) == (1, "b", 3.0)
diff --git a/dbgpt/util/global_helper.py b/dbgpt/util/global_helper.py
index 617995373..c3972bba2 100644
--- a/dbgpt/util/global_helper.py
+++ b/dbgpt/util/global_helper.py
@@ -196,7 +196,7 @@ def truncate_text(text: str, max_length: int) -> str:
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""Iterate over an iterable in batches.
- >>> list(iter_batch([1,2,3,4,5], 3))
+ >>> list(iter_batch([1, 2, 3, 4, 5], 3))
[[1, 2, 3], [4, 5]]
"""
source_iter = iter(iterable)
diff --git a/dbgpt/util/json_utils.py b/dbgpt/util/json_utils.py
index e861e7284..9e12e880a 100644
--- a/dbgpt/util/json_utils.py
+++ b/dbgpt/util/json_utils.py
@@ -1,10 +1,10 @@
"""Utilities for the json_fixes package."""
import json
-from datetime import date, datetime
-from dataclasses import dataclass, asdict, is_dataclass
+import logging
import os.path
import re
-import logging
+from dataclasses import asdict, dataclass, is_dataclass
+from datetime import date, datetime
from jsonschema import Draft7Validator
diff --git a/dbgpt/util/memory_utils.py b/dbgpt/util/memory_utils.py
index cb0427c08..07d69339d 100644
--- a/dbgpt/util/memory_utils.py
+++ b/dbgpt/util/memory_utils.py
@@ -1,4 +1,5 @@
from typing import Any
+
from pympler import asizeof
diff --git a/dbgpt/util/model_utils.py b/dbgpt/util/model_utils.py
index 5df73dff0..0bf77bb94 100644
--- a/dbgpt/util/model_utils.py
+++ b/dbgpt/util/model_utils.py
@@ -1,6 +1,6 @@
-from typing import List, Tuple
-from dataclasses import dataclass
import logging
+from dataclasses import dataclass
+from typing import List, Tuple
logger = logging.getLogger(__name__)
@@ -17,9 +17,10 @@ def _clear_model_cache(device="cuda"):
def _clear_torch_cache(device="cuda"):
- import torch
import gc
+ import torch
+
gc.collect()
if device != "cpu":
if torch.has_mps:
diff --git a/dbgpt/util/module_utils.py b/dbgpt/util/module_utils.py
index c2d857440..d26a387d2 100644
--- a/dbgpt/util/module_utils.py
+++ b/dbgpt/util/module_utils.py
@@ -1,5 +1,5 @@
-from typing import Type
from importlib import import_module
+from typing import Type
def import_from_string(module_path: str, ignore_import_error: bool = False):
diff --git a/dbgpt/util/net_utils.py b/dbgpt/util/net_utils.py
index 8fc803e6f..fc9fb3f86 100644
--- a/dbgpt/util/net_utils.py
+++ b/dbgpt/util/net_utils.py
@@ -1,5 +1,5 @@
-import socket
import errno
+import socket
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
diff --git a/dbgpt/util/openai_utils.py b/dbgpt/util/openai_utils.py
index 5e1673f08..0e66c727a 100644
--- a/dbgpt/util/openai_utils.py
+++ b/dbgpt/util/openai_utils.py
@@ -1,8 +1,9 @@
-from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
-import httpx
import asyncio
-import logging
import json
+import logging
+from typing import Any, Awaitable, Callable, Dict, Iterator, Optional
+
+import httpx
logger = logging.getLogger(__name__)
MessageCaller = Callable[[str], Awaitable[None]]
diff --git a/dbgpt/util/pagination_utils.py b/dbgpt/util/pagination_utils.py
index cbe21dda0..4b9288cb8 100644
--- a/dbgpt/util/pagination_utils.py
+++ b/dbgpt/util/pagination_utils.py
@@ -1,4 +1,5 @@
-from typing import TypeVar, Generic, List
+from typing import Generic, List, TypeVar
+
from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T")
diff --git a/dbgpt/util/parameter_utils.py b/dbgpt/util/parameter_utils.py
index 8baddffa3..202c5f8a6 100644
--- a/dbgpt/util/parameter_utils.py
+++ b/dbgpt/util/parameter_utils.py
@@ -1,8 +1,8 @@
import argparse
import os
-from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
-from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
from collections import OrderedDict
+from dataclasses import MISSING, asdict, dataclass, field, fields, is_dataclass
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
if TYPE_CHECKING:
from dbgpt._private.pydantic import BaseModel
diff --git a/dbgpt/util/prompt_util.py b/dbgpt/util/prompt_util.py
index e0c0a3846..bad972a7a 100644
--- a/dbgpt/util/prompt_util.py
+++ b/dbgpt/util/prompt_util.py
@@ -12,12 +12,11 @@
from string import Formatter
from typing import Callable, List, Optional, Sequence, Set
-from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
-
-from dbgpt.util.global_helper import globals_helper
-from dbgpt.core.interface.prompt import get_template_vars
from dbgpt._private.llm_metadata import LLMMetadata
+from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr
+from dbgpt.core.interface.prompt import get_template_vars
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
+from dbgpt.util.global_helper import globals_helper
DEFAULT_PADDING = 5
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
diff --git a/dbgpt/util/serialization/json_serialization.py b/dbgpt/util/serialization/json_serialization.py
index 58811cae2..b66530b76 100644
--- a/dbgpt/util/serialization/json_serialization.py
+++ b/dbgpt/util/serialization/json_serialization.py
@@ -1,6 +1,6 @@
+import json
from abc import ABC, abstractmethod
from typing import Dict, Type
-import json
from dbgpt.core.interface.serialization import Serializable, Serializer
diff --git a/dbgpt/util/speech/eleven_labs.py b/dbgpt/util/speech/eleven_labs.py
index 6bafd5597..5173b63d9 100644
--- a/dbgpt/util/speech/eleven_labs.py
+++ b/dbgpt/util/speech/eleven_labs.py
@@ -1,6 +1,7 @@
"""ElevenLabs speech module"""
-import os
import logging
+import os
+
import requests
from dbgpt._private.config import Config
diff --git a/dbgpt/util/splitter_utils.py b/dbgpt/util/splitter_utils.py
index 9c57f2111..8905aa3b4 100644
--- a/dbgpt/util/splitter_utils.py
+++ b/dbgpt/util/splitter_utils.py
@@ -25,7 +25,6 @@ def split_by_sentence_tokenizer() -> Callable[[str], List[str]]:
import os
import nltk
-
from llama_index.utils import get_cache_dir
cache_dir = get_cache_dir()
diff --git a/dbgpt/util/system_utils.py b/dbgpt/util/system_utils.py
index 50b54be0d..463fe354f 100644
--- a/dbgpt/util/system_utils.py
+++ b/dbgpt/util/system_utils.py
@@ -1,11 +1,11 @@
-from dataclasses import dataclass, asdict
-from enum import Enum
-from typing import Tuple, Dict
import os
import platform
-import subprocess
import re
+import subprocess
+from dataclasses import asdict, dataclass
+from enum import Enum
from functools import cache
+from typing import Dict, Tuple
@dataclass
diff --git a/dbgpt/util/tests/test_function_utils.py b/dbgpt/util/tests/test_function_utils.py
index f245b8c14..aad1dd600 100644
--- a/dbgpt/util/tests/test_function_utils.py
+++ b/dbgpt/util/tests/test_function_utils.py
@@ -1,6 +1,7 @@
-from typing import List, Dict, Any
+from typing import Any, Dict, List
import pytest
+
from dbgpt.util.function_utils import rearrange_args_by_type
diff --git a/dbgpt/util/tests/test_parameter_utils.py b/dbgpt/util/tests/test_parameter_utils.py
index 79b672ef6..4749076ed 100644
--- a/dbgpt/util/tests/test_parameter_utils.py
+++ b/dbgpt/util/tests/test_parameter_utils.py
@@ -1,5 +1,7 @@
import argparse
+
import pytest
+
from dbgpt.util.parameter_utils import _extract_parameter_details
diff --git a/dbgpt/util/tracer/__init__.py b/dbgpt/util/tracer/__init__.py
index fd6c7b4de..25cbe372e 100644
--- a/dbgpt/util/tracer/__init__.py
+++ b/dbgpt/util/tracer/__init__.py
@@ -1,23 +1,23 @@
from dbgpt.util.tracer.base import (
- SpanType,
Span,
- SpanTypeRunName,
- Tracer,
SpanStorage,
SpanStorageType,
+ SpanType,
+ SpanTypeRunName,
+ Tracer,
TracerContext,
)
from dbgpt.util.tracer.span_storage import (
- MemorySpanStorage,
FileSpanStorage,
+ MemorySpanStorage,
SpanStorageContainer,
)
from dbgpt.util.tracer.tracer_impl import (
- root_tracer,
- trace,
- initialize_tracer,
DefaultTracer,
TracerManager,
+ initialize_tracer,
+ root_tracer,
+ trace,
)
__all__ = [
diff --git a/dbgpt/util/tracer/base.py b/dbgpt/util/tracer/base.py
index dcf988bdf..77f8049f1 100644
--- a/dbgpt/util/tracer/base.py
+++ b/dbgpt/util/tracer/base.py
@@ -1,13 +1,13 @@
from __future__ import annotations
-from typing import Dict, Callable, Optional, List
-from dataclasses import dataclass
-from abc import ABC, abstractmethod
-from enum import Enum
import uuid
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
from datetime import datetime
+from enum import Enum
+from typing import Callable, Dict, List, Optional
-from dbgpt.component import BaseComponent, SystemApp, ComponentType
+from dbgpt.component import BaseComponent, ComponentType, SystemApp
class SpanType(str, Enum):
diff --git a/dbgpt/util/tracer/span_storage.py b/dbgpt/util/tracer/span_storage.py
index 4872336dc..296480d9d 100644
--- a/dbgpt/util/tracer/span_storage.py
+++ b/dbgpt/util/tracer/span_storage.py
@@ -1,12 +1,12 @@
-import os
-import json
-import time
import datetime
-import threading
-import queue
+import json
import logging
-from typing import Optional, List
+import os
+import queue
+import threading
+import time
from concurrent.futures import Executor, ThreadPoolExecutor
+from typing import List, Optional
from dbgpt.component import SystemApp
from dbgpt.util.tracer.base import Span, SpanStorage
diff --git a/dbgpt/util/tracer/tests/test_base.py b/dbgpt/util/tracer/tests/test_base.py
index 2a0449afa..933f55b97 100644
--- a/dbgpt/util/tracer/tests/test_base.py
+++ b/dbgpt/util/tracer/tests/test_base.py
@@ -1,8 +1,7 @@
from typing import Dict
-from dbgpt.component import SystemApp
-
-from dbgpt.util.tracer import Span, SpanType, SpanStorage, Tracer
+from dbgpt.component import SystemApp
+from dbgpt.util.tracer import Span, SpanStorage, SpanType, Tracer
# Mock implementations
diff --git a/dbgpt/util/tracer/tests/test_span_storage.py b/dbgpt/util/tracer/tests/test_span_storage.py
index 6bf9a48ee..984c8dd1c 100644
--- a/dbgpt/util/tracer/tests/test_span_storage.py
+++ b/dbgpt/util/tracer/tests/test_span_storage.py
@@ -1,18 +1,19 @@
-import os
-import pytest
import asyncio
import json
+import os
import tempfile
import time
-from unittest.mock import patch
from datetime import datetime, timedelta
+from unittest.mock import patch
+
+import pytest
from dbgpt.util.tracer import (
- SpanStorage,
FileSpanStorage,
Span,
- SpanType,
+ SpanStorage,
SpanStorageContainer,
+ SpanType,
)
diff --git a/dbgpt/util/tracer/tests/test_tracer_impl.py b/dbgpt/util/tracer/tests/test_tracer_impl.py
index 9e09ea623..2ea077402 100644
--- a/dbgpt/util/tracer/tests/test_tracer_impl.py
+++ b/dbgpt/util/tracer/tests/test_tracer_impl.py
@@ -1,14 +1,15 @@
import pytest
+
+from dbgpt.component import SystemApp
from dbgpt.util.tracer import (
+ DefaultTracer,
+ MemorySpanStorage,
Span,
- SpanStorageType,
SpanStorage,
- DefaultTracer,
- TracerManager,
+ SpanStorageType,
Tracer,
- MemorySpanStorage,
+ TracerManager,
)
-from dbgpt.component import SystemApp
@pytest.fixture
diff --git a/dbgpt/util/tracer/tracer_cli.py b/dbgpt/util/tracer/tracer_cli.py
index f902910e0..40fdf87e3 100644
--- a/dbgpt/util/tracer/tracer_cli.py
+++ b/dbgpt/util/tracer/tracer_cli.py
@@ -1,10 +1,12 @@
-import os
-import click
-import logging
import glob
import json
+import logging
+import os
from datetime import datetime
-from typing import Iterable, Dict, Callable
+from typing import Callable, Dict, Iterable
+
+import click
+
from dbgpt.configs.model_config import LOGDIR
from dbgpt.util.tracer import SpanType, SpanTypeRunName
diff --git a/dbgpt/util/tracer/tracer_impl.py b/dbgpt/util/tracer/tracer_impl.py
index ac9485e6c..cd7f87737 100644
--- a/dbgpt/util/tracer/tracer_impl.py
+++ b/dbgpt/util/tracer/tracer_impl.py
@@ -1,22 +1,21 @@
-from typing import Dict, Optional
-from contextvars import ContextVar
-from functools import wraps
import asyncio
import inspect
import logging
+from contextvars import ContextVar
+from functools import wraps
+from typing import Dict, Optional
-
-from dbgpt.component import SystemApp, ComponentType
+from dbgpt.component import ComponentType, SystemApp
+from dbgpt.util.module_utils import import_from_checked_string
from dbgpt.util.tracer.base import (
- SpanType,
Span,
- Tracer,
SpanStorage,
SpanStorageType,
+ SpanType,
+ Tracer,
TracerContext,
)
from dbgpt.util.tracer.span_storage import MemorySpanStorage
-from dbgpt.util.module_utils import import_from_checked_string
logger = logging.getLogger(__name__)
diff --git a/dbgpt/util/tracer/tracer_middleware.py b/dbgpt/util/tracer/tracer_middleware.py
index 55920278c..6e0d35222 100644
--- a/dbgpt/util/tracer/tracer_middleware.py
+++ b/dbgpt/util/tracer/tracer_middleware.py
@@ -4,8 +4,8 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import ASGIApp
-from dbgpt.util.tracer import TracerContext, Tracer
+from dbgpt.util.tracer import Tracer, TracerContext
_DEFAULT_EXCLUDE_PATHS = ["/api/controller/heartbeat"]
diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py
index 4f7c2678d..bed637cba 100644
--- a/dbgpt/util/utils.py
+++ b/dbgpt/util/utils.py
@@ -1,12 +1,11 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
+import asyncio
import logging
import logging.handlers
-from typing import Any, List
-
import os
-import asyncio
+from typing import Any, List
from dbgpt.configs.model_config import LOGDIR
diff --git a/dbgpt/vis/__init__.py b/dbgpt/vis/__init__.py
index 2cc1c3aa8..dab2559a8 100644
--- a/dbgpt/vis/__init__.py
+++ b/dbgpt/vis/__init__.py
@@ -1,7 +1,7 @@
+from .client import vis_client
+from .tags.vis_agent_message import VisAgentMessages
+from .tags.vis_agent_plans import VisAgentPlans
from .tags.vis_chart import VisChart
from .tags.vis_code import VisCode
from .tags.vis_dashboard import VisDashboard
-from .tags.vis_agent_plans import VisAgentPlans
-from .tags.vis_agent_message import VisAgentMessages
from .tags.vis_plugin import VisPlugin
-from .client import vis_client
diff --git a/dbgpt/vis/base.py b/dbgpt/vis/base.py
index c5f87c4c6..a01b6859a 100644
--- a/dbgpt/vis/base.py
+++ b/dbgpt/vis/base.py
@@ -1,6 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
+
from dbgpt.util.json_utils import serialize
diff --git a/dbgpt/vis/client.py b/dbgpt/vis/client.py
index b106d6ca3..4f62fd452 100644
--- a/dbgpt/vis/client.py
+++ b/dbgpt/vis/client.py
@@ -1,11 +1,12 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
-from .tags.vis_code import VisCode
+
+from .base import Vis
+from .tags.vis_agent_message import VisAgentMessages
+from .tags.vis_agent_plans import VisAgentPlans
from .tags.vis_chart import VisChart
+from .tags.vis_code import VisCode
from .tags.vis_dashboard import VisDashboard
-from .tags.vis_agent_plans import VisAgentPlans
-from .tags.vis_agent_message import VisAgentMessages
from .tags.vis_plugin import VisPlugin
-from .base import Vis
class VisClient:
diff --git a/dbgpt/vis/tags/vis_agent_message.py b/dbgpt/vis/tags/vis_agent_message.py
index c33f6e618..f15e8c210 100644
--- a/dbgpt/vis/tags/vis_agent_message.py
+++ b/dbgpt/vis/tags/vis_agent_message.py
@@ -1,4 +1,5 @@
from typing import Optional
+
from ..base import Vis
diff --git a/dbgpt/vis/tags/vis_agent_plans.py b/dbgpt/vis/tags/vis_agent_plans.py
index e4a24955c..45f8b6b36 100644
--- a/dbgpt/vis/tags/vis_agent_plans.py
+++ b/dbgpt/vis/tags/vis_agent_plans.py
@@ -1,4 +1,5 @@
from typing import Optional
+
from ..base import Vis
diff --git a/dbgpt/vis/tags/vis_chart.py b/dbgpt/vis/tags/vis_chart.py
index 946504793..40237786e 100644
--- a/dbgpt/vis/tags/vis_chart.py
+++ b/dbgpt/vis/tags/vis_chart.py
@@ -1,6 +1,8 @@
+import json
from typing import Optional
+
import yaml
-import json
+
from ..base import Vis
diff --git a/dbgpt/vis/tags/vis_code.py b/dbgpt/vis/tags/vis_code.py
index bd2647ba0..672eb9cf8 100644
--- a/dbgpt/vis/tags/vis_code.py
+++ b/dbgpt/vis/tags/vis_code.py
@@ -1,4 +1,5 @@
from typing import Optional
+
from ..base import Vis
diff --git a/dbgpt/vis/tags/vis_dashboard.py b/dbgpt/vis/tags/vis_dashboard.py
index 44a95abe0..7f6dbc627 100644
--- a/dbgpt/vis/tags/vis_dashboard.py
+++ b/dbgpt/vis/tags/vis_dashboard.py
@@ -1,5 +1,6 @@
import json
from typing import Optional
+
from ..base import Vis
diff --git a/dbgpt/vis/tags/vis_plugin.py b/dbgpt/vis/tags/vis_plugin.py
index 29117a746..eed6db9f3 100644
--- a/dbgpt/vis/tags/vis_plugin.py
+++ b/dbgpt/vis/tags/vis_plugin.py
@@ -1,4 +1,5 @@
from typing import Optional
+
from ..base import Vis
diff --git a/examples/awel/data_analyst_assistant.py b/examples/awel/data_analyst_assistant.py
index 7997a60e0..c80d06b8b 100644
--- a/examples/awel/data_analyst_assistant.py
+++ b/examples/awel/data_analyst_assistant.py
@@ -322,10 +322,12 @@ async def build_model_request(
# Load and store chat history
chat_history_load_task = ServePreChatHistoryLoadOperator()
- last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5))
- # History transform task, here we keep last k round messages
+ keep_start_rounds = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_KEEP_START_ROUNDS", 0))
+ keep_end_rounds = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_KEEP_END_ROUNDS", 5))
+ # History transform task, here we keep `keep_start_rounds` round messages of history,
+ # and keep `keep_end_rounds` round messages of history.
history_transform_task = BufferedConversationMapperOperator(
- last_k_round=last_k_round
+ keep_start_rounds=keep_start_rounds, keep_end_rounds=keep_end_rounds
)
history_prompt_build_task = HistoryDynamicPromptBuilderOperator(
history_key="chat_history"
diff --git a/examples/awel/simple_chat_history_example.py b/examples/awel/simple_chat_history_example.py
index c1977117c..138cf73d9 100644
--- a/examples/awel/simple_chat_history_example.py
+++ b/examples/awel/simple_chat_history_example.py
@@ -137,7 +137,7 @@ async def build_model_request(
composer_operator = ChatHistoryPromptComposerOperator(
prompt_template=prompt,
- last_k_round=5,
+ keep_end_rounds=5,
storage=InMemoryStorage(),
message_storage=InMemoryStorage(),
)
diff --git a/examples/awel/simple_rag_example.py b/examples/awel/simple_rag_example.py
deleted file mode 100644
index 6ea44889d..000000000
--- a/examples/awel/simple_rag_example.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""AWEL: Simple rag example
-
- DB-GPT will automatically load and execute the current file after startup.
-
- Example:
-
- .. code-block:: shell
-
- curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
- -H "Content-Type: application/json" -d '{
- "conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
- "model_name": "proxyllm",
- "chat_mode": "chat_knowledge",
- "user_input": "What is DB-GPT?",
- "select_param": "default"
- }'
-
-"""
-
-from dbgpt.app.openapi.api_view_model import ConversationVo
-from dbgpt.app.scene import ChatScene
-from dbgpt.app.scene.operator._experimental import (
- BaseChatOperator,
- ChatContext,
- ChatHistoryOperator,
- ChatHistoryStorageOperator,
- EmbeddingEngingOperator,
- PromptManagerOperator,
-)
-from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
-from dbgpt.model.operator.model_operator import ModelOperator
-
-
-class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- async def map(self, input_value: ConversationVo) -> ChatContext:
- return ChatContext(
- current_user_input=input_value.user_input,
- model_name=input_value.model_name,
- chat_session_id=input_value.conv_uid,
- select_param=input_value.select_param,
- chat_scene=ChatScene.ChatKnowledge,
- )
-
-
-with DAG("simple_rag_example") as dag:
- trigger_task = HttpTrigger(
- "/examples/simple_rag", methods="POST", request_body=ConversationVo
- )
- req_parse_task = RequestParseOperator()
- # TODO should register prompt template first
- prompt_task = PromptManagerOperator()
- history_storage_task = ChatHistoryStorageOperator()
- history_task = ChatHistoryOperator()
- embedding_task = EmbeddingEngingOperator()
- chat_task = BaseChatOperator()
- model_task = ModelOperator()
- output_parser_task = MapOperator(lambda out: out.to_dict()["text"])
-
- (
- trigger_task
- >> req_parse_task
- >> prompt_task
- >> history_storage_task
- >> history_task
- >> embedding_task
- >> chat_task
- >> model_task
- >> output_parser_task
- )
-
-
-if __name__ == "__main__":
- if dag.leaf_nodes[0].dev_mode:
- from dbgpt.core.awel import setup_dev_environment
-
- setup_dev_environment([dag])
- else:
- pass