Skip to content

Commit

Permalink
feat: 登录支持多点登录的配置项
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jul 2, 2024
1 parent cb619b4 commit 0b672f9
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 142 deletions.
14 changes: 3 additions & 11 deletions src/backend/bisheng/api/JWT.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import List

from pydantic import BaseModel
from fastapi import Depends
from fastapi_jwt_auth import AuthJWT

from bisheng.settings import settings
from bisheng.api.services.user_service import UserPayload

# 配置JWT token的有效期
ACCESS_TOKEN_EXPIRE_TIME = 86400


class Settings(BaseModel):
Expand All @@ -17,11 +17,3 @@ class Settings(BaseModel):
authjwt_cookie_csrf_protect: bool = False


async def get_login_user(authorize: AuthJWT = Depends()):
"""
获取当前登录的用户
"""
authorize.jwt_required()
current_user = json.loads(authorize.get_jwt_subject())
user = UserPayload(**current_user)
return user
102 changes: 56 additions & 46 deletions src/backend/bisheng/api/services/audit_log.py

Large diffs are not rendered by default.

30 changes: 28 additions & 2 deletions src/backend/bisheng/api/services/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
import functools
from typing import List

from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from fastapi import HTTPException
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT

from bisheng.settings import settings
from bisheng.utils.constants import USER_CURRENT_SESSION


class UserPayload:

Expand Down Expand Up @@ -117,7 +122,7 @@ def gen_user_jwt(db_user: User):
# 生成JWT令牌
payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
# Create the tokens and passing to set_access_cookies or set_refresh_cookies
access_token = AuthJWT().create_access_token(subject=json.dumps(payload), expires_time=86400)
access_token = AuthJWT().create_access_token(subject=json.dumps(payload), expires_time=ACCESS_TOKEN_EXPIRE_TIME)

refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)

Expand Down Expand Up @@ -202,3 +207,24 @@ def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_si
'total':
total_count
}


async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取当前登录的用户
"""
# 校验是否过期,过期则直接返回http 状态码的 401
authorize.jwt_required()

current_user = json.loads(authorize.get_jwt_subject())
user = UserPayload(**current_user)

# 判断是否允许多点登录
if not settings.get_system_login_method().allow_multi_login:
# 获取access_token
current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
# 登录被挤下线了,状态码是200, 内部的status_code是403
if current_token != authorize._token:
raise HTTPException(status_code=403,
detail='您的账户已在另一设备上登录,此设备上的会话已被注销。\n如果这不是您本人的操作,请尽快修改您的账户密码。')
return user
70 changes: 24 additions & 46 deletions src/backend/bisheng/api/v1/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import yaml
from bisheng_langchain.gpts.tools.api_tools.openapi import OpenApiTools

from bisheng.api.JWT import get_login_user
from bisheng.api.services.assistant import AssistantService
from bisheng.api.services.openapi import OpenApiSchema
from bisheng.api.services.user_service import UserPayload
from bisheng.api.services.user_service import UserPayload, get_login_user
from bisheng.api.utils import get_url_content
from bisheng.api.v1.schemas import (AssistantCreateReq, AssistantInfo, AssistantUpdateReq,
StreamData, UnifiedResponseModel, resp_200, resp_500, DeleteToolTypeReq,
Expand All @@ -34,11 +33,8 @@ def get_assistant(*,
page: Optional[int] = Query(default=1, gt=0, description='页码'),
limit: Optional[int] = Query(default=10, gt=0, description='每页条数'),
status: Optional[int] = Query(default=None, description='是否上线状态'),
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
return AssistantService.get_assistant(user, name, status, page, limit)
login_user: UserPayload = Depends(get_login_user)):
return AssistantService.get_assistant(login_user, name, status, page, limit)


# 获取某个助手的详细信息
Expand Down Expand Up @@ -111,39 +107,29 @@ async def event_stream():
async def update_prompt(*,
assistant_id: UUID = Body(description='助手唯一ID', alias='id'),
prompt: str = Body(description='用户使用的prompt'),
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
return AssistantService.update_prompt(assistant_id, prompt, user)
login_user: UserPayload = Depends(get_login_user)):
return AssistantService.update_prompt(assistant_id, prompt, login_user)


@router.post('/flow', response_model=UnifiedResponseModel)
async def update_flow_list(*,
assistant_id: UUID = Body(description='助手唯一ID', alias='id'),
flow_list: List[str] = Body(description='用户选择的技能列表'),
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
return AssistantService.update_flow_list(assistant_id, flow_list, user)
login_user: UserPayload = Depends(get_login_user)):
return AssistantService.update_flow_list(assistant_id, flow_list, login_user)


@router.post('/tool', response_model=UnifiedResponseModel)
async def update_tool_list(*,
assistant_id: UUID = Body(description='助手唯一ID', alias='id'),
tool_list: List[int] = Body(description='用户选择的工具列表'),
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
return AssistantService.update_tool_list(assistant_id, tool_list, user)
login_user: UserPayload = Depends(get_login_user)):
return AssistantService.update_tool_list(assistant_id, tool_list, login_user)


# 获取助手可用的模型列表
@router.get('/models', response_model=UnifiedResponseModel)
async def get_models(*, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
async def get_models(*, login_user: UserPayload = Depends(get_login_user)):
return AssistantService.get_models()


Expand Down Expand Up @@ -191,7 +177,7 @@ async def get_tool_schema(*,
download_url: Optional[str] = Body(default=None,
description='下载url不为空的话优先用下载url'),
file_content: Optional[str] = Body(default=None, description='上传的文件'),
Authorize: AuthJWT = Depends()):
login_user: UserPayload = Depends(get_login_user)):
""" 下载或者解析openapi schema的内容 转为助手自定义工具的格式 """
if download_url:
try:
Expand Down Expand Up @@ -241,42 +227,34 @@ async def get_tool_schema(*,

@router.post('/tool_list', response_model=UnifiedResponseModel[GptsToolsTypeRead])
def add_tool_type(*, req: Dict = Body(default={}, description="openapi解析后的工具对象"),
Authorize: AuthJWT = Depends()):
login_user: UserPayload = Depends(get_login_user)):
""" 新增自定义tool """
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
req = GptsToolsTypeRead(**req)
return AssistantService.add_gpts_tools(user, req)
return AssistantService.add_gpts_tools(login_user, req)


@router.put('/tool_list', response_model=UnifiedResponseModel[GptsToolsTypeRead])
def update_tool_type(*, req: Dict = Body(default={}, description="通过openapi 解析后的内容,包含类别的唯一ID"),
Authorize: AuthJWT = Depends()):
def update_tool_type(*,
login_user: UserPayload = Depends(get_login_user),
req: Dict = Body(default={}, description="通过openapi 解析后的内容,包含类别的唯一ID")):
""" 更新自定义tool """
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
req = GptsToolsTypeRead(**req)
return AssistantService.update_gpts_tools(user, req)
return AssistantService.update_gpts_tools(login_user, req)


@router.delete('/tool_list', response_model=UnifiedResponseModel)
def delete_tool_type(*, req: DeleteToolTypeReq, Authorize: AuthJWT = Depends()):
def delete_tool_type(*,
login_user: UserPayload = Depends(get_login_user),
req: DeleteToolTypeReq):
""" 删除自定义工具 """
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)
return AssistantService.delete_gpts_tools(user, req.tool_type_id)
return AssistantService.delete_gpts_tools(login_user, req.tool_type_id)


@router.post('/tool_test', response_model=UnifiedResponseModel)
async def test_tool_type(*, req: TestToolReq, Authorize: AuthJWT = Depends()):
async def test_tool_type(*,
login_user: UserPayload = Depends(get_login_user),
req: TestToolReq):
""" 测试自定义工具 """
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
user = UserPayload(**current_user)

tool_params = OpenApiSchema.parse_openapi_tool_params('test', 'test', req.extra, req.server_host,
req.auth_method, req.auth_type, req.api_key)

Expand Down
5 changes: 2 additions & 3 deletions src/backend/bisheng/api/v1/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from fastapi import APIRouter, Query, Depends

from bisheng.api.JWT import get_login_user
from bisheng.api.services.user_service import UserPayload
from bisheng.api.v1.schemas import UnifiedResponseModel, resp_200
from bisheng.api.services.user_service import UserPayload, get_login_user
from bisheng.api.v1.schemas import UnifiedResponseModel
from bisheng.api.services.audit_log import AuditLogService

router = APIRouter(prefix='/audit', tags=['AuditLog'])
Expand Down
7 changes: 3 additions & 4 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import List, Optional
from uuid import UUID

from bisheng.api.JWT import get_login_user
from bisheng.api.services.assistant import AssistantService
from bisheng.api.services.audit_log import AuditLogService
from bisheng.api.services.chat_imp import comment_answer
from bisheng.api.services.knowledge_imp import delete_es, delete_vector
from bisheng.api.services.user_service import UserPayload
from bisheng.api.services.user_service import UserPayload, get_login_user
from bisheng.api.utils import build_flow, build_input_keys_response, get_request_ip
from bisheng.api.v1.schemas import (BuildStatus, BuiltResponse, ChatInput, ChatList,
FlowGptsOnlineList, InitResponse, StreamData,
Expand Down Expand Up @@ -183,11 +182,11 @@ def get_chatlist_list(*,
payload = json.loads(Authorize.get_jwt_subject())

smt = (select(ChatMessage.flow_id, ChatMessage.chat_id,
func.max(ChatMessage.create_time).label('create_time'),
func.min(ChatMessage.create_time).label('create_time'),
func.max(ChatMessage.update_time).label('update_time')).where(
ChatMessage.user_id == payload.get('user_id')).group_by(
ChatMessage.flow_id,
ChatMessage.chat_id).order_by(func.max(ChatMessage.create_time).desc()))
ChatMessage.chat_id).order_by(func.max(ChatMessage.update_time).desc()))
with session_getter() as session:
db_message = session.exec(smt).all()
flow_ids = [message.flow_id for message in db_message]
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def getn_env():
# add tips from settings
env['dialog_tips'] = settings.settings.get_from_db('dialog_tips')
# 判断是否SSO
env['sso'] = settings.settings.get_system_login_method().get('SSO_OAuth', False)
env['sso'] = settings.settings.get_system_login_method().SSO_OAuth
# add env dict from settings
env.update(settings.settings.get_from_db('env') or {})
return resp_200(env)
Expand Down
6 changes: 1 addition & 5 deletions src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@
from typing import Any
from uuid import UUID

from bisheng.api.JWT import get_login_user
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.services.audit_log import AuditLogService
from bisheng.api.services.flow import FlowService
from bisheng.api.services.user_service import UserPayload
from bisheng.api.services.user_service import UserPayload, get_login_user
from bisheng.api.utils import build_flow_no_yield, get_L2_param_from_flow, remove_api_keys
from bisheng.api.v1.schemas import (FlowCompareReq, FlowListRead, FlowVersionCreate, StreamData,
UnifiedResponseModel, resp_200)
from bisheng.database.base import session_getter
from bisheng.database.models.flow import (Flow, FlowCreate, FlowDao, FlowRead, FlowReadWithStyle,
FlowUpdate)
from bisheng.database.models.flow_version import FlowVersionDao
from bisheng.database.models.group_resource import GroupResource, GroupResourceDao, ResourceTypeEnum
from bisheng.database.models.role_access import AccessType
from bisheng.database.models.user_group import UserGroupDao
from bisheng.settings import settings
from bisheng.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, Query, Request
Expand Down
6 changes: 2 additions & 4 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import json
import os
import re
import time
from typing import List, Optional
from uuid import uuid4

from bisheng.api.JWT import get_login_user
from bisheng.api.utils import get_request_ip
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.services.audit_log import AuditLogService
from bisheng.api.services.knowledge_imp import (addEmbedding, decide_vectorstores,
delete_knowledge_file_vectors, retry_files)
from bisheng.api.services.user_service import UserPayload
from bisheng.api.services.user_service import UserPayload, get_login_user
from bisheng.api.v1.schemas import UnifiedResponseModel, UploadFileResponse, resp_200, resp_500
from bisheng.cache.utils import file_download, save_uploaded_file
from bisheng.database.base import session_getter
Expand All @@ -20,7 +18,7 @@
KnowledgeRead)
from bisheng.database.models.knowledge_file import (KnowledgeFile, KnowledgeFileDao,
KnowledgeFileRead)
from bisheng.database.models.role_access import AccessType, RoleAccess, RoleAccessDao
from bisheng.database.models.role_access import AccessType, RoleAccess
from bisheng.database.models.user import User
from bisheng.database.models.user_group import UserGroupDao
from bisheng.interface.embeddings.custom import FakeEmbedding
Expand Down
Loading

0 comments on commit 0b672f9

Please sign in to comment.