diff --git a/src/backend/bisheng/chat/client.py b/src/backend/bisheng/chat/client.py index 3aff9599c..e3fc46599 100644 --- a/src/backend/bisheng/chat/client.py +++ b/src/backend/bisheng/chat/client.py @@ -43,7 +43,7 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s self.gpts_conf = settings.get_from_db('gpts') # 异步任务列表 self.task_ids = [] - # 流式输出的队列,用来接受流式输出的内容 + # 流式输出的队列,用来接受流式输出的内容,每次处理新的question时都清空 self.stream_queue = Queue() async def send_message(self, message: str): @@ -217,9 +217,13 @@ async def stop_handle_message(self, message: Dict[any, any]): # 有流式输出内容的话,记录流式输出内容到数据库 if answer.strip(): res = await self.add_message('bot', answer, 'answer', 'break_answer') - await self.send_response('answer', 'end', answer, message_id=res.id if res else None) + await self.send_response('answer', 'end', '', message_id=res.id if res else None) await self.send_response('processing', 'close', '') + async def clear_stream_queue(self): + while not self.stream_queue.empty(): + self.stream_queue.get() + async def handle_gpts_message(self, message: Dict[any, any]): if not message: return @@ -228,6 +232,8 @@ async def handle_gpts_message(self, message: Dict[any, any]): await self.stop_handle_message(message) return + # 清空流式队列,防止把上一次的回答,污染本次回答 + await self.clear_stream_queue() inputs = message.get('inputs', {}) input_msg = inputs.get('input') if not input_msg: diff --git a/src/backend/bisheng/chat/handlers.py b/src/backend/bisheng/chat/handlers.py index 70d232a23..431beaa23 100644 --- a/src/backend/bisheng/chat/handlers.py +++ b/src/backend/bisheng/chat/handlers.py @@ -44,6 +44,10 @@ async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str action = 'default' if action not in self.handler_dict: raise Exception(f'unknown action {action}') + if action != 'stop': + # 清空流式输出队列,防止上次的回答污染本次回答 + while not self.stream_queue.empty(): + self.stream_queue.get() await self.handler_dict[action](session, client_id, chat_id, payload, user_id) logger.info(f'dispatch_task done timecost={time.time() - start_time}')