Skip to content

Commit

Permalink
feat(core): Support multi round conversation operator (#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Dec 27, 2023
1 parent 9aec636 commit b13d3f6
Show file tree
Hide file tree
Showing 63 changed files with 1,996 additions and 299 deletions.
1 change: 1 addition & 0 deletions assets/schema/knowledge_management.sql
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ CREATE TABLE IF NOT EXISTS `prompt_manage`
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
`prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/_private/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,13 @@
validator,
PrivateAttr,
)


def model_to_json(model, **kwargs):
"""Convert a pydantic model to json"""
if PYDANTIC_VERSION == 1:
return model.json(**kwargs)
else:
if "ensure_ascii" in kwargs:
del kwargs["ensure_ascii"]
return model.model_dump_json(**kwargs)
24 changes: 19 additions & 5 deletions dbgpt/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
cfg = Config()
cfg.SYSTEM_APP = system_app
# Initialize db storage first
_initialize_db_storage(param)
_initialize_db_storage(param, system_app)

# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
Expand Down Expand Up @@ -86,12 +86,14 @@ def startup_event(wh):
return startup_event


def _initialize_db_storage(param: "WebServerParameters"):
def _initialize_db_storage(param: "WebServerParameters", system_app: SystemApp):
"""Initialize the db storage.
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)
_initialize_db(
try_to_create_db=not param.disable_alembic_upgrade, system_app=system_app
)


def _migration_db_storage(param: "WebServerParameters"):
Expand All @@ -114,7 +116,9 @@ def _migration_db_storage(param: "WebServerParameters"):
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)


def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
def _initialize_db(
try_to_create_db: Optional[bool] = False, system_app: Optional[SystemApp] = None
) -> str:
"""Initialize the database
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
Expand Down Expand Up @@ -147,7 +151,11 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args)
db = initialize_db(db_url, db_name, engine_args)
if system_app:
from dbgpt.storage.metadata import UnifiedDBManagerFactory

system_app.register(UnifiedDBManagerFactory, db)
return default_meta_data_path


Expand Down Expand Up @@ -273,3 +281,9 @@ class WebServerParameters(BaseParameters):
"help": "Whether to disable alembic to initialize and upgrade database metadata",
},
)
awel_dirs: Optional[str] = field(
default=None,
metadata={
"help": "The directories to search awel files, split by `,`",
},
)
14 changes: 10 additions & 4 deletions dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def initialize_components(
param, system_app, embedding_model_name, embedding_model_path
)
_initialize_model_cache(system_app)
_initialize_awel(system_app)
_initialize_awel(system_app, param)
# Register serve apps
register_serve_apps(system_app)
register_serve_apps(system_app, CFG)


def _initialize_model_cache(system_app: SystemApp):
Expand All @@ -64,8 +64,14 @@ def _initialize_model_cache(system_app: SystemApp):
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)


def _initialize_awel(system_app: SystemApp):
def _initialize_awel(system_app: SystemApp, param: WebServerParameters):
from dbgpt.core.awel import initialize_awel
from dbgpt.configs.model_config import _DAG_DEFINITION_DIR

initialize_awel(system_app, _DAG_DEFINITION_DIR)
# Add default dag definition dir
dag_dirs = [_DAG_DEFINITION_DIR]
if param.awel_dirs:
dag_dirs += param.awel_dirs.strip().split(",")
dag_dirs = [x.strip() for x in dag_dirs]

initialize_awel(system_app, dag_dirs)
8 changes: 5 additions & 3 deletions dbgpt/app/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
mount_routers(app)
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
system_app.on_init()

# Before start, after initialize_components
# TODO: initialize_worker_manager_in_client as a component register in system_app
system_app.before_start()
# Migration db storage, so you db models must be imported before this
_migration_db_storage(param)

model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
# TODO: initialize_worker_manager_in_client as a component register in system_app
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
Expand Down Expand Up @@ -186,6 +185,9 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
CFG.SERVER_LIGHT_MODE = True

mount_static_files(app)

# Before start, after on_init
system_app.before_start()
return param


Expand Down
23 changes: 19 additions & 4 deletions dbgpt/app/initialization/serve_initialization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
from dbgpt.component import SystemApp
from dbgpt._private.config import Config


def register_serve_apps(system_app: SystemApp):
def register_serve_apps(system_app: SystemApp, cfg: Config):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)

# ################################ Prompt Serve Register Begin ######################################
from dbgpt.serve.prompt.serve import (
Serve as PromptServe,
SERVE_CONFIG_KEY_PREFIX as PROMPT_SERVE_CONFIG_KEY_PREFIX,
)

# Replace old prompt serve
# Set config
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
# Register serve app
system_app.register(PromptServe, api_prefix="/prompt")
# ################################ Prompt Serve Register End ########################################

# ################################ Conversation Serve Register Begin ######################################
from dbgpt.serve.conversation.serve import Serve as ConversationServe

# Register serve app
system_app.register(ConversationServe)
# ################################ Conversation Serve Register End ########################################
4 changes: 4 additions & 0 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ async def dialogue_list(
model_name = item.get("model_name", CFG.LLM_MODEL)
user_name = item.get("user_name")
sys_code = item.get("sys_code")
if not item.get("messages"):
# Skip the empty messages
# TODO support new conversation and message mode
continue

messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x["chat_order"])
Expand Down
33 changes: 31 additions & 2 deletions dbgpt/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,28 @@


class LifeCycle:
"""This class defines hooks for lifecycle events of a component."""
"""This class defines hooks for lifecycle events of a component.
Execution order of lifecycle hooks:
1. on_init
2. before_start(async_before_start)
3. after_start(async_after_start)
4. before_stop(async_before_stop)
"""

def on_init(self):
"""Called when the component is being initialized."""
pass

async def async_on_init(self):
"""Asynchronous version of on_init."""
pass

def before_start(self):
"""Called before the component starts."""
"""Called before the component starts.
This method is called after the component has been initialized and before it is started.
"""
pass

async def async_before_start(self):
Expand Down Expand Up @@ -59,6 +77,7 @@ class ComponentType(str, Enum):
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"


@PublicAPI(stability="beta")
Expand Down Expand Up @@ -177,6 +196,16 @@ def get_component(
raise TypeError(f"Component {name} is not of type {component_type}")
return component

def on_init(self):
"""Invoke the on_init hooks for all registered components."""
for _, v in self.components.items():
v.on_init()

async def async_on_init(self):
"""Asynchronously invoke the on_init hooks for all registered components."""
tasks = [v.async_on_init() for _, v in self.components.items()]
await asyncio.gather(*tasks)

def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.components.items():
Expand Down
51 changes: 30 additions & 21 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,36 @@
"""

from typing import List, Optional
from dbgpt.component import SystemApp

from .dag.base import DAGContext, DAG
from dbgpt.component import SystemApp

from .dag.base import DAG, DAGContext
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchFunc,
BranchOperator,
InputOperator,
BranchFunc,
JoinOperator,
MapOperator,
ReduceStreamOperator,
)

from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
UnstreamifyAbsOperator,
)

from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .runner.local_runner import DefaultWorkflowRunner
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
DefaultTaskContext,
SimpleCallDataInputSource,
SimpleInputSource,
SimpleStreamTaskOutput,
SimpleTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner

__all__ = [
"initialize_awel",
Expand Down Expand Up @@ -73,34 +71,45 @@
]


def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .dag.dag_manager import DAGManager
from .operator.base import initialize_runner
from .trigger.trigger_manager import DefaultTriggerManager

DAGVar.set_current_system_app(system_app)

system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
dag_manager = DAGManager(system_app, dag_dirs)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()


def setup_dev_environment(
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
dags: List[DAG],
host: Optional[str] = "0.0.0.0",
port: Optional[int] = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
) -> None:
"""Setup a development environment for AWEL.
Just using in development environment, not production environment.
"""
import uvicorn
from fastapi import FastAPI

from dbgpt.component import SystemApp
from .trigger.trigger_manager import DefaultTriggerManager
from dbgpt.util.utils import setup_logging

from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager

if not logger_filename:
logger_filename = "dbgpt_awel_dev.log"
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)

app = FastAPI()
system_app = SystemApp(app)
Expand Down
16 changes: 10 additions & 6 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import contextvars
import logging
import threading
import uuid
from abc import ABC, abstractmethod
from collections import deque
from functools import cache
from concurrent.futures import Executor
from functools import cache
from typing import Any, Dict, List, Optional, Sequence, Set, Union

from dbgpt.component import SystemApp

from ..resource.base import ResourceGroup
from ..task.base import TaskContext, TaskOutput

Expand Down Expand Up @@ -502,6 +503,9 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()

def __repr__(self):
return f"DAG(dag_id={self.dag_id})"


def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
Expand Down
10 changes: 6 additions & 4 deletions dbgpt/core/awel/dag/dag_manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import Dict, Optional
import logging
from typing import Dict, List

from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader

from .base import DAG
from .loader import LocalFileDAGLoader

logger = logging.getLogger(__name__)


class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER

def __init__(self, system_app: SystemApp, dag_filepath: str):
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}

Expand Down
Loading

0 comments on commit b13d3f6

Please sign in to comment.