diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b0cb6663a55b..06be7e62a1ab9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,13 +43,15 @@ steps: fast_check: true source_file_dependencies: - vllm/ + - tests/mq_llm_engine - tests/async_engine - tests/test_inputs - tests/multimodal - tests/test_utils - tests/worker commands: - - pytest -v -s async_engine # Async Engine + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index e22d547293445..9e8b2f1817567 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. - Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. - ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_TIMEOUT=1800000`` Example commands and usage: =========================== diff --git a/tests/async_engine/test_openapi_server.py b/tests/async_engine/test_openapi_server.py deleted file mode 100644 index 9e5c7c04287eb..0000000000000 --- a/tests/async_engine/test_openapi_server.py +++ /dev/null @@ -1,106 +0,0 @@ -import openai # use the official client for correctness check -import pytest -import pytest_asyncio - -from ..utils import VLLM_PATH, RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "facebook/opt-125m" -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() - - -@pytest.fixture(scope="module") -def server(): - args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "2048", - "--enforce-eager", - "--chat-template", - str(chatml_jinja_path), - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) - - -@pytest.mark.asyncio -async def test_single_completion(client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert len(completion.choices) == 1 - assert len(completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 5 - - -@pytest.mark.asyncio -async def test_single_chat_session(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert len(chat_completion.choices) == 1 - - choice = chat_completion.choices[0] - assert choice.finish_reason == "length" - assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=55, total_tokens=65) - - message = choice.message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py deleted file mode 100644 index cafd125c5a598..0000000000000 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ /dev/null @@ -1,120 +0,0 @@ -import asyncio -import tempfile -import unittest -import unittest.mock -import uuid - -import pytest -import pytest_asyncio - -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, - RPCClientClosedError) -from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest_asyncio.fixture(scope="function") -async def dummy_server(tmp_socket, monkeypatch): - dummy_engine = unittest.mock.AsyncMock() - - def dummy_engine_builder(*args, **kwargs): - return dummy_engine - - with monkeypatch.context() as m: - m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) - server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) - - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - try: - yield server - finally: - server_task.cancel() - server.cleanup() - - -@pytest_asyncio.fixture(scope="function") -async def client(tmp_socket): - client = AsyncEngineRPCClient(rpc_path=tmp_socket) - # Sanity check: the server is connected - await client._wait_for_server_rpc() - - try: - yield client - finally: - client.close() - - -@pytest.mark.asyncio -async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server _not_ reply with a model config - m.setattr(dummy_server, "get_config", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # And ensure the task completes anyway - # (client.setup() invokes server.get_config()) - client_task = asyncio.get_running_loop().create_task(client.setup()) - with pytest.raises(TimeoutError, match="Server didn't reply within"): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Hang all abort requests - m.setattr(dummy_server, "abort", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # The client should suppress timeouts on `abort`s - # and return normally, assuming the server will eventually - # abort the request. - client_task = asyncio.get_running_loop().create_task( - client.abort("test request id")) - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_data_methods_reraise_exceptions( - monkeypatch, dummy_server, client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server raise some random exception - exception = RuntimeError("Client test exception") - - def raiser(): - raise exception - - m.setattr(dummy_server.engine, "get_model_config", raiser) - m.setattr(client, "_data_timeout", 10) - - client_task = asyncio.get_running_loop().create_task(client.setup()) - # And ensure the task completes, raising the exception - with pytest.raises(RuntimeError, match=str(exception)): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_errors_after_closing(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - - client.close() - - # Healthchecks and generate requests will fail with explicit errors - with pytest.raises(RPCClientClosedError): - await client.check_health() - with pytest.raises(RPCClientClosedError): - async for _ in client.generate(None, None, None): - pass - - # But no-ops like aborting will pass - await client.abort("test-request-id") - await client.do_log_stats() diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b442a903c33ae..2ad8460023c25 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -18,38 +18,32 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] +MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] -@pytest.fixture(scope="module") -def server(): - args = [ - "--max-model-len", "4096", "--enable-chunked-prefill", - "--disable-log-requests", "--enforce-eager" - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def server_data(server): - return { - "url": f"{server.url_for('v1')}/completions", - } +@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) +def test_lm_eval_accuracy(more_args): + args = list(DEFAULT_ARGS) + args.extend(more_args) + print(f"Running with: {args}") -def test_lm_eval_accuracy(server_data): - model_args = (f"model={MODEL_NAME}," - f"base_url={server_data['url']}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") - - results = lm_eval.simple_evaluate( - model="local-completions", - model_args=model_args, - tasks=TASK, - ) - - measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL_NAME}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/async_engine/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py similarity index 99% rename from tests/async_engine/test_chat_template.py rename to tests/entrypoints/openai/test_chat_template.py index 61a6d77cd8756..b98ab2e30d78d 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -5,7 +5,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -from ..utils import VLLM_PATH +from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() diff --git a/tests/entrypoints/openai/test_mp_api_server.py b/tests/entrypoints/openai/test_mp_api_server.py deleted file mode 100644 index fbfe0db19dd03..0000000000000 --- a/tests/entrypoints/openai/test_mp_api_server.py +++ /dev/null @@ -1,40 +0,0 @@ -import time - -import pytest - -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser - - -@pytest.mark.asyncio -async def test_mp_crash_detection(): - - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c3a6c65be1d90..de2a932199a01 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from vllm.config import MultiModalConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer @@ -52,8 +52,9 @@ def test_async_serving_chat_init(): def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 325bc03434287..6d9e620b4af7d 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -4,7 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) @@ -18,7 +18,7 @@ async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_engine_client = MagicMock(spec=EngineClient) mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 73ecb74007272..25ab91ef69333 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path): prompt="Hello, my name is") # Now the server should shut down - return_code = remote_server.proc.wait(timeout=3) + return_code = remote_server.proc.wait(timeout=8) assert return_code is not None diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/mq_llm_engine/__init__.py similarity index 100% rename from tests/entrypoints/openai/rpc/__init__.py rename to tests/mq_llm_engine/__init__.py diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py new file mode 100644 index 0000000000000..782b508a57149 --- /dev/null +++ b/tests/mq_llm_engine/test_abort.py @@ -0,0 +1,67 @@ +"""Test that aborting is handled properly.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" +EXPECTED_TOKENS = 250 + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_id_to_be_aborted = "request-aborted" + request_ids_a = [f"request-a-{idx}" for idx in range(10)] + request_ids_b = [f"request-b-{idx}" for idx in range(10)] + + # Requests started before one to be aborted. + tasks = [] + for request_id in request_ids_a: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Aborted. + task_aborted = asyncio.create_task( + generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) + + # Requests started after one to be aborted. + for request_id in request_ids_b: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Actually abort. + await asyncio.sleep(0.5) + await client.abort(request_id_to_be_aborted) + + # Confirm that we got all the EXPECTED tokens from the requests. + for task in tasks: + count, request_id = await task + assert count == EXPECTED_TOKENS, ( + f"{request_id} generated only {count} tokens") + + # Cancel task (this will hang indefinitely if not). + task_aborted.cancel() + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py new file mode 100644 index 0000000000000..49cfc5aa04c36 --- /dev/null +++ b/tests/mq_llm_engine/test_error_handling.py @@ -0,0 +1,244 @@ +"""Test that various errors are handled properly.""" + +import asyncio +import tempfile +import time +import uuid +from unittest.mock import Mock + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.lora.request import LoRARequest +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.execute_model = Mock( + side_effect=RAISED_ERROR(RAISED_VALUE)) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_evil_forward(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_forward) as engine: + + client = await engine.make_client() + + # Server should be healthy after initial probe. + await asyncio.sleep(2.0) + await client.check_health() + + # Throws an error in first forward pass. + with pytest.raises(RAISED_ERROR): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + # Engine is errored, should get ENGINE_DEAD_ERROR. + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + await asyncio.sleep(1.0) + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Shutdown. + client.close() + + +def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, + ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_health_check(tmp_socket): + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_model_executor_health) as engine: + + client = await engine.make_client() + assert client.is_running + + # Health probe should throw RAISED_ERROR. + await asyncio.sleep(15.) + + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Generate call should throw ENGINE_DEAD_ERROR + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + + client.close() + + +def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during abort call. + engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # Firsh check health should work. + await client.check_health() + + # Trigger an abort on the client side. + async def bad_abort_after_2s(): + await asyncio.sleep(2.0) + await client.abort(request_id="foo") + + # Trigger an abort in 2s from now. + abort_task = asyncio.create_task(bad_abort_after_2s()) + + # Exception in abort() will happen during this generation. + # This will kill the engine and should return ENGINE_DEAD_ERROR + # with reference to the original KeyError("foo") + with pytest.raises(MQEngineDeadError) as execinfo: + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=2000), + request_id=uuid.uuid4()): + pass + assert "KeyError" in repr(execinfo.value) + assert client.errored + + await abort_task + + # This should raise the original error. + with pytest.raises(RAISED_ERROR): + await client.check_health() + + client.close() + + +@pytest.mark.asyncio +async def test_bad_request(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + # Invalid request should fail, but not crash the server. + with pytest.raises(ValueError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-1", + lora_request=LoRARequest( + "invalid-lora", 1, + "invalid-path")): + pass + + # This request should be okay. + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-2"): + pass + + # Shutdown. + client.close() + + +@pytest.mark.asyncio +async def test_mp_crash_detection(monkeypatch): + + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + # When LLMEngine is loaded, it will crash. + def mock_init(): + raise ValueError + + monkeypatch.setattr(LLMEngine, "__init__", mock_init) + + start = time.perf_counter() + async with build_async_engine_client(args): + pass + end = time.perf_counter() + + assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " + "if there is an error in the startup.") + + +@pytest.mark.asyncio +async def test_mp_cuda_init(): + # it should not crash, when cuda is initialized + # in the API server process + import torch + torch.cuda.init() + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + async with build_async_engine_client(args): + pass diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py new file mode 100644 index 0000000000000..630c112d0f0c9 --- /dev/null +++ b/tests/mq_llm_engine/test_load.py @@ -0,0 +1,57 @@ +"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +NUM_EXPECTED_TOKENS = 10 +NUM_REQUESTS = 10000 + +# Scenarios to test for num generated token. +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_load(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(client, request_id, NUM_EXPECTED_TOKENS))) + + # Confirm that we got all the EXPECTED tokens from the requests. + failed_request_id = None + tokens = None + for task in tasks: + num_generated_tokens, request_id = await task + if (num_generated_tokens != NUM_EXPECTED_TOKENS + and failed_request_id is None): + failed_request_id = request_id + tokens = num_generated_tokens + + assert failed_request_id is None, ( + f"{failed_request_id} generated {tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py new file mode 100644 index 0000000000000..e27fd77923412 --- /dev/null +++ b/tests/mq_llm_engine/utils.py @@ -0,0 +1,78 @@ +import asyncio +import multiprocessing +from typing import Callable, Tuple, Union + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + + +async def generate( + client: MQLLMEngineClient, + request_id: str, + num_tokens: int, + return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: + + final_output = None + count = 0 + async for out in client.generate( + request_id=request_id, + inputs="Hello my name is Robert and", + sampling_params=SamplingParams(max_tokens=num_tokens, + temperature=0)): + + count += 1 + final_output = out + await asyncio.sleep(0.) + + if return_output: + return final_output + + # Confirm we generated all the tokens we expected. + return count, request_id + + +def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Run engine. + engine.start() + + +class RemoteMQLLMEngine: + + def __init__(self, + engine_args: AsyncEngineArgs, + ipc_path: str, + run_fn: Callable = run_normal) -> None: + + self.engine_args = engine_args + self.ipc_path = ipc_path + context = multiprocessing.get_context("spawn") + self.proc = context.Process(target=run_fn, + args=(engine_args, ipc_path)) + self.proc.start() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.kill() + + async def make_client(self) -> MQLLMEngineClient: + engine_config = self.engine_args.create_engine_config() + client = MQLLMEngineClient(self.ipc_path, engine_config) + while True: + try: + await client.setup() + break + except TimeoutError: + assert self.proc.is_alive() + return client diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 7f3fb595321ad..69ab67abdd12b 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,12 @@ +import os + from ..utils import compare_two_settings +# --enforce-eager on TPU causes graph compilation +# this times out default Health Check in the MQLLMEngine, +# so we set the timeout here to 30s +os.environ["VLLM_RPC_TIMEOUT"] = "30000" + def test_custom_dispatcher(): compare_two_settings("google/gemma-2b", diff --git a/tests/utils.py b/tests/utils.py index f6c2be17ebdcf..81442cad78da2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,7 +119,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() try: - self.proc.wait(3) + self.proc.wait(8) except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 410e6ffaa2d50..5c2244145955d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -601,9 +601,12 @@ def errored(self) -> bool: return self._errored_with is not None @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8b5009b2c6668..6e77065475b07 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1289,6 +1289,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # torch.distributed ops which may otherwise timeout, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() return ctx.request_outputs diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000000000..ba5c6e15fc821 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCHealthRequest: + pass + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, + RPCStartupRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000000000..18b620c74ddf9 --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,452 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient: + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: EngineConfig): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + if engine_args.pipeline_parallel_size > 1: + return True + + is_embedding = ModelConfig( + model=engine_args.model, + revision=engine_args.revision, + tokenizer=engine_args.model, + tokenizer_mode="auto", + trust_remote_code=engine_args.trust_remote_code, + quantization=engine_args.quantization, + seed=0, + dtype="auto").embedding_mode + + return is_embedding + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQLLMEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). + """ + + try: + while True: + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._send_one_way_rpc_request( + RPCHealthRequest(), self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) + else: + # Server sent a health status message unprompted. + await self._check_success( + error_message="Health check failed.", + socket=self.health_socket) + + logger.debug("Health probe successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + + if request_id is None: + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Ignore do_log_stats (handled on MQLLMEngine polling)""" + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if sampling_params.logits_processors: + # Defensive shallow copy + sampling_params = copy.copy(sampling_params) + logits_processors = sampling_params.logits_processors + sampling_params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000000000..70cd6e5cb6000 --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,321 @@ +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, LLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine.generate` is kicked off when a new + RPCGenerateRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + executor_class = LLMEngine._get_executor_cls(engine_config) + + return cls( + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + + try: + return self.engine.step() + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCGenerateRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + lprocs = cloudpickle.loads(frames[1].buffer) + request.sampling_params.logits_processors = lprocs + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() + else: + raise ValueError("Unknown RPCRequest Type: {request}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_generate_request(self, request: RPCGenerateRequest): + """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_health_request(self): + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm + raise KeyboardInterrupt("MQLLMEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + engine.start() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -14,8 +14,8 @@ @runtime_checkable -class AsyncEngineClient(Protocol): - """Protocol class for Clients to AsyncLLMEngine""" +class EngineClient(Protocol): + """Protocol class for Clients to Engine""" @property def is_running(self) -> bool: @@ -30,8 +30,8 @@ def errored(self) -> bool: ... @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" + def dead_error(self) -> BaseException: + ... def generate( self, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 47d227010c075..5dcf50bd1b0a1 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,21 +1,21 @@ import asyncio import signal from http import HTTPStatus -from typing import Any, Optional +from typing import Any import uvicorn from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, limit_concurrency: Optional[int], - **uvicorn_kwargs: Any): +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -26,15 +26,6 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int], logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server) @@ -63,7 +54,7 @@ async def dummy_shutdown() -> None: logger.debug( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) - logger.info("Gracefully stopping http server") + logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() @@ -90,7 +81,7 @@ async def runtime_error_handler(request: Request, __): return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) @app.exception_handler(AsyncEngineDeadError) - async def engine_dead_handler(_, __): + async def async_engine_dead_handler(_, __): """Kill the server if the async engine is already dead. It will not handle any further requests.""" if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: @@ -99,3 +90,14 @@ async def engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b50fc6a265f8d..b263384dd3778 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -25,7 +25,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -43,8 +45,6 @@ TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -66,29 +66,16 @@ _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str], - revision: Optional[str]) -> bool: - return ModelConfig(model=model_name, - revision=revision, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - quantization=quantization, - seed=0, - dtype="auto").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): try: if app.state.log_stats: - async_engine_client = app.state.engine_client + engine_client: EngineClient = app.state.engine_client async def _force_log(): while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() + await asyncio.sleep(10.) + await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) _running_tasks.add(task) @@ -107,9 +94,9 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + args: Namespace) -> AsyncIterator[Optional[EngineClient]]: - # Context manager to handle async_engine_client lifecycle + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) @@ -122,19 +109,18 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[AsyncEngineClient]]: +) -> AsyncIterator[Optional[EngineClient]]: """ - Create AsyncEngineClient, either: + Create EngineClient, either: - in-process using the AsyncLLMEngine Directly - multiprocess using AsyncLLMEngine RPC Returns the Client or None if the creation failed. """ - # If manually triggered or embedding model, use AsyncLLMEngine in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization, engine_args.revision) + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) or disable_frontend_multiprocessing): engine_config = engine_args.create_engine_config() uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), @@ -172,56 +158,60 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) - - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - rpc_server_process.start() - logger.info("Started engine process with PID %d", - rpc_server_process.pid) + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): - logger.error( - "RPCServer process died before responding " - "to readiness probe") + if not engine_process.is_alive(): + logger.error("Engine process died before responding " + "to readiness probe") yield None return - yield rpc_client # type: ignore[misc] + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() - # Wait for server process to join - rpc_server_process.join() + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) router = APIRouter() @@ -269,7 +259,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding: return request.app.state.openai_serving_embedding -def engine_client(request: Request) -> AsyncEngineClient: +def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -466,7 +456,7 @@ async def authentication(request: Request, call_next): def init_app_state( - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, state: State, args: Namespace, @@ -481,11 +471,11 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) - state.engine_client = async_engine_client + state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.openai_serving_chat = OpenAIServingChat( - async_engine_client, + engine_client, model_config, served_model_names, args.response_role, @@ -497,7 +487,7 @@ def init_app_state( enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) state.openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -506,13 +496,13 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) state.openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + engine_client, model_config, served_model_names, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -531,19 +521,18 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as async_engine_client: + async with build_async_engine_client(args) as engine_client: # If None, creation of the client failed and we exit. - if async_engine_client is None: + if engine_client is None: return app = build_app(args) - model_config = await async_engine_client.get_model_config() - init_app_state(async_engine_client, model_config, app.state, args) + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) shutdown_task = await serve_http( app, - limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py deleted file mode 100644 index efc7e43afdcc9..0000000000000 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Mapping, Optional, Union - -from vllm.inputs import PromptInputs -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams - -# Success string used for RPC instructions. -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py deleted file mode 100644 index 9b88db746be5c..0000000000000 --- a/vllm/entrypoints/openai/rpc/client.py +++ /dev/null @@ -1,451 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, - VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS - self._errored = False - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 - - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" - while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() - self.context.destroy() - - @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(INPROC_PROXY_PATH) - yield socket - finally: - socket.close(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.to_proxy_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - return pickle.loads(frame.buffer) - - # Make a new socket connection. - if socket is None: - with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request) - - # Use existing socket connection. - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - finished = False - try: - with self.to_proxy_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - assert isinstance(message, Frame) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output - - finally: - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index 460ff0636b6e9..0000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,243 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("AsyncEngineRPCServer terminated") - - signal.signal(signal.SIGTERM, signal_handler) - - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 58e42fb5363fb..e4f1c834b9105 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,7 +9,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, @@ -45,7 +45,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -57,7 +57,7 @@ def __init__(self, return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, tool_parser: Optional[str] = None): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -105,6 +105,12 @@ async def create_chat_completion( logger.error("Error with model %s", error_check_ret) return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + try: ( lora_request, @@ -112,8 +118,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) @@ -206,8 +211,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -215,7 +220,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.async_engine_client.generate( + result_generator = self.engine_client.generate( engine_inputs, sampling_params, request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 42142efb5f23e..14fa60243c584 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -52,7 +52,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -78,6 +78,12 @@ async def create_completion( if error_check_ret is not None: return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( @@ -95,8 +101,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -124,8 +129,8 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -133,7 +138,7 @@ async def create_completion( raw_request.headers): log_tracing_disabled_warning() - generator = self.async_engine_client.generate( + generator = self.engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..f111a3a8277b5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -8,7 +8,7 @@ from typing_extensions import assert_never from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -118,8 +118,7 @@ async def create_embedding( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -144,7 +143,7 @@ async def create_embedding( "Prompt adapter is not supported " "for embedding models") - generator = self.async_engine_client.encode( + generator = self.engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac74527441cd9..72f9381abc7db 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -64,7 +64,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -75,7 +75,7 @@ def __init__( ): super().__init__() - self.async_engine_client = async_engine_client + self.engine_client = engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -159,7 +159,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.async_engine_client.get_decoding_config() + decoding_config = await self.engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 6e802b71ae2b4..8f8862897fc4e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (apply_hf_chat_template, apply_mistral_chat_template, load_chat_template, @@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -37,7 +37,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -66,7 +66,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): @@ -132,7 +132,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index 2003ede95d2d8..262e56869e885 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 + VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False @@ -392,8 +392,8 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_GET_DATA_TIMEOUT_MS": - lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 7380b73ad6548..9ad240ef60820 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -106,6 +106,7 @@ def _init_executor(self) -> None: )) for rank in range(1, world_size) ] + self.worker_monitor = None if world_size != 1 or is_async: if is_async: async_worker_list = self.workers + [self.driver_worker] diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index aa2a16c04d08d..5bef76b90d332 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -168,6 +168,8 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], self.tasks[task_id] = future try: self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise except BaseException as e: del self.tasks[task_id] raise ChildProcessError("worker died") from e @@ -222,6 +224,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except SystemExit: + raise except KeyboardInterrupt: break except BaseException as e: