Skip to content

Commit

Permalink
fix embedding auto (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 committed Sep 15, 2023
2 parents 51a8969 + 991b8cc commit 738ce40
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 552 deletions.
15 changes: 10 additions & 5 deletions src/backend/bisheng/interface/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from bisheng.settings import settings
from bisheng.template.frontend_node.chains import ChainFrontendNode
from bisheng.utils.logger import logger
from bisheng.utils.util import (build_template_from_class, build_template_from_method)
from langchain import chains
from bisheng.utils.util import build_template_from_class, build_template_from_method
from bisheng_langchain import chains as bisheng_chains
from langchain import chains

# Assuming necessary imports for Field, Template, and FrontendNode classes


Expand All @@ -32,7 +33,8 @@ def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
# langchain
self.type_dict: dict[str, Any] = {
chain_name: import_class(f'langchain.chains.{chain_name}') for chain_name in chains.__all__
chain_name: import_class(f'langchain.chains.{chain_name}')
for chain_name in chains.__all__
}
# bisheng-langchain
bisheng = {
Expand All @@ -46,7 +48,9 @@ def type_to_loader_dict(self) -> Dict:
self.type_dict.update(CUSTOM_CHAINS)
# Filter according to settings.chains
self.type_dict = {
name: chain for name, chain in self.type_dict.items() if name in settings.chains or settings.dev
name: chain
for name, chain in self.type_dict.items()
if name in settings.chains or settings.dev
}
return self.type_dict

Expand All @@ -71,7 +75,8 @@ def get_signature(self, name: str) -> Optional[Dict]:
def to_list(self) -> List[str]:
names = []
for _, chain in self.type_to_loader_dict.items():
chain_name = (chain.function_name() if hasattr(chain, 'function_name') else chain.__name__)
chain_name = (chain.function_name()
if hasattr(chain, 'function_name') else chain.__name__)
names.append(chain_name)
return names

Expand Down
134 changes: 48 additions & 86 deletions src/backend/bisheng/interface/initialize/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
return custom_node(**params)

class_object = import_by_type(_type=base_type, name=node_type)
return instantiate_based_on_type(
class_object, base_type, node_type, params
)
return instantiate_based_on_type(class_object, base_type, node_type, params)


def convert_params_to_sets(params):
Expand All @@ -56,9 +54,7 @@ def convert_params_to_sets(params):
def convert_kwargs(params):
# if *kwargs are passed as a string, convert to dict
# first find any key that has kwargs or config in it
kwargs_keys = [
key for key in params.keys() if 'kwargs' in key or 'config' in key
]
kwargs_keys = [key for key in params.keys() if 'kwargs' in key or 'config' in key]
for key in kwargs_keys:
if isinstance(params[key], str):
params[key] = json.loads(params[key])
Expand Down Expand Up @@ -136,8 +132,7 @@ def instantiate_llm(node_type, class_object, params: Dict):
return initialize_vertexai(class_object=class_object, params=params)
# max_tokens sometimes is a string and should be an int
if 'max_tokens' in params:
if isinstance(params['max_tokens'],
str) and params['max_tokens'].isdigit():
if isinstance(params['max_tokens'], str) and params['max_tokens'].isdigit():
params['max_tokens'] = int(params['max_tokens'])
elif not isinstance(params.get('max_tokens'), int):
params.pop('max_tokens', None)
Expand All @@ -155,21 +150,17 @@ def instantiate_memory(node_type, class_object, params):
params.pop(key)

try:
if 'retriever' in params and hasattr(
params['retriever'], 'as_retriever'
):
if 'retriever' in params and hasattr(params['retriever'], 'as_retriever'):
params['retriever'] = params['retriever'].as_retriever()
return class_object(**params)
# I want to catch a specific attribute error that happens
# when the object does not have a cursor attribute
except Exception as exc:
if "object has no attribute 'cursor'" in str(
exc
) or 'object has no field "conn"' in str(exc):
raise AttributeError((
'Failed to build connection to database.'
f' Please check your connection string and try again. Error: {exc}'
)) from exc
if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(
exc):
raise AttributeError(
('Failed to build connection to database.'
f' Please check your connection string and try again. Error: {exc}')) from exc
raise exc


Expand All @@ -185,6 +176,12 @@ def instantiate_retriever(node_type, class_object, params):


def instantiate_chains(node_type, class_object: Type[Chain], params: Dict):
if node_type == 'SequentialChain':
chains = params['chains']
for index, chain in enumerate(chains):
chain.__setattr__('output_key', chain.output_keys[0] + str(index))
params['input_variables'] = sum((chain.input_keys for chain in chains), [])

if 'retriever' in params and hasattr(params['retriever'], 'as_retriever'):
params['retriever'] = params['retriever'].as_retriever()
if node_type in chain_creator.from_method_nodes:
Expand All @@ -196,17 +193,15 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict):
return class_object(**params)


def instantiate_agent(
node_type, class_object: Type[agent_module.Agent], params: Dict
):
def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: Dict):
if node_type in agent_creator.from_method_nodes:
method = agent_creator.from_method_nodes[node_type]
if class_method := getattr(class_object, method, None):
agent = class_method(**params)
tools = params.get('tools', [])
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, handle_parsing_errors=True
)
return AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
handle_parsing_errors=True)
return load_agent_executor(class_object, params)


Expand All @@ -217,9 +212,7 @@ def instantiate_prompt(node_type, class_object, params: Dict):
return ZeroShotAgent.create_prompt(**params)
elif 'MessagePromptTemplate' in node_type:
# Then we only need the template
from_template_params = {
'template': params.pop('prompt', params.pop('template', ''))
}
from_template_params = {'template': params.pop('prompt', params.pop('template', ''))}

if not from_template_params.get('template'):
raise ValueError('Prompt template is required')
Expand All @@ -236,21 +229,16 @@ def instantiate_prompt(node_type, class_object, params: Dict):
variable = params[input_variable]
if isinstance(variable, str):
format_kwargs[input_variable] = variable
elif isinstance(variable, BaseOutputParser) and hasattr(
variable, 'get_format_instructions'
):
format_kwargs[input_variable
] = variable.get_format_instructions()
elif isinstance(variable, BaseOutputParser) and hasattr(variable,
'get_format_instructions'):
format_kwargs[input_variable] = variable.get_format_instructions()
elif isinstance(variable, List) and all(
isinstance(item, Document) for item in variable
):
isinstance(item, Document) for item in variable):
# Format document to contain page_content and metadata
# as one string separated by a newline
if len(variable) > 1:
content = '\n'.join([
item.page_content for item in variable
if item.page_content
])
content = '\n'.join(
[item.page_content for item in variable if item.page_content])
else:
if not variable:
format_kwargs[input_variable] = ''
Expand All @@ -265,13 +253,10 @@ def instantiate_prompt(node_type, class_object, params: Dict):
# handle_keys will be a list but it does not exist yet
# so we need to create it

if (
isinstance(variable, List)
and all(isinstance(item, Document) for item in variable)
) or (
isinstance(variable, BaseOutputParser)
and hasattr(variable, 'get_format_instructions')
):
if (isinstance(variable, List) and
all(isinstance(item, Document)
for item in variable)) or (isinstance(variable, BaseOutputParser) and
hasattr(variable, 'get_format_instructions')):
if 'handle_keys' not in format_kwargs:
format_kwargs['handle_keys'] = []

Expand Down Expand Up @@ -301,9 +286,7 @@ def instantiate_tool(node_type, class_object: Type[BaseTool], params: Dict):
return class_object(**params)


def instantiate_toolkit(
node_type, class_object: Type[BaseToolkit], params: Dict
):
def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict):
loaded_toolkit = class_object(**params)
# Commenting this out for now to use toolkits as normal tools
# if toolkits_creator.has_create_function(node_type):
Expand All @@ -318,10 +301,7 @@ def instantiate_embedding(class_object, params: Dict):
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items() if key in class_object.__fields__
}
params = {key: value for key, value in params.items() if key in class_object.__fields__}
return class_object(**params)


Expand Down Expand Up @@ -353,17 +333,13 @@ def instantiate_documentloader(class_object: Type[BaseLoader], params: Dict):
# in x and if it is, we will return True
file_filter = params.pop('file_filter')
extensions = file_filter.split(',')
params['file_filter'] = lambda x: any(
extension.strip() in x for extension in extensions
)
params['file_filter'] = lambda x: any(extension.strip() in x for extension in extensions)
metadata = params.pop('metadata', None)
if metadata and isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError as exc:
raise ValueError(
'The metadata you provided is not a valid JSON string.'
) from exc
raise ValueError('The metadata you provided is not a valid JSON string.') from exc
# make it success when file not present
if 'file_path' in params and not params['file_path']:
return []
Expand All @@ -389,21 +365,16 @@ def instantiate_textsplitter(
if not documents:
return []
except KeyError as exc:
raise ValueError(
'The source you provided did not load correctly or was empty.'
'Try changing the chunk_size of the Text Splitter.'
) from exc

if (
'separator_type' in params and params['separator_type'] == 'Text'
) or 'separator_type' not in params:
raise ValueError('The source you provided did not load correctly or was empty.'
'Try changing the chunk_size of the Text Splitter.') from exc

if ('separator_type' in params and
params['separator_type'] == 'Text') or 'separator_type' not in params:
params.pop('separator_type', None)
# separators might come in as an escaped string like \\n
# so we need to convert it to a string
if 'separators' in params:
params['separators'] = (
params['separators'].encode().decode('unicode-escape')
)
params['separators'] = (params['separators'].encode().decode('unicode-escape'))
text_splitter = class_object(**params)
else:
from langchain.text_splitter import Language
Expand All @@ -428,34 +399,27 @@ def replace_zero_shot_prompt_with_prompt_template(nodes):
if node['data']['type'] == 'ZeroShotPrompt':
# Build Prompt Template
tools = [
tool for tool in nodes if tool['type'] != 'chatOutputNode'
and 'Tool' in tool['data']['node']['base_classes']
tool for tool in nodes if tool['type'] != 'chatOutputNode' and
'Tool' in tool['data']['node']['base_classes']
]
node['data'] = build_prompt_template(
prompt=node['data'], tools=tools
)
node['data'] = build_prompt_template(prompt=node['data'], tools=tools)
break
return nodes


def load_agent_executor(
agent_class: type[agent_module.Agent], params, **kwargs
):
def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs):
"""Load agent executor from agent class, tools and chain"""
allowed_tools: Sequence[BaseTool] = params.get('allowed_tools', [])
llm_chain = params['llm_chain']
# agent has hidden args for memory. might need to be support
# memory = params["memory"]
# if allowed_tools is not a list or set, make it a list
if not isinstance(allowed_tools,
(list, set)) and isinstance(allowed_tools, BaseTool):
if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool):
allowed_tools = [allowed_tools]
tool_names = [tool.name for tool in allowed_tools]
# Agent class requires an output_parser but Agent classes
# have a default output_parser.
agent = agent_class(
allowed_tools=tool_names, llm_chain=llm_chain
) # type: ignore
agent = agent_class(allowed_tools=tool_names, llm_chain=llm_chain) # type: ignore
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=allowed_tools,
Expand All @@ -475,12 +439,10 @@ def build_prompt_template(prompt, tools):
"""Build PromptTemplate from ZeroShotPrompt"""
prefix = prompt['node']['template']['prefix']['value']
suffix = prompt['node']['template']['suffix']['value']
format_instructions = prompt['node']['template']['format_instructions'][
'value']
format_instructions = prompt['node']['template']['format_instructions']['value']

tool_strings = '\n'.join([
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
for tool in tools
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools
])
tool_names = ', '.join([tool['data']['node']['name'] for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
Expand Down
28 changes: 15 additions & 13 deletions src/backend/bisheng/interface/initialize/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
def docs_in_params(params: dict) -> bool:
"""Check if params has documents OR texts and one of them is not an empty list,
If any of them is not an empty list, return True, else return False"""
return ('documents' in params and params['documents']) or ('texts' in params and params['texts'])
return ('documents' in params and params['documents']) or ('texts' in params and
params['texts'])


def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
Expand Down Expand Up @@ -207,21 +208,22 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict):


def initial_milvus(class_object: Type[Milvus], params: dict):
if 'connection_args' not in params:
params['connection_args'] = settings.knowledges.get('vectorstores').get('Milvus')
if not params['connection_args']:
params['connection_args'] = settings.knowledges.get('vectorstores').get('Milvus').get(
'connection_args')
if 'embedding' not in params:
# 匹配知识库的embedding
col = params['collection_name']
with get_session() as session:
knowledge = session.exec(select(Knowledge).where(Knowledge.collection_name == col)).first()
if not knowledge:
raise Exception(f'不能找到知识库collection={col}')
model_param = settings.knowledges.get('embeddings').get(knowledge.model)
if Knowledge.model == 'text-embedding-ada-002':
embedding = OpenAIEmbeddings(**model_param)
else:
embedding = HostEmbeddings(**model_param)
params['embedding'] = embedding
session = next(get_session())
knowledge = session.exec(select(Knowledge).where(Knowledge.collection_name == col)).first()
if not knowledge:
raise Exception(f'不能找到知识库collection={col}')
model_param = settings.knowledges.get('embeddings').get(knowledge.model)
if knowledge.model == 'text-embedding-ada-002':
embedding = OpenAIEmbeddings(**model_param)
else:
embedding = HostEmbeddings(**model_param)
params['embedding'] = embedding

elif isinstance(params.get('connection_args'), str):
print(f"milvus before params={params} type={type(params['connection_args'])}")
Expand Down
Loading

0 comments on commit 738ce40

Please sign in to comment.