Skip to content

Commit

Permalink
feat(core): Support more chat flows (eosphoros-ai#1180)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored and csunny committed Mar 1, 2024
1 parent 197360e commit 3b6da64
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 55 deletions.
2 changes: 1 addition & 1 deletion dbgpt/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.4.7"
version = "0.5.0"
32 changes: 1 addition & 31 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,7 @@ async def chat_completions(
context=flow_ctx,
)
return StreamingResponse(
flow_stream_generator(
flow_service.chat_flow(dialogue.select_param, flow_req),
dialogue.incremental,
dialogue.model_name,
),
flow_service.chat_flow(dialogue.select_param, flow_req),
headers=headers,
media_type="text/event-stream",
)
Expand Down Expand Up @@ -426,32 +422,6 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"


async def flow_stream_generator(func, incremental: bool, model_name: str):
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in func:
if chunk:
msg = chunk.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"


async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses
Expand Down
30 changes: 25 additions & 5 deletions dbgpt/core/awel/flow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,16 +632,36 @@ def get_runnable_parameters(
runnable_parameters: Dict[str, Any] = {}
if not self.parameters or not view_parameters:
return runnable_parameters
if len(self.parameters) != len(view_parameters):
view_required_parameters = {
parameter.name: parameter
for parameter in view_parameters
if not parameter.optional
}
current_required_parameters = {
parameter.name: parameter
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters.
raise FlowParameterMetadataException(
f"Parameters count not match. Expected {len(self.parameters)}, "
f"Parameters count not match(current key: {self.id}). "
f"Expected {len(self.parameters)}, "
f"but got {len(view_parameters)} from JSON metadata."
f"Required parameters: {current_required_parameters.keys()}, "
f"but got {view_required_parameters.keys()}."
)
for i, parameter in enumerate(self.parameters):
view_param = view_parameters[i]
for view_param in view_parameters:
view_param_key = view_param.name
if view_param_key not in current_parameters:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not found in the metadata."
)
runnable_parameters.update(
parameter.to_runnable_parameter(
current_parameters[view_param_key].to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance
)
)
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
This class extends DAGNode by adding execution capabilities.
"""

streaming_operator: bool = False

def __init__(
self,
task_id: Optional[str] = None,
Expand Down
4 changes: 4 additions & 0 deletions dbgpt/core/awel/operators/stream_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""

streaming_operator = True

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
Expand Down Expand Up @@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
AsyncIterator[IN] to another AsyncIterator[OUT].
"""

streaming_operator = True

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/core/interface/operators/prompt_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages."""
if "system_message" not in values:
raise ValueError("No system message")
values["system_message"] = "You are a helpful AI Assistant."
if "human_message" not in values:
raise ValueError("No human message")
values["human_message"] = "{user_input}"
if "message_placeholder" not in values:
raise ValueError("No message placeholder")
values["message_placeholder"] = "chat_history"
system_message = values.pop("system_message")
human_message = values.pop("human_message")
message_placeholder = values.pop("message_placeholder")
Expand Down
148 changes: 136 additions & 12 deletions dbgpt/serve/flow/service/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import traceback
from typing import List, Optional, cast
from typing import Any, List, Optional, cast

from fastapi import HTTPException

Expand All @@ -14,6 +15,7 @@
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
Expand Down Expand Up @@ -276,12 +278,39 @@ def get_list_by_page(
"""
return self.dao.get_list_page(request, page, page_size)

async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
async def chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
try:
async for output in self._call_chat_flow(flow_uid, request, incremental):
yield output
except HTTPException as e:
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"

async def _call_chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.
Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
flow = self.get({"uid": flow_uid})
if not flow:
Expand All @@ -291,18 +320,18 @@ async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
raise HTTPException(
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
)
if flow.flow_category != FlowCategory.CHAT_FLOW:
raise ValueError(f"Flow {flow_uid} is not a chat flow")
dag = self.dag_manager.dag_map[dag_id]
if (
flow.flow_category != FlowCategory.CHAT_FLOW
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
):
raise ValueError(f"Flow {flow_uid} is not a chat flow")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
if request.stream:
async for output in await end_node.call_stream(request):
yield output
else:
yield await end_node.call(request)
async for output in _chat_with_dag_task(end_node, request, incremental):
yield output

def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category
Expand Down Expand Up @@ -335,9 +364,104 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory:
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if common_http_trigger and (
real_class == str or real_class == CommonLLMHttpResponseBody
):
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON


def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if is_class:
return (
obj == str
or obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator:
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)


async def _chat_with_dag_task(
task: BaseOperator,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the DAG task.
Args:
task (BaseOperator): The task
request (CommonLLMHttpRequestBody): The request
"""
if request.stream and task.streaming_operator:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if incremental:
async for output in await task.call_stream(request):
yield output
else:
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
)

previous_text = ""
async for output in await task.call_stream(request):
if not isinstance(output, str):
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
if output == "data: [DONE]\n\n":
return
json_data = "".join(output.split("data: ")[1:])
dict_data = json.loads(json_data)
if "choices" not in dict_data:
error_msg = dict_data.get("text", "Unknown error")
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
return
choices = dict_data["choices"]
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
previous_text += delta_data.delta.content
if previous_text:
full_text = previous_text.replace("\n", "\\n")
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
if isinstance(output, str):
if output.strip():
yield output
else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
else:
result = await task.call(request)
if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n"
elif isinstance(result, str):
yield f"data:{result}\n\n"
elif isinstance(result, ModelOutput):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
2 changes: 1 addition & 1 deletion dbgpt/util/dbgpts/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def update_repo(repo: str):
logger.info(f"Repo '{repo}' is not a git repository.")
return
logger.info(f"Updating repo '{repo}'...")
subprocess.run(["git", "pull"], check=True)
subprocess.run(["git", "pull"], check=False)


def install(
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/upgrade/v0.5.0.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Upgrade To v0.5.0(Draft)
# Upgrade To v0.5.0

## Overview

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files:
# dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0")

BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = (
Expand Down

0 comments on commit 3b6da64

Please sign in to comment.