Skip to content

Commit

Permalink
feat: 给助手的免登录提供对应的接口
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 11, 2024
1 parent 839b321 commit 9e9e44a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
5 changes: 2 additions & 3 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import hashlib
import json
import time
import xml.dom.minidom
from pathlib import Path
from typing import Dict, List

import aiohttp
from fastapi import Request
from fastapi import Request, WebSocket
from fastapi_jwt_auth import AuthJWT
from platformdirs import user_cache_dir
from sqlalchemy import delete
Expand Down Expand Up @@ -421,7 +420,7 @@ async def get_url_content(url: str) -> str:
return res.decode('utf-8')


def get_request_ip(request: Request) -> str:
def get_request_ip(request: Request | WebSocket) -> str:
""" 获取客户端真实IP """
x_forwarded_for = request.headers.get('X-Forwarded-For')
if x_forwarded_for:
Expand Down
31 changes: 27 additions & 4 deletions src/backend/bisheng/api/v2/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter, Request, HTTPException, WebSocket
from fastapi import APIRouter, Request, HTTPException, WebSocket, WebSocketException
from fastapi import status as http_status
from fastapi.responses import StreamingResponse, ORJSONResponse
from langchain_core.messages import HumanMessage, AIMessage
from loguru import logger
Expand All @@ -13,8 +14,10 @@
from bisheng.api.services.assistant_agent import AssistantAgent
from bisheng.api.services.user_service import UserPayload
from bisheng.api.utils import get_request_ip
from bisheng.api.v1.chat import chat_manager
from bisheng.api.v1.schemas import OpenAIChatCompletionResp, OpenAIChatCompletionReq, UnifiedResponseModel, \
AssistantInfo, OpenAIChoice
from bisheng.chat.types import WorkType
from bisheng.database.models.user import UserDao
from bisheng.settings import settings

Expand Down Expand Up @@ -129,7 +132,7 @@ async def get_assistant_info(request: Request, assistant_id: UUID):
"""
获取助手信息, 用系统配置里的default_operator.user的用户信息来做权限校验
"""
logger.info('act=assistant_info')
logger.info(f'act=get_default_operator assistant_id={assistant_id}, ip={get_request_ip(request)}')
default_user = get_default_operator()
login_user = UserPayload(**{
'user_id': default_user.user_id,
Expand All @@ -141,10 +144,30 @@ async def get_assistant_info(request: Request, assistant_id: UUID):

@router.websocket('/chat/{assistant_id}')
async def chat(*,
assistant_id: str,
websocket: WebSocket,
assistant_id: str,
chat_id: Optional[str] = None):
"""
助手的ws免登录接口
"""
pass
logger.info(f'act=assistant_chat_ws assistant_id={assistant_id}, ip={get_request_ip(websocket)}')
default_user = get_default_operator()
login_user = UserPayload(**{
'user_id': default_user.user_id,
'user_name': default_user.user_name,
'role': ''
})
try:
request = websocket
await chat_manager.dispatch_client(request, assistant_id, chat_id, login_user, WorkType.GPTS,
websocket)
except WebSocketException as exc:
logger.error(f'Websocket exception: {str(exc)}')
await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as exc:
logger.exception(f'Error in chat websocket: {str(exc)}')
message = exc.detail if isinstance(exc, HTTPException) else str(exc)
if 'Could not validate credentials' in str(exc):
await websocket.close(code=http_status.WS_1008_POLICY_VIOLATION, reason='Unauthorized')
else:
await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=message)

0 comments on commit 9e9e44a

Please sign in to comment.