From 13566dc521f6ce187aa10081cb7e89a61060f033 Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:38:40 +0200 Subject: [PATCH] [Bugfix] fix OpenAI API server startup with --disable-frontend-multiprocessing (#8537) Signed-off-by: Amit Garg --- tests/entrypoints/openai/test_basic.py | 58 +++++++++++++++++++++++++- vllm/entrypoints/openai/api_server.py | 10 +++-- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index a7e418db30a29..d3aea533b6db9 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,4 +1,5 @@ from http import HTTPStatus +from typing import List import openai import pytest @@ -12,8 +13,44 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +@pytest.fixture(scope='module') +def server_args(request: pytest.FixtureRequest) -> List[str]: + """ Provide extra arguments to the server via indirect parametrization + + Usage: + + >>> @pytest.mark.parametrize( + >>> "server_args", + >>> [ + >>> ["--disable-frontend-multiprocessing"], + >>> [ + >>> "--model=NousResearch/Hermes-3-Llama-3.1-70B", + >>> "--enable-auto-tool-choice", + >>> ], + >>> ], + >>> indirect=True, + >>> ) + >>> def test_foo(server, client): + >>> ... + + This will run `test_foo` twice with servers with: + - `--disable-frontend-multiprocessing` + - `--model=NousResearch/Hermes-3-Llama-3.1-70B --enable-auto-tool-choice`. + + """ + if not hasattr(request, "param"): + return [] + + val = request.param + + if isinstance(val, str): + return [val] + + return request.param + + @pytest.fixture(scope="module") -def server(): +def server(server_args): args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -23,6 +60,7 @@ def server(): "--enforce-eager", "--max-num-seqs", "128", + *server_args, ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -35,6 +73,15 @@ async def client(server): yield async_client +@pytest.mark.parametrize( + "server_args", + [ + pytest.param([], id="default-frontend-multiprocessing"), + pytest.param(["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) @pytest.mark.asyncio async def test_show_version(client: openai.AsyncOpenAI): base_url = str(client.base_url)[:-3].strip("/") @@ -45,6 +92,15 @@ async def test_show_version(client: openai.AsyncOpenAI): assert response.json() == {"version": VLLM_VERSION} +@pytest.mark.parametrize( + "server_args", + [ + pytest.param([], id="default-frontend-multiprocessing"), + pytest.param(["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) @pytest.mark.asyncio async def test_check_health(client: openai.AsyncOpenAI): base_url = str(client.base_url)[:-3].strip("/") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cda1601549e9e..ae44b26a6c55a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -537,8 +537,11 @@ async def run_server(args, **uvicorn_kwargs) -> None: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valide_tool_parses)} }})") - temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - temp_socket.bind(("", args.port)) + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", args.port)) def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing @@ -552,8 +555,6 @@ def signal_handler(*_) -> None: model_config = await engine_client.get_model_config() init_app_state(engine_client, model_config, app.state, args) - temp_socket.close() - shutdown_task = await serve_http( app, host=args.host, @@ -564,6 +565,7 @@ def signal_handler(*_) -> None: ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + fd=sock.fileno(), **uvicorn_kwargs, )