diff --git a/src/backend/bisheng/api/router.py b/src/backend/bisheng/api/router.py index c10cdb626..7b2800cdf 100644 --- a/src/backend/bisheng/api/router.py +++ b/src/backend/bisheng/api/router.py @@ -3,7 +3,7 @@ finetune_router, flows_router, group_router, knowledge_router, qa_router, report_router, server_router, skillcenter_router, user_router, validate_router, variable_router, audit_router, evaluation_router) -from bisheng.api.v2 import chat_router_rpc, knowledge_router_rpc, rpc_router_rpc, flow_router +from bisheng.api.v2 import chat_router_rpc, knowledge_router_rpc, rpc_router_rpc, flow_router, assistant_router_rpc from fastapi import APIRouter router = APIRouter(prefix='/api/v1', ) @@ -30,3 +30,4 @@ router_rpc.include_router(chat_router_rpc) router_rpc.include_router(rpc_router_rpc) router_rpc.include_router(flow_router) +router_rpc.include_router(assistant_router_rpc) diff --git a/src/backend/bisheng/api/services/assistant_agent.py b/src/backend/bisheng/api/services/assistant_agent.py index 5646d22ed..6d0bc6ea1 100644 --- a/src/backend/bisheng/api/services/assistant_agent.py +++ b/src/backend/bisheng/api/services/assistant_agent.py @@ -3,7 +3,7 @@ import time import uuid from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Any from uuid import UUID import httpx @@ -338,51 +338,79 @@ def choose_tools(self, tool_list: List[Dict[str, str]], prompt: str) -> List[str tool_selector = ToolSelector(llm=self.llm, tools=tool_list) return tool_selector.select(self.assistant.name, prompt) + async def fake_callback(self, callback: Callbacks): + if not callback: + return + # 假回调,将已下线的技能回调给前端 + for one in self.offline_flows: + run_id = uuid.uuid4() + await callback[0].on_tool_start({ + 'name': one, + }, input_str='flow if offline', run_id=run_id) + await callback[0].on_tool_end(output='flow is offline', name=one, run_id=run_id) + + async def record_chat_history(self, message: List[Any]): + # 记录助手的聊天历史 + if not os.getenv("BISHENG_RECORD_HISTORY"): + return + try: + os.makedirs("/app/data/history", exist_ok=True) + with open(f"/app/data/history/{self.assistant.id}_{time.time()}.json", "w", encoding="utf-8") as f: + json.dump({ + "system": self.assistant.prompt, + "message": message, + "tools": [format_tool_to_openai_tool(t) for t in self.tools] + }, f, ensure_ascii=False) + except Exception as e: + logger.error(f"record assistant history error: {str(e)}") + + async def arun(self, query: str, chat_history: List = None, callback: Callbacks = None): + await self.fake_callback(callback) + + if chat_history: + chat_history.append(HumanMessage(content=query)) + inputs = chat_history + else: + inputs = [HumanMessage(content=query)] + + async for one in self.agent.astream(inputs, config=RunnableConfig(callbacks=callback)): + yield one + async def run(self, query: str, chat_history: List = None, callback: Callbacks = None): """ 运行智能体对话 """ - # 假回调,将已下线的技能回调给前端 - for one in self.offline_flows: - if callback is not None: - run_id = uuid.uuid4() - await callback[0].on_tool_start({ - 'name': one, - }, input_str='flow if offline', run_id=run_id) - await callback[0].on_tool_end(output='flow is offline', name=one, run_id=run_id) - if self.current_agent_executor == 'ReAct': - return await self.react_run(query, chat_history, callback) + await self.fake_callback(callback) if chat_history: chat_history.append(HumanMessage(content=query)) inputs = chat_history else: inputs = [HumanMessage(content=query)] - result = await self.agent.ainvoke(inputs, config=RunnableConfig(callbacks=callback)) + + if self.current_agent_executor == 'ReAct': + result = await self.react_run(inputs, callback) + else: + result = await self.agent.ainvoke(inputs, config=RunnableConfig(callbacks=callback)) # 包含了history,将history排除, 默认取最后一个为最终结果 res = [result[-1]] - # 记录助手的聊天历史 - if os.getenv("BISHENG_RECORD_HISTORY"): - try: - os.makedirs("/app/data/history", exist_ok=True) - with open(f"/app/data/history/{self.assistant.id}_{time.time()}.json", "w", encoding="utf-8") as f: - json.dump({ - "system": self.assistant.prompt, - "message": [one.to_json() for one in result], - "tools": [format_tool_to_openai_tool(t) for t in self.tools] - }, f, ensure_ascii=False) - except Exception as e: - logger.error(f"record assistant history error: {str(e)}") + + # 记录聊天历史 + await self.record_chat_history([one.to_json() for one in result]) + return res - async def react_run(self, query: str, chat_history: List = None, callback: Callbacks = None): + async def react_run(self, inputs: List, callback: Callbacks = None): """ react 模式的输入和执行 """ result = await self.agent.ainvoke({ - 'input': query, - 'chat_history': chat_history + 'input': inputs[-1].content, + 'chat_history': inputs[:-1], }, config=RunnableConfig(callbacks=callback)) logger.debug(f"react_run result: {result}") output = result['agent_outcome'].return_values['output'] if isinstance(output, dict): output = list(output.values())[0] - return [AIMessage(content=output)] + for one in result['intermediate_steps']: + inputs.append(one[0]) + inputs.append(AIMessage(content=output)) + return inputs diff --git a/src/backend/bisheng/api/services/user_service.py b/src/backend/bisheng/api/services/user_service.py index 8e82891c4..ae9619c4f 100644 --- a/src/backend/bisheng/api/services/user_service.py +++ b/src/backend/bisheng/api/services/user_service.py @@ -16,6 +16,7 @@ from bisheng.database.models.assistant import Assistant, AssistantDao from bisheng.database.models.flow import Flow, FlowDao, FlowRead from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead +from bisheng.database.models.role import AdminRole from bisheng.database.models.role_access import AccessType, RoleAccessDao from bisheng.database.models.user import User, UserDao from bisheng.database.models.user_group import UserGroupDao @@ -35,7 +36,13 @@ def __init__(self, **kwargs): self.user_name = kwargs.get('user_name') def is_admin(self): - return self.user_role == 'admin' + if self.user_role == 'admin': + return True + if isinstance(self.user_role, list): + for one in self.user_role: + if one == AdminRole: + return True + return False @staticmethod def wrapper_access_check(func): diff --git a/src/backend/bisheng/api/utils.py b/src/backend/bisheng/api/utils.py index 339ad29eb..87ecdabc2 100644 --- a/src/backend/bisheng/api/utils.py +++ b/src/backend/bisheng/api/utils.py @@ -5,7 +5,7 @@ from typing import Dict, List import aiohttp -from fastapi import Request +from fastapi import Request, WebSocket from fastapi_jwt_auth import AuthJWT from platformdirs import user_cache_dir from sqlalchemy import delete @@ -420,7 +420,7 @@ async def get_url_content(url: str) -> str: return res.decode('utf-8') -def get_request_ip(request: Request) -> str: +def get_request_ip(request: Request | WebSocket) -> str: """ 获取客户端真实IP """ x_forwarded_for = request.headers.get('X-Forwarded-For') if x_forwarded_for: diff --git a/src/backend/bisheng/api/v1/schemas.py b/src/backend/bisheng/api/v1/schemas.py index 3f12eb4dc..758ccf3be 100644 --- a/src/backend/bisheng/api/v1/schemas.py +++ b/src/backend/bisheng/api/v1/schemas.py @@ -212,10 +212,12 @@ class UploadFileResponse(BaseModel): class StreamData(BaseModel): event: str - data: dict + data: dict | str def __str__(self) -> str: - return f'event: {self.event}\ndata: {orjson.dumps(self.data).decode()}\n\n' + if isinstance(self.data, dict): + return f'event: {self.event}\ndata: {orjson.dumps(self.data).decode()}\n\n' + return f'event: {self.event}\ndata: {self.data}\n\n' class FinetuneCreateReq(BaseModel): @@ -320,3 +322,29 @@ class CreateUserReq(BaseModel): user_name: str = Field(max_length=30, description='用户名') password: str = Field(description='密码') group_roles: List[GroupAndRoles] = Field(description='要加入的用户组和角色列表') + + +class OpenAIChatCompletionReq(BaseModel): + messages: List[dict] = Field(..., description="聊天消息列表,只支持user、assistant。system用数据库内的数据") + model: str = Field(..., description="助手的唯一ID") + n: int = Field(default=1, description="返回的答案个数, 助手侧默认为1,暂不支持多个回答") + stream: bool = Field(default=False, description="是否开启流式回复") + temperature: float = Field(default=0.0, description="模型温度, 传入0或者不传表示不覆盖") + tools: List[dict] = Field(default=[], description="工具列表, 助手暂不支持,使用助手的配置") + + +class OpenAIChoice(BaseModel): + index: int = Field(..., description="选项的索引") + message: dict = Field(default=None, description="对应的消息内容,和输入的格式一致") + finish_reason: str = Field(default='stop', description="结束原因, 助手只有stop") + delta: dict = Field(default=None, description="对应的openai流式返回消息内容") + + +class OpenAIChatCompletionResp(BaseModel): + id: str = Field(..., description="请求的唯一ID") + object: str = Field(default='chat.completion', description="返回的类型") + created: int = Field(default=..., description="返回的创建时间戳") + model: str = Field(..., description="返回的模型,对应助手的id") + choices: List[OpenAIChoice] = Field(..., description="返回的答案列表") + usage: dict = Field(default=None, description="返回的token用量, 助手此值为空") + system_fingerprint: Optional[str] = Field(default=None, description="系统指纹") diff --git a/src/backend/bisheng/api/v1/user.py b/src/backend/bisheng/api/v1/user.py index dd8557734..04f425d76 100644 --- a/src/backend/bisheng/api/v1/user.py +++ b/src/backend/bisheng/api/v1/user.py @@ -60,7 +60,9 @@ async def regist(*, user: UserCreate): # check if user already exist user_exists = UserDao.get_user_by_username(db_user.user_name) if user_exists: - raise HTTPException(status_code=500, detail='账号已存在') + raise HTTPException(status_code=500, detail='用户名已存在') + if len(db_user.user_name)>30: + raise HTTPException(status_code=500, detail='用户名最长 30 个字符') try: db_user.password = UserService.decrypt_md5_password(user.password) # 判断下admin用户是否存在 diff --git a/src/backend/bisheng/api/v2/__init__.py b/src/backend/bisheng/api/v2/__init__.py index b5befc7f2..52c1d9cc0 100644 --- a/src/backend/bisheng/api/v2/__init__.py +++ b/src/backend/bisheng/api/v2/__init__.py @@ -2,4 +2,5 @@ from bisheng.api.v2.filelib import router as knowledge_router_rpc from bisheng.api.v2.rpc import router as rpc_router_rpc from bisheng.api.v2.flow import router as flow_router -__all__ = ['knowledge_router_rpc', 'chat_router_rpc', 'rpc_router_rpc', 'flow_router'] +from bisheng.api.v2.assistant import router as assistant_router_rpc +__all__ = ['knowledge_router_rpc', 'chat_router_rpc', 'rpc_router_rpc', 'flow_router', 'assistant_router_rpc'] diff --git a/src/backend/bisheng/api/v2/assistant.py b/src/backend/bisheng/api/v2/assistant.py index eb237948c..da0bcbc87 100644 --- a/src/backend/bisheng/api/v2/assistant.py +++ b/src/backend/bisheng/api/v2/assistant.py @@ -1,5 +1,173 @@ # 免登录的助手相关接口 +import time +import uuid +from typing import Optional +from uuid import UUID +from fastapi import APIRouter, Request, HTTPException, WebSocket, WebSocketException +from fastapi import status as http_status +from fastapi.responses import StreamingResponse, ORJSONResponse +from langchain_core.messages import HumanMessage, AIMessage +from loguru import logger -router = APIRouter(prefix='/chat', tags=['AssistantOpenApi']) +from bisheng.api.services.assistant import AssistantService +from bisheng.api.services.assistant_agent import AssistantAgent +from bisheng.api.services.user_service import UserPayload +from bisheng.api.utils import get_request_ip +from bisheng.api.v1.chat import chat_manager +from bisheng.api.v1.schemas import OpenAIChatCompletionResp, OpenAIChatCompletionReq, UnifiedResponseModel, \ + AssistantInfo, OpenAIChoice +from bisheng.chat.types import WorkType +from bisheng.database.models.user import UserDao +from bisheng.settings import settings +router = APIRouter(prefix='/assistant', tags=['AssistantOpenApi']) + + +def get_default_operator(): + user_id = settings.get_from_db('default_operator').get('user') + if not user_id: + raise HTTPException(status_code=500, detail='未配置default_operator中user配置') + # 查找默认用户信息 + login_user = UserDao.get_user(user_id) + if not login_user: + raise HTTPException(status_code=500, detail='未找到默认用户信息') + return login_user + + +@router.post('/chat/completions', response_model=OpenAIChatCompletionResp) +async def assistant_chat_completions(request: Request, + req_data: OpenAIChatCompletionReq): + """ + 兼容openai接口格式,所有的错误必须返回非http200的状态码 + 和助手进行聊天 + """ + logger.info(f'act=assistant_chat_completions assistant_id={req_data.model}, ip={get_request_ip(request)}') + try: + # 获取系统配置里配置的默认用户信息 + default_user = get_default_operator() + except Exception as e: + return ORJSONResponse(status_code=500, content=str(e), media_type='application/json') + login_user = UserPayload(**{ + 'user_id': default_user.user_id, + 'user_name': default_user.user_name, + 'role': '' + }) + # 查找助手信息 + res = AssistantService.get_assistant_info(UUID(req_data.model), login_user) + if res.status_code != 200: + return ORJSONResponse(status_code=500, content=res.status_message, media_type='application/json') + + assistant_info = res.data + # 覆盖温度设置 + if req_data.temperature != 0: + assistant_info.temperature = req_data.temperature + + chat_history = [] + question = '' + # 解析出对话历史和用户最新的问题 + for one in req_data.messages: + if one['role'] == 'user': + chat_history.append(HumanMessage(content=one['content'])) + question = one['content'] + elif one['role'] == 'assistant': + chat_history.append(AIMessage(content=one['content'])) + # 在历史记录里去除用户的问题 + if chat_history and chat_history[-1].content == question: + chat_history = chat_history[:-1] + + # 初始化助手agent + agent = AssistantAgent(assistant_info, '') # 初始化agent + await agent.init_assistant() + answer = await agent.run(question, chat_history) + answer = answer[0].content + + openai_resp_id = uuid.uuid4().hex + logger.info(f'act=assistant_chat_completions_over openai_resp_id={openai_resp_id}') + # 将结果包装成openai的数据格式 + openai_resp = OpenAIChatCompletionResp( + id=openai_resp_id, + object='chat.completion', + created=int(time.time()), + model=req_data.model, + choices=[OpenAIChoice( + index=0, + message={ + 'role': 'assistant', + 'content': answer + } + )], + ) + + # 非流式直接返回结果 + if not req_data.stream: + return openai_resp + + # 流式返回最终结果, 兼容openai格式处理 + openai_resp.object = 'chat.completion.chunk' + openai_resp.choices = [ + OpenAIChoice( + index=0, + delta={ + 'content': answer + } + ) + ] + + async def _event_stream(): + # todo:zgq 后续优化成真正的流式输出,目前是出现最终答案之后直接流式返回的 + yield f'data: {openai_resp.json()}\n\n' + # 最后的[DONE] + yield 'data: [DONE]\n\n' + + try: + return StreamingResponse(_event_stream(), media_type='text/event-stream') + except Exception as exc: + logger.error(exc) + return ORJSONResponse(status_code=500, content=str(exc)) + + +@router.get('/info/{assistant_id}', response_model=UnifiedResponseModel[AssistantInfo]) +async def get_assistant_info(request: Request, assistant_id: UUID): + """ + 获取助手信息, 用系统配置里的default_operator.user的用户信息来做权限校验 + """ + logger.info(f'act=get_default_operator assistant_id={assistant_id}, ip={get_request_ip(request)}') + default_user = get_default_operator() + login_user = UserPayload(**{ + 'user_id': default_user.user_id, + 'user_name': default_user.user_name, + 'role': '' + }) + return AssistantService.get_assistant_info(assistant_id, login_user) + + +@router.websocket('/chat/{assistant_id}') +async def chat(*, + websocket: WebSocket, + assistant_id: str, + chat_id: Optional[str] = None): + """ + 助手的ws免登录接口 + """ + logger.info(f'act=assistant_chat_ws assistant_id={assistant_id}, ip={get_request_ip(websocket)}') + default_user = get_default_operator() + login_user = UserPayload(**{ + 'user_id': default_user.user_id, + 'user_name': default_user.user_name, + 'role': '' + }) + try: + request = websocket + await chat_manager.dispatch_client(request, assistant_id, chat_id, login_user, WorkType.GPTS, + websocket) + except WebSocketException as exc: + logger.error(f'Websocket exception: {str(exc)}') + await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=str(exc)) + except Exception as exc: + logger.exception(f'Error in chat websocket: {str(exc)}') + message = exc.detail if isinstance(exc, HTTPException) else str(exc) + if 'Could not validate credentials' in str(exc): + await websocket.close(code=http_status.WS_1008_POLICY_VIOLATION, reason='Unauthorized') + else: + await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=message) diff --git a/src/frontend/src/controllers/API/pro.ts b/src/frontend/src/controllers/API/pro.ts index dc2d32843..3d196f505 100644 --- a/src/frontend/src/controllers/API/pro.ts +++ b/src/frontend/src/controllers/API/pro.ts @@ -78,5 +78,16 @@ export function getUserGroupsProApi() { // GET sso URL export function getSSOurlApi() { // return Promise.resolve(url) - return axios.get(`/api/oauth2/list`).then(res => res.wx); + return axios.get(`/api/oauth2/list`) +} + +export async function getKeyApi() { + return await axios.get('/api/getkey') +} + +export async function ldapLoginApi(username:string, password:string) { + return await axios.post('/api/oauth2/ldap', { + username, + password + }) } \ No newline at end of file diff --git a/src/frontend/src/pages/LoginPage/login.tsx b/src/frontend/src/pages/LoginPage/login.tsx index 8313491fa..833113427 100644 --- a/src/frontend/src/pages/LoginPage/login.tsx +++ b/src/frontend/src/pages/LoginPage/login.tsx @@ -11,8 +11,10 @@ import { useNavigate } from 'react-router-dom'; import { getCaptchaApi, loginApi, registerApi } from "../../controllers/API/user"; import { captureAndAlertRequestErrorHoc } from "../../controllers/request"; import LoginBridge from './loginBridge'; -import { PWD_RULE, handleEncrypt } from './utils'; +import { PWD_RULE, handleEncrypt, handleLdapEncrypt } from './utils'; import { locationContext } from '@/contexts/locationContext'; +import { ldapLoginApi } from '@/controllers/API/pro'; + export const LoginPage = () => { // const { setErrorData, setSuccessData } = useContext(alertContext); const { t, i18n } = useTranslation(); @@ -41,6 +43,7 @@ export const LoginPage = () => { getCaptchaApi().then(setCaptchaData) }; + const ldapRef = useRef(false) const handleLogin = async () => { const error = [] const [mail, pwd] = [mailRef.current.value, pwdRef.current.value] @@ -58,6 +61,12 @@ export const LoginPage = () => { // }); const encryptPwd = await handleEncrypt(pwd) + if(ldapRef.current) { + const encryptLdapPwd = await handleLdapEncrypt(pwd) + captureAndAlertRequestErrorHoc(ldapLoginApi(mail, encryptLdapPwd).then(res => console.log(res))) + fetchCaptchaData() + return + } captureAndAlertRequestErrorHoc(loginApi(mail, encryptPwd, captchaData.captcha_key, captchaRef.current?.value).then((res: any) => { // setUser(res.data) localStorage.setItem('ws_token', res.access_token) @@ -196,7 +205,7 @@ export const LoginPage = () => { disabled={isLoading} onClick={handleRegister} >{t('login.registerButton')} } - {appConfig.hasSSO && } + {appConfig.hasSSO && ldapRef.current = bool} />}
v{json.version} diff --git a/src/frontend/src/pages/LoginPage/loginBridge.tsx b/src/frontend/src/pages/LoginPage/loginBridge.tsx index a1e3f93c9..d8b9a59ae 100644 --- a/src/frontend/src/pages/LoginPage/loginBridge.tsx +++ b/src/frontend/src/pages/LoginPage/loginBridge.tsx @@ -6,13 +6,16 @@ import { useEffect, useRef } from "react"; import { ReactComponent as Wxpro } from "./icons/wxpro.svg"; import { useTranslation } from "react-i18next"; -export default function LoginBridge() { +export default function LoginBridge({onHasLdap}) { const { t } = useTranslation() const urlRef = useRef('') useEffect(() => { - getSSOurlApi().then(url => urlRef.current = url) + getSSOurlApi().then((urls:any) => { + urlRef.current = urls.wx + urls.ldap && onHasLdap(true) + }) }, []) const clickQwLogin = () => { diff --git a/src/frontend/src/pages/LoginPage/utils.ts b/src/frontend/src/pages/LoginPage/utils.ts index 25fe51044..6b5ed6503 100644 --- a/src/frontend/src/pages/LoginPage/utils.ts +++ b/src/frontend/src/pages/LoginPage/utils.ts @@ -1,4 +1,5 @@ import { getPublicKeyApi } from "@/controllers/API/user"; +import { getKeyApi } from "@/controllers/API/pro"; import { JSEncrypt } from 'jsencrypt'; export const handleEncrypt = async (pwd: string): Promise => { @@ -8,4 +9,11 @@ export const handleEncrypt = async (pwd: string): Promise => { return encrypt.encrypt(pwd) as string; }; +export const handleLdapEncrypt = async (pwd: string): Promise => { + const public_key:any = await getKeyApi(); + const encrypt = new JSEncrypt(); + encrypt.setPublicKey(public_key); + return encrypt.encrypt(pwd) as string; +}; + export const PWD_RULE = /^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[\W_]).{8,}$/ \ No newline at end of file diff --git a/src/frontend/src/pages/SkillPage/l2Edit.tsx b/src/frontend/src/pages/SkillPage/l2Edit.tsx index 301b02740..0f23367ba 100644 --- a/src/frontend/src/pages/SkillPage/l2Edit.tsx +++ b/src/frontend/src/pages/SkillPage/l2Edit.tsx @@ -1,3 +1,6 @@ +import FlowSetting from "@/components/Pro/security/FlowSetting"; +import { useToast } from "@/components/bs-ui/toast/use-toast"; +import { locationContext } from "@/contexts/locationContext"; import { ArrowLeft, ChevronUp } from "lucide-react"; import { useContext, useEffect, useMemo, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; @@ -5,21 +8,14 @@ import { useNavigate, useParams } from "react-router-dom"; import L2ParameterComponent from "../../CustomNodes/GenericNode/components/parameterComponent/l2Index"; import ShadTooltip from "../../components/ShadTooltipComponent"; import { Button } from "../../components/bs-ui/button"; -import { Input } from "../../components/bs-ui/input"; +import { Input, Textarea } from "../../components/bs-ui/input"; import { Label } from "../../components/bs-ui/label"; -import { Textarea } from "../../components/bs-ui/input"; -import { alertContext } from "../../contexts/alertContext"; import { TabsContext } from "../../contexts/tabsContext"; import { userContext } from "../../contexts/userContext"; import { createCustomFlowApi, getFlowApi } from "../../controllers/API/flow"; +import { captureAndAlertRequestErrorHoc } from "../../controllers/request"; import { useHasForm } from "../../util/hook"; import FormSet from "./components/FormSet"; -import { captureAndAlertRequestErrorHoc } from "../../controllers/request"; -import { useToast } from "@/components/bs-ui/toast/use-toast"; -import { SettingIcon } from "@/components/bs-icons/setting"; -import { Switch } from "@/components/bs-ui/switch"; -import FlowSetting from "@/components/Pro/security/FlowSetting"; -import { locationContext } from "@/contexts/locationContext"; export default function l2Edit() { const { t } = useTranslation()