Skip to content

Commit

Permalink
fix bug (#27)
Browse files Browse the repository at this point in the history
1. es 参数对齐
  • Loading branch information
yaojin3616 committed Sep 13, 2023
2 parents e0df4bd + 17f634c commit 075199e
Show file tree
Hide file tree
Showing 50 changed files with 1,222 additions and 446 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
pip install Cython
pip install wheel
pip install twine
cd ./src/bisheng_langchain
cd ./src/bisheng-langchain
python setup.py bdist_wheel
twine upload dist/* -u ${{ secrets.PYPI_USER }} -p ${{ secrets.PYPI_PASSWORD }} --repository pypi
Expand Down Expand Up @@ -86,4 +86,4 @@ jobs:
tags: |
${{ env.DOCKERHUB_REPO }}bisheng-frontend:latest
${{ env.DOCKERHUB_REPO }}bisheng-frontend:${{ steps.get_version.outputs.VERSION }}
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"jinja": true,
"justMyCode": false,
"env": {
"PYTHONPATH": "${workspaceRoot}/src/backend/"
"PYTHONPATH": "${workspaceRoot}/src/backend/",
"config": "/Users/huangly/Code/bisheng/src/backend/bisheng/config.dev.yaml"
},
"cwd": "${workspaceFolder}/src/backend/",
}
Expand Down
6 changes: 5 additions & 1 deletion docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ knowledges:
vectorstores:
# Milvus 最低要求cpu 4C 8G 推荐4C 16G
Milvus: # 如果需要切换其他vectordb,确保其他服务已经启动,然后配置对应参数
connection_args = {'host': '110.16.193.170', 'port': '50032', 'user': '', 'password': '', 'secure': False}
connection_args: {'host': '110.16.193.170', 'port': '50032', 'user': '', 'password': '', 'secure': False}

agents:
ZeroShotAgent:
Expand All @@ -32,6 +32,10 @@ agents:
SQLAgent:
documentation: ""
chains:
MultiRetrievalQA:
documentation: ""
SequentialChain:
documentation: ""
LLMChain:
documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain"
LLMMathChain:
Expand Down
39 changes: 13 additions & 26 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from typing import Any, Dict, List, Union

from bisheng.api.v1.schemas import ChatResponse
from bisheng.utils.logger import logger
from fastapi import WebSocket
Expand All @@ -20,35 +19,25 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type='stream', intermediate_steps='')
await self.websocket.send_json(resp.dict())

async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
"""Run when LLM starts running."""

async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""

async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
"""Run when LLM errors."""

async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain starts running."""

async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""

async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
"""Run when chain errors."""

async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message='',
Expand All @@ -60,8 +49,9 @@ async def on_tool_start(
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')
from langchain.docstore.document import Document # noqa
result = eval(output).get('result')
# from langchain.docstore.document import Document # noqa
# result = eval(output).get('result')
result = output

# Create a formatted message.
intermediate_steps = f'{observation_prefix}{result}'
Expand All @@ -79,9 +69,7 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
except Exception as e:
logger.error(e)

async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
"""Run when tool errors."""

async def on_text(self, text: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -154,9 +142,7 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message='',
Expand All @@ -171,8 +157,9 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')

from langchain.docstore.document import Document # noqa
result = eval(output).get('result')
# from langchain.docstore.document import Document # noqa
# result = eval(output).get('result')
result = output
# Create a formatted message.
intermediate_steps = f'{observation_prefix}{result}'

Expand Down
22 changes: 9 additions & 13 deletions src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import List
from uuid import UUID

from sqlalchemy import func

from bisheng.api.utils import build_flow_no_yield, remove_api_keys
from bisheng.api.v1.schemas import FlowListCreate, FlowListRead
from bisheng.database.base import get_session
Expand All @@ -24,10 +26,6 @@ def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate, Au
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

if flow.flow_id:
# copy from template
temp_flow = session.get(Flow, flow.flow_id)
flow.data = temp_flow.data
flow.user_id = payload.get('user_id')
db_flow = Flow.from_orm(flow)
session.add(db_flow)
Expand All @@ -36,7 +34,7 @@ def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate, Au
return db_flow


@router.get('/', response_model=list[FlowReadWithStyle], status_code=200)
@router.get('/', status_code=200)
def read_flows(*,
session: Session = Depends(get_session),
name: str = Query(default=None, description='根据name查找数据库'),
Expand All @@ -50,14 +48,17 @@ def read_flows(*,

try:
sql = select(Flow)
count_sql = select(func.count(Flow.id))
if 'admin' != payload.get('role'):
sql = sql.where(Flow.user_id == payload.get('user_id'))
count_sql = count_sql.where(Flow.user_id == payload.get('user_id'))
if name:
sql = sql.where(Flow.name.like(f'%{name}%'))
count_sql = count_sql.where(Flow.name.like(f'%{name}%'))
if status:
sql = sql.where(Flow.status == status)
# count = session.exec(sql.count())
# total = count.scalar()
count_sql = count_sql.where(Flow.status == status)
total_count = session.scalar(count_sql)

sql = sql.order_by(Flow.update_time.desc())
if page_num and page_size:
Expand All @@ -72,7 +73,7 @@ def read_flows(*,
for r in res:
r['user_name'] = userMap[r['user_id']]

return res
return {'data': res, 'total': total_count}

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down Expand Up @@ -134,11 +135,6 @@ def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID, Autho
if 'admin' != payload.get('role') and flow.user_id != payload.get('user_id'):
raise HTTPException(status_code=500, detail='没有权限删除此技能')

# 判断是否属于模板
db_template = session.exec(select(Template).where(Template.flow_id == flow_id)).first()
if db_template:
session.delete(db_template)

session.delete(flow)
session.commit()
return {'message': 'Flow deleted successfully'}
Expand Down
44 changes: 28 additions & 16 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import asyncio
from ctypes import Union
import json
import time
from typing import List
from typing import List, Optional
from uuid import uuid4

from sqlalchemy import func

from bisheng.api.v1.schemas import UploadFileResponse
from bisheng.cache.utils import save_uploaded_file
from bisheng.database.base import get_session
from bisheng.database.models.knowledge import Knowledge, KnowledgeCreate, KnowledgeRead
from bisheng.database.models.knowledge_file import KnowledgeFile, KnowledgeFileRead
from bisheng.database.models.knowledge_file import KnowledgeFile
from bisheng.database.models.user import User
from bisheng.interface.importing.utils import import_vectorstore
from bisheng.interface.initialize.loading import instantiate_vectorstore
Expand Down Expand Up @@ -120,42 +123,51 @@ def create_knowledge(*,
return db_knowldge


@router.get('/', response_model=List[KnowledgeRead], status_code=200)
def get_knowledge(*, session: Session = Depends(get_session), Authorize: AuthJWT = Depends()):
@router.get('/', status_code=200)
def get_knowledge(*,
session: Session = Depends(get_session),
page_size: Optional[int],
page_num: Optional[str],
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
""" 读取所有知识库信息. """

try:
sql = select(Knowledge)
count_sql = select(func.count(Knowledge.id))
if 'admin' != payload.get('role'):
knowledges = session.exec(
select(Knowledge).where(Knowledge.user_id == payload.get('user_id')).order_by(
Knowledge.update_time.desc())).all()
else:
knowledges = session.exec(select(Knowledge).order_by(Knowledge.update_time.desc())).all()
sql = sql.where(Knowledge.user_id == payload.get('user_id'))
count_sql = count_sql.where(Knowledge.user_id == payload.get('user_id'))
total_count = session.scalar(count_sql)

if page_num and page_size and page_num != 'undefined':
page_num = int(page_num)
sql = sql.offset((page_num - 1) * page_size).limit(page_size)

knowledges = session.exec(sql).all()
res = [jsonable_encoder(flow) for flow in knowledges]
if knowledges:
db_user_ids = {flow.user_id for flow in knowledges}
db_user = session.exec(select(User).where(User.user_id.in_(db_user_ids))).all()
userMap = {user.user_id: user.user_name for user in db_user}
for r in res:
r['user_name'] = userMap[r['user_id']]
return res
return {'data': res, 'total': total_count}

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e


@router.get('/file_list/{knowledge_id}', response_model=List[KnowledgeFileRead], status_code=200)
@router.get('/file_list/{knowledge_id}', status_code=200)
def get_filelist(*, session: Session = Depends(get_session), knowledge_id: int, page_size: int = 10, page_num: int = 1):
""" 删除知识库信息. """
knowledge = session.get(Knowledge, knowledge_id)
if not knowledge:
raise HTTPException(status_code=404, detail='knowledge not found')
""" 获取知识库文件信息. """
# 查找上传的文件信息
total_count = session.scalar(select(func.count(KnowledgeFile.id)).where(KnowledgeFile.knowledge_id == knowledge_id))
files = session.exec(
select(KnowledgeFile).where(KnowledgeFile.knowledge_id == knowledge_id).order_by(
KnowledgeFile.update_time.desc()).offset(page_size * (page_num - 1)).limit(page_size)).all()
return [jsonable_encoder(knowledgefile) for knowledgefile in files]
return {"data": [jsonable_encoder(knowledgefile) for knowledgefile in files], "total": total_count}


@router.delete('/{knowledge_id}', status_code=200)
Expand Down
Loading

0 comments on commit 075199e

Please sign in to comment.