Skip to content

Commit

Permalink
fix:技能支持中止流式输出
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 9, 2024
1 parent 25061f7 commit a3e3bb1
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 58 deletions.
29 changes: 13 additions & 16 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@ def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: in
# }, # 存储工具调用的input信息
# }

# 流式输出的队列
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)
# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
if self.stream_queue:
self.stream_queue.put(token)

async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str],
**kwargs: Any) -> Any:
Expand Down Expand Up @@ -233,10 +239,13 @@ async def on_chat_model_start(self, serialized: Dict[str, Any],
class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id
self.user_id = user_id

self.stream_queue: Queue = kwargs.get('stream_queue')

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token,
Expand All @@ -248,6 +257,9 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

if self.stream_queue:
self.stream_queue.put(token)

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
log = f'\nThought: {action.log}'
# if there are line breaks, split them and send them
Expand Down Expand Up @@ -386,21 +398,6 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:

class AsyncGptsDebugCallbackHandler(AsyncGptsLLMCallbackHandler):

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
super().__init__(websocket, flow_id, chat_id, user_id, **kwargs)
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)

# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
self.stream_queue.put(token)

@staticmethod
def parse_tool_category(tool_name) -> (str, str):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def get_config(admin_user: UserPayload = Depends(get_admin_user)):

@router.post('/config/save')
def save_config(data: dict, admin_user: UserPayload = Depends(get_admin_user)):
if not data.get('data', '').strip():
raise HTTPException(status_code=500, detail='配置不能为空')
try:
# 校验是否符合yaml格式
_ = yaml.safe_load(data.get('data'))
Expand Down
73 changes: 47 additions & 26 deletions src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import time
from queue import Queue
from typing import Dict

from bisheng.api.v1.schemas import ChatMessage, ChatResponse
Expand All @@ -23,12 +24,16 @@

class Handler:

def __init__(self) -> None:
self.handler_dict = {}
self.handler_dict['default'] = self.process_message
self.handler_dict['autogen'] = self.process_autogen
self.handler_dict['auto_file'] = self.process_file
self.handler_dict['report'] = self.process_report
def __init__(self, stream_queue: Queue) -> None:
self.handler_dict = {
'default': self.process_message,
'autogen': self.process_autogen,
'auto_file': self.process_file,
'report': self.process_report,
'stop': self.process_stop
}
# 记录流式输出的内容
self.stream_queue = stream_queue

async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str, action: str,
payload: dict, user_id):
Expand All @@ -44,6 +49,39 @@ async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str
logger.info(f'dispatch_task done timecost={time.time() - start_time}')
return client_id, chat_id

async def process_stop(self, session: ChatManager, client_id: str, chat_id: str, payload: Dict, user_id):
key = get_cache_key(client_id, chat_id)
langchain_object = session.in_memory_cache.get(key)
action = payload.get('action')
if isinstance(langchain_object, AutoGenChain):
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
await langchain_object.stop()
else:
logger.error(f'act=auto_gen act={action}')
else:
# 普通技能的stop
res = thread_pool.cancel_task([key]) # 将进行中的任务进行cancel
if res[0]:
message = payload.get('inputs') or '手动停止'
res = ChatResponse(type='end', user_id=user_id, message='')
close = ChatResponse(type='close')
await session.send_json(client_id, chat_id, res, add=False)
await session.send_json(client_id, chat_id, close, add=False)
answer = ''
# 记录中止后产生的流式输出内容
while not self.stream_queue.empty():
answer += self.stream_queue.get()
if answer.strip():
chat_message = ChatMessage(message=answer,
category='answer',
type='end',
user_id=user_id,
remark='break_answer',
is_bot=True)
session.chat_history.add_message(client_id, chat_id, chat_message)
logger.info(f'process_stop done')

async def process_report(self,
session: ChatManager,
client_id: str,
Expand Down Expand Up @@ -170,6 +208,7 @@ async def process_message(self,
websocket=session.active_connections[get_cache_key(client_id, chat_id)],
flow_id=client_id,
chat_id=chat_id,
stream_queue=self.stream_queue,
)

except Exception as e:
Expand Down Expand Up @@ -219,8 +258,7 @@ async def process_message(self,
source=int(source))
await session.send_json(client_id, chat_id, response)


# 循环结束
# 循环结束
if is_begin:
close_resp = ChatResponse(type='close', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)
Expand Down Expand Up @@ -300,24 +338,7 @@ async def process_autogen(self, session: ChatManager, client_id: str, chat_id: s
langchain_object = session.in_memory_cache.get(key)
logger.info(f'reciever_human_interactive langchain={langchain_object}')
action = payload.get('action')
if action.lower() == 'stop':
if isinstance(langchain_object, AutoGenChain):
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
await langchain_object.stop()
else:
logger.error(f'act=auto_gen act={action}')
else:
# 普通技能的stop
res = thread_pool.cancel_task([key]) # 将进行中的任务进行cancel
if res[0]:
message = payload.get('inputs') or '手动停止'
res = ChatResponse(type='end', user_id=user_id, message=message)
close = ChatResponse(type='close')
await session.send_json(client_id, chat_id, res)
await session.send_json(client_id, chat_id, close)

elif action.lower() == 'continue':
if action.lower() == 'continue':
# autgen_user 对话的时候,进程 wait() 需要换新
if hasattr(langchain_object, 'input'):
await langchain_object.input(payload.get('inputs'))
Expand Down
17 changes: 12 additions & 5 deletions src/backend/bisheng/chat/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from typing import Any, Dict, List
from uuid import UUID
from queue import Queue

from loguru import logger
from fastapi import WebSocket, WebSocketDisconnect, status, Request
Expand Down Expand Up @@ -81,6 +82,9 @@ def __init__(self):
# 已连接的客户端
self.active_clients: Dict[str, ChatClient] = {}

# 记录流式输出结果
self.stream_queue: Dict[str, Queue] = {}

def update(self):
if self.cache_manager.current_client_id in self.active_connections:
self.last_cached_object_dict = self.cache_manager.get_last()
Expand All @@ -98,9 +102,11 @@ def update(self):
async def connect(self, client_id: str, chat_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[get_cache_key(client_id, chat_id)] = websocket
self.stream_queue[get_cache_key(client_id, chat_id)] = Queue()

def reuse_connect(self, client_id: str, chat_id: str, websocket: WebSocket):
self.active_connections[get_cache_key(client_id, chat_id)] = websocket
self.stream_queue[get_cache_key(client_id, chat_id)] = Queue()

def disconnect(self, client_id: str, chat_id: str, key: str = None):
if key:
Expand Down Expand Up @@ -305,7 +311,7 @@ async def handle_websocket(

# 判断当前是否是空循环
process_param = {
'autogen_pool': autogen_pool,
'autogen_pool': thread_pool,
'user_id': user_id,
'payload': payload,
'graph_data': gragh_data,
Expand All @@ -321,8 +327,7 @@ async def handle_websocket(

# 处理任务状态
complete_normal = await thread_pool.as_completed(key_list)
autoComplete = await autogen_pool.as_completed(key_list)
complete = complete_normal + autoComplete
complete = complete_normal
# if async_task and async_task.done():
# logger.debug(f'async_task_complete result={async_task.result}')
if complete:
Expand Down Expand Up @@ -456,9 +461,9 @@ async def _process_when_payload(self, flow_id: str, chat_id: str,
if isinstance(self.in_memory_cache.get(langchain_obj_key), AutoGenChain):
# autogen chain
logger.info(f'autogen_submit {langchain_obj_key}')
autogen_pool.submit(key, Handler().dispatch_task, **params)
autogen_pool.submit(key, Handler(stream_queue=self.stream_queue[key]).dispatch_task, **params)
else:
thread_pool.submit(key, Handler().dispatch_task, **params)
thread_pool.submit(key, Handler(stream_queue=self.stream_queue[key]).dispatch_task, **params)
status_ = 'init'
context.update({'status': status_})
context.update({'payload': {}}) # clean message
Expand Down Expand Up @@ -504,6 +509,8 @@ async def preper_action(self, client_id, chat_id, langchain_obj_key, payload,
action = 'report'
step_resp.intermediate_steps = '文件解析完成,开始生成报告'
await self.send_json(client_id, chat_id, step_resp)
elif payload.get('action') == 'stop':
action = 'stop'
elif 'action' in payload:
action = 'autogen'
elif 'clear_history' in payload and payload['clear_history']:
Expand Down
6 changes: 4 additions & 2 deletions src/backend/bisheng/chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ async def process_graph(langchain_object,
chat_inputs: ChatMessage,
websocket: WebSocket,
flow_id: str = None,
chat_id: str = None):
chat_id: str = None,
**kwargs):
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug('Loaded langchain object')

Expand All @@ -45,7 +46,8 @@ async def process_graph(langchain_object,
chat_inputs.message,
websocket=websocket,
flow_id=flow_id,
chat_id=chat_id)
chat_id=chat_id,
**kwargs)
logger.debug('Generated result and intermediate_steps')
return result, intermediate_steps, source_document
except Exception as e:
Expand Down
23 changes: 14 additions & 9 deletions src/backend/bisheng/initdb_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ minio_conf: &minio_conf
knowledges:
# 知识库所需的大模型配置,用来对文档内容总结一个标题,然后将标题和chunk合并存储到向量库内, 不配置则不总结文档
llm:
type: "ChatOpenAi"
type: "ChatOpenAI"
model: "gpt-3.5-turbo"
<<: *openai_conf
unstructured_api_url: "" # 毕昇非结构化数据解析服务地址,提供包括OCR文字识别、表格识别、版式分析等能力。非必填,溯源必填
Expand Down Expand Up @@ -68,6 +68,8 @@ llm_request:
default_operator:
user: 3
url: https://bisheng.dataelem.com
# 免登录接口是否校验权限
api_need_login: false

# 是否需要验证码
use_captcha:
Expand All @@ -83,6 +85,7 @@ env:
# 可配置与http不一致的websocket地址
# websocket_url: 192.168.106.120:3003
office_url: http://IP:8701 # office 组件访问地址,需要浏览器能直接访问
pro: false # 是否开启闭源网关

# 智能助手相关配置
gpts:
Expand Down Expand Up @@ -139,11 +142,13 @@ password_conf:
max_error_times: 5

system_login_method:
# SSO 登录
SSO_OAuth: true
# # LDAP 登录
# LDAP: true
# LDAP服务器地址配置: XX

# 切换 SSO/LDAP 登录后管理员用户名
admin_username: admin
# SSO 登录
SSO_OAuth: true
# # LDAP 登录
# LDAP: true
# LDAP服务器地址配置: XX

# 切换 SSO/LDAP 登录后管理员用户名
admin_username: admin
allow_multi_login: true # 是否允许多点登录

2 changes: 2 additions & 0 deletions src/backend/bisheng/utils/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def decide_embeddings(model: str) -> Embeddings:
"""embed method"""
model_list = settings.get_knowledge().get('embeddings')
params = model_list.get(model)
if not params:
raise Exception(f'not found embedding {model} in system settings')
component = params.pop('component', '')
if model == 'text-embedding-ada-002' or component == 'openai':
if is_openai_v1() and params.get('openai_proxy'):
Expand Down

0 comments on commit a3e3bb1

Please sign in to comment.