From 1a97b127ec705edaf645ce962c5ac8bf7f9644d5 Mon Sep 17 00:00:00 2001 From: GuoQing Zhang Date: Thu, 11 Jul 2024 11:57:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E4=B8=A4=E4=B8=AA?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E5=92=8C=E8=8E=B7=E5=8F=96=E5=89=8D=E7=AB=AF?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/bisheng/api/v1/endpoints.py | 36 +++- src/backend/bisheng/database/base.py | 159 +---------------- src/backend/bisheng/database/init_config.py | 2 +- src/backend/bisheng/database/init_data.py | 162 ++++++++++++++++++ src/backend/bisheng/database/models/config.py | 23 ++- src/backend/bisheng/main.py | 2 +- src/backend/bisheng/settings.py | 3 +- 7 files changed, 220 insertions(+), 167 deletions(-) create mode 100644 src/backend/bisheng/database/init_data.py diff --git a/src/backend/bisheng/api/v1/endpoints.py b/src/backend/bisheng/api/v1/endpoints.py index 8a95c4731..0eb558758 100644 --- a/src/backend/bisheng/api/v1/endpoints.py +++ b/src/backend/bisheng/api/v1/endpoints.py @@ -5,7 +5,8 @@ import yaml from bisheng import settings -from bisheng.api.services.user_service import UserPayload, get_admin_user, get_login_user +from bisheng.api.services.user_service import UserPayload, get_admin_user +from bisheng.api.utils import get_request_ip from bisheng.api.v1 import knowledge from bisheng.api.v1.schemas import (ProcessResponse, UnifiedResponseModel, UploadFileResponse, resp_200) @@ -13,7 +14,7 @@ from bisheng.cache.utils import save_uploaded_file from bisheng.chat.utils import judge_source, process_source_document from bisheng.database.base import session_getter -from bisheng.database.models.config import Config +from bisheng.database.models.config import Config, ConfigDao from bisheng.database.models.flow import Flow from bisheng.database.models.message import ChatMessage from bisheng.interface.types import get_all_types_dict @@ -21,8 +22,7 @@ from bisheng.services.deps import get_session_service, get_task_service from bisheng.services.task.service import TaskService from bisheng.utils.logger import logger -from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile -from fastapi_jwt_auth import AuthJWT +from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, Request from sqlmodel import select try: @@ -100,6 +100,34 @@ def save_config(data: dict, admin_user: UserPayload = Depends(get_admin_user)): return resp_200('保存成功') +@router.get('/web/config') +async def get_web_config(): + """ 获取一些前端所需要的配置项,内容由前端决定 """ + web_conf = ConfigDao.get_config('web_config') + if not web_conf: + return resp_200(data='') + return resp_200(data={ + "value": web_conf.value + }) + + +@router.post('/web/config') +async def update_web_config(request: Request, + admin_user: UserPayload = Depends(get_admin_user), + value: str = Body(embed=True)): + """ 更新一些前端所需要的配置项,内容由前端决定 """ + logger.info(f'update_web_config user_name={admin_user.user_name}, ip={get_request_ip(request)}') + web_conf = ConfigDao.get_config('web_config') + if not web_conf: + web_conf = Config(key='web_config', value=value) + else: + web_conf.value = value + ConfigDao.insert_config(web_conf) + return resp_200(data={ + "value": web_conf.value + }) + + @router.post('/process/{flow_id}') async def process_flow_old( flow_id: UUID, diff --git a/src/backend/bisheng/database/base.py b/src/backend/bisheng/database/base.py index 336c8b65d..e55eac55a 100644 --- a/src/backend/bisheng/database/base.py +++ b/src/backend/bisheng/database/base.py @@ -1,158 +1,14 @@ -import hashlib -import json -import os import uuid from contextlib import contextmanager -from typing import List -from bisheng.database.init_config import init_config from bisheng.database.service import DatabaseService from bisheng.settings import settings from bisheng.utils.logger import logger -from sqlalchemy import text -from sqlmodel import Session, select, update +from sqlmodel import Session db_service: 'DatabaseService' = DatabaseService(settings.database_url) -def init_default_data(): - """初始化数据库""" - from bisheng.cache.redis import redis_client - from bisheng.database.models.component import Component - from bisheng.database.models.role import Role, AdminRole, DefaultRole - from bisheng.database.models.user import User - from bisheng.database.models.gpts_tools import GptsTools - from bisheng.database.models.gpts_tools import GptsToolsType - from bisheng.database.models.sft_model import SftModel - from bisheng.database.models.flow_version import FlowVersion - from bisheng.database.models.user_role import UserRoleDao - from bisheng.database.models.group import Group, DefaultGroup - from bisheng.database.models.role_access import RoleAccess, AccessType - - if redis_client.setNx('init_default_data', '1'): - try: - db_service.create_db_and_tables() - with session_getter() as session: - db_role = session.exec(select(Role).limit(1)).all() - if not db_role: - # 初始化系统配置, 管理员拥有所有权限 - db_role = Role(id=AdminRole, role_name='系统管理员', remark='系统所有权限管理员', - group_id=DefaultGroup) - session.add(db_role) - db_role_normal = Role(id=DefaultRole, role_name='普通用户', remark='默认用户', - group_id=DefaultGroup) - session.add(db_role_normal) - # 给普通用户赋予 构建、知识、模型菜单栏的查看权限 - session.add_all([ - RoleAccess(role_id=DefaultRole, type=AccessType.WEB_MENU.value, third_id='build'), - RoleAccess(role_id=DefaultRole, type=AccessType.WEB_MENU.value, third_id='knowledge'), - RoleAccess(role_id=DefaultRole, type=AccessType.WEB_MENU.value, third_id='model'), - ]) - session.commit() - # 添加默认用户组 - group = session.exec(select(Group).limit(1)).all() - if not group: - group = Group(id=DefaultGroup, group_name='默认用户组', create_user=1, update_user=1) - session.add(group) - session.commit() - session.refresh(group) - - user = session.exec(select(User).limit(1)).all() - if not user and settings.admin: - md5 = hashlib.md5() - md5.update(settings.admin.get('password').encode('utf-8')) - user = User( - user_id=1, - user_name=settings.admin.get('user_name'), - password=md5.hexdigest(), - ) - session.add(user) - session.commit() - session.refresh(user) - UserRoleDao.set_admin_user(user.user_id) - - component_db = session.exec(select(Component).limit(1)).all() - if not component_db: - db_components = [] - json_items = json.loads(read_from_conf('component.json')) - for item in json_items: - for k, v in item.items(): - db_component = Component(name=k, user_id=1, user_name='admin', data=v) - db_components.append(db_component) - session.add_all(db_components) - session.commit() - - # 初始化预置工具列表 - preset_tools = session.exec(select(GptsTools).limit(1)).all() - if not preset_tools: - preset_tools = [] - json_items = json.loads(read_from_conf('t_gpts_tools.json')) - for item in json_items: - item['api_params'] = json.loads(item['api_params']) - preset_tool = GptsTools(**item) - preset_tools.append(preset_tool) - session.add_all(preset_tools) - session.commit() - # 初始化预置工具类别 - preset_tools_type = session.exec(select(GptsToolsType).limit(1)).all() - if not preset_tools_type: - preset_tools_type = [] - json_items = json.loads(read_from_conf('t_gpts_tools_type.json')) - for item in json_items: - preset_tool_type = GptsToolsType(**item) - preset_tools_type.append(preset_tool_type) - session.add_all(preset_tools_type) - session.commit() - # 设置预置工具所属的类别, 需要和预置数据一致,所以id是固定的 - for i in range(1, 7): - session.exec(update(GptsTools).where(GptsTools.id == i).values(type=i)) - # 属于天眼查类别下的工具 - tyc_types: List[int] = list(range(7, 18)) - session.exec( - update(GptsTools).where(GptsTools.id.in_(tyc_types)).values(type=7)) - # 属于金融类别下的工具 - jr_types: List[int] = list(range(18, 28)) - session.exec( - update(GptsTools).where(GptsTools.id.in_(jr_types)).values(type=8)) - session.commit() - # 初始化配置可用于微调的基准模型 - preset_models = session.exec(select(SftModel).limit(1)).all() - if not preset_models: - preset_models = [] - json_items = json.loads(read_from_conf('sft_model.json')) - for item in json_items: - preset_model = SftModel(**item) - preset_models.append(preset_model) - session.add_all(preset_models) - session.commit() - - # 初始化补充默认的技能版本表 - flow_version = session.exec(select(FlowVersion).limit(1)).all() - if not flow_version: - sql_query = text( - "INSERT INTO `flowversion` (`name`, `flow_id`, `data`, `user_id`, `is_current`, `is_delete`) \ - select 'v0', `id` as flow_id, `data`, `user_id`, 1, 0 from `flow`;") - session.execute(sql_query) - session.commit() - # 修改表单数据表 - sql_query = text( - 'UPDATE `t_variable_value` a SET a.version_id=(SELECT `id` from `flowversion` WHERE flow_id=a.flow_id and is_current=1)' - # noqa - ) - session.execute(sql_query) - session.commit() - # 初始化数据库config - init_config() - except Exception as exc: - # if the exception involves tables already existing - # we can ignore it - if 'already exists' not in str(exc): - logger.error(f'Error creating DB and tables: {exc}') - raise RuntimeError('Error creating DB and tables') from exc - finally: - redis_client.delete('init_default_data') - - @contextmanager def session_getter() -> Session: """轻量级session context""" @@ -167,19 +23,6 @@ def session_getter() -> Session: session.close() -def read_from_conf(file_path: str) -> str: - if '/' not in file_path: - # Get current path - current_path = os.path.dirname(os.path.abspath(__file__)) - - file_path = os.path.join(current_path, file_path) - - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - - return content - - def generate_uuid() -> str: """ 生成uuid的字符串 diff --git a/src/backend/bisheng/database/init_config.py b/src/backend/bisheng/database/init_config.py index 08ff88cbc..3d1bac1db 100644 --- a/src/backend/bisheng/database/init_config.py +++ b/src/backend/bisheng/database/init_config.py @@ -2,6 +2,7 @@ import yaml from bisheng.database.models.config import Config +from bisheng.database.base import session_getter from bisheng.settings import parse_key, read_from_conf from bisheng.utils.logger import logger from sqlmodel import select @@ -9,7 +10,6 @@ def init_config(): # 初始化config - from bisheng.database.base import session_getter # 首先通过yaml 获取配置文件所有的key config_content = read_from_conf('initdb_config.yaml') diff --git a/src/backend/bisheng/database/init_data.py b/src/backend/bisheng/database/init_data.py new file mode 100644 index 000000000..687a281dc --- /dev/null +++ b/src/backend/bisheng/database/init_data.py @@ -0,0 +1,162 @@ +import hashlib +import json +import os +from typing import List + +from loguru import logger +from sqlmodel import select, update, text + +from bisheng.database.init_config import init_config +from bisheng.database.base import session_getter, db_service +from bisheng.settings import settings +from bisheng.cache.redis import redis_client +from bisheng.database.models.component import Component +from bisheng.database.models.role import Role, AdminRole, DefaultRole +from bisheng.database.models.user import User +from bisheng.database.models.gpts_tools import GptsTools +from bisheng.database.models.gpts_tools import GptsToolsType +from bisheng.database.models.sft_model import SftModel +from bisheng.database.models.flow_version import FlowVersion +from bisheng.database.models.user_role import UserRoleDao +from bisheng.database.models.group import Group, DefaultGroup +from bisheng.database.models.role_access import RoleAccess, AccessType + + +def init_default_data(): + """初始化数据库""" + + if redis_client.setNx('init_default_data', '1'): + try: + db_service.create_db_and_tables() + with session_getter() as session: + db_role = session.exec(select(Role).limit(1)).all() + if not db_role: + # 初始化系统配置, 管理员拥有所有权限 + db_role = Role(id=AdminRole, role_name='系统管理员', remark='系统所有权限管理员', + group_id=DefaultGroup) + session.add(db_role) + db_role_normal = Role(id=DefaultRole, role_name='普通用户', remark='默认用户', + group_id=DefaultGroup) + session.add(db_role_normal) + # 给普通用户赋予 构建、知识、模型菜单栏的查看权限 + session.add_all([ + RoleAccess(role_id=DefaultRole, type=AccessType.WEB_MENU.value, third_id='build'), + RoleAccess(role_id=DefaultRole, type=AccessType.WEB_MENU.value, third_id='knowledge'), + RoleAccess(role_id=DefaultRole, type=AccessType.WEB_MENU.value, third_id='model'), + ]) + session.commit() + # 添加默认用户组 + group = session.exec(select(Group).limit(1)).all() + if not group: + group = Group(id=DefaultGroup, group_name='默认用户组', create_user=1, update_user=1) + session.add(group) + session.commit() + session.refresh(group) + + user = session.exec(select(User).limit(1)).all() + if not user and settings.admin: + md5 = hashlib.md5() + md5.update(settings.admin.get('password').encode('utf-8')) + user = User( + user_id=1, + user_name=settings.admin.get('user_name'), + password=md5.hexdigest(), + ) + session.add(user) + session.commit() + session.refresh(user) + UserRoleDao.set_admin_user(user.user_id) + + component_db = session.exec(select(Component).limit(1)).all() + if not component_db: + db_components = [] + json_items = json.loads(read_from_conf('component.json')) + for item in json_items: + for k, v in item.items(): + db_component = Component(name=k, user_id=1, user_name='admin', data=v) + db_components.append(db_component) + session.add_all(db_components) + session.commit() + + # 初始化预置工具列表 + preset_tools = session.exec(select(GptsTools).limit(1)).all() + if not preset_tools: + preset_tools = [] + json_items = json.loads(read_from_conf('t_gpts_tools.json')) + for item in json_items: + item['api_params'] = json.loads(item['api_params']) + preset_tool = GptsTools(**item) + preset_tools.append(preset_tool) + session.add_all(preset_tools) + session.commit() + # 初始化预置工具类别 + preset_tools_type = session.exec(select(GptsToolsType).limit(1)).all() + if not preset_tools_type: + preset_tools_type = [] + json_items = json.loads(read_from_conf('t_gpts_tools_type.json')) + for item in json_items: + preset_tool_type = GptsToolsType(**item) + preset_tools_type.append(preset_tool_type) + session.add_all(preset_tools_type) + session.commit() + # 设置预置工具所属的类别, 需要和预置数据一致,所以id是固定的 + for i in range(1, 7): + session.exec(update(GptsTools).where(GptsTools.id == i).values(type=i)) + # 属于天眼查类别下的工具 + tyc_types: List[int] = list(range(7, 18)) + session.exec( + update(GptsTools).where(GptsTools.id.in_(tyc_types)).values(type=7)) + # 属于金融类别下的工具 + jr_types: List[int] = list(range(18, 28)) + session.exec( + update(GptsTools).where(GptsTools.id.in_(jr_types)).values(type=8)) + session.commit() + # 初始化配置可用于微调的基准模型 + preset_models = session.exec(select(SftModel).limit(1)).all() + if not preset_models: + preset_models = [] + json_items = json.loads(read_from_conf('sft_model.json')) + for item in json_items: + preset_model = SftModel(**item) + preset_models.append(preset_model) + session.add_all(preset_models) + session.commit() + + # 初始化补充默认的技能版本表 + flow_version = session.exec(select(FlowVersion).limit(1)).all() + if not flow_version: + sql_query = text( + "INSERT INTO `flowversion` (`name`, `flow_id`, `data`, `user_id`, `is_current`, `is_delete`) \ + select 'v0', `id` as flow_id, `data`, `user_id`, 1, 0 from `flow`;") + session.execute(sql_query) + session.commit() + # 修改表单数据表 + sql_query = text( + 'UPDATE `t_variable_value` a SET a.version_id=(SELECT `id` from `flowversion` ' + 'WHERE flow_id=a.flow_id and is_current=1)' + ) + session.execute(sql_query) + session.commit() + # 初始化数据库config + init_config() + except Exception as exc: + # if the exception involves tables already existing + # we can ignore it + if 'already exists' not in str(exc): + logger.error(f'Error creating DB and tables: {exc}') + raise RuntimeError('Error creating DB and tables') from exc + finally: + redis_client.delete('init_default_data') + + +def read_from_conf(file_path: str) -> str: + if '/' not in file_path: + # Get current path + current_path = os.path.dirname(os.path.abspath(__file__)) + + file_path = os.path.join(current_path, file_path) + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + return content diff --git a/src/backend/bisheng/database/models/config.py b/src/backend/bisheng/database/models/config.py index dcb4c05e2..4d7b3344e 100644 --- a/src/backend/bisheng/database/models/config.py +++ b/src/backend/bisheng/database/models/config.py @@ -1,9 +1,11 @@ from datetime import datetime from typing import Optional +from sqlalchemy import Column, DateTime, text, Text +from sqlmodel import Field, select + from bisheng.database.models.base import SQLModelSerializable -from sqlalchemy import Column, DateTime, String, text, Text -from sqlmodel import Field +from bisheng.database.base import session_getter class ConfigBase(SQLModelSerializable): @@ -35,3 +37,20 @@ class ConfigUpdate(SQLModelSerializable): key: str value: Optional[str] comment: Optional[str] + + +class ConfigDao(ConfigBase): + + @classmethod + def get_config(cls, key: str) -> Optional[Config]: + with session_getter() as session: + statement = select(Config).where(Config.key == key) + return session.exec(statement).first() + + @classmethod + def insert_config(cls, config: Config) -> Config: + with session_getter() as session: + session.add(config) + session.commit() + session.refresh(config) + return config diff --git a/src/backend/bisheng/main.py b/src/backend/bisheng/main.py index eb6c48d92..e20595c4c 100644 --- a/src/backend/bisheng/main.py +++ b/src/backend/bisheng/main.py @@ -3,7 +3,7 @@ from typing import Optional from bisheng.api import router, router_rpc -from bisheng.database.base import init_default_data +from bisheng.database.init_data import init_default_data from bisheng.interface.utils import setup_llm_caching from bisheng.restructure.register import register_restructure from bisheng.services.utils import initialize_services, teardown_services diff --git a/src/backend/bisheng/settings.py b/src/backend/bisheng/settings.py index 0e26294a2..00b2a3e0d 100644 --- a/src/backend/bisheng/settings.py +++ b/src/backend/bisheng/settings.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Union import yaml -from bisheng.database.models.config import Config from cryptography.fernet import Fernet from langchain.pydantic_v1 import BaseSettings, root_validator, validator from loguru import logger @@ -167,6 +166,8 @@ def get_from_db(self, key: str): def get_all_config(self): from bisheng.database.base import session_getter from bisheng.cache.redis import redis_client + from bisheng.database.models.config import Config + redis_key = 'config:initdb_config' cache = redis_client.get(redis_key) if cache: