Skip to content

Commit

Permalink
将所有httpx请求改为使用Client,提高效率,方便以后设置代理等。
Browse files Browse the repository at this point in the history
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
  • Loading branch information
liunux4odoo committed Sep 21, 2023
1 parent 818cb1a commit 13e8f69
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 102 deletions.
33 changes: 18 additions & 15 deletions server/llm_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address, list_llm_models
import httpx
from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client



def list_running_models(
Expand All @@ -13,8 +13,9 @@ def list_running_models(
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
with get_httpx_client() as client:
r = client.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
Expand All @@ -41,11 +42,12 @@ def stop_llm_model(
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
Expand All @@ -64,12 +66,13 @@ def change_llm_model(
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
Expand Down
35 changes: 18 additions & 17 deletions server/model_workers/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastchat import conversation as conv
import sys
import json
import httpx
from server.utils import get_httpx_client
from pprint import pprint
from typing import List, Dict

Expand Down Expand Up @@ -63,22 +63,23 @@ def generate_stream_gate(self, params):
}
print("request data sent to minimax:")
pprint(data)
response = httpx.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:])
if not data.get("usage"):
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip()
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
with get_httpx_client() as client:
response = client.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:])
if not data.get("usage"):
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip()
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"

def get_embeddings(self, params):
# TODO: 支持embeddings
Expand Down
22 changes: 12 additions & 10 deletions server/model_workers/qianfan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import httpx
from cachetools import cached, TTLCache
from server.utils import get_model_worker_config
from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict


Expand Down Expand Up @@ -54,7 +54,8 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
return httpx.get(url, params=params).json().get("access_token")
with get_httpx_client() as client:
return client.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")

Expand Down Expand Up @@ -91,14 +92,15 @@ def request_qianfan_api(
'Accept': 'application/json',
}

with httpx.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp


class QianFanWorker(ApiModelWorker):
Expand Down
107 changes: 100 additions & 7 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose,
FSCHAT_MODEL_WORKERS)
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT, HTTPX_PROXY)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union


thread_pool = ThreadPoolExecutor(os.cpu_count())
Expand Down Expand Up @@ -376,19 +377,63 @@ def get_prompt_template(name: str) -> Optional[str]:
return prompt_config.PROMPT_TEMPLATES.get(name)


def set_httpx_timeout(timeout: float = None):
def set_httpx_config(
timeout: float = HTTPX_DEFAULT_TIMEOUT,
proxy: Union[str, Dict] = None,
):
'''
设置httpx默认timeout。
httpx默认timeout是5秒,在请求LLM回答时不够用。
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
'''
import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import os

timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout

# 在进程范围内设置系统级代理
proxies = {}
if isinstance(proxy, str):
for n in ["http", "https", "all"]:
proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict):
for n in ["http", "https", "all"]:
if p:= proxy.get(n):
proxies[n + "_proxy"] = p
elif p:= proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p

for k, v in proxies.items():
os.environ[k] = v

# set host to bypass proxy
no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()]
no_proxy += [
# do not use proxy for locahost
"http://127.0.0.1",
"http://localhost",
]
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
host = ":".join(x.split(":")[:2])
if host not in no_proxy:
no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy)

# TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings
# def _get_proxies():
# return {}

# import urllib.request
# urllib.request.getproxies = _get_proxies


# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
Expand Down Expand Up @@ -436,3 +481,51 @@ def run_in_thread_pool(
for obj in as_completed(tasks):
yield obj.result()


def get_httpx_client(
use_async: bool = False,
proxies: Union[str, Dict] = None,
timeout: float = HTTPX_DEFAULT_TIMEOUT,
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
'''
helper to get httpx client with default proxies that bypass local addesses.
'''
default_proxies = {
# do not use proxy for locahost
"all://127.0.0.1": None,
"all://localhost": None,
}
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
host = ":".join(x.split(":")[:2])
default_proxies.update({host: None})

# get proxies from system envionrent
default_proxies.update({
"http://": os.environ.get("http_proxy"),
"https://": os.environ.get("https_proxy"),
"all://": os.environ.get("all_proxy"),
})
for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip():
default_proxies.update({host: None})

# merge default proxies with user provided proxies
if isinstance(proxies, str):
proxies = {"all://": proxies}

if isinstance(proxies, dict):
default_proxies.update(proxies)

# construct Client
kwargs.update(timeout=timeout, proxies=default_proxies)
if use_async:
return httpx.AsyncClient(**kwargs)
else:
return httpx.Client(**kwargs)

27 changes: 19 additions & 8 deletions startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
Expand Down Expand Up @@ -203,7 +203,6 @@ def create_openai_api_app(
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
if started_event is not None:
started_event.set()

Expand All @@ -214,6 +213,8 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
from fastapi import Body
import time
import sys
from server.utils import set_httpx_config
set_httpx_config()

app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
Expand Down Expand Up @@ -251,12 +252,13 @@ def release_worker(
logger.error(msg)
return {"code": 500, "msg": msg}

r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
with get_httpx_client() as client:
r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}

if new_model_name:
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
Expand Down Expand Up @@ -299,6 +301,8 @@ def run_model_worker(
import uvicorn
from fastapi import Body
import sys
from server.utils import set_httpx_config
set_httpx_config()

kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
Expand Down Expand Up @@ -337,6 +341,8 @@ def release_model(
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
from server.utils import set_httpx_config
set_httpx_config()

controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
Expand All @@ -353,6 +359,8 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
def run_api_server(started_event: mp.Event = None):
from server.api import create_app
import uvicorn
from server.utils import set_httpx_config
set_httpx_config()

app = create_app()
_set_app_event(app, started_event)
Expand All @@ -364,6 +372,9 @@ def run_api_server(started_event: mp.Event = None):


def run_webui(started_event: mp.Event = None):
from server.utils import set_httpx_config
set_httpx_config()

host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]

Expand Down
Loading

0 comments on commit 13e8f69

Please sign in to comment.