Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix]: During testing, use pytest monkeypatch for safely overriding the env var that indicates the vLLM backend #5210

Merged
Merged
27 changes: 10 additions & 17 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
from unittest.mock import patch

import pytest
import torch

from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
override_backend_env_variable)
from vllm.attention.selector import which_attn_to_use


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(name: str, device: str):
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = name

override_backend_env_variable(monkeypatch, name)

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
Expand All @@ -32,14 +33,11 @@ def test_env(name: str, device: str):
torch.float16, 16)
assert backend.name == name

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_flash_attn():
def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
Expand Down Expand Up @@ -71,14 +69,9 @@ def test_flash_attn():
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_invalid_env():
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
22 changes: 22 additions & 0 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Kernel test utils"""

import pytest

STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"


def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
backend_name: str) -> None:
'''
Override the environment variable indicating the vLLM backend temporarily,
using pytest monkeypatch to ensure that the env vars get
reset once the test context exits.

Arguments:

* mpatch: pytest monkeypatch instance
* backend_name: attention backend name to force
'''
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
Loading