diff --git a/src/backend/bisheng/api/v1/chat.py b/src/backend/bisheng/api/v1/chat.py index afac2ce6b..92af9006e 100644 --- a/src/backend/bisheng/api/v1/chat.py +++ b/src/backend/bisheng/api/v1/chat.py @@ -1,4 +1,5 @@ import json +import time from typing import List, Optional from uuid import UUID @@ -172,7 +173,6 @@ def get_chatlist_list(*, page: Optional[int] = 1, limit: Optional[int] = 10, login_user: UserPayload = Depends(get_login_user)): - smt = (select(ChatMessage.flow_id, ChatMessage.chat_id, func.min(ChatMessage.create_time).label('create_time'), func.max(ChatMessage.update_time).label('update_time')).where( @@ -181,6 +181,7 @@ def get_chatlist_list(*, ChatMessage.chat_id).order_by(func.max(ChatMessage.update_time).desc())) with session_getter() as session: db_message = session.exec(smt).all() + flow_ids = [message.flow_id for message in db_message] with session_getter() as session: db_flow = session.exec(select(Flow).where(Flow.id.in_(flow_ids))).all() @@ -212,6 +213,14 @@ def get_chatlist_list(*, else: # 通过接口创建的会话记录,不关联技能或者助手 logger.debug(f'unknown message.flow_id={message.flow_id}') + res = chat_list[(page - 1) * limit:page * limit] + chat_ids = [one.chat_id for one in res] + latest_messages = ChatMessageDao.get_latest_message_by_chat_ids(chat_ids, 'answer') + latest_messages = {one.chat_id: one for one in latest_messages} + + for one in res: + # 获取每个会话的最后一条回复内容 + one.latest_message = latest_messages.get(one.chat_id, None) return resp_200(chat_list[(page - 1) * limit:page * limit]) diff --git a/src/backend/bisheng/api/v1/endpoints.py b/src/backend/bisheng/api/v1/endpoints.py index aec409f4e..b2bb7d3a6 100644 --- a/src/backend/bisheng/api/v1/endpoints.py +++ b/src/backend/bisheng/api/v1/endpoints.py @@ -37,13 +37,13 @@ def process_graph_cached_task(*args, **kwargs): @router.get('/all') -def get_all(login_user: UserPayload = Depends(get_login_user)): +def get_all(): """获取所有参数""" return resp_200(get_all_types_dict()) @router.get('/env') -def get_env(login_user: UserPayload = Depends(get_login_user)): +def get_env(): """获取环境变量参数""" uns_support = [ 'png', 'jpg', 'jpeg', 'bmp', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'txt', 'md', diff --git a/src/backend/bisheng/api/v1/schemas.py b/src/backend/bisheng/api/v1/schemas.py index f66868295..3f12eb4dc 100644 --- a/src/backend/bisheng/api/v1/schemas.py +++ b/src/backend/bisheng/api/v1/schemas.py @@ -3,14 +3,16 @@ from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from uuid import UUID +from langchain.docstore.document import Document +from orjson import orjson +from pydantic import BaseModel, Field, validator + from bisheng.database.models.assistant import AssistantBase from bisheng.database.models.finetune import TrainMethod from bisheng.database.models.flow import FlowCreate, FlowRead from bisheng.database.models.gpts_tools import GptsToolsRead, AuthMethod, AuthType from bisheng.database.models.knowledge import KnowledgeRead -from langchain.docstore.document import Document -from orjson import orjson -from pydantic import BaseModel, Field, validator +from bisheng.database.models.message import ChatMessageRead class CaptchaInput(BaseModel): @@ -119,6 +121,7 @@ class ChatList(BaseModel): create_time: datetime = None update_time: datetime = None flow_type: str = None # flow: 技能 assistant:gpts助手 + latest_message: ChatMessageRead = None class FlowGptsOnlineList(BaseModel): diff --git a/src/backend/bisheng/database/models/message.py b/src/backend/bisheng/database/models/message.py index 73a62d96f..c4b21d15e 100644 --- a/src/backend/bisheng/database/models/message.py +++ b/src/backend/bisheng/database/models/message.py @@ -86,6 +86,15 @@ def get_latest_message_by_chatid(cls, chat_id: str): else: return None + @classmethod + def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None): + statement = select(ChatMessage).where(ChatMessage.chat_id.in_(chat_ids)) + if category: + statement = statement.where(ChatMessage.category == category) + statement = statement.order_by(ChatMessage.create_time.desc()).limit(1) + with session_getter() as session: + return session.exec(statement).all() + @classmethod def get_messages_by_chat_id(cls, chat_id: str, category_list: list = None, limit: int = 10): with session_getter() as session: @@ -142,7 +151,6 @@ def get_message_by_id(cls, message_id: int) -> Optional[ChatMessage]: with session_getter() as session: return session.exec(select(ChatMessage).where(ChatMessage.id == message_id)).first() - @classmethod def update_message(cls, message_id: int, user_id: int, message: str): with session_getter() as session: