Skip to content

Commit

Permalink
merge:合并
Browse files Browse the repository at this point in the history
  • Loading branch information
QwQ-wuwuwu committed Jul 8, 2024
2 parents eda4d23 + cc58fb8 commit 25061f7
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 31 deletions.
18 changes: 17 additions & 1 deletion src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import copy
import json
from queue import Queue
from typing import Any, Dict, List, Union

from bisheng.api.v1.schemas import ChatResponse
Expand All @@ -19,7 +20,7 @@
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id
Expand Down Expand Up @@ -385,6 +386,21 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:

class AsyncGptsDebugCallbackHandler(AsyncGptsLLMCallbackHandler):

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
super().__init__(websocket, flow_id, chat_id, user_id, **kwargs)
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)

# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
self.stream_queue.put(token)

@staticmethod
def parse_tool_category(tool_name) -> (str, str):
"""
Expand Down
75 changes: 60 additions & 15 deletions src/backend/bisheng/chat/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json
import os
import time
from typing import Dict
from typing import Dict, Callable
from uuid import UUID, uuid4
from queue import Queue

from loguru import logger
from langchain_core.messages import AIMessage, HumanMessage
from langchain.tools.render import format_tool_to_openai_tool
from fastapi import WebSocket, status, Request

from bisheng.api.services.assistant_agent import AssistantAgent
Expand All @@ -20,6 +18,7 @@
from bisheng.database.models.message import ChatMessageDao
from bisheng.settings import settings
from bisheng.api.utils import get_request_ip
from bisheng.utils.threadpool import ThreadPoolManager, thread_pool


class ChatClient:
Expand All @@ -42,6 +41,10 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s
# 和模型对话时传入的 完整的历史对话轮数
self.latest_history_num = 5
self.gpts_conf = settings.get_from_db('gpts')
# 异步任务列表
self.task_ids = []
# 流式输出的队列,用来接受流式输出的内容
self.stream_queue = Queue()

async def send_message(self, message: str):
await self.websocket.send_text(message)
Expand All @@ -51,15 +54,33 @@ async def send_json(self, message: ChatMessage):

async def handle_message(self, message: Dict[any, any]):
trace_id = uuid4().hex
logger.info(f'client_id={self.client_key} trace_id={trace_id} message={message}')
with logger.contextualize(trace_id=trace_id):
# 处理客户端发过来的信息
# 处理客户端发过来的信息, 提交到线程池内执行
if self.work_type == WorkType.GPTS:
await self.handle_gpts_message(message)
thread_pool.submit(trace_id,
self.wrapper_task,
trace_id,
self.handle_gpts_message,
message,
trace_id=trace_id)
# await self.handle_gpts_message(message)

async def add_message(self, msg_type: str, message: str, category: str):
async def wrapper_task(self, task_id: str, fn: Callable, *args, **kwargs):
# 包装处理函数为异步任务
self.task_ids.append(task_id)
try:
# 执行处理函数
await fn(*args, **kwargs)
finally:
# 执行完成后将任务id从列表移除
self.task_ids.remove(task_id)

async def add_message(self, msg_type: str, message: str, category: str, remark: str = ''):
self.chat_history.append({
'category': category,
'message': message
'message': message,
'remark': remark
})
if not self.chat_id:
# debug模式无需保存历史
Expand All @@ -75,6 +96,7 @@ async def add_message(self, msg_type: str, message: str, category: str):
flow_id=self.client_id,
chat_id=self.chat_id,
user_id=self.user_id,
remark=remark,
))
# 记录审计日志, 是新建会话
if len(self.chat_history) <= 1:
Expand Down Expand Up @@ -142,7 +164,8 @@ async def init_chat_history(self):
for one in res:
self.chat_history.append({
'message': one.message,
'category': one.category
'category': one.category,
'remark': one.remark
})

async def get_latest_history(self):
Expand All @@ -152,13 +175,15 @@ async def get_latest_history(self):
is_answer = True
# 从聊天历史里获取
for i in range(len(self.chat_history) - 1, -1, -1):
one_item = self.chat_history[i]
if find_i >= self.latest_history_num:
break
if self.chat_history[i]['category'] == 'answer' and is_answer:
tmp.insert(0, AIMessage(content=self.chat_history[i]['message']))
# 不包含中断的答案
if one_item['category'] == 'answer' and one_item.get('remark') != 'break_answer' and is_answer:
tmp.insert(0, AIMessage(content=one_item['message']))
is_answer = False
elif self.chat_history[i]['category'] == 'question' and not is_answer:
tmp.insert(0, HumanMessage(content=json.loads(self.chat_history[i]['message'])['input']))
elif one_item['category'] == 'question' and not is_answer:
tmp.insert(0, HumanMessage(content=json.loads(one_item['message'])['input']))
is_answer = True
find_i += 1

Expand All @@ -171,16 +196,36 @@ async def init_gpts_callback(self):
'websocket': self.websocket,
'flow_id': self.client_id,
'chat_id': self.chat_id,
'user_id': self.user_id
'user_id': self.user_id,
'stream_queue': self.stream_queue,
})]
self.gpts_async_callback = async_callbacks

async def stop_handle_message(self, message: Dict[any, any]):
# 中止流式输出, 因为最新的任务id是中止任务的id,不能取消自己
logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}')

# 中止之前的处理函数
thread_pool.cancel_task(self.task_ids[:-1])

# 将流式输出的内容写到数据库内
answer = ''
while not self.stream_queue.empty():
msg = self.stream_queue.get()
answer += msg

# 有流式输出内容的话,记录流式输出内容到数据库
if answer.strip():
res = await self.add_message('bot', answer, 'answer', 'break_answer')
await self.send_response('answer', 'end', answer, message_id=res.id if res else None)
await self.send_response('processing', 'close', '')

async def handle_gpts_message(self, message: Dict[any, any]):
if not message:
return
logger.debug(f'receive client message, client_key: {self.client_key} message: {message}')
if message.get('action') == 'stop':
logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}')
await self.stop_handle_message(message)
return

inputs = message.get('inputs', {})
Expand Down
15 changes: 12 additions & 3 deletions src/backend/bisheng/database/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class MessageBase(SQLModelSerializable):
receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的发送方')
intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志')
files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)), description='备注')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)),
description='备注。break_answer: 中断的回复不作为history传给模型')
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP')))
update_time: Optional[datetime] = Field(
Expand Down Expand Up @@ -88,11 +89,19 @@ def get_latest_message_by_chatid(cls, chat_id: str):

@classmethod
def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None):
statement = select(ChatMessage).where(ChatMessage.chat_id.in_(chat_ids))
"""
获取每个会话最近的一次消息内容
"""
statement = select(ChatMessage.chat_id, func.max(ChatMessage.id)).where(ChatMessage.chat_id.in_(chat_ids))
if category:
statement = statement.where(ChatMessage.category == category)
statement = statement.order_by(ChatMessage.create_time.desc()).limit(1)
statement = statement.group_by(ChatMessage.chat_id)
with session_getter() as session:
# 获取最新的id列表
res = session.exec(statement).all()
ids = [one[1] for one in res]
# 获取消息的具体内容
statement = select(ChatMessage).where(ChatMessage.id.in_(ids))
return session.exec(statement).all()

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ export default function MessagePanne({ useName, guideWord, loadMore }) {
type = 'separator'
} else if (msg.files?.length) {
type = 'file'
} else if (['tool', 'flow', 'knowledge'].includes(msg.category)){
// || msg.category === 'processing') { // 项目演示?
} else if (['tool', 'flow', 'knowledge'].includes(msg.category)
|| (msg.category === 'processing' && msg.thought.indexOf(`status_code`) === -1)
) { // 项目演示?
type = 'runLog'
} else if (msg.thought) {
type = 'system'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ export default function MessageSystem({ data }) {

const border = { system: 'border-slate-500', question: 'border-amber-500', processing: 'border-cyan-600', answer: 'border-lime-600', report: 'border-slate-500', guide: 'border-none' }

// 中英去掉最终的回答(report)
// if(data.category === 'report') return null

return <div className="py-1">
<div className={`relative rounded-sm px-6 py-4 border text-sm dark:bg-gray-900 ${data.category === 'guide' ? 'bg-[#EDEFF6]' : 'bg-slate-50'} ${border[data.category || 'system']}`}>
{logMkdown}
{/* 中英 */}
{data.category === 'report' && <CopyIcon className=" absolute right-4 top-2 cursor-pointer" onClick={(e) => handleCopy(e.target.parentNode)}></CopyIcon>}
{/* {<CopyIcon className=" absolute right-4 top-2 cursor-pointer" onClick={(e) => handleCopy(e.target.parentNode)}></CopyIcon>} */}
</div>
</div>
};
4 changes: 0 additions & 4 deletions src/frontend/src/pages/LoginPage/login.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ export const LoginPage = () => {
</div>
)
}
{/* 中英 */}
{/* <Button
className='h-[48px] mt-[32px] dark:bg-button'
disabled={isLoading} onClick={handleLogin} >{t('login.loginButton')}</Button> */}
{
showLogin ? <>
<div className="text-center">
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { viteStaticCopy } from 'vite-plugin-static-copy';
import svgr from "vite-plugin-svgr";

// Use environment variable to determine the target.
const target = process.env.VITE_PROXY_TARGET || "http://192.168.106.120:3002";
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7861";
const apiRoutes = ["^/api/", "/health"];

const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
Expand Down

0 comments on commit 25061f7

Please sign in to comment.