Skip to content

Commit

Permalink
feat: 助手支持中断流式输出的内容,且中止的回答不会作为历史记忆
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 8, 2024
1 parent 78ec0f3 commit 50a8c4f
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 17 deletions.
18 changes: 17 additions & 1 deletion src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import copy
import json
from queue import Queue
from typing import Any, Dict, List, Union

from bisheng.api.v1.schemas import ChatResponse
Expand All @@ -19,7 +20,7 @@
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id
Expand Down Expand Up @@ -385,6 +386,21 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:

class AsyncGptsDebugCallbackHandler(AsyncGptsLLMCallbackHandler):

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
super().__init__(websocket, flow_id, chat_id, user_id, **kwargs)
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)

# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
self.stream_queue.put(token)

@staticmethod
def parse_tool_category(tool_name) -> (str, str):
"""
Expand Down
75 changes: 60 additions & 15 deletions src/backend/bisheng/chat/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json
import os
import time
from typing import Dict
from typing import Dict, Callable
from uuid import UUID, uuid4
from queue import Queue

from loguru import logger
from langchain_core.messages import AIMessage, HumanMessage
from langchain.tools.render import format_tool_to_openai_tool
from fastapi import WebSocket, status, Request

from bisheng.api.services.assistant_agent import AssistantAgent
Expand All @@ -20,6 +18,7 @@
from bisheng.database.models.message import ChatMessageDao
from bisheng.settings import settings
from bisheng.api.utils import get_request_ip
from bisheng.utils.threadpool import ThreadPoolManager, thread_pool


class ChatClient:
Expand All @@ -42,6 +41,10 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s
# 和模型对话时传入的 完整的历史对话轮数
self.latest_history_num = 5
self.gpts_conf = settings.get_from_db('gpts')
# 异步任务列表
self.task_ids = []
# 流式输出的队列,用来接受流式输出的内容
self.stream_queue = Queue()

async def send_message(self, message: str):
await self.websocket.send_text(message)
Expand All @@ -51,15 +54,33 @@ async def send_json(self, message: ChatMessage):

async def handle_message(self, message: Dict[any, any]):
trace_id = uuid4().hex
logger.info(f'client_id={self.client_key} trace_id={trace_id} message={message}')
with logger.contextualize(trace_id=trace_id):
# 处理客户端发过来的信息
# 处理客户端发过来的信息, 提交到线程池内执行
if self.work_type == WorkType.GPTS:
await self.handle_gpts_message(message)
thread_pool.submit(trace_id,
self.wrapper_task,
trace_id,
self.handle_gpts_message,
message,
trace_id=trace_id)
# await self.handle_gpts_message(message)

async def add_message(self, msg_type: str, message: str, category: str):
async def wrapper_task(self, task_id: str, fn: Callable, *args, **kwargs):
# 包装处理函数为异步任务
self.task_ids.append(task_id)
try:
# 执行处理函数
await fn(*args, **kwargs)
finally:
# 执行完成后将任务id从列表移除
self.task_ids.remove(task_id)

async def add_message(self, msg_type: str, message: str, category: str, remark: str = ''):
self.chat_history.append({
'category': category,
'message': message
'message': message,
'remark': remark
})
if not self.chat_id:
# debug模式无需保存历史
Expand All @@ -75,6 +96,7 @@ async def add_message(self, msg_type: str, message: str, category: str):
flow_id=self.client_id,
chat_id=self.chat_id,
user_id=self.user_id,
remark=remark,
))
# 记录审计日志, 是新建会话
if len(self.chat_history) <= 1:
Expand Down Expand Up @@ -142,7 +164,8 @@ async def init_chat_history(self):
for one in res:
self.chat_history.append({
'message': one.message,
'category': one.category
'category': one.category,
'remark': one.remark
})

async def get_latest_history(self):
Expand All @@ -152,13 +175,15 @@ async def get_latest_history(self):
is_answer = True
# 从聊天历史里获取
for i in range(len(self.chat_history) - 1, -1, -1):
one_item = self.chat_history[i]
if find_i >= self.latest_history_num:
break
if self.chat_history[i]['category'] == 'answer' and is_answer:
tmp.insert(0, AIMessage(content=self.chat_history[i]['message']))
# 不包含中断的答案
if one_item['category'] == 'answer' and one_item.get('remark') != 'break_answer' and is_answer:
tmp.insert(0, AIMessage(content=one_item['message']))
is_answer = False
elif self.chat_history[i]['category'] == 'question' and not is_answer:
tmp.insert(0, HumanMessage(content=json.loads(self.chat_history[i]['message'])['input']))
elif one_item['category'] == 'question' and not is_answer:
tmp.insert(0, HumanMessage(content=json.loads(one_item['message'])['input']))
is_answer = True
find_i += 1

Expand All @@ -171,16 +196,36 @@ async def init_gpts_callback(self):
'websocket': self.websocket,
'flow_id': self.client_id,
'chat_id': self.chat_id,
'user_id': self.user_id
'user_id': self.user_id,
'stream_queue': self.stream_queue,
})]
self.gpts_async_callback = async_callbacks

async def stop_handle_message(self, message: Dict[any, any]):
# 中止流式输出, 因为最新的任务id是中止任务的id,不能取消自己
logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}')

# 中止之前的处理函数
thread_pool.cancel_task(self.task_ids[:-1])

# 将流式输出的内容写到数据库内
answer = ''
while not self.stream_queue.empty():
msg = self.stream_queue.get()
answer += msg

# 有流式输出内容的话,记录流式输出内容到数据库
if answer.strip():
res = await self.add_message('bot', answer, 'answer', 'break_answer')
await self.send_response('answer', 'end', answer, message_id=res.id if res else None)
await self.send_response('processing', 'close', '')

async def handle_gpts_message(self, message: Dict[any, any]):
if not message:
return
logger.debug(f'receive client message, client_key: {self.client_key} message: {message}')
if message.get('action') == 'stop':
logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}')
await self.stop_handle_message(message)
return

inputs = message.get('inputs', {})
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/database/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MessageBase(SQLModelSerializable):
receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的发送方')
intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志')
files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)), description='备注')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)), description='备注。break_answer: 中断的回复不作为history传给模型')
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP')))
update_time: Optional[datetime] = Field(
Expand Down

0 comments on commit 50a8c4f

Please sign in to comment.