ldapRef.current = bool} />}
v{json.version}
diff --git a/src/frontend/src/pages/LoginPage/loginBridge.tsx b/src/frontend/src/pages/LoginPage/loginBridge.tsx
index a1e3f93c9..d8b9a59ae 100644
--- a/src/frontend/src/pages/LoginPage/loginBridge.tsx
+++ b/src/frontend/src/pages/LoginPage/loginBridge.tsx
@@ -6,13 +6,16 @@ import { useEffect, useRef } from "react";
import { ReactComponent as Wxpro } from "./icons/wxpro.svg";
import { useTranslation } from "react-i18next";
-export default function LoginBridge() {
+export default function LoginBridge({onHasLdap}) {
const { t } = useTranslation()
const urlRef = useRef('')
useEffect(() => {
- getSSOurlApi().then(url => urlRef.current = url)
+ getSSOurlApi().then((urls:any) => {
+ urlRef.current = urls.wx
+ urls.ldap && onHasLdap(true)
+ })
}, [])
const clickQwLogin = () => {
diff --git a/src/frontend/src/pages/LoginPage/utils.ts b/src/frontend/src/pages/LoginPage/utils.ts
index 25fe51044..6b5ed6503 100644
--- a/src/frontend/src/pages/LoginPage/utils.ts
+++ b/src/frontend/src/pages/LoginPage/utils.ts
@@ -1,4 +1,5 @@
import { getPublicKeyApi } from "@/controllers/API/user";
+import { getKeyApi } from "@/controllers/API/pro";
import { JSEncrypt } from 'jsencrypt';
export const handleEncrypt = async (pwd: string): Promise => {
@@ -8,4 +9,11 @@ export const handleEncrypt = async (pwd: string): Promise => {
return encrypt.encrypt(pwd) as string;
};
+export const handleLdapEncrypt = async (pwd: string): Promise => {
+ const public_key:any = await getKeyApi();
+ const encrypt = new JSEncrypt();
+ encrypt.setPublicKey(public_key);
+ return encrypt.encrypt(pwd) as string;
+};
+
export const PWD_RULE = /^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[\W_]).{8,}$/
\ No newline at end of file
diff --git a/src/frontend/src/pages/SkillPage/l2Edit.tsx b/src/frontend/src/pages/SkillPage/l2Edit.tsx
index 301b02740..0f23367ba 100644
--- a/src/frontend/src/pages/SkillPage/l2Edit.tsx
+++ b/src/frontend/src/pages/SkillPage/l2Edit.tsx
@@ -1,3 +1,6 @@
+import FlowSetting from "@/components/Pro/security/FlowSetting";
+import { useToast } from "@/components/bs-ui/toast/use-toast";
+import { locationContext } from "@/contexts/locationContext";
import { ArrowLeft, ChevronUp } from "lucide-react";
import { useContext, useEffect, useMemo, useRef, useState } from "react";
import { useTranslation } from "react-i18next";
@@ -5,21 +8,14 @@ import { useNavigate, useParams } from "react-router-dom";
import L2ParameterComponent from "../../CustomNodes/GenericNode/components/parameterComponent/l2Index";
import ShadTooltip from "../../components/ShadTooltipComponent";
import { Button } from "../../components/bs-ui/button";
-import { Input } from "../../components/bs-ui/input";
+import { Input, Textarea } from "../../components/bs-ui/input";
import { Label } from "../../components/bs-ui/label";
-import { Textarea } from "../../components/bs-ui/input";
-import { alertContext } from "../../contexts/alertContext";
import { TabsContext } from "../../contexts/tabsContext";
import { userContext } from "../../contexts/userContext";
import { createCustomFlowApi, getFlowApi } from "../../controllers/API/flow";
+import { captureAndAlertRequestErrorHoc } from "../../controllers/request";
import { useHasForm } from "../../util/hook";
import FormSet from "./components/FormSet";
-import { captureAndAlertRequestErrorHoc } from "../../controllers/request";
-import { useToast } from "@/components/bs-ui/toast/use-toast";
-import { SettingIcon } from "@/components/bs-icons/setting";
-import { Switch } from "@/components/bs-ui/switch";
-import FlowSetting from "@/components/Pro/security/FlowSetting";
-import { locationContext } from "@/contexts/locationContext";
export default function l2Edit() {
const { t } = useTranslation()
From 839b321a960eab0a602127badbf1a99f04e28948 Mon Sep 17 00:00:00 2001
From: GuoQing Zhang
Date: Thu, 11 Jul 2024 10:52:05 +0800
Subject: [PATCH 2/3] =?UTF-8?q?feat:=20=E5=8A=A9=E6=89=8B=E6=8F=90?=
=?UTF-8?q?=E4=BE=9B=E5=85=BC=E5=AE=B9openai=E6=A0=BC=E5=BC=8F=E7=9A=84?=
=?UTF-8?q?=E5=85=8D=E9=89=B4=E6=9D=83=E6=8E=A5=E5=8F=A3=E8=B0=83=E7=94=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/backend/bisheng/api/router.py | 3 +-
.../bisheng/api/services/assistant_agent.py | 84 ++++++----
.../bisheng/api/services/user_service.py | 9 +-
src/backend/bisheng/api/utils.py | 1 +
src/backend/bisheng/api/v1/schemas.py | 32 +++-
src/backend/bisheng/api/v1/user.py | 4 +-
src/backend/bisheng/api/v2/__init__.py | 3 +-
src/backend/bisheng/api/v2/assistant.py | 147 +++++++++++++++++-
8 files changed, 248 insertions(+), 35 deletions(-)
diff --git a/src/backend/bisheng/api/router.py b/src/backend/bisheng/api/router.py
index c10cdb626..7b2800cdf 100644
--- a/src/backend/bisheng/api/router.py
+++ b/src/backend/bisheng/api/router.py
@@ -3,7 +3,7 @@
finetune_router, flows_router, group_router, knowledge_router,
qa_router, report_router, server_router, skillcenter_router,
user_router, validate_router, variable_router, audit_router, evaluation_router)
-from bisheng.api.v2 import chat_router_rpc, knowledge_router_rpc, rpc_router_rpc, flow_router
+from bisheng.api.v2 import chat_router_rpc, knowledge_router_rpc, rpc_router_rpc, flow_router, assistant_router_rpc
from fastapi import APIRouter
router = APIRouter(prefix='/api/v1', )
@@ -30,3 +30,4 @@
router_rpc.include_router(chat_router_rpc)
router_rpc.include_router(rpc_router_rpc)
router_rpc.include_router(flow_router)
+router_rpc.include_router(assistant_router_rpc)
diff --git a/src/backend/bisheng/api/services/assistant_agent.py b/src/backend/bisheng/api/services/assistant_agent.py
index 5646d22ed..6d0bc6ea1 100644
--- a/src/backend/bisheng/api/services/assistant_agent.py
+++ b/src/backend/bisheng/api/services/assistant_agent.py
@@ -3,7 +3,7 @@
import time
import uuid
from pathlib import Path
-from typing import Dict, List
+from typing import Dict, List, Any
from uuid import UUID
import httpx
@@ -338,51 +338,79 @@ def choose_tools(self, tool_list: List[Dict[str, str]], prompt: str) -> List[str
tool_selector = ToolSelector(llm=self.llm, tools=tool_list)
return tool_selector.select(self.assistant.name, prompt)
+ async def fake_callback(self, callback: Callbacks):
+ if not callback:
+ return
+ # 假回调,将已下线的技能回调给前端
+ for one in self.offline_flows:
+ run_id = uuid.uuid4()
+ await callback[0].on_tool_start({
+ 'name': one,
+ }, input_str='flow if offline', run_id=run_id)
+ await callback[0].on_tool_end(output='flow is offline', name=one, run_id=run_id)
+
+ async def record_chat_history(self, message: List[Any]):
+ # 记录助手的聊天历史
+ if not os.getenv("BISHENG_RECORD_HISTORY"):
+ return
+ try:
+ os.makedirs("/app/data/history", exist_ok=True)
+ with open(f"/app/data/history/{self.assistant.id}_{time.time()}.json", "w", encoding="utf-8") as f:
+ json.dump({
+ "system": self.assistant.prompt,
+ "message": message,
+ "tools": [format_tool_to_openai_tool(t) for t in self.tools]
+ }, f, ensure_ascii=False)
+ except Exception as e:
+ logger.error(f"record assistant history error: {str(e)}")
+
+ async def arun(self, query: str, chat_history: List = None, callback: Callbacks = None):
+ await self.fake_callback(callback)
+
+ if chat_history:
+ chat_history.append(HumanMessage(content=query))
+ inputs = chat_history
+ else:
+ inputs = [HumanMessage(content=query)]
+
+ async for one in self.agent.astream(inputs, config=RunnableConfig(callbacks=callback)):
+ yield one
+
async def run(self, query: str, chat_history: List = None, callback: Callbacks = None):
"""
运行智能体对话
"""
- # 假回调,将已下线的技能回调给前端
- for one in self.offline_flows:
- if callback is not None:
- run_id = uuid.uuid4()
- await callback[0].on_tool_start({
- 'name': one,
- }, input_str='flow if offline', run_id=run_id)
- await callback[0].on_tool_end(output='flow is offline', name=one, run_id=run_id)
- if self.current_agent_executor == 'ReAct':
- return await self.react_run(query, chat_history, callback)
+ await self.fake_callback(callback)
if chat_history:
chat_history.append(HumanMessage(content=query))
inputs = chat_history
else:
inputs = [HumanMessage(content=query)]
- result = await self.agent.ainvoke(inputs, config=RunnableConfig(callbacks=callback))
+
+ if self.current_agent_executor == 'ReAct':
+ result = await self.react_run(inputs, callback)
+ else:
+ result = await self.agent.ainvoke(inputs, config=RunnableConfig(callbacks=callback))
# 包含了history,将history排除, 默认取最后一个为最终结果
res = [result[-1]]
- # 记录助手的聊天历史
- if os.getenv("BISHENG_RECORD_HISTORY"):
- try:
- os.makedirs("/app/data/history", exist_ok=True)
- with open(f"/app/data/history/{self.assistant.id}_{time.time()}.json", "w", encoding="utf-8") as f:
- json.dump({
- "system": self.assistant.prompt,
- "message": [one.to_json() for one in result],
- "tools": [format_tool_to_openai_tool(t) for t in self.tools]
- }, f, ensure_ascii=False)
- except Exception as e:
- logger.error(f"record assistant history error: {str(e)}")
+
+ # 记录聊天历史
+ await self.record_chat_history([one.to_json() for one in result])
+
return res
- async def react_run(self, query: str, chat_history: List = None, callback: Callbacks = None):
+ async def react_run(self, inputs: List, callback: Callbacks = None):
""" react 模式的输入和执行 """
result = await self.agent.ainvoke({
- 'input': query,
- 'chat_history': chat_history
+ 'input': inputs[-1].content,
+ 'chat_history': inputs[:-1],
}, config=RunnableConfig(callbacks=callback))
logger.debug(f"react_run result: {result}")
output = result['agent_outcome'].return_values['output']
if isinstance(output, dict):
output = list(output.values())[0]
- return [AIMessage(content=output)]
+ for one in result['intermediate_steps']:
+ inputs.append(one[0])
+ inputs.append(AIMessage(content=output))
+ return inputs
diff --git a/src/backend/bisheng/api/services/user_service.py b/src/backend/bisheng/api/services/user_service.py
index 8e82891c4..ae9619c4f 100644
--- a/src/backend/bisheng/api/services/user_service.py
+++ b/src/backend/bisheng/api/services/user_service.py
@@ -16,6 +16,7 @@
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
+from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
@@ -35,7 +36,13 @@ def __init__(self, **kwargs):
self.user_name = kwargs.get('user_name')
def is_admin(self):
- return self.user_role == 'admin'
+ if self.user_role == 'admin':
+ return True
+ if isinstance(self.user_role, list):
+ for one in self.user_role:
+ if one == AdminRole:
+ return True
+ return False
@staticmethod
def wrapper_access_check(func):
diff --git a/src/backend/bisheng/api/utils.py b/src/backend/bisheng/api/utils.py
index 339ad29eb..b813e98da 100644
--- a/src/backend/bisheng/api/utils.py
+++ b/src/backend/bisheng/api/utils.py
@@ -1,5 +1,6 @@
import hashlib
import json
+import time
import xml.dom.minidom
from pathlib import Path
from typing import Dict, List
diff --git a/src/backend/bisheng/api/v1/schemas.py b/src/backend/bisheng/api/v1/schemas.py
index 3f12eb4dc..758ccf3be 100644
--- a/src/backend/bisheng/api/v1/schemas.py
+++ b/src/backend/bisheng/api/v1/schemas.py
@@ -212,10 +212,12 @@ class UploadFileResponse(BaseModel):
class StreamData(BaseModel):
event: str
- data: dict
+ data: dict | str
def __str__(self) -> str:
- return f'event: {self.event}\ndata: {orjson.dumps(self.data).decode()}\n\n'
+ if isinstance(self.data, dict):
+ return f'event: {self.event}\ndata: {orjson.dumps(self.data).decode()}\n\n'
+ return f'event: {self.event}\ndata: {self.data}\n\n'
class FinetuneCreateReq(BaseModel):
@@ -320,3 +322,29 @@ class CreateUserReq(BaseModel):
user_name: str = Field(max_length=30, description='用户名')
password: str = Field(description='密码')
group_roles: List[GroupAndRoles] = Field(description='要加入的用户组和角色列表')
+
+
+class OpenAIChatCompletionReq(BaseModel):
+ messages: List[dict] = Field(..., description="聊天消息列表,只支持user、assistant。system用数据库内的数据")
+ model: str = Field(..., description="助手的唯一ID")
+ n: int = Field(default=1, description="返回的答案个数, 助手侧默认为1,暂不支持多个回答")
+ stream: bool = Field(default=False, description="是否开启流式回复")
+ temperature: float = Field(default=0.0, description="模型温度, 传入0或者不传表示不覆盖")
+ tools: List[dict] = Field(default=[], description="工具列表, 助手暂不支持,使用助手的配置")
+
+
+class OpenAIChoice(BaseModel):
+ index: int = Field(..., description="选项的索引")
+ message: dict = Field(default=None, description="对应的消息内容,和输入的格式一致")
+ finish_reason: str = Field(default='stop', description="结束原因, 助手只有stop")
+ delta: dict = Field(default=None, description="对应的openai流式返回消息内容")
+
+
+class OpenAIChatCompletionResp(BaseModel):
+ id: str = Field(..., description="请求的唯一ID")
+ object: str = Field(default='chat.completion', description="返回的类型")
+ created: int = Field(default=..., description="返回的创建时间戳")
+ model: str = Field(..., description="返回的模型,对应助手的id")
+ choices: List[OpenAIChoice] = Field(..., description="返回的答案列表")
+ usage: dict = Field(default=None, description="返回的token用量, 助手此值为空")
+ system_fingerprint: Optional[str] = Field(default=None, description="系统指纹")
diff --git a/src/backend/bisheng/api/v1/user.py b/src/backend/bisheng/api/v1/user.py
index dd8557734..04f425d76 100644
--- a/src/backend/bisheng/api/v1/user.py
+++ b/src/backend/bisheng/api/v1/user.py
@@ -60,7 +60,9 @@ async def regist(*, user: UserCreate):
# check if user already exist
user_exists = UserDao.get_user_by_username(db_user.user_name)
if user_exists:
- raise HTTPException(status_code=500, detail='账号已存在')
+ raise HTTPException(status_code=500, detail='用户名已存在')
+ if len(db_user.user_name)>30:
+ raise HTTPException(status_code=500, detail='用户名最长 30 个字符')
try:
db_user.password = UserService.decrypt_md5_password(user.password)
# 判断下admin用户是否存在
diff --git a/src/backend/bisheng/api/v2/__init__.py b/src/backend/bisheng/api/v2/__init__.py
index b5befc7f2..52c1d9cc0 100644
--- a/src/backend/bisheng/api/v2/__init__.py
+++ b/src/backend/bisheng/api/v2/__init__.py
@@ -2,4 +2,5 @@
from bisheng.api.v2.filelib import router as knowledge_router_rpc
from bisheng.api.v2.rpc import router as rpc_router_rpc
from bisheng.api.v2.flow import router as flow_router
-__all__ = ['knowledge_router_rpc', 'chat_router_rpc', 'rpc_router_rpc', 'flow_router']
+from bisheng.api.v2.assistant import router as assistant_router_rpc
+__all__ = ['knowledge_router_rpc', 'chat_router_rpc', 'rpc_router_rpc', 'flow_router', 'assistant_router_rpc']
diff --git a/src/backend/bisheng/api/v2/assistant.py b/src/backend/bisheng/api/v2/assistant.py
index eb237948c..b33cbccfc 100644
--- a/src/backend/bisheng/api/v2/assistant.py
+++ b/src/backend/bisheng/api/v2/assistant.py
@@ -1,5 +1,150 @@
# 免登录的助手相关接口
+import time
+import uuid
+from typing import Optional
+from uuid import UUID
+from fastapi import APIRouter, Request, HTTPException, WebSocket
+from fastapi.responses import StreamingResponse, ORJSONResponse
+from langchain_core.messages import HumanMessage, AIMessage
+from loguru import logger
-router = APIRouter(prefix='/chat', tags=['AssistantOpenApi'])
+from bisheng.api.services.assistant import AssistantService
+from bisheng.api.services.assistant_agent import AssistantAgent
+from bisheng.api.services.user_service import UserPayload
+from bisheng.api.utils import get_request_ip
+from bisheng.api.v1.schemas import OpenAIChatCompletionResp, OpenAIChatCompletionReq, UnifiedResponseModel, \
+ AssistantInfo, OpenAIChoice
+from bisheng.database.models.user import UserDao
+from bisheng.settings import settings
+router = APIRouter(prefix='/assistant', tags=['AssistantOpenApi'])
+
+
+def get_default_operator():
+ user_id = settings.get_from_db('default_operator').get('user')
+ if not user_id:
+ raise HTTPException(status_code=500, detail='未配置default_operator中user配置')
+ # 查找默认用户信息
+ login_user = UserDao.get_user(user_id)
+ if not login_user:
+ raise HTTPException(status_code=500, detail='未找到默认用户信息')
+ return login_user
+
+
+@router.post('/chat/completions', response_model=OpenAIChatCompletionResp)
+async def assistant_chat_completions(request: Request,
+ req_data: OpenAIChatCompletionReq):
+ """
+ 兼容openai接口格式,所有的错误必须返回非http200的状态码
+ 和助手进行聊天
+ """
+ logger.info(f'act=assistant_chat_completions assistant_id={req_data.model}, ip={get_request_ip(request)}')
+ try:
+ # 获取系统配置里配置的默认用户信息
+ default_user = get_default_operator()
+ except Exception as e:
+ return ORJSONResponse(status_code=500, content=str(e), media_type='application/json')
+ login_user = UserPayload(**{
+ 'user_id': default_user.user_id,
+ 'user_name': default_user.user_name,
+ 'role': ''
+ })
+ # 查找助手信息
+ res = AssistantService.get_assistant_info(UUID(req_data.model), login_user)
+ if res.status_code != 200:
+ return ORJSONResponse(status_code=500, content=res.status_message, media_type='application/json')
+
+ assistant_info = res.data
+ # 覆盖温度设置
+ if req_data.temperature != 0:
+ assistant_info.temperature = req_data.temperature
+
+ chat_history = []
+ question = ''
+ # 解析出对话历史和用户最新的问题
+ for one in req_data.messages:
+ if one['role'] == 'user':
+ chat_history.append(HumanMessage(content=one['content']))
+ question = one['content']
+ elif one['role'] == 'assistant':
+ chat_history.append(AIMessage(content=one['content']))
+ # 在历史记录里去除用户的问题
+ if chat_history and chat_history[-1].content == question:
+ chat_history = chat_history[:-1]
+
+ # 初始化助手agent
+ agent = AssistantAgent(assistant_info, '') # 初始化agent
+ await agent.init_assistant()
+ answer = await agent.run(question, chat_history)
+ answer = answer[0].content
+
+ openai_resp_id = uuid.uuid4().hex
+ logger.info(f'act=assistant_chat_completions_over openai_resp_id={openai_resp_id}')
+ # 将结果包装成openai的数据格式
+ openai_resp = OpenAIChatCompletionResp(
+ id=openai_resp_id,
+ object='chat.completion',
+ created=int(time.time()),
+ model=req_data.model,
+ choices=[OpenAIChoice(
+ index=0,
+ message={
+ 'role': 'assistant',
+ 'content': answer
+ }
+ )],
+ )
+
+ # 非流式直接返回结果
+ if not req_data.stream:
+ return openai_resp
+
+ # 流式返回最终结果, 兼容openai格式处理
+ openai_resp.object = 'chat.completion.chunk'
+ openai_resp.choices = [
+ OpenAIChoice(
+ index=0,
+ delta={
+ 'content': answer
+ }
+ )
+ ]
+
+ async def _event_stream():
+ # todo:zgq 后续优化成真正的流式输出,目前是出现最终答案之后直接流式返回的
+ yield f'data: {openai_resp.json()}\n\n'
+ # 最后的[DONE]
+ yield 'data: [DONE]\n\n'
+
+ try:
+ return StreamingResponse(_event_stream(), media_type='text/event-stream')
+ except Exception as exc:
+ logger.error(exc)
+ return ORJSONResponse(status_code=500, content=str(exc))
+
+
+@router.get('/info/{assistant_id}', response_model=UnifiedResponseModel[AssistantInfo])
+async def get_assistant_info(request: Request, assistant_id: UUID):
+ """
+ 获取助手信息, 用系统配置里的default_operator.user的用户信息来做权限校验
+ """
+ logger.info('act=assistant_info')
+ default_user = get_default_operator()
+ login_user = UserPayload(**{
+ 'user_id': default_user.user_id,
+ 'user_name': default_user.user_name,
+ 'role': ''
+ })
+ return AssistantService.get_assistant_info(assistant_id, login_user)
+
+
+@router.websocket('/chat/{assistant_id}')
+async def chat(*,
+ assistant_id: str,
+ websocket: WebSocket,
+ chat_id: Optional[str] = None):
+ """
+ 助手的ws免登录接口
+ """
+ pass
From 9e9e44ae36e61f0e39bffdb6f55b3f4dcb642710 Mon Sep 17 00:00:00 2001
From: GuoQing Zhang
Date: Thu, 11 Jul 2024 11:07:35 +0800
Subject: [PATCH 3/3] =?UTF-8?q?feat:=20=E7=BB=99=E5=8A=A9=E6=89=8B?=
=?UTF-8?q?=E7=9A=84=E5=85=8D=E7=99=BB=E5=BD=95=E6=8F=90=E4=BE=9B=E5=AF=B9?=
=?UTF-8?q?=E5=BA=94=E7=9A=84=E6=8E=A5=E5=8F=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/backend/bisheng/api/utils.py | 5 ++--
src/backend/bisheng/api/v2/assistant.py | 31 +++++++++++++++++++++----
2 files changed, 29 insertions(+), 7 deletions(-)
diff --git a/src/backend/bisheng/api/utils.py b/src/backend/bisheng/api/utils.py
index b813e98da..87ecdabc2 100644
--- a/src/backend/bisheng/api/utils.py
+++ b/src/backend/bisheng/api/utils.py
@@ -1,12 +1,11 @@
import hashlib
import json
-import time
import xml.dom.minidom
from pathlib import Path
from typing import Dict, List
import aiohttp
-from fastapi import Request
+from fastapi import Request, WebSocket
from fastapi_jwt_auth import AuthJWT
from platformdirs import user_cache_dir
from sqlalchemy import delete
@@ -421,7 +420,7 @@ async def get_url_content(url: str) -> str:
return res.decode('utf-8')
-def get_request_ip(request: Request) -> str:
+def get_request_ip(request: Request | WebSocket) -> str:
""" 获取客户端真实IP """
x_forwarded_for = request.headers.get('X-Forwarded-For')
if x_forwarded_for:
diff --git a/src/backend/bisheng/api/v2/assistant.py b/src/backend/bisheng/api/v2/assistant.py
index b33cbccfc..da0bcbc87 100644
--- a/src/backend/bisheng/api/v2/assistant.py
+++ b/src/backend/bisheng/api/v2/assistant.py
@@ -4,7 +4,8 @@
from typing import Optional
from uuid import UUID
-from fastapi import APIRouter, Request, HTTPException, WebSocket
+from fastapi import APIRouter, Request, HTTPException, WebSocket, WebSocketException
+from fastapi import status as http_status
from fastapi.responses import StreamingResponse, ORJSONResponse
from langchain_core.messages import HumanMessage, AIMessage
from loguru import logger
@@ -13,8 +14,10 @@
from bisheng.api.services.assistant_agent import AssistantAgent
from bisheng.api.services.user_service import UserPayload
from bisheng.api.utils import get_request_ip
+from bisheng.api.v1.chat import chat_manager
from bisheng.api.v1.schemas import OpenAIChatCompletionResp, OpenAIChatCompletionReq, UnifiedResponseModel, \
AssistantInfo, OpenAIChoice
+from bisheng.chat.types import WorkType
from bisheng.database.models.user import UserDao
from bisheng.settings import settings
@@ -129,7 +132,7 @@ async def get_assistant_info(request: Request, assistant_id: UUID):
"""
获取助手信息, 用系统配置里的default_operator.user的用户信息来做权限校验
"""
- logger.info('act=assistant_info')
+ logger.info(f'act=get_default_operator assistant_id={assistant_id}, ip={get_request_ip(request)}')
default_user = get_default_operator()
login_user = UserPayload(**{
'user_id': default_user.user_id,
@@ -141,10 +144,30 @@ async def get_assistant_info(request: Request, assistant_id: UUID):
@router.websocket('/chat/{assistant_id}')
async def chat(*,
- assistant_id: str,
websocket: WebSocket,
+ assistant_id: str,
chat_id: Optional[str] = None):
"""
助手的ws免登录接口
"""
- pass
+ logger.info(f'act=assistant_chat_ws assistant_id={assistant_id}, ip={get_request_ip(websocket)}')
+ default_user = get_default_operator()
+ login_user = UserPayload(**{
+ 'user_id': default_user.user_id,
+ 'user_name': default_user.user_name,
+ 'role': ''
+ })
+ try:
+ request = websocket
+ await chat_manager.dispatch_client(request, assistant_id, chat_id, login_user, WorkType.GPTS,
+ websocket)
+ except WebSocketException as exc:
+ logger.error(f'Websocket exception: {str(exc)}')
+ await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=str(exc))
+ except Exception as exc:
+ logger.exception(f'Error in chat websocket: {str(exc)}')
+ message = exc.detail if isinstance(exc, HTTPException) else str(exc)
+ if 'Could not validate credentials' in str(exc):
+ await websocket.close(code=http_status.WS_1008_POLICY_VIOLATION, reason='Unauthorized')
+ else:
+ await websocket.close(code=http_status.WS_1011_INTERNAL_ERROR, reason=message)