Skip to content

Commit

Permalink
feat: 每次问答时清空流式输出队列,防止污染本次的回答
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 9, 2024
1 parent c5b1d85 commit e793b30
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/backend/bisheng/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s
self.gpts_conf = settings.get_from_db('gpts')
# 异步任务列表
self.task_ids = []
# 流式输出的队列,用来接受流式输出的内容
# 流式输出的队列,用来接受流式输出的内容,每次处理新的question时都清空
self.stream_queue = Queue()

async def send_message(self, message: str):
Expand Down Expand Up @@ -217,9 +217,13 @@ async def stop_handle_message(self, message: Dict[any, any]):
# 有流式输出内容的话,记录流式输出内容到数据库
if answer.strip():
res = await self.add_message('bot', answer, 'answer', 'break_answer')
await self.send_response('answer', 'end', answer, message_id=res.id if res else None)
await self.send_response('answer', 'end', '', message_id=res.id if res else None)
await self.send_response('processing', 'close', '')

async def clear_stream_queue(self):
while not self.stream_queue.empty():
self.stream_queue.get()

async def handle_gpts_message(self, message: Dict[any, any]):
if not message:
return
Expand All @@ -228,6 +232,8 @@ async def handle_gpts_message(self, message: Dict[any, any]):
await self.stop_handle_message(message)
return

# 清空流式队列,防止把上一次的回答,污染本次回答
await self.clear_stream_queue()
inputs = message.get('inputs', {})
input_msg = inputs.get('input')
if not input_msg:
Expand Down
4 changes: 4 additions & 0 deletions src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str
action = 'default'
if action not in self.handler_dict:
raise Exception(f'unknown action {action}')
if action != 'stop':
# 清空流式输出队列,防止上次的回答污染本次回答
while not self.stream_queue.empty():
self.stream_queue.get()

await self.handler_dict[action](session, client_id, chat_id, payload, user_id)
logger.info(f'dispatch_task done timecost={time.time() - start_time}')
Expand Down

0 comments on commit e793b30

Please sign in to comment.