From 419cc318e7ec35abf9c16ba48d3ada46b3cbf7dc Mon Sep 17 00:00:00 2001 From: GuoQing Zhang Date: Fri, 5 Jul 2024 15:52:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=9A=E8=AF=9D=E5=88=97=E8=A1=A8?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E8=BF=94=E5=9B=9E=E5=AF=B9=E5=BA=94=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E6=9C=80=E6=96=B0=E7=9A=84AI=E5=9B=9E=E5=A4=8D?= =?UTF-8?q?=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/bisheng/api/v1/chat.py | 11 ++++++++++- src/backend/bisheng/api/v1/endpoints.py | 4 ++-- src/backend/bisheng/api/v1/schemas.py | 9 ++++++--- src/backend/bisheng/database/models/message.py | 10 +++++++++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/backend/bisheng/api/v1/chat.py b/src/backend/bisheng/api/v1/chat.py index afac2ce6b..92af9006e 100644 --- a/src/backend/bisheng/api/v1/chat.py +++ b/src/backend/bisheng/api/v1/chat.py @@ -1,4 +1,5 @@ import json +import time from typing import List, Optional from uuid import UUID @@ -172,7 +173,6 @@ def get_chatlist_list(*, page: Optional[int] = 1, limit: Optional[int] = 10, login_user: UserPayload = Depends(get_login_user)): - smt = (select(ChatMessage.flow_id, ChatMessage.chat_id, func.min(ChatMessage.create_time).label('create_time'), func.max(ChatMessage.update_time).label('update_time')).where( @@ -181,6 +181,7 @@ def get_chatlist_list(*, ChatMessage.chat_id).order_by(func.max(ChatMessage.update_time).desc())) with session_getter() as session: db_message = session.exec(smt).all() + flow_ids = [message.flow_id for message in db_message] with session_getter() as session: db_flow = session.exec(select(Flow).where(Flow.id.in_(flow_ids))).all() @@ -212,6 +213,14 @@ def get_chatlist_list(*, else: # 通过接口创建的会话记录,不关联技能或者助手 logger.debug(f'unknown message.flow_id={message.flow_id}') + res = chat_list[(page - 1) * limit:page * limit] + chat_ids = [one.chat_id for one in res] + latest_messages = ChatMessageDao.get_latest_message_by_chat_ids(chat_ids, 'answer') + latest_messages = {one.chat_id: one for one in latest_messages} + + for one in res: + # 获取每个会话的最后一条回复内容 + one.latest_message = latest_messages.get(one.chat_id, None) return resp_200(chat_list[(page - 1) * limit:page * limit]) diff --git a/src/backend/bisheng/api/v1/endpoints.py b/src/backend/bisheng/api/v1/endpoints.py index aec409f4e..b2bb7d3a6 100644 --- a/src/backend/bisheng/api/v1/endpoints.py +++ b/src/backend/bisheng/api/v1/endpoints.py @@ -37,13 +37,13 @@ def process_graph_cached_task(*args, **kwargs): @router.get('/all') -def get_all(login_user: UserPayload = Depends(get_login_user)): +def get_all(): """获取所有参数""" return resp_200(get_all_types_dict()) @router.get('/env') -def get_env(login_user: UserPayload = Depends(get_login_user)): +def get_env(): """获取环境变量参数""" uns_support = [ 'png', 'jpg', 'jpeg', 'bmp', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'txt', 'md', diff --git a/src/backend/bisheng/api/v1/schemas.py b/src/backend/bisheng/api/v1/schemas.py index f66868295..3f12eb4dc 100644 --- a/src/backend/bisheng/api/v1/schemas.py +++ b/src/backend/bisheng/api/v1/schemas.py @@ -3,14 +3,16 @@ from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from uuid import UUID +from langchain.docstore.document import Document +from orjson import orjson +from pydantic import BaseModel, Field, validator + from bisheng.database.models.assistant import AssistantBase from bisheng.database.models.finetune import TrainMethod from bisheng.database.models.flow import FlowCreate, FlowRead from bisheng.database.models.gpts_tools import GptsToolsRead, AuthMethod, AuthType from bisheng.database.models.knowledge import KnowledgeRead -from langchain.docstore.document import Document -from orjson import orjson -from pydantic import BaseModel, Field, validator +from bisheng.database.models.message import ChatMessageRead class CaptchaInput(BaseModel): @@ -119,6 +121,7 @@ class ChatList(BaseModel): create_time: datetime = None update_time: datetime = None flow_type: str = None # flow: 技能 assistant:gpts助手 + latest_message: ChatMessageRead = None class FlowGptsOnlineList(BaseModel): diff --git a/src/backend/bisheng/database/models/message.py b/src/backend/bisheng/database/models/message.py index 73a62d96f..c4b21d15e 100644 --- a/src/backend/bisheng/database/models/message.py +++ b/src/backend/bisheng/database/models/message.py @@ -86,6 +86,15 @@ def get_latest_message_by_chatid(cls, chat_id: str): else: return None + @classmethod + def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None): + statement = select(ChatMessage).where(ChatMessage.chat_id.in_(chat_ids)) + if category: + statement = statement.where(ChatMessage.category == category) + statement = statement.order_by(ChatMessage.create_time.desc()).limit(1) + with session_getter() as session: + return session.exec(statement).all() + @classmethod def get_messages_by_chat_id(cls, chat_id: str, category_list: list = None, limit: int = 10): with session_getter() as session: @@ -142,7 +151,6 @@ def get_message_by_id(cls, message_id: int) -> Optional[ChatMessage]: with session_getter() as session: return session.exec(select(ChatMessage).where(ChatMessage.id == message_id)).first() - @classmethod def update_message(cls, message_id: int, user_id: int, message: str): with session_getter() as session: