From 01121928ded116e32b5ecf2ba4093dda38a74829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 19 Sep 2024 18:39:48 +0200 Subject: [PATCH] add parametrized fixture --- tests/entrypoints/openai/test_basic.py | 58 +++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index a7e418db30a29..d3aea533b6db9 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,4 +1,5 @@ from http import HTTPStatus +from typing import List import openai import pytest @@ -12,8 +13,44 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +@pytest.fixture(scope='module') +def server_args(request: pytest.FixtureRequest) -> List[str]: + """ Provide extra arguments to the server via indirect parametrization + + Usage: + + >>> @pytest.mark.parametrize( + >>> "server_args", + >>> [ + >>> ["--disable-frontend-multiprocessing"], + >>> [ + >>> "--model=NousResearch/Hermes-3-Llama-3.1-70B", + >>> "--enable-auto-tool-choice", + >>> ], + >>> ], + >>> indirect=True, + >>> ) + >>> def test_foo(server, client): + >>> ... + + This will run `test_foo` twice with servers with: + - `--disable-frontend-multiprocessing` + - `--model=NousResearch/Hermes-3-Llama-3.1-70B --enable-auto-tool-choice`. + + """ + if not hasattr(request, "param"): + return [] + + val = request.param + + if isinstance(val, str): + return [val] + + return request.param + + @pytest.fixture(scope="module") -def server(): +def server(server_args): args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -23,6 +60,7 @@ def server(): "--enforce-eager", "--max-num-seqs", "128", + *server_args, ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -35,6 +73,15 @@ async def client(server): yield async_client +@pytest.mark.parametrize( + "server_args", + [ + pytest.param([], id="default-frontend-multiprocessing"), + pytest.param(["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) @pytest.mark.asyncio async def test_show_version(client: openai.AsyncOpenAI): base_url = str(client.base_url)[:-3].strip("/") @@ -45,6 +92,15 @@ async def test_show_version(client: openai.AsyncOpenAI): assert response.json() == {"version": VLLM_VERSION} +@pytest.mark.parametrize( + "server_args", + [ + pytest.param([], id="default-frontend-multiprocessing"), + pytest.param(["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) @pytest.mark.asyncio async def test_check_health(client: openai.AsyncOpenAI): base_url = str(client.base_url)[:-3].strip("/")