diff --git a/src/backend/bisheng/api/utils.py b/src/backend/bisheng/api/utils.py index b813e98da..87ecdabc2 100644 --- a/src/backend/bisheng/api/utils.py +++ b/src/backend/bisheng/api/utils.py @@ -1,12 +1,11 @@ import hashlib import json -import time import xml.dom.minidom from pathlib import Path from typing import Dict, List import aiohttp -from fastapi import Request +from fastapi import Request, WebSocket from fastapi_jwt_auth import AuthJWT from platformdirs import user_cache_dir from sqlalchemy import delete @@ -421,7 +420,7 @@ async def get_url_content(url: str) -> str: return res.decode('utf-8') -def get_request_ip(request: Request) -> str: +def get_request_ip(request: Request | WebSocket) -> str: """ 获取客户端真实IP """ x_forwarded_for = request.headers.get('X-Forwarded-For') if x_forwarded_for: diff --git a/src/backend/bisheng/api/v2/assistant.py b/src/backend/bisheng/api/v2/assistant.py index b33cbccfc..da0bcbc87 100644 --- a/src/backend/bisheng/api/v2/assistant.py +++ b/src/backend/bisheng/api/v2/assistant.py @@ -4,7 +4,8 @@ from typing import Optional from uuid import UUID -from fastapi import APIRouter, Request, HTTPException, WebSocket +from fastapi import APIRouter, Request, HTTPException, WebSocket, WebSocketException +from fastapi import status as http_status from fastapi.responses import StreamingResponse, ORJSONResponse from langchain_core.messages import HumanMessage, AIMessage from loguru import logger @@ -13,8 +14,10 @@ from bisheng.api.services.assistant_agent import AssistantAgent from bisheng.api.services.user_service import UserPayload from bisheng.api.utils import get_request_ip +from bisheng.api.v1.chat import chat_manager from bisheng.api.v1.schemas import OpenAIChatCompletionResp, OpenAIChatCompletionReq, UnifiedResponseModel, \ AssistantInfo, OpenAIChoice +from bisheng.chat.types import WorkType from bisheng.database.models.user import UserDao from bisheng.settings import settings @@ -129,7 +132,7 @@ async def get_assistant_info(request: Request, assistant_id: UUID): """ 获取助手信息, 用系统配置里的default_operator.user的用户信息来做权限校验 """ - logger.info('act=assistant_info') + logger.info(f'act=get_default_operator assistant_id={assistant_id}, ip={get_request_ip(request)}') default_user = get_default_operator() login_user = UserPayload(**{ 'user_id': default_user.user_id, @@ -141,10 +144,30 @@ async def get_assistant_info(request: Request, assistant_id: UUID): @router.websocket('/chat/{assistant_id}') async def chat(*, - assistant_id: str, websocket: WebSocket, + assistant_id: str, chat_id: Optional[str] = None): """ 助手的ws免登录接口 """ - pass + logger.info(f'act=assistant_chat_ws assistant_id={assistant_id}, ip={get_request_ip(websocket)}') + default_user = get_default_operator() + login_user = UserPayload(**{ + 'user_id': default_user.user_id, + 'user_name': default_user.user_name, + 'role': '' + }) + try: + request = websocket + await chat_manager.dispatch_client(request, assistant_id, chat_id, login_user, WorkType.GPTS, + websocket) + except WebSocketException as exc: + logger.error(f'Websocket exception: {str(exc)}') + await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=str(exc)) + except Exception as exc: + logger.exception(f'Error in chat websocket: {str(exc)}') + message = exc.detail if isinstance(exc, HTTPException) else str(exc) + if 'Could not validate credentials' in str(exc): + await websocket.close(code=http_status.WS_1008_POLICY_VIOLATION, reason='Unauthorized') + else: + await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=message)