Skip to content

Commit

Permalink
运行startup.py时,如果不加参数直接显示配置和帮助信息后退出 (#1284)
Browse files Browse the repository at this point in the history
* 统一XX_kb_service.add_doc/do_add_doc接口,不再需要embeddings参数

* 运行startup.py时,如果不加参数直接显示配置和帮助信息后退出

---------

Co-authored-by: liunux4odoo <liunu@qq.com>
  • Loading branch information
liunux4odoo and liunux4odoo authored Aug 28, 2023
1 parent 3acbf4d commit ca0ae29
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 35 deletions.
4 changes: 1 addition & 3 deletions server/knowledge_base/kb_service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):

if docs:
self.delete_doc(kb_file)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings=embeddings, **kwargs)
self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
else:
status = False
Expand Down Expand Up @@ -181,7 +180,6 @@ def do_search(self,
@abstractmethod
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
):
"""
向知识库添加文档子类实自己逻辑
Expand Down
2 changes: 1 addition & 1 deletion server/knowledge_base/kb_service/default_kb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def do_create_kb(self):
def do_drop_kb(self):
pass

def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
def do_add_doc(self, docs: List[Document]):
pass

def do_clear_vs(self):
Expand Down
1 change: 0 additions & 1 deletion server/knowledge_base/kb_service/faiss_kb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def do_search(self,

def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
**kwargs,
):
vector_store = self.load_vector_store()
Expand Down
2 changes: 1 addition & 1 deletion server/knowledge_base/kb_service/milvus_kb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def do_search(self, query: str, top_k: int, score_threshold: float, embeddings:
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))

def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
def do_add_doc(self, docs: List[Document], **kwargs):
self.milvus.add_documents(docs)

def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion server/knowledge_base/kb_service/pg_kb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def do_search(self, query: str, top_k: int, score_threshold: float, embeddings:
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))

def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
def do_add_doc(self, docs: List[Document], **kwargs):
self.pg_vector.add_documents(docs)

def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
Expand Down
62 changes: 34 additions & 28 deletions startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def parse_args() -> argparse.ArgumentParser:
dest="webui",
)
args = parser.parse_args()
return args
return args, parser


def dump_server_info(after_start=False):
Expand All @@ -330,7 +330,7 @@ def dump_server_info(after_start=False):
import fastchat
from configs.server_config import api_address, webui_address

print("\n\n")
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本:{sys.version}")
Expand All @@ -351,15 +351,15 @@ def dump_server_info(after_start=False):
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
print("\n")


if __name__ == "__main__":
import time

mp.set_start_method("spawn")
queue = Queue()
args = parse_args()
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
Expand All @@ -379,8 +379,10 @@ def dump_server_info(after_start=False):
args.webui = False

dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")

if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")

processes = {}

Expand Down Expand Up @@ -435,28 +437,32 @@ def dump_server_info(after_start=False):
process.start()
processes["webui"] = process

try:
# log infors
while True:
no = queue.get()
if no == len(processes):
time.sleep(0.5)
dump_server_info(True)
break
else:
queue.put(no)

if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for name, process in processes.items():
if name != "model_worker":
process.join()
except:
if model_worker_process := processes.get("model_worker"):
model_worker_process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()
if len(processes) == 0:
parser.print_help()
else:
try:
# log infors
while True:
no = queue.get()
if no == len(processes):
time.sleep(0.5)
dump_server_info(True)
break
else:
queue.put(no)

if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for name, process in processes.items():
if name != "model_worker":
process.join()
except:
if model_worker_process := processes.get("model_worker"):
model_worker_process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()


# 服务启动后接口调用示例:
# import openai
Expand Down

0 comments on commit ca0ae29

Please sign in to comment.