From 22639130dbd373ba8ec4f55e81da796d1aec0084 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 23 Apr 2024 08:56:34 +0200 Subject: [PATCH] tests(tgi): reduce combinations Since all deployment options are tested in the implicit env file, we only use the local neuron deployment option when testing generation. --- .../tests/integration/test_generate.py | 30 ++----------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/text-generation-inference/tests/integration/test_generate.py b/text-generation-inference/tests/integration/test_generate.py index 6f5659401..0a6abf649 100644 --- a/text-generation-inference/tests/integration/test_generate.py +++ b/text-generation-inference/tests/integration/test_generate.py @@ -1,39 +1,15 @@ -import os import Levenshtein import pytest -@pytest.fixture(params=["hub-neuron", "hub", "local-neuron"]) -async def tgi_service(request, launcher, neuron_model_config): - """Expose a TGI service corresponding to a model configuration - - For each model configuration, the service will be started using the following - deployment options: - - from the hub original model (export parameters specified as env variables), - - from the hub pre-exported neuron model, - - from a local path to the neuron model. - """ - if request.param == "hub": - export_kwargs = neuron_model_config["export_kwargs"] - # Expose export parameters as environment variables - for kwarg, value in export_kwargs.items(): - env_var = f"HF_{kwarg.upper()}" - os.environ[env_var] = str(value) - model_name_or_path = neuron_model_config["model_id"] - elif request.param == "hub-neuron": - model_name_or_path = neuron_model_config["neuron_model_id"] - else: - model_name_or_path = neuron_model_config["neuron_model_path"] +@pytest.fixture +async def tgi_service(launcher, neuron_model_config): + model_name_or_path = neuron_model_config["neuron_model_path"] service_name = neuron_model_config["name"] with launcher(service_name, model_name_or_path) as tgi_service: await tgi_service.health(600) yield tgi_service - if request.param == "hub": - # Cleanup export parameters - for kwarg in export_kwargs: - env_var = f"HF_{kwarg.upper()}" - del os.environ[env_var] @pytest.mark.asyncio