diff --git a/docs/en/llm/proxy_server.md b/docs/en/llm/proxy_server.md index 1169794d4..26b24f534 100644 --- a/docs/en/llm/proxy_server.md +++ b/docs/en/llm/proxy_server.md @@ -7,12 +7,19 @@ The request distributor service can parallelize multiple api_server services. Us Start the proxy service: ```shell -python3 -m lmdeploy.serve.proxy.proxy --server_name {server_name} --server_port {server_port} --strategy "min_expected_latency" +lmdeploy serve proxy --server-name {server_name} --server-port {server_port} --strategy "min_expected_latency" ``` After startup is successful, the URL of the proxy service will also be printed by the script. Access this URL in your browser to open the Swagger UI. +Subsequently, users can add it directly to the proxy service when starting the `api_server` service by using the `--proxy-url` command. For example: +`lmdeploy serve api_server InternLM/internlm2-chat-1_8b --proxy-url http://0.0.0.0:8000`。 +In this way, users can access the services of the `api_server` through the proxy node, and the usage of the proxy node is exactly the same as that of the `api_server`, both of which are compatible with the OpenAI format. -## API +- /v1/models +- /v1/chat/completions +- /v1/completions + +## Node Management Through Swagger UI, we can see multiple APIs. Those related to api_server node management include: @@ -22,13 +29,64 @@ Through Swagger UI, we can see multiple APIs. Those related to api_server node m They respectively represent viewing all api_server service nodes, adding a certain node, and deleting a certain node. -APIs related to usage include: +### Node Management through curl -- /v1/models -- /v1/chat/completions -- /v1/completions +```shell +curl -X 'GET' \ + 'http://localhost:8000/nodes/status' \ + -H 'accept: application/json' +``` -The usage of these APIs is the same as that of api_server. +```shell +curl -X 'POST' \ + 'http://localhost:8000/nodes/add' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "url": "http://0.0.0.0:23333" +}' +``` + +```shell +curl -X 'POST' \ + 'http://localhost:8000/nodes/remove?node_url=http://0.0.0.0:23333' \ + -H 'accept: application/json' \ + -d '' +``` + +### Node Management through python + +```python +# query all nodes +import requests +url = 'http://localhost:8000/nodes/status' +headers = {'accept': 'application/json'} +response = requests.get(url, headers=headers) +print(response.text) +``` + +```python +# add a new node +import requests +url = 'http://localhost:8000/nodes/add' +headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' +} +data = {"url": "http://0.0.0.0:23333"} +response = requests.post(url, headers=headers, json=data) +print(response.text) +``` + +```python +# delete a node +import requests +url = 'http://localhost:8000/nodes/remove' +headers = {'accept': 'application/json',} +params = {'node_url': 'http://0.0.0.0:23333',} +response = requests.post(url, headers=headers, data='', params=params) +print(response.text) +``` ## Dispatch Strategy diff --git a/docs/zh_cn/llm/proxy_server.md b/docs/zh_cn/llm/proxy_server.md index 79d8e45f6..960ab7a74 100644 --- a/docs/zh_cn/llm/proxy_server.md +++ b/docs/zh_cn/llm/proxy_server.md @@ -7,12 +7,18 @@ 启动代理服务: ```shell -python3 -m lmdeploy.serve.proxy.proxy --server_name {server_name} --server_port {server_port} --strategy "min_expected_latency" +lmdeploy serve proxy --server-name {server_name} --server-port {server_port} --strategy "min_expected_latency" ``` 启动成功后,代理服务的 URL 也会被脚本打印。浏览器访问这个 URL,可以打开 Swagger UI。 +随后,用户可以在启动 api_server 服务的时候,通过 `--proxy-url` 命令将其直接添加到代理服务中。例如:`lmdeploy serve api_server InternLM/internlm2-chat-1_8b --proxy-url http://0.0.0.0:8000`。 +这样,用户可以通过代理节点访问 api_server 的服务,代理节点的使用方式和 api_server 一模一样,都是兼容 OpenAI 的形式。 -## API +- /v1/models +- /v1/chat/completions +- /v1/completions + +## 节点管理 通过 Swagger UI,我们可以看到多个 API。其中,和 api_server 节点管理相关的有: @@ -20,15 +26,66 @@ python3 -m lmdeploy.serve.proxy.proxy --server_name {server_name} --server_port - /nodes/add - /nodes/remove -他们分别表示,查看所有的 api_server 服务节点,增加某个节点,删除某个节点。 +他们分别表示,查看所有的 api_server 服务节点,增加某个节点,删除某个节点。他们的使用方式,最直接的可以在浏览器里面直接操作。也可以通过命令行或者 python 操作。 -和使用相关的 api 有: +### 通过 command 增删查 -- /v1/models -- /v1/chat/completions -- /v1/completions +```shell +curl -X 'GET' \ + 'http://localhost:8000/nodes/status' \ + -H 'accept: application/json' +``` -这些 API 的使用方式和 api_server 一样。 +```shell +curl -X 'POST' \ + 'http://localhost:8000/nodes/add' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "url": "http://0.0.0.0:23333" +}' +``` + +```shell +curl -X 'POST' \ + 'http://localhost:8000/nodes/remove?node_url=http://0.0.0.0:23333' \ + -H 'accept: application/json' \ + -d '' +``` + +### 通过 python 脚本增删查 + +```python +# 查询所有节点 +import requests +url = 'http://localhost:8000/nodes/status' +headers = {'accept': 'application/json'} +response = requests.get(url, headers=headers) +print(response.text) +``` + +```python +# 添加新节点 +import requests +url = 'http://localhost:8000/nodes/add' +headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' +} +data = {"url": "http://0.0.0.0:23333"} +response = requests.post(url, headers=headers, json=data) +print(response.text) +``` + +```python +# 删除某个节点 +import requests +url = 'http://localhost:8000/nodes/remove' +headers = {'accept': 'application/json',} +params = {'node_url': 'http://0.0.0.0:23333',} +response = requests.post(url, headers=headers, data='', params=params) +print(response.text) +``` ## 分发策略 diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index effb6e5e5..c82089b4e 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -134,6 +134,10 @@ def add_parser_api_server(): type=str, default=['*'], help='A list of allowed http headers for cors') + parser.add_argument('--proxy-url', + type=str, + default=None, + help='The proxy url for api server.') # common args ArgumentHelper.backend(parser) ArgumentHelper.log_level(parser) @@ -204,6 +208,32 @@ def add_parser_api_client(): 'api key will be used') ArgumentHelper.session_id(parser) + @staticmethod + def add_parser_proxy(): + """Add parser for proxy server command.""" + parser = SubCliServe.subparsers.add_parser( + 'proxy', + formatter_class=DefaultsAndTypesHelpFormatter, + description=SubCliServe.proxy.__doc__, + help=SubCliServe.proxy.__doc__) + parser.set_defaults(run=SubCliServe.proxy) + parser.add_argument('--server-name', + type=str, + default='0.0.0.0', + help='Host ip for proxy serving') + parser.add_argument('--server-port', + type=int, + default=8000, + help='Server port of the proxy') + parser.add_argument( + '--strategy', + type=str, + choices=['random', 'min_expected_latency', 'min_observed_latency'], + default='min_expected_latency', + help='the strategy to dispatch requests to nodes') + ArgumentHelper.api_keys(parser) + ArgumentHelper.ssl(parser) + @staticmethod def gradio(args): """Serve LLMs with web UI using gradio.""" @@ -311,6 +341,7 @@ def api_server(args): log_level=args.log_level.upper(), api_keys=args.api_keys, ssl=args.ssl, + proxy_url=args.proxy_url, max_log_len=args.max_log_len) @staticmethod @@ -320,8 +351,16 @@ def api_client(args): kwargs = convert_args(args) run_api_client(**kwargs) + @staticmethod + def proxy(args): + """Proxy server that manages distributed api_server nodes.""" + from lmdeploy.serve.proxy.proxy import proxy + kwargs = convert_args(args) + proxy(**kwargs) + @staticmethod def add_parsers(): SubCliServe.add_parser_gradio() SubCliServe.add_parser_api_server() SubCliServe.add_parser_api_client() + SubCliServe.add_parser_proxy() diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index f8ff7105d..dfcdf82be 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -40,6 +40,9 @@ class VariableInterface: session_id: int = 0 api_keys: Optional[List[str]] = None request_hosts = [] + # following are for registering to proxy server + proxy_url: Optional[str] = None + api_server_url: Optional[str] = None app = FastAPI(docs_url='/') @@ -926,6 +929,33 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +@app.on_event('startup') +async def startup_event(): + if VariableInterface.proxy_url is None: + return + try: + import requests + url = f'{VariableInterface.proxy_url}/nodes/add' + data = { + 'url': VariableInterface.api_server_url, + 'status': { + 'models': get_model_list() + } + } + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + response = requests.post(url, headers=headers, json=data) + + if response.status_code != 200: + raise HTTPException(status_code=400, + detail='Service registration failed') + print(response.text) + except Exception as e: + print(f'Service registration failed: {e}') + + def serve(model_path: str, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', @@ -941,6 +971,7 @@ def serve(model_path: str, log_level: str = 'ERROR', api_keys: Optional[Union[List[str], str]] = None, ssl: bool = False, + proxy_url: Optional[str] = None, max_log_len: int = None, **kwargs): """An example to perform model inference through the command line @@ -983,6 +1014,7 @@ def serve(model_path: str, api key applied. ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'. + proxy_url (str): The proxy url to register the api_server. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited """ @@ -1019,6 +1051,9 @@ def serve(model_path: str, max_log_len=max_log_len, **kwargs) + if proxy_url is not None: + VariableInterface.proxy_url = proxy_url + VariableInterface.api_server_url = f'{http_or_https}://{server_name}:{server_port}' # noqa for i in range(3): print( f'HINT: Please open \033[93m\033[1m{http_or_https}://' diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 15d182a3d..5f05930bd 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -4,6 +4,7 @@ import os import os.path as osp import random +import threading import time from collections import deque from http import HTTPStatus @@ -46,6 +47,17 @@ class Node(BaseModel): status: Optional[Status] = None +CONTROLLER_HEART_BEAT_EXPIRATION = int( + os.getenv('LMDEPLOY_CONTROLLER_HEART_BEAT_EXPIRATION', 90)) + + +def heart_beat_controller(proxy_controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + logger.info('Start heart beat check') + proxy_controller.remove_stale_nodes_by_expiration() + + class NodeManager: """Manage all the sub nodes. @@ -77,6 +89,10 @@ def __init__(self, for url, status in self.nodes.items(): status = Status(**status) self.nodes[url] = status + self.heart_beat_thread = threading.Thread(target=heart_beat_controller, + args=(self, ), + daemon=True) + self.heart_beat_thread.start() def update_config_file(self): """Update the config file.""" @@ -100,6 +116,10 @@ def add(self, node_url: str, status: Optional[Status] = None): """ if status is None: status = self.nodes.get(node_url, Status()) + if status.models != []: # force register directly + self.nodes[node_url] = status + self.update_config_file() + return try: from lmdeploy.serve.openai.api_client import APIClient client = APIClient(api_server_url=node_url) @@ -115,6 +135,22 @@ def remove(self, node_url: str): self.nodes.pop(node_url) self.update_config_file() + def remove_stale_nodes_by_expiration(self): + """remove stale nodes.""" + to_be_deleted = [] + for node_url in self.nodes.keys(): + url = f'{node_url}/health' + headers = {'accept': 'application/json'} + try: + response = requests.get(url, headers=headers) + if response.status_code != 200: + to_be_deleted.append(node_url) + except: # noqa + to_be_deleted.append(node_url) + for node_url in to_be_deleted: + self.remove(node_url) + logger.info(f'Removed node_url: {node_url}') + @property def model_list(self): """Supported model list.""" @@ -476,7 +512,7 @@ async def completions_v1(request: CompletionRequest, def proxy(server_name: str = '0.0.0.0', - server_port: int = 10086, + server_port: int = 8000, strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency', api_keys: Optional[Union[List[str], str]] = None, @@ -486,7 +522,7 @@ def proxy(server_name: str = '0.0.0.0', Args: server_name (str): the server name of the proxy. Default to '0.0.0.0'. - server_port (str): the server port. Default to 10086. + server_port (str): the server port. Default to 8000. strategy ('random' | 'min_expected_latency' | 'min_observed_latency'): the strategy to dispatch requests to nodes. Default to 'min_expected_latency'