Skip to content

Commit

Permalink
add feat 0.1.8 (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 authored Oct 24, 2023
2 parents 013e07a + b6c257b commit 8a76084
Show file tree
Hide file tree
Showing 77 changed files with 2,024 additions and 551 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ qdrant_storage
.DS_Store

# VSCode
.vscode/
.vscode
.vscode/settings.json
.chroma
.ruff_cache
Expand Down
2 changes: 1 addition & 1 deletion docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 数据库配置
database_url:
"mysql+pymysql://root:1234@mysql:3306/bisheng"
"mysql+pymysql://root:1234@mysql:3306/bisheng?charset=utf8mb4"
redis_url:
"redis:6379"

Expand Down
2 changes: 2 additions & 0 deletions src/backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ WORKDIR /app
# Install Poetry
RUN apt-get update && apt-get install gcc g++ curl build-essential postgresql-server-dev-all -y
RUN apt-get update && apt-get install procps -y
# opencv
RUN apt-get install -y libglib2.0-0 libsm6 libxrender1 libxext6 libgl1
RUN curl -sSL https://install.python-poetry.org | python3 -
# # Add Poetry to PATH
ENV PATH="${PATH}:/root/.local/bin"
Expand Down
26 changes: 24 additions & 2 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from bisheng.api.v1.schemas import StreamData
from bisheng.database.base import get_session
from bisheng.database.models.role_access import AccessType, RoleAccess
from bisheng.graph.graph.base import Graph
from bisheng.utils.logger import logger
from sqlmodel import select

API_WORDS = ['api', 'key', 'token']

Expand Down Expand Up @@ -91,7 +94,10 @@ def build_flow(graph_data: dict, artifacts, process_file=False, flow_id=None, ch
# tmp_{chat_id}
if vertex.base_type == 'vectorstores':
if 'collection_name' in vertex.params and not vertex.params.get('collection_name'):
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id}'
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id if chat_id else 1}'
elif 'index_name' in vertex.params and not vertex.params.get('index_name'):
# es
vertex.params['index_name'] = f'tmp_{flow_id}_{chat_id if chat_id else 1}'

vertex.build()
params = vertex._built_object_repr()
Expand Down Expand Up @@ -156,11 +162,14 @@ def build_flow_no_yield(graph_data: dict,
# tmp_{chat_id}
if vertex.base_type == 'vectorstores':
if 'collection_name' in vertex.params and not vertex.params.get('collection_name'):
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id}'
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id if chat_id else 1}'
logger.info(f"rename_vector_col col={vertex.params['collection_name']}")
if process_file:
# L1 清除Milvus历史记录
vertex.params['drop_old'] = True
elif 'index_name' in vertex.params and not vertex.params.get('index_name'):
# es
vertex.params['index_name'] = f'tmp_{flow_id}_{chat_id if chat_id else 1}'

vertex.build()
params = vertex._built_object_repr()
Expand All @@ -174,3 +183,16 @@ def build_flow_no_yield(graph_data: dict,
except Exception as exc:
raise exc
return graph


def access_check(payload: dict, owner_user_id: int, target_id: int, type: AccessType) -> bool:
if payload.get('role') != 'admin':
# role_access
session = next(get_session())
role_access = session.exec(
select(RoleAccess).where(RoleAccess.role_id.in_(payload.get('role')),
RoleAccess.type == type.value)).all()
third_ids = [access.third_id for access in role_access]
if owner_user_id != payload.get('user_id') or not third_ids or target_id not in third_ids:
return False
return True
49 changes: 49 additions & 0 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,55 @@ async def chat(client_id: str,
logger.error(str(e))


@router.websocket('/chat/ws/{client_id}')
async def union_websocket(client_id: str,
websocket: WebSocket,
chat_id: Optional[str] = None,
type: Optional[str] = None,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required(auth_from='websocket', websocket=websocket)
payload = json.loads(Authorize.get_jwt_subject())
user_id = payload.get('user_id')
"""Websocket endpoint for chat."""
if type and type == 'L1':
with next(get_session()) as session:
db_flow = session.get(Flow, client_id)
if not db_flow:
await websocket.accept()
message = '该技能已被删除'
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
if db_flow.status != 2:
await websocket.accept()
message = '当前技能未上线,无法直接对话'
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
graph_data = db_flow.data
else:
flow_data_key = 'flow_data_' + client_id
if str(flow_data_store.hget(flow_data_key, 'status'), 'utf-8') != BuildStatus.SUCCESS.value:
await websocket.accept()
message = '当前编译没通过'
await websocket.close(code=status.WS_1013_TRY_AGAIN_LATER, reason=message)
graph_data = json.loads(flow_data_store.hget(flow_data_key, 'graph_data'))

try:
graph = build_flow_no_yield(graph_data=graph_data,
artifacts={},
process_file=False,
flow_id=UUID(client_id).hex,
chat_id=chat_id)
langchain_object = graph.build()
for node in langchain_object:
key_node = get_cache_key(client_id, chat_id, node.id)
chat_manager.set_cache(key_node, node._built_object)
chat_manager.set_cache(get_cache_key(client_id, chat_id), node._built_object)
await chat_manager.handle_websocket(client_id, chat_id, websocket, user_id)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as e:
logger.error(str(e))


@router.post('/build/init/{flow_id}', response_model=InitResponse, status_code=201)
async def init_build(*, graph_data: dict, session: Session = Depends(get_session), flow_id: str):
"""Initialize the build by storing graph data and returning a unique session ID."""
Expand Down
47 changes: 39 additions & 8 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Optional

import yaml
from bisheng.api.v1.schemas import ProcessResponse, UploadFileResponse
from bisheng.cache.utils import save_uploaded_file
from bisheng.database.base import get_session
from bisheng.database.models.config import Config
from bisheng.database.models.flow import Flow
from bisheng.interface.types import langchain_types_dict
from bisheng.processing.process import process_graph_cached, process_tweaks
from bisheng.settings import parse_key
from bisheng.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from sqlmodel import Session
from sqlalchemy import delete
from sqlmodel import Session, select

# build router
router = APIRouter(tags=['Base'])
Expand All @@ -19,14 +23,43 @@ def get_all():
return langchain_types_dict


@router.get('/config')
def get_config(session: Session = Depends(get_session)):
configs = session.exec(select(Config)).all()
config_str = []
for config in configs:
config_str.append(config.key + ':')
config_str.append(config.value)
return '\n'.join(config_str)


@router.post('/config/save')
def save_config(data: dict, session: Session = Depends(get_session)):
try:
config_yaml = yaml.safe_load(data.get('data'))
session.exec(delete(Config))
keys = list(config_yaml.keys())
values = parse_key(keys, data.get('data'))

for index, key in enumerate(keys):
config = Config(key=key, value=values[index])
session.add(config)
session.commit()
except Exception as e:
session.rollback()
raise HTTPException(status_code=500, detail=f'格式不正确, {str(e)}')

return {'message': 'save success'}


# For backwards compatibility we will keep the old endpoint
@router.post('/predict/{flow_id}', response_model=ProcessResponse)
@router.post('/process/{flow_id}', response_model=ProcessResponse)
async def process_flow(
flow_id: str,
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
session: Session = Depends(get_session),
flow_id: str,
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
session: Session = Depends(get_session),
):
"""
Endpoint to process an input with a given flow_id.
Expand All @@ -46,9 +79,7 @@ async def process_flow(
except Exception as exc:
logger.error(f'Error processing tweaks: {exc}')
response = process_graph_cached(graph_data, inputs)
return ProcessResponse(
result=response,
)
return ProcessResponse(result=response,)
except Exception as e:
# Log stack trace
logger.exception(e)
Expand Down
26 changes: 21 additions & 5 deletions src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from bisheng.api.v1.schemas import FlowListCreate, FlowListRead
from bisheng.database.base import get_session
from bisheng.database.models.flow import Flow, FlowCreate, FlowRead, FlowReadWithStyle, FlowUpdate
from bisheng.database.models.role_access import AccessType, RoleAccess
from bisheng.database.models.user import User
from bisheng.settings import settings
from bisheng.utils.logger import logger
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import func
from sqlalchemy import func, or_
from sqlmodel import Session, select

# build router
Expand Down Expand Up @@ -51,8 +53,18 @@ def read_flows(*,
sql = select(Flow)
count_sql = select(func.count(Flow.id))
if 'admin' != payload.get('role'):
sql = sql.where(Flow.user_id == payload.get('user_id'))
count_sql = count_sql.where(Flow.user_id == payload.get('user_id'))
rol_flow_id = session.exec(
select(RoleAccess).where(RoleAccess.role_id.in_(payload.get('role')))).all()
if rol_flow_id:
flow_ids = [
acess.third_id for acess in rol_flow_id if acess.type == AccessType.FLOW
]
sql = sql.where(or_(Flow.user_id == payload.get('user_id'), Flow.id.in_(flow_ids)))
count_sql = count_sql.where(
or_(Flow.user_id == payload.get('user_id'), Flow.id.in_(flow_ids)))
else:
sql = sql.where(Flow.user_id == payload.get('user_id'))
count_sql = count_sql.where(Flow.user_id == payload.get('user_id'))
if name:
sql = sql.where(Flow.name.like(f'%{name}%'))
count_sql = count_sql.where(Flow.name.like(f'%{name}%'))
Expand Down Expand Up @@ -111,9 +123,13 @@ def update_flow(*,
# 上线校验
try:
art = {}
build_flow_no_yield(graph_data=db_flow.data, artifacts=art, process_file=False)
build_flow_no_yield(graph_data=db_flow.data,
artifacts=art,
process_file=False,
flow_id=flow_id.hex)
except Exception as exc:
raise HTTPException(status_code=500, detail='Flow 编译不通过') from exc
logger.exception(exc)
raise HTTPException(status_code=500, detail=f'Flow 编译不通过, {str(exc)}')

if settings.remove_api_keys:
flow_data = remove_api_keys(flow_data)
Expand Down
Loading

0 comments on commit 8a76084

Please sign in to comment.