Skip to content

Commit

Permalink
Feat/0.2.2.2 (#281)
Browse files Browse the repository at this point in the history
提测 bug修复
  • Loading branch information
yaojin3616 authored Jan 19, 2024
2 parents 28464d5 + b6dfa61 commit 91791a8
Show file tree
Hide file tree
Showing 27 changed files with 253 additions and 89 deletions.
2 changes: 1 addition & 1 deletion docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ services:
mysql:
container_name: bisheng-mysql
image: mysql:8.0
pull_policy: always

ports:
- "3306:3306"
environment:
Expand Down
66 changes: 43 additions & 23 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def create_knowledge(*,
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
""" 创建知识库. """
db_knowldge = Knowledge.from_orm(knowledge)
knowledge.is_partition = knowledge.is_partition or settings.vectorstores.get('Milvus', {}).get(
'is_partition', True)
db_knowldge = Knowledge.model_validate(knowledge)
know = session.exec(
select(Knowledge).where(Knowledge.name == knowledge.name,
knowledge.user_id == payload.get('user_id'))).all()
Expand All @@ -183,9 +185,9 @@ def create_knowledge(*,
if not db_knowldge.collection_name:
if knowledge.is_partition:
embedding = re.sub(r'[^\w]', '_', knowledge.model)
id = settings.get_knowledge().get('vectorstores').get('Milvus',
{}).get('partition_suffix', 1)
db_knowldge.collection_name = f'partition_{embedding}_knowledge_{id}'
suffix_id = settings.get_knowledge().get('vectorstores').get('Milvus', {}).get(
'partition_suffix', 1)
db_knowldge.collection_name = f'partition_{embedding}_knowledge_{suffix_id}'
else:
# 默认collectionName
db_knowldge.collection_name = f'col_{int(time.time())}_{str(uuid4())[:8]}'
Expand Down Expand Up @@ -413,16 +415,33 @@ def addEmbedding(collection_name, index_name, knowledge_id: int, model: str, chu
error_msg = error_msg + 'ESException:' + str(e)
logger.exception(e)

if not vectore_client and not es_client:
raise ValueError('no vectordb present')

callback_obj = {}
for index, path in enumerate(file_paths):
ts1 = time.time()
knowledge_file = knowledge_files[index]
logger.info('process_file_begin knowledge_id={} file_name={} file_size={} ',
knowledge_files[0].knowledge_id, knowledge_file.file_name, len(file_paths))

if not vectore_client and not es_client:
# 设置错误
with session_getter() as session:
db_file = session.get(KnowledgeFile, knowledge_file.id)
setattr(db_file, 'status', 3)
setattr(db_file, 'remark', error_msg[:500])
session.add(db_file)
callback_obj = db_file.copy()
session.commit()
if callback:
inp = {
'file_name': knowledge_file.file_name,
'file_status': knowledge_file.status,
'file_id': callback_obj.id,
'error_msg': callback_obj.remark
}
logger.error('add_fail callback={} file_name={} status={}', callback,
callback_obj.file_name, callback_obj.status)
requests.post(url=callback, json=inp, timeout=3)
continue
try:
# 存储 mysql
with session_getter() as session:
Expand All @@ -449,16 +468,17 @@ def addEmbedding(collection_name, index_name, knowledge_id: int, model: str, chu
for metadata in metadatas:
metadata.update({'file_id': knowledge_file.id, 'knowledge_id': f'{knowledge_id}'})

vectore_client.add_texts(texts=texts, metadatas=metadatas)
if vectore_client:
vectore_client.add_texts(texts=texts, metadatas=metadatas)

# 存储es
if es_client:
es_client.add_texts(texts=texts, metadatas=metadatas)

callback_obj = db_file.copy()
logger.info(
f'process_file_done file_name={knowledge_file.file_name} file_id={knowledge_file.id} time_cost={time.time()-ts1}' # noqa
)
logger.info('process_file_done file_name={} file_id={} time_cost={}',
knowledge_file.file_name, knowledge_file.id,
time.time() - ts1)
except Exception as e:
logger.error(e)
session = next(get_session())
Expand All @@ -468,18 +488,18 @@ def addEmbedding(collection_name, index_name, knowledge_id: int, model: str, chu
session.add(db_file)
callback_obj = db_file.copy()
session.commit()
if callback:
# asyn
inp = {
'file_name': callback_obj.file_name,
'file_status': callback_obj.status,
'file_id': callback_obj.id,
'error_msg': callback_obj.remark
}
logger.info(
f'add_complete callback={callback} file_name={callback_obj.file_name} status={callback_obj.status}'
)
requests.post(url=callback, json=inp, timeout=3)
if callback:
# asyn
inp = {
'file_name': callback_obj.file_name,
'file_status': callback_obj.status,
'file_id': callback_obj.id,
'error_msg': callback_obj.remark
}
logger.info(
f'add_complete callback={callback} file_name={callback_obj.file_name} status={callback_obj.status}'
)
requests.post(url=callback, json=inp, timeout=3)


def _read_chunk_text(input_file, file_name, size, chunk_overlap, separator):
Expand Down
20 changes: 8 additions & 12 deletions src/backend/bisheng/api/v2/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from bisheng.database.base import get_session, session_getter
from bisheng.database.models.flow import Flow
from bisheng.database.models.message import ChatMessage
from bisheng.graph.graph.base import Graph
from bisheng.processing.process import process_tweaks
from bisheng.settings import settings
from bisheng.utils.logger import logger
Expand Down Expand Up @@ -44,17 +43,14 @@ async def union_websocket(flow_id: str,
if tweak:
tweak = json.loads(tweak)
graph_data = process_tweaks(graph_data, tweak)
graph = Graph.from_payload(graph_data)
for node in graph.vertices:
if node.base_type == 'vectorstores':
if 'collection_name' in node.data.get('node').get('template').keys():
node.data.get('node').get(
'template')['collection_name']['collection_id'] = knowledge_id
elif 'index_name' in node.data.get('node').get('template').keys():
node.data.get('node').get(
'template')['index_name']['collection_id'] = knowledge_id

graph_data = graph.raw_graph_data
# vectordatabase update
for node in graph_data['nodes']:
if 'VectorStore' in node['data']['node']['base_classes']:
if 'collection_name' in node['data'].get('node').get('template').keys():
node['data']['node']['template']['collection_name'][
'collection_id'] = knowledge_id
if 'index_name' in node['data'].get('node').get('template').keys():
node['data']['node']['template']['index_name']['collection_id'] = knowledge_id
trace_id = str(uuid4().hex)
with logger.contextualize(trace_id=trace_id):
await chat_manager.handle_websocket(
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def judge_source(result, source_document, chat_id, extra: Dict):
else:
source = 1

if source:
if source == 1:
for doc in source_document:
# 确保每个chunk 都可溯源
if 'bbox' not in doc.metadata or not doc.metadata['bbox'] or not json.loads(
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/database/models/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ class KnowledgeUpdate(KnowledgeBase):


class KnowledgeCreate(KnowledgeBase):
is_partition: bool = True
is_partition: Optional[bool] = None
2 changes: 1 addition & 1 deletion src/backend/bisheng/database/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MessageBase(SQLModelSerializable):
receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的发送方')
intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志')
files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等')
remark: Optional[str] = Field(sa_column=Column(String(length=1024)), description='备注')
remark: Optional[str] = Field(sa_column=Column(String(length=10000)), description='备注')
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP')))
update_time: Optional[datetime] = Field(
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/default_node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ embeddings:
documentation: ""
HostEmbeddings:
documentation: ""
CustomHostEmbedding:
documentation: ""
llms:
AzureChatOpenAI:
documentation: ""
Expand Down
3 changes: 1 addition & 2 deletions src/backend/bisheng/graph/vertex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def _build_params(self):
if not value.get('required') and params.get(key) is None:
if value.get('default'):
params[key] = value.get('default')
else:
params.pop(key, None)

# Add _type to params
self._raw_params = params
self.params = params
Expand Down
3 changes: 3 additions & 0 deletions src/backend/bisheng/initdb_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ knowledges:
# Milvus 最低要求cpu 4C 8G 推荐4C 16G
Milvus: # 如果需要切换其他vectordb,确保其他服务已经启动,然后配置对应参数
connection_args: {'host': 'milvus', 'port': '19530', 'user': '', 'password': '', 'secure': False}
# partiton-key model, 用于分区的字段,如果不配置默认True, 分区后,新的partiton不会新建collection,可以通过增加suffix强制增加collection
is_partition: True
partition_suffix: 1
# 可选配置,有些类型的场景使用ES可以提高召回效果
ElasticKeywordsSearch:
elasticsearch_url: 'http://elasticsearch:9200'
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/template/frontend_node/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def format_field(field: TemplateField, name: Optional[str] = None) -> None:
field.placeholder = ':memory:'
elif field.name == 'collection_name' and name == 'Milvus':
field.value = ''
elif field.name == 'index_name':
field.value = ''

elif field.name in advanced_fields:
field.show = True
Expand Down
32 changes: 15 additions & 17 deletions src/bisheng-langchain/bisheng_langchain/chat_models/host_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import json
import logging
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

Expand All @@ -13,18 +14,19 @@
from langchain.schema.messages import (AIMessage, BaseMessage, ChatMessage, FunctionMessage,
HumanMessage, SystemMessage)
from langchain.utils import get_from_dict_or_env
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import Field, root_validator
from loguru import logger

# from requests.exceptions import HTTPError
from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
wait_exponential)

# from .interface import MinimaxChatCompletion
# from .interface.types import ChatInput

if TYPE_CHECKING:
import tiktoken

logger = logging.getLogger(__name__)


def _import_tiktoken() -> Any:
try:
Expand All @@ -36,19 +38,15 @@ def _import_tiktoken() -> Any:
return tiktoken


def _create_retry_decorator(llm: BaseHostChatLLM) -> Callable[[Any], Any]:
def _create_retry_decorator(
llm: BaseHostChatLLM,
run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]] = None,
) -> Callable[[Any], Any]:

min_seconds = 1
max_seconds = 20
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(Exception)),
before_sleep=before_sleep_log(logger, logger.level('WARNING')),
)
errors = [requests.exceptions.ReadTimeout, ValueError]
return create_base_retry_decorator(error_types=errors,
max_retries=llm.max_retries,
run_manager=run_manager)


def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Expand Down Expand Up @@ -215,7 +213,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:
# print('messages:', messages)
# print('functions:', kwargs.get('functions', []))
if self.verbose:
logger.info(f'payload={params}')
logger.info('payload=%s', json.dumps(params, indent=2))
try:
resp = self.client.post(url=self.host_base_url, json=params)
if resp.text.startswith('data:'):
Expand Down Expand Up @@ -297,7 +295,7 @@ async def _acompletion_with_retry(**kwargs: Any) -> Any:
if text.startswith('{'):
yield (is_error, response[len('data:'):])
else:
logger.info('agenerate_no_json text={}', text)
logger.info('agenerate_no_json text=%s', text)
if is_error:
break
elif response.startswith('{'):
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "bisheng",
"version": "0.2.2.1",
"version": "0.2.2.2",
"private": true,
"dependencies": {
"@emotion/react": "^11.11.1",
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/public/locales/en/bs.json
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@
"fileStorageFailure": "File storage failure!",
"confirmDeleteChat": "Confirm deletion of this chat?",
"roundOver": "This round is over",
"chatDialogTip": "Set the input variables defined in the prompt template. Interact with agents and chains."
"chatDialogTip": "Set the input variables defined in the prompt template. Interact with agents and chains.",
"feedback": "Feedback"
},
"model": {
"modelConfiguration": "Model Configuration",
Expand Down Expand Up @@ -292,6 +293,7 @@
"close": "Close",
"cancel": "Cancel",
"save": "Save",
"submit": "Submit",
"operations": "Operations",
"previousPage": "Previous Page",
"nextPage": "Next Page",
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/public/locales/zh/bs.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@
"fileStorageFailure": " 文件地址失效!",
"confirmDeleteChat": "确认删除该会话?",
"roundOver": "本轮结束",
"chatDialogTip": "设置提示模板中定义的输入变量。与代理和链互动"
"chatDialogTip": "设置提示模板中定义的输入变量。与代理和链互动",
"feedback": "反馈"
},
"model": {
"modelConfiguration": "模型配置",
Expand Down Expand Up @@ -288,6 +289,7 @@
"close": "关闭",
"cancel": "取消",
"save": "保存",
"submit": "提交",
"operations": "操作",
"previousPage": "上一页",
"nextPage": "下一页",
Expand Down
34 changes: 34 additions & 0 deletions src/frontend/src/components/ui/editLabel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { PenLine } from "lucide-react";
import { useRef, useState } from "react";

export default function EditLabel({ str, children, onChange }) {
const [value, setValue] = useState(str)

const [edit, setEdit] = useState(false)
const inputRef = useRef(str)

if (edit) return <div className="">
<input
type="text"
ref={inputRef}
defaultValue={str}
onKeyDown={(e) => {
e.key === 'Enter' && (setEdit(false), onChange(inputRef.current.value));
e.code === 'Space' && e.preventDefault();
}}
className="flex h-6 w-full rounded-xl border border-input bg-background px-3 py-1 text-sm shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50"
/>
</div>


return <div className="flex items-center text-gray-900 dark:text-gray-300 group">
{children(inputRef.current.value || str)}
<button
className="hidden transition-all group-hover:block"
// title={t('flow.editAlias')}
onClick={() => setEdit(true)}
>
<PenLine size={18} className="ml-2 cursor-pointer" />
</button>
</div >
};
3 changes: 2 additions & 1 deletion src/frontend/src/contexts/locationContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ export function LocationProvider({ children }: { children: ReactNode }) {
setAppConfig({
isDev: res.env === 'dev',
libAccepts: res.uns_support,
officeUrl: res.office_url
officeUrl: res.office_url,
dialogTips: res.dialog_tips
})
})
}, [])
Expand Down
Loading

0 comments on commit 91791a8

Please sign in to comment.