Skip to content

Commit

Permalink
feat: 会话列表接口返回对应会话最新的AI回复内容
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 5, 2024
1 parent abf67d0 commit 419cc31
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
11 changes: 10 additions & 1 deletion src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
from typing import List, Optional
from uuid import UUID

Expand Down Expand Up @@ -172,7 +173,6 @@ def get_chatlist_list(*,
page: Optional[int] = 1,
limit: Optional[int] = 10,
login_user: UserPayload = Depends(get_login_user)):

smt = (select(ChatMessage.flow_id, ChatMessage.chat_id,
func.min(ChatMessage.create_time).label('create_time'),
func.max(ChatMessage.update_time).label('update_time')).where(
Expand All @@ -181,6 +181,7 @@ def get_chatlist_list(*,
ChatMessage.chat_id).order_by(func.max(ChatMessage.update_time).desc()))
with session_getter() as session:
db_message = session.exec(smt).all()

flow_ids = [message.flow_id for message in db_message]
with session_getter() as session:
db_flow = session.exec(select(Flow).where(Flow.id.in_(flow_ids))).all()
Expand Down Expand Up @@ -212,6 +213,14 @@ def get_chatlist_list(*,
else:
# 通过接口创建的会话记录,不关联技能或者助手
logger.debug(f'unknown message.flow_id={message.flow_id}')
res = chat_list[(page - 1) * limit:page * limit]
chat_ids = [one.chat_id for one in res]
latest_messages = ChatMessageDao.get_latest_message_by_chat_ids(chat_ids, 'answer')
latest_messages = {one.chat_id: one for one in latest_messages}

for one in res:
# 获取每个会话的最后一条回复内容
one.latest_message = latest_messages.get(one.chat_id, None)
return resp_200(chat_list[(page - 1) * limit:page * limit])


Expand Down
4 changes: 2 additions & 2 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def process_graph_cached_task(*args, **kwargs):


@router.get('/all')
def get_all(login_user: UserPayload = Depends(get_login_user)):
def get_all():
"""获取所有参数"""
return resp_200(get_all_types_dict())


@router.get('/env')
def get_env(login_user: UserPayload = Depends(get_login_user)):
def get_env():
"""获取环境变量参数"""
uns_support = [
'png', 'jpg', 'jpeg', 'bmp', 'doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'txt', 'md',
Expand Down
9 changes: 6 additions & 3 deletions src/backend/bisheng/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import UUID

from langchain.docstore.document import Document
from orjson import orjson
from pydantic import BaseModel, Field, validator

from bisheng.database.models.assistant import AssistantBase
from bisheng.database.models.finetune import TrainMethod
from bisheng.database.models.flow import FlowCreate, FlowRead
from bisheng.database.models.gpts_tools import GptsToolsRead, AuthMethod, AuthType
from bisheng.database.models.knowledge import KnowledgeRead
from langchain.docstore.document import Document
from orjson import orjson
from pydantic import BaseModel, Field, validator
from bisheng.database.models.message import ChatMessageRead


class CaptchaInput(BaseModel):
Expand Down Expand Up @@ -119,6 +121,7 @@ class ChatList(BaseModel):
create_time: datetime = None
update_time: datetime = None
flow_type: str = None # flow: 技能 assistant:gpts助手
latest_message: ChatMessageRead = None


class FlowGptsOnlineList(BaseModel):
Expand Down
10 changes: 9 additions & 1 deletion src/backend/bisheng/database/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def get_latest_message_by_chatid(cls, chat_id: str):
else:
return None

@classmethod
def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None):
statement = select(ChatMessage).where(ChatMessage.chat_id.in_(chat_ids))
if category:
statement = statement.where(ChatMessage.category == category)
statement = statement.order_by(ChatMessage.create_time.desc()).limit(1)
with session_getter() as session:
return session.exec(statement).all()

@classmethod
def get_messages_by_chat_id(cls, chat_id: str, category_list: list = None, limit: int = 10):
with session_getter() as session:
Expand Down Expand Up @@ -142,7 +151,6 @@ def get_message_by_id(cls, message_id: int) -> Optional[ChatMessage]:
with session_getter() as session:
return session.exec(select(ChatMessage).where(ChatMessage.id == message_id)).first()


@classmethod
def update_message(cls, message_id: int, user_id: int, message: str):
with session_getter() as session:
Expand Down

0 comments on commit 419cc31

Please sign in to comment.