Skip to content

Commit

Permalink
Feat/0.3.3 (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 9, 2024
2 parents c21ed6b + e793b30 commit 965180f
Show file tree
Hide file tree
Showing 26 changed files with 602 additions and 183 deletions.
34 changes: 28 additions & 6 deletions src/backend/bisheng/api/services/assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def __init__(self, assistant_info: Assistant, chat_id: str):
self.tools: List[BaseTool] = []
self.offline_flows = []
self.agent: ConfigurableAssistant | None = None
self.agent_executor_dict = {
'ReAct': 'get_react_agent_executor',
'function call': 'get_openai_functions_agent_executor',
}
self.current_agent_executor = None
self.llm: BaseLanguageModel | None = None
self.llm_agent_executor = None
self.knowledge_skill_path = str(Path(__file__).parent / 'knowledge_skill.json')
Expand Down Expand Up @@ -281,6 +286,9 @@ async def init_agent(self):
# 引入agent执行参数
agent_executor_params = self.get_agent_executor()
agent_executor_type = self.llm_agent_executor or agent_executor_params.pop('type')
self.current_agent_executor = agent_executor_type
# 做转换
agent_executor_type = self.agent_executor_dict.get(agent_executor_type, agent_executor_type)

prompt = self.assistant.prompt
if self.assistant.model_name.startswith("command-r"):
Expand Down Expand Up @@ -334,12 +342,6 @@ async def run(self, query: str, chat_history: List = None, callback: Callbacks =
"""
运行智能体对话
"""
if chat_history:
chat_history.append(HumanMessage(content=query))
inputs = chat_history
else:
inputs = [HumanMessage(content=query)]

# 假回调,将已下线的技能回调给前端
for one in self.offline_flows:
if callback is not None:
Expand All @@ -348,6 +350,14 @@ async def run(self, query: str, chat_history: List = None, callback: Callbacks =
'name': one,
}, input_str='flow if offline', run_id=run_id)
await callback[0].on_tool_end(output='flow is offline', name=one, run_id=run_id)
if self.current_agent_executor == 'ReAct':
return await self.react_run(query, chat_history, callback)

if chat_history:
chat_history.append(HumanMessage(content=query))
inputs = chat_history
else:
inputs = [HumanMessage(content=query)]
result = await self.agent.ainvoke(inputs, config=RunnableConfig(callbacks=callback))
# 包含了history,将history排除, 默认取最后一个为最终结果
res = [result[-1]]
Expand All @@ -364,3 +374,15 @@ async def run(self, query: str, chat_history: List = None, callback: Callbacks =
except Exception as e:
logger.error(f"record assistant history error: {str(e)}")
return res

async def react_run(self, query: str, chat_history: List = None, callback: Callbacks = None):
""" react 模式的输入和执行 """
result = await self.agent.ainvoke({
'input': query,
'chat_history': chat_history
}, config=RunnableConfig(callbacks=callback))
print(result)
output = result['agent_outcome'].return_values['output']
if isinstance(output, dict):
output = output['text']
return [AIMessage(content=output)]
29 changes: 13 additions & 16 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@ def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: in
# }, # 存储工具调用的input信息
# }

# 流式输出的队列
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)
# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
if self.stream_queue:
self.stream_queue.put(token)

async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str],
**kwargs: Any) -> Any:
Expand Down Expand Up @@ -233,10 +239,13 @@ async def on_chat_model_start(self, serialized: Dict[str, Any],
class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id
self.user_id = user_id

self.stream_queue: Queue = kwargs.get('stream_queue')

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token,
Expand All @@ -248,6 +257,9 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

if self.stream_queue:
self.stream_queue.put(token)

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
log = f'\nThought: {action.log}'
# if there are line breaks, split them and send them
Expand Down Expand Up @@ -386,21 +398,6 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:

class AsyncGptsDebugCallbackHandler(AsyncGptsLLMCallbackHandler):

def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any):
super().__init__(websocket, flow_id, chat_id, user_id, **kwargs)
self.stream_queue: Queue = kwargs.get('stream_queue')

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)

# 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库
await self.websocket.send_json(resp.dict())
self.stream_queue.put(token)

@staticmethod
def parse_tool_category(tool_name) -> (str, str):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def get_config(admin_user: UserPayload = Depends(get_admin_user)):

@router.post('/config/save')
def save_config(data: dict, admin_user: UserPayload = Depends(get_admin_user)):
if not data.get('data', '').strip():
raise HTTPException(status_code=500, detail='配置不能为空')
try:
# 校验是否符合yaml格式
_ = yaml.safe_load(data.get('data'))
Expand Down
5 changes: 5 additions & 0 deletions src/backend/bisheng/api/v2/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# 免登录的助手相关接口


router = APIRouter(prefix='/chat', tags=['AssistantOpenApi'])

10 changes: 8 additions & 2 deletions src/backend/bisheng/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s
self.gpts_conf = settings.get_from_db('gpts')
# 异步任务列表
self.task_ids = []
# 流式输出的队列,用来接受流式输出的内容
# 流式输出的队列,用来接受流式输出的内容,每次处理新的question时都清空
self.stream_queue = Queue()

async def send_message(self, message: str):
Expand Down Expand Up @@ -217,9 +217,13 @@ async def stop_handle_message(self, message: Dict[any, any]):
# 有流式输出内容的话,记录流式输出内容到数据库
if answer.strip():
res = await self.add_message('bot', answer, 'answer', 'break_answer')
await self.send_response('answer', 'end', answer, message_id=res.id if res else None)
await self.send_response('answer', 'end', '', message_id=res.id if res else None)
await self.send_response('processing', 'close', '')

async def clear_stream_queue(self):
while not self.stream_queue.empty():
self.stream_queue.get()

async def handle_gpts_message(self, message: Dict[any, any]):
if not message:
return
Expand All @@ -228,6 +232,8 @@ async def handle_gpts_message(self, message: Dict[any, any]):
await self.stop_handle_message(message)
return

# 清空流式队列,防止把上一次的回答,污染本次回答
await self.clear_stream_queue()
inputs = message.get('inputs', {})
input_msg = inputs.get('input')
if not input_msg:
Expand Down
77 changes: 51 additions & 26 deletions src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import time
from queue import Queue
from typing import Dict

from bisheng.api.v1.schemas import ChatMessage, ChatResponse
Expand All @@ -23,12 +24,16 @@

class Handler:

def __init__(self) -> None:
self.handler_dict = {}
self.handler_dict['default'] = self.process_message
self.handler_dict['autogen'] = self.process_autogen
self.handler_dict['auto_file'] = self.process_file
self.handler_dict['report'] = self.process_report
def __init__(self, stream_queue: Queue) -> None:
self.handler_dict = {
'default': self.process_message,
'autogen': self.process_autogen,
'auto_file': self.process_file,
'report': self.process_report,
'stop': self.process_stop
}
# 记录流式输出的内容
self.stream_queue = stream_queue

async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str, action: str,
payload: dict, user_id):
Expand All @@ -39,11 +44,48 @@ async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str
action = 'default'
if action not in self.handler_dict:
raise Exception(f'unknown action {action}')
if action != 'stop':
# 清空流式输出队列,防止上次的回答污染本次回答
while not self.stream_queue.empty():
self.stream_queue.get()

await self.handler_dict[action](session, client_id, chat_id, payload, user_id)
logger.info(f'dispatch_task done timecost={time.time() - start_time}')
return client_id, chat_id

async def process_stop(self, session: ChatManager, client_id: str, chat_id: str, payload: Dict, user_id):
key = get_cache_key(client_id, chat_id)
langchain_object = session.in_memory_cache.get(key)
action = payload.get('action')
if isinstance(langchain_object, AutoGenChain):
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
await langchain_object.stop()
else:
logger.error(f'act=auto_gen act={action}')
else:
# 普通技能的stop
res = thread_pool.cancel_task([key]) # 将进行中的任务进行cancel
if res[0]:
message = payload.get('inputs') or '手动停止'
res = ChatResponse(type='end', user_id=user_id, message='')
close = ChatResponse(type='close')
await session.send_json(client_id, chat_id, res, add=False)
await session.send_json(client_id, chat_id, close, add=False)
answer = ''
# 记录中止后产生的流式输出内容
while not self.stream_queue.empty():
answer += self.stream_queue.get()
if answer.strip():
chat_message = ChatMessage(message=answer,
category='answer',
type='end',
user_id=user_id,
remark='break_answer',
is_bot=True)
session.chat_history.add_message(client_id, chat_id, chat_message)
logger.info(f'process_stop done')

async def process_report(self,
session: ChatManager,
client_id: str,
Expand Down Expand Up @@ -170,6 +212,7 @@ async def process_message(self,
websocket=session.active_connections[get_cache_key(client_id, chat_id)],
flow_id=client_id,
chat_id=chat_id,
stream_queue=self.stream_queue,
)

except Exception as e:
Expand Down Expand Up @@ -219,8 +262,7 @@ async def process_message(self,
source=int(source))
await session.send_json(client_id, chat_id, response)


# 循环结束
# 循环结束
if is_begin:
close_resp = ChatResponse(type='close', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)
Expand Down Expand Up @@ -300,24 +342,7 @@ async def process_autogen(self, session: ChatManager, client_id: str, chat_id: s
langchain_object = session.in_memory_cache.get(key)
logger.info(f'reciever_human_interactive langchain={langchain_object}')
action = payload.get('action')
if action.lower() == 'stop':
if isinstance(langchain_object, AutoGenChain):
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
await langchain_object.stop()
else:
logger.error(f'act=auto_gen act={action}')
else:
# 普通技能的stop
res = thread_pool.cancel_task([key]) # 将进行中的任务进行cancel
if res[0]:
message = payload.get('inputs') or '手动停止'
res = ChatResponse(type='end', user_id=user_id, message=message)
close = ChatResponse(type='close')
await session.send_json(client_id, chat_id, res)
await session.send_json(client_id, chat_id, close)

elif action.lower() == 'continue':
if action.lower() == 'continue':
# autgen_user 对话的时候,进程 wait() 需要换新
if hasattr(langchain_object, 'input'):
await langchain_object.input(payload.get('inputs'))
Expand Down
17 changes: 12 additions & 5 deletions src/backend/bisheng/chat/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from typing import Any, Dict, List
from uuid import UUID
from queue import Queue

from loguru import logger
from fastapi import WebSocket, WebSocketDisconnect, status, Request
Expand Down Expand Up @@ -81,6 +82,9 @@ def __init__(self):
# 已连接的客户端
self.active_clients: Dict[str, ChatClient] = {}

# 记录流式输出结果
self.stream_queue: Dict[str, Queue] = {}

def update(self):
if self.cache_manager.current_client_id in self.active_connections:
self.last_cached_object_dict = self.cache_manager.get_last()
Expand All @@ -98,9 +102,11 @@ def update(self):
async def connect(self, client_id: str, chat_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[get_cache_key(client_id, chat_id)] = websocket
self.stream_queue[get_cache_key(client_id, chat_id)] = Queue()

def reuse_connect(self, client_id: str, chat_id: str, websocket: WebSocket):
self.active_connections[get_cache_key(client_id, chat_id)] = websocket
self.stream_queue[get_cache_key(client_id, chat_id)] = Queue()

def disconnect(self, client_id: str, chat_id: str, key: str = None):
if key:
Expand Down Expand Up @@ -305,7 +311,7 @@ async def handle_websocket(

# 判断当前是否是空循环
process_param = {
'autogen_pool': autogen_pool,
'autogen_pool': thread_pool,
'user_id': user_id,
'payload': payload,
'graph_data': gragh_data,
Expand All @@ -321,8 +327,7 @@ async def handle_websocket(

# 处理任务状态
complete_normal = await thread_pool.as_completed(key_list)
autoComplete = await autogen_pool.as_completed(key_list)
complete = complete_normal + autoComplete
complete = complete_normal
# if async_task and async_task.done():
# logger.debug(f'async_task_complete result={async_task.result}')
if complete:
Expand Down Expand Up @@ -456,9 +461,9 @@ async def _process_when_payload(self, flow_id: str, chat_id: str,
if isinstance(self.in_memory_cache.get(langchain_obj_key), AutoGenChain):
# autogen chain
logger.info(f'autogen_submit {langchain_obj_key}')
autogen_pool.submit(key, Handler().dispatch_task, **params)
autogen_pool.submit(key, Handler(stream_queue=self.stream_queue[key]).dispatch_task, **params)
else:
thread_pool.submit(key, Handler().dispatch_task, **params)
thread_pool.submit(key, Handler(stream_queue=self.stream_queue[key]).dispatch_task, **params)
status_ = 'init'
context.update({'status': status_})
context.update({'payload': {}}) # clean message
Expand Down Expand Up @@ -504,6 +509,8 @@ async def preper_action(self, client_id, chat_id, langchain_obj_key, payload,
action = 'report'
step_resp.intermediate_steps = '文件解析完成,开始生成报告'
await self.send_json(client_id, chat_id, step_resp)
elif payload.get('action') == 'stop':
action = 'stop'
elif 'action' in payload:
action = 'autogen'
elif 'clear_history' in payload and payload['clear_history']:
Expand Down
6 changes: 4 additions & 2 deletions src/backend/bisheng/chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ async def process_graph(langchain_object,
chat_inputs: ChatMessage,
websocket: WebSocket,
flow_id: str = None,
chat_id: str = None):
chat_id: str = None,
**kwargs):
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug('Loaded langchain object')

Expand All @@ -45,7 +46,8 @@ async def process_graph(langchain_object,
chat_inputs.message,
websocket=websocket,
flow_id=flow_id,
chat_id=chat_id)
chat_id=chat_id,
**kwargs)
logger.debug('Generated result and intermediate_steps')
return result, intermediate_steps, source_document
except Exception as e:
Expand Down
Loading

0 comments on commit 965180f

Please sign in to comment.