From e793b309404ecb1f5093fdcc57555b44ee5e1a0d Mon Sep 17 00:00:00 2001 From: GuoQing Zhang Date: Tue, 9 Jul 2024 17:42:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=AF=8F=E6=AC=A1=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E6=97=B6=E6=B8=85=E7=A9=BA=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E9=98=9F=E5=88=97=EF=BC=8C=E9=98=B2=E6=AD=A2=E6=B1=A1=E6=9F=93?= =?UTF-8?q?=E6=9C=AC=E6=AC=A1=E7=9A=84=E5=9B=9E=E7=AD=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/bisheng/chat/client.py | 10 ++++++++-- src/backend/bisheng/chat/handlers.py | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/backend/bisheng/chat/client.py b/src/backend/bisheng/chat/client.py index 3aff9599c..e3fc46599 100644 --- a/src/backend/bisheng/chat/client.py +++ b/src/backend/bisheng/chat/client.py @@ -43,7 +43,7 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s self.gpts_conf = settings.get_from_db('gpts') # 异步任务列表 self.task_ids = [] - # 流式输出的队列,用来接受流式输出的内容 + # 流式输出的队列,用来接受流式输出的内容,每次处理新的question时都清空 self.stream_queue = Queue() async def send_message(self, message: str): @@ -217,9 +217,13 @@ async def stop_handle_message(self, message: Dict[any, any]): # 有流式输出内容的话,记录流式输出内容到数据库 if answer.strip(): res = await self.add_message('bot', answer, 'answer', 'break_answer') - await self.send_response('answer', 'end', answer, message_id=res.id if res else None) + await self.send_response('answer', 'end', '', message_id=res.id if res else None) await self.send_response('processing', 'close', '') + async def clear_stream_queue(self): + while not self.stream_queue.empty(): + self.stream_queue.get() + async def handle_gpts_message(self, message: Dict[any, any]): if not message: return @@ -228,6 +232,8 @@ async def handle_gpts_message(self, message: Dict[any, any]): await self.stop_handle_message(message) return + # 清空流式队列,防止把上一次的回答,污染本次回答 + await self.clear_stream_queue() inputs = message.get('inputs', {}) input_msg = inputs.get('input') if not input_msg: diff --git a/src/backend/bisheng/chat/handlers.py b/src/backend/bisheng/chat/handlers.py index 70d232a23..431beaa23 100644 --- a/src/backend/bisheng/chat/handlers.py +++ b/src/backend/bisheng/chat/handlers.py @@ -44,6 +44,10 @@ async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str action = 'default' if action not in self.handler_dict: raise Exception(f'unknown action {action}') + if action != 'stop': + # 清空流式输出队列,防止上次的回答污染本次回答 + while not self.stream_queue.empty(): + self.stream_queue.get() await self.handler_dict[action](session, client_id, chat_id, payload, user_id) logger.info(f'dispatch_task done timecost={time.time() - start_time}')