Skip to content

Commit

Permalink
为开源做准备1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Kilig947 committed Apr 16, 2024
1 parent e15ade0 commit 4a1f736
Show file tree
Hide file tree
Showing 23 changed files with 733 additions and 477 deletions.
392 changes: 72 additions & 320 deletions README.md

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions __main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,21 @@ def signals_input_setting(self):
self.know_combo = [self.kb_input_select, self.vector_search_score, self.vector_search_top_k]
self.input_combo.extend(self.know_combo)
# 高级设置
self.models_combo = [self.input_models, self.vision_models, self.project_models, self.vector_search_to_history]
self.input_models.input(func_signals.update_models,
inputs=[self.input_models, self.vision_models, self.project_models],
inputs=self.models_combo,
outputs=[self.models_box])
self.vision_models.input(func_signals.update_models,
inputs=[self.input_models, self.vision_models, self.project_models],
inputs=self.models_combo,
outputs=[self.models_box])
self.project_models.input(func_signals.update_models,
inputs=[self.input_models, self.vision_models, self.project_models],
inputs=self.models_combo,
outputs=[self.models_box])
self.setting_combo = [self.models_box, self.history_round_num, self.default_worker_num, self.ocr_identifying_trust]
self.vector_search_to_history.input(func_signals.update_models,
inputs=self.models_combo,
outputs=[self.models_box])
self.setting_combo = [self.models_box, self.history_round_num, self.default_worker_num,
self.ocr_identifying_trust]
self.input_combo.extend(self.setting_combo)
# 个人信息
self.user_combo = [self.openai_keys, self.wps_cookie, self.qq_cookie, self.feishu_cookie,
Expand Down Expand Up @@ -538,7 +543,7 @@ def init_gradio_app():
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
dependencies = []
endpoint = None
gradio_app: fastapi # 增加类型提示,避免警告
gradio_app: fastapi # 增加类型提示,避免警告
for route in list(gradio_app.router.routes):
if route.path == "/file/{path:path}":
gradio_app.router.routes.remove(route)
Expand All @@ -554,6 +559,7 @@ async def file(path_or_url: str, request: fastapi.Request):
if not file_authorize_user(path_or_url, request, gradio_app):
return {"detail": "Hack me? How dare you?"}
return await endpoint(path_or_url, request)

server_app = create_app()

server_app.mount(CUSTOM_PATH, gradio_app)
Expand Down Expand Up @@ -591,5 +597,3 @@ def init_start():

if __name__ == '__main__':
init_start()


2 changes: 1 addition & 1 deletion common/db/repository/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def _read_user_auth(source):
return or_(PromptModel.source == source, source in _sys)
return or_(PromptModel.source == source, source in _sys, PromptModel.source.in_(_sys))


def add_prompt(in_class, name, value, source):
Expand Down
1 change: 1 addition & 0 deletions common/func_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ def _format(title, content: str = '', status=''):
if isinstance(content, dict):
content = json.dumps(content, indent=4, ensure_ascii=False)
content = f'\n```\n{content.replace("```", "").strip()}\n```\n'
title = title.replace('\n', '').strip()
return fold_html.format(title=f"<p>{title}</p>", content=content, status=status)

return _format
Expand Down
14 changes: 14 additions & 0 deletions common/init_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
sys.path.append(".")
from common.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files)
from common.db.repository import prompt_repository
from common.configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk

Expand Down Expand Up @@ -38,6 +39,14 @@
"--import-db",
help="import tables from specified sqlite database"
)
parser.add_argument(
"--import-pdb",
help="初始化提示词"
)
parser.add_argument(
"--export-pdb",
help="导出提示词"
)
parser.add_argument(
"-u",
"--update-in-db",
Expand Down Expand Up @@ -106,6 +115,11 @@
create_tables()
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.import_pdb:
create_tables()
prompt_repository.batch_import_prompt_dir()
elif args.export_pdb:
prompt_repository.batch_export_path()
elif args.import_db:
import_from_db(args.import_db)
elif args.update_in_db:
Expand Down
51 changes: 36 additions & 15 deletions common/knowledge_base/kb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_kb_key_value(kb_names: list):
def user_intent_recognition(user_input, history, llm_kwargs) -> tuple[bool, Any] | bool | Any:
kb_names = llm_kwargs['kb_config']['names']
llm, response_format = llm_accelerate_init(llm_kwargs)
ipaddr = func_box.user_client_mark(llm_kwargs['ipaddr'])
ipaddr = llm_kwargs['ipaddr']
prompt = prompt_repository.query_prompt('意图识别', '知识库提示词', ipaddr, quote_num=True)
if prompt:
prompt = prompt.value
Expand Down Expand Up @@ -67,10 +67,11 @@ def user_intent_recognition(user_input, history, llm_kwargs) -> tuple[bool, Any]
def get_vector_to_dict(vector_list):
data = {}
for i in vector_list:
if not data.get(i.metadata['source'], False):
data[i.metadata['source']] = ''
key_work = str(i.score)[:4]+i.metadata['source']
if not data.get(key_work, False):
data[key_work] = ''
try:
data[i.metadata['source']] += f"{i.page_content}\n"
data[key_work] += f"{i.page_content}\n"
except TypeError:
pass
return data
Expand Down Expand Up @@ -99,22 +100,42 @@ def vector_recall_by_input(user_input, chatbot, history, llm_kwargs, kb_prompt_c
vector_content += f"\n向量召回:{json.dumps(data, indent=4, ensure_ascii=False)}"
if not source_data:
vector_content += '无数据,转发到普通对话'
return user_input
return user_input, history
chatbot[-1][1] = vector_fold_format(title='向量召回完成', content=vector_content, status='Done')
yield from update_ui(chatbot, history)
ipaddr = func_box.user_client_mark(llm_kwargs['ipaddr'])
prompt = prompt_repository.query_prompt(kb_prompt_name, kb_prompt_cls, ipaddr, quote_num=True)
if prompt:
prompt = prompt.value
source_text = "\n".join([f"## {i}\n{source_data[i]}" for i in source_data])
kb_prompt = func_box.replace_expected_text(prompt, source_text, '{{{v}}}')
user_input = func_box.replace_expected_text(kb_prompt, user_input, '{{{q}}}')
return user_input
repeat_recall = ''
source_text = ''
for data in source_data:
if any([i for i in history if data in i]):
repeat_recall += f'过滤重复召回片段: {data}\n'
else:
source_text += f"## {data}\n{source_data[data]}"
if '专注力转移' in llm_kwargs['input_models'] and source_text:
title = '向量召回完成, 当前模式为专注力转移,不采用提示词'
user_show = user_input[:20] + "..." + user_input[-20:] + '\n找到以上文本相关文档片段'
history = history + [f'{user_show}', source_text]
elif source_text:
title = f'向量召回完成, 当前模式专注模式,使用`{kb_prompt_name}`提示词进行对话'
prompt = prompt_repository.query_prompt(kb_prompt_name, kb_prompt_cls, llm_kwargs['ipaddr'], quote_num=True)
if prompt:
prompt = prompt.value
kb_prompt = func_box.replace_expected_text(prompt, source_text, '{{{v}}}')
user_input = func_box.replace_expected_text(kb_prompt, user_input, '{{{q}}}')
else:
chatbot[-1][1] = vector_fold_format(title='检测召回完全相同文档,转发到普通对话',
content=repeat_recall,
status='Done')
return user_input, history
chatbot[-1][1] = vector_fold_format(title=title,
content=f"# 可用文档片段\n{source_text}\n"
f"# 重复文档片段\n{repeat_recall}",
status='Done')
return user_input, history
chatbot[-1][1] = vector_fold_format(title='无法找到可用知识库, 转发到普通对话',
content=str(user_intent) + f"prompt: \n{prompt}",
content=str(user_intent) + f"\nprompt: \n{prompt}",
status='Done')
yield from update_ui(chatbot, history)
return user_input
return user_input, history


if __name__ == '__main__':
Expand Down
1 change: 0 additions & 1 deletion common/path_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self):
self.private_knowledge_path = os.path.join(self.users_private_path, 'knowledge')
self.private_files_path = os.path.join(self.users_private_path, 'files')
self.private_history_path = os.path.join(self.users_private_path, 'history')
self.private_db_path = os.path.join(self.users_private_path, 'db')

def __getattribute__(self, name):
"""
Expand Down
21 changes: 12 additions & 9 deletions common/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def decorated(cookies, max_length, llm_model, txt, # 调优参数
'top_p': top_p, 'temperature': temperature, 'n_choices': n_choices, 'stop': stop_sequence,
'max_context': max_context, 'max_generation': max_generation, 'presence_penalty': presence_penalty,
'frequency_penalty': frequency_penalty, 'logit_bias': logit_bias, 'user_identifier': user_identifier,
'response_format': response_format,
'response_format': response_format, 'input_models': models,
'system_prompt': system_prompt, 'ipaddr': func_box.user_client_mark(ipaddr),
'kb_config': {"names": kb_selects, 'score': vector_score, 'top-k': vector_top_k},
}
Expand Down Expand Up @@ -152,8 +152,8 @@ def decorated(cookies, max_length, llm_model, txt, # 调优参数
txt_proc, func_redirect = yield from model_selection(txt, models, llm_kwargs, plugin_kwargs, cookies,
chatbot_with_cookie, history, args, func)
# 根据args判断需要对提交和历史对话做什么处理
txt_proc, history = yield from plugins_selection(txt_proc, history, plugin_kwargs,
args, cookies, chatbot_with_cookie, llm_kwargs)
txt_proc, history, func_redirect = yield from plugins_selection(txt_proc, history, plugin_kwargs,
args, cookies, chatbot_with_cookie, llm_kwargs, func)
# 根据cookie 或 对话配置决定到底走哪一步
yield from func_decision_tree(func_redirect, cookies, single_mode, agent_mode,
txt_proc, llm_kwargs, plugin_kwargs, chatbot_with_cookie,
Expand Down Expand Up @@ -219,7 +219,7 @@ def _vision_select_model(llm_kwargs, models):
return vision_llm


def plugins_selection(txt_proc, history, plugin_kwargs, args, cookies, chatbot_with_cookie, llm_kwargs):
def plugins_selection(txt_proc, history, plugin_kwargs, args, cookies, chatbot_with_cookie, llm_kwargs, func):
# 插件会传多参数,如果是插件,那么更新知识库 和 默认高级参数
if len(args) > 1:
plugin_kwargs['advanced_arg'] = ''
Expand All @@ -233,10 +233,13 @@ def plugins_selection(txt_proc, history, plugin_kwargs, args, cookies, chatbot_w
plugin_kwargs['advanced_arg'] = ''
from common.knowledge_base.kb_func import vector_recall_by_input
if llm_kwargs['kb_config']['names']:
txt_proc = yield from vector_recall_by_input(txt_proc, chatbot_with_cookie, history,
llm_kwargs, '知识库提示词_sys',
'引用知识库回答')
return txt_proc, history
from crazy_functions.reader_fns.crazy_box import submit_no_use_ui_task
unpacking_input = yield from vector_recall_by_input(txt_proc, chatbot_with_cookie, history,
llm_kwargs, '知识库提示词',
'引用知识库回答')
txt_proc, history = unpacking_input
func = submit_no_use_ui_task
return txt_proc, history, func


def func_decision_tree(func, cookies, single_mode, agent_mode,
Expand Down Expand Up @@ -358,7 +361,7 @@ def decorated(main_input: str, llm_kwargs: dict, plugin_kwargs: dict,
if len(chatbot_with_cookie) == 0:
chatbot_with_cookie.clear()
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n")
chatbot_with_cookie[-1][1] += f"\n\n[Local Message] 插件调用出错: \n\n{tb_str} \n"
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面

return decorated
Expand Down
Loading

0 comments on commit 4a1f736

Please sign in to comment.