Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

优化server_config配置项 #1293

Merged
merged 14 commits into from
Aug 29, 2023
Merged
3 changes: 3 additions & 0 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ llm_model_dict = {
# LLM 名称
LLM_MODEL = "chatglm2-6b"

# 历史对话轮数
HISTORY_LEN = 3

# LLM 运行设备
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

Expand Down
36 changes: 2 additions & 34 deletions configs/server_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ FSCHAT_MODEL_WORKERS = {
"port": 20002,
"device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
"gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
"num_gpus": 1, # 使用GPU的数量
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
# "num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# "load_8bit": False, # 开启8bit量化
Expand Down Expand Up @@ -66,35 +66,3 @@ FSCHAT_CONTROLLER = {
"port": 20001,
"dispatch_method": "shortest_queue",
}


# 以下不要更改
def fschat_controller_address() -> str:
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"


def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := FSCHAT_MODEL_WORKERS.get(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"


def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"


def api_address() -> str:
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"


def webui_address() -> str:
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"
9 changes: 1 addition & 8 deletions server/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline
from server.utils import MakeFastAPIOffline, set_httpx_timeout


host_ip = "0.0.0.0"
Expand All @@ -15,13 +15,6 @@
base_url = "http://127.0.0.1:{}"


def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout


def create_controller_app(
dispatch_method="shortest_queue",
):
Expand Down
70 changes: 70 additions & 0 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs.model_config import LLM_MODEL
from typing import Any, Optional


Expand Down Expand Up @@ -186,3 +187,72 @@ async def redoc_html(request: Request) -> HTMLResponse:
with_google_fonts=False,
redoc_favicon_url=favicon,
)


# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
'''
加载model worker的配置项。
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import llm_model_dict

config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
return config


def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER

host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"


def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := get_model_worker_config(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
return ""


def fschat_openai_api_address() -> str:
from configs.server_config import FSCHAT_OPENAI_API

host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"


def api_address() -> str:
from configs.server_config import API_SERVER

host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"


def webui_address() -> str:
from configs.server_config import WEBUI_SERVER

host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"


def set_httpx_timeout(timeout: float = None):
'''
设置httpx默认timeout。
httpx默认timeout是5秒,在请求LLM回答时不够用。
'''
import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT

timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
14 changes: 4 additions & 10 deletions startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,15 @@
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, )
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout)
from server.utils import MakeFastAPIOffline, FastAPI
import argparse
from typing import Tuple, List
from configs import VERSION


def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout


def create_controller_app(
dispatch_method: str,
) -> FastAPI:
Expand Down Expand Up @@ -328,7 +322,7 @@ def dump_server_info(after_start=False):
import platform
import langchain
import fastchat
from configs.server_config import api_address, webui_address
from server.utils import api_address, webui_address

print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_kb_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path

Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_stream_chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY
from configs.server_config import API_SERVER, api_address
from server.utils import api_address

from pprint import pprint

Expand Down
8 changes: 3 additions & 5 deletions webui_pages/dialogue/dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def on_mode_change():
on_change=on_mode_change,
key="dialogue_mode",
)
history_len = st.number_input("历史对话轮数:", 0, 10, 3)

# todo: support history len
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)

def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
Expand All @@ -75,7 +73,7 @@ def on_kb_change():
on_change=on_kb_change,
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
Expand All @@ -87,7 +85,7 @@ def on_kb_change():
options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)

# Display chat messages from history on app rerun

Expand Down
13 changes: 2 additions & 11 deletions webui_pages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
HISTORY_LEN,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
Expand All @@ -20,24 +21,14 @@
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async
from server.utils import run_async, iter_over_async, set_httpx_timeout

from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from pprint import pprint


def set_httpx_timeout(timeout=60.0):
'''
设置httpx默认timeout到60秒。
httpx默认timeout是5秒,在请求LLM回答时不够用。
'''
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout


KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout()

Expand Down