Skip to content

Commit

Permalink
Feat/0.3.3 (#734)
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 8, 2024
2 parents 8d01f2c + c744a15 commit baf2a86
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 54 deletions.
18 changes: 17 additions & 1 deletion src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import copy
import json
from queue import Queue
from typing import Any, Dict, List, Union

from bisheng.api.v1.schemas import ChatResponse
Expand All @@ -19,7 +20,7 @@
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id
Expand Down Expand Up @@ -385,6 +386,21 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:

class AsyncGptsDebugCallbackHandler(AsyncGptsLLMCallbackHandler):

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
super().__init__(websocket, flow_id, chat_id, user_id, **kwargs)
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)

# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
self.stream_queue.put(token)

@staticmethod
def parse_tool_category(tool_name) -> (str, str):
"""
Expand Down
75 changes: 60 additions & 15 deletions src/backend/bisheng/chat/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json
import os
import time
from typing import Dict
from typing import Dict, Callable
from uuid import UUID, uuid4
from queue import Queue

from loguru import logger
from langchain_core.messages import AIMessage, HumanMessage
from langchain.tools.render import format_tool_to_openai_tool
from fastapi import WebSocket, status, Request

from bisheng.api.services.assistant_agent import AssistantAgent
Expand All @@ -20,6 +18,7 @@
from bisheng.database.models.message import ChatMessageDao
from bisheng.settings import settings
from bisheng.api.utils import get_request_ip
from bisheng.utils.threadpool import ThreadPoolManager, thread_pool


class ChatClient:
Expand All @@ -42,6 +41,10 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s
# 和模型对话时传入的 完整的历史对话轮数
self.latest_history_num = 5
self.gpts_conf = settings.get_from_db('gpts')
# 异步任务列表
self.task_ids = []
# 流式输出的队列,用来接受流式输出的内容
self.stream_queue = Queue()

async def send_message(self, message: str):
await self.websocket.send_text(message)
Expand All @@ -51,15 +54,33 @@ async def send_json(self, message: ChatMessage):

async def handle_message(self, message: Dict[any, any]):
trace_id = uuid4().hex
logger.info(f'client_id={self.client_key} trace_id={trace_id} message={message}')
with logger.contextualize(trace_id=trace_id):
# 处理客户端发过来的信息
# 处理客户端发过来的信息, 提交到线程池内执行
if self.work_type == WorkType.GPTS:
await self.handle_gpts_message(message)
thread_pool.submit(trace_id,
self.wrapper_task,
trace_id,
self.handle_gpts_message,
message,
trace_id=trace_id)
# await self.handle_gpts_message(message)

async def add_message(self, msg_type: str, message: str, category: str):
async def wrapper_task(self, task_id: str, fn: Callable, *args, **kwargs):
# 包装处理函数为异步任务
self.task_ids.append(task_id)
try:
# 执行处理函数
await fn(*args, **kwargs)
finally:
# 执行完成后将任务id从列表移除
self.task_ids.remove(task_id)

async def add_message(self, msg_type: str, message: str, category: str, remark: str = ''):
self.chat_history.append({
'category': category,
'message': message
'message': message,
'remark': remark
})
if not self.chat_id:
# debug模式无需保存历史
Expand All @@ -75,6 +96,7 @@ async def add_message(self, msg_type: str, message: str, category: str):
flow_id=self.client_id,
chat_id=self.chat_id,
user_id=self.user_id,
remark=remark,
))
# 记录审计日志, 是新建会话
if len(self.chat_history) <= 1:
Expand Down Expand Up @@ -142,7 +164,8 @@ async def init_chat_history(self):
for one in res:
self.chat_history.append({
'message': one.message,
'category': one.category
'category': one.category,
'remark': one.remark
})

async def get_latest_history(self):
Expand All @@ -152,13 +175,15 @@ async def get_latest_history(self):
is_answer = True
# 从聊天历史里获取
for i in range(len(self.chat_history) - 1, -1, -1):
one_item = self.chat_history[i]
if find_i >= self.latest_history_num:
break
if self.chat_history[i]['category'] == 'answer' and is_answer:
tmp.insert(0, AIMessage(content=self.chat_history[i]['message']))
# 不包含中断的答案
if one_item['category'] == 'answer' and one_item.get('remark') != 'break_answer' and is_answer:
tmp.insert(0, AIMessage(content=one_item['message']))
is_answer = False
elif self.chat_history[i]['category'] == 'question' and not is_answer:
tmp.insert(0, HumanMessage(content=json.loads(self.chat_history[i]['message'])['input']))
elif one_item['category'] == 'question' and not is_answer:
tmp.insert(0, HumanMessage(content=json.loads(one_item['message'])['input']))
is_answer = True
find_i += 1

Expand All @@ -171,16 +196,36 @@ async def init_gpts_callback(self):
'websocket': self.websocket,
'flow_id': self.client_id,
'chat_id': self.chat_id,
'user_id': self.user_id
'user_id': self.user_id,
'stream_queue': self.stream_queue,
})]
self.gpts_async_callback = async_callbacks

async def stop_handle_message(self, message: Dict[any, any]):
# 中止流式输出, 因为最新的任务id是中止任务的id,不能取消自己
logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}')

# 中止之前的处理函数
thread_pool.cancel_task(self.task_ids[:-1])

# 将流式输出的内容写到数据库内
answer = ''
while not self.stream_queue.empty():
msg = self.stream_queue.get()
answer += msg

# 有流式输出内容的话,记录流式输出内容到数据库
if answer.strip():
res = await self.add_message('bot', answer, 'answer', 'break_answer')
await self.send_response('answer', 'end', answer, message_id=res.id if res else None)
await self.send_response('processing', 'close', '')

async def handle_gpts_message(self, message: Dict[any, any]):
if not message:
return
logger.debug(f'receive client message, client_key: {self.client_key} message: {message}')
if message.get('action') == 'stop':
logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}')
await self.stop_handle_message(message)
return

inputs = message.get('inputs', {})
Expand Down
4 changes: 2 additions & 2 deletions src/backend/bisheng/database/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MessageBase(SQLModelSerializable):
receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的发送方')
intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志')
files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)), description='备注')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)), description='备注。break_answer: 中断的回复不作为history传给模型')
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP')))
update_time: Optional[datetime] = Field(
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = Non
statement = select(ChatMessage).where(ChatMessage.chat_id.in_(chat_ids))
if category:
statement = statement.where(ChatMessage.category == category)
statement = statement.order_by(ChatMessage.create_time.desc()).limit(1)
statement = statement.order_by(ChatMessage.create_time.desc())
with session_getter() as session:
return session.exec(statement).all()

Expand Down
11 changes: 4 additions & 7 deletions src/frontend/src/components/bs-comp/chatComponent/MessageBs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import remarkMath from "remark-math";
import MessageButtons from "./MessageButtons";
import SourceEntry from "./SourceEntry";
import { useMessageStore } from "./messageStore";
import { isSameDay } from "@/util/utils";
import { formatStrTime } from "@/util/utils";

// 颜色列表
const colorList = [
Expand Down Expand Up @@ -77,17 +77,14 @@ export default function MessageBs({ data, onUnlike = () => { }, onSource }: { da
}

const chatId = useMessageStore(state => state.chatId)
const [show, setShow] = useState(false)

return <div className="flex w-full">
<div className="w-fit max-w-[90%]">
<div className={`text-right ${show ? 'opacity-100' : 'opacity-0'}`}>
<span className="text-slate-400 text-sm">{isSameDay(data.update_time, new Date())}</span>
<div className={`text-right hover:opacity-100 opacity-0`}>
<span className="text-slate-400 text-sm">{formatStrTime(data.update_time, 'MM 月 dd 日 HH:mm')}</span>
</div>
{data.sender && <p className="text-gray-600 text-xs mb-2">{data.sender}</p>}
<div className="min-h-8 px-6 py-4 rounded-2xl bg-[#F5F6F8] dark:bg-[#313336]"
onMouseEnter={(e) => setShow(true)}
onMouseLeave={(e) => setShow(false)}>
<div className="min-h-8 px-6 py-4 rounded-2xl bg-[#F5F6F8] dark:bg-[#313336]">
<div className="flex gap-2">
<div className="w-6 h-6 min-w-6 flex justify-center items-center rounded-full" style={{ background: avatarColor }} ><AvatarIcon /></div>
{data.message.toString() ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import { ChatMessageType } from "@/types/chat";
import { MagnifyingGlassIcon, Pencil2Icon, ReloadIcon } from "@radix-ui/react-icons";
import { useContext, useState } from "react";
import { useMessageStore } from "./messageStore";
import { isSameDay } from "@/util/utils";
import { formatStrTime } from "@/util/utils";

export default function MessageUser({ useName, data }: { data: ChatMessageType }) {
export default function MessageUser({ useName = 'xxx', data }: { data: ChatMessageType }) {
const msg = data.message[data.chatKey]

const { appConfig } = useContext(locationContext)
Expand All @@ -25,17 +25,13 @@ export default function MessageUser({ useName, data }: { data: ChatMessageType }
document.dispatchEvent(myEvent);
}

const [show, setShow] = useState(false)

return <div className="flex justify-end w-full">
<div className="w-fit min-h-8 max-w-[90%]">
<div className={`text-right ${show ? 'opacity-100' : 'opacity-0'}`}>
<span className="text-slate-400 text-sm">{isSameDay(data.update_time, new Date())}</span>
<div className={`text-right hover:opacity-100 opacity-0`}>
<span className="text-slate-400 text-sm">{formatStrTime(data.update_time, 'MM 月 dd 日 HH:mm')}</span>
</div>
{useName && <p className="text-gray-600 text-xs mb-2 text-right">{useName}</p>}
<div className="rounded-2xl px-6 py-4 bg-[#EEF2FF] dark:bg-[#333A48]"
onMouseEnter={(e) => setShow(true)}
onMouseLeave={(e) => setShow(false)}>
<div className="rounded-2xl px-6 py-4 bg-[#EEF2FF] dark:bg-[#333A48]">
<div className="flex gap-2 ">
<div className="text-[#0D1638] dark:text-[#CFD5E8] text-sm break-all whitespace-break-spaces">{msg}</div>
<div className="w-6 h-6 min-w-6"><img src={__APP_ENV__.BASE_URL + '/user.png'} alt="" /></div>
Expand Down
13 changes: 12 additions & 1 deletion src/frontend/src/components/bs-comp/chatComponent/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import ChatInput from "./ChatInput";
import MessagePanne from "./MessagePanne";

export default function ChatComponent({ stop = false, clear = false, questions = [], form = false, useName, inputForm = null, guideWord, wsUrl, onBeforSend, loadMore = () => { } }) {
export default function ChatComponent({
stop = false,
clear = false,
questions = [],
form = false,
useName,
inputForm = null,
guideWord,
wsUrl,
onBeforSend,
loadMore = () => { }
}) {

return <div className="relative h-full">
<MessagePanne useName={useName} guideWord={guideWord} loadMore={loadMore}></MessagePanne>
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/pages/ChatAppPage/components/ChatPanne.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ export default function ChatPanne({ customWsHost = '', appendHistory = false, da
</div>
<ChatComponent
form={flowSate.isForm}
stop={flowSate.isReport || flowSate.isRoom}
stop
// stop={flowSate.isReport || flowSate.isRoom}
useName={sendUserName}
guideWord={flow.guide_word}
wsUrl={wsUrl}
Expand All @@ -207,6 +208,7 @@ export default function ChatPanne({ customWsHost = '', appendHistory = false, da
<span className="text-sm">{assistant.name}</span>
</div>
<ChatComponent
stop
useName={sendUserName}
questions={assistantState.guide_question.filter((item) => item)}
guideWord={assistantState.guide_word}
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/pages/ChatAppPage/components/FileView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,15 @@ export default function FileView({ data }) {
fileWidthRef.current = w
}

return <div ref={paneRef} className="flex-1 bg-gray-100 rounded-md py-4 px-2 relative">
return <div ref={paneRef} className="flex-1 bg-gray-100 rounded-md py-4 px-2 relative" onContextMenu={(e) => e.preventDefault()}>
{
loading
? <div className="absolute w-full h-full top-0 left-0 flex justify-center items-center z-10 bg-[rgba(255,255,255,0.6)] dark:bg-blur-shared">
<span className="loading loading-infinity loading-lg"></span>
</div>
// {/* 中英 */}
: <div id="warp-pdf" className="file-view absolute">
{/* : <div id="warp-pdf" className="file-view absolute pointer-events-none"> */}
<List
ref={listRef}
itemCount={pdf?.numPages || 100}
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/pages/ChatAppPage/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { captureAndAlertRequestErrorHoc } from "../../controllers/request";
import { useDebounce } from "../../util/hook";
import { generateUUID } from "../../utils";
import ChatPanne from "./components/ChatPanne";
import { isSameDay } from "@/util/utils";
import { formatStrTime } from "@/util/utils";
import { SkillIcon, AssistantIcon } from "@/components/bs-icons";

export default function SkillChatPage() {
Expand Down Expand Up @@ -99,13 +99,13 @@ export default function SkillChatPage() {
onClick={() => handleSelectChat(chat)}>
<div className="flex place-items-center space-x-3">
<div className=" inline-block bg-purple-500 rounded-md">
{chat.flow_type === 'assistant' ? <AssistantIcon/> : <SkillIcon/>}
{chat.flow_type === 'assistant' ? <AssistantIcon /> : <SkillIcon />}
</div>
<p className="truncate text-sm font-bold text-gray-950 dark:text-[#F2F2F2] leading-6">{chat.flow_name}</p>
</div>
<span className="block text-xs text-gray-600 dark:text-[#8D8D8E] mt-3 break-words truncate">{chat.flow_description}</span>
<span className="block text-xs text-gray-600 dark:text-[#8D8D8E] mt-3 break-words truncate">{chat.latest_message?.message || ''}</span>
<div className="mt-6">
<span className="text-gray-400 text-xs absolute bottom-2 left-2">{isSameDay(chat.update_time, new Date())}</span>
<span className="text-gray-400 text-xs absolute bottom-2 left-2">{formatStrTime(chat.update_time, 'MM 月 dd 日')}</span>
<Trash2 size={14} className="absolute bottom-2 right-2 text-gray-400 hidden group-hover:block" onClick={(e) => handleDeleteChat(e, chat.chat_id)}></Trash2>
</div>
</div>
Expand Down
31 changes: 18 additions & 13 deletions src/frontend/src/util/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,30 @@ export function formatDate(date: Date, format: string): string {
const addZero = (num) => num < 10 ? `0${num}` : `${num}`
const replacements = {
'yyyy': date.getFullYear(),
'MM': addZero(date.getMonth() + 1),
'MM': addZero(date.getMonth() + 1),
'dd': addZero(date.getDate()),
'HH': addZero(date.getHours()),
'mm': addZero(date.getMinutes()),
'ss': addZero(date.getSeconds())
}
return format.replace(/yyyy|MM|dd|HH|mm|ss/g, (match) => replacements[match])
}

// string类型和Date对比是否同一天
export function isSameDay(time: string, date: Date): string {
if(!time) return '暂无时间'
const newTime = time.substring(0, time.indexOf('T')).split('-')
const arrayTime = newTime.map(t => Number(t))
const [year, month, day] = [date.getFullYear(), date.getMonth() + 1, date.getDay()]
if(year === arrayTime[0] && month === arrayTime[1] && day === arrayTime[2]) {
return time.substring(time.indexOf('T') + 1, time.length - 3)
}
return `${newTime[1]}${newTime[2]}日`
}

// param time: yyyy-mm-ddTxxxx
export function formatStrTime(time: string, notSameDayFormat: string): string {
if (!time) return ''
const date1 = new Date(time)
const date2 = new Date()
return date1.getFullYear() === date2.getFullYear() &&
date1.getMonth() === date2.getMonth() &&
date1.getDate() === date2.getDate() ? formatDate(date1, 'HH:mm') : formatDate(date1, notSameDayFormat)
// const newTime = time.substring(0, time.indexOf('T')).split('-')
// const arrayTime = newTime.map(t => Number(t))
// const [year, month, day] = [date.getFullYear(), date.getMonth() + 1, date.getDay()]
// if(year === arrayTime[0] && month === arrayTime[1] && day === arrayTime[2]) {
// return time.substring(time.indexOf('T') + 1, time.length - 3)
// }
// return `${newTime[1]}月${newTime[2]}日`
}

export function toTitleCase(str: string | undefined): string {
Expand Down

0 comments on commit baf2a86

Please sign in to comment.