Skip to content

Commit

Permalink
TST Enable test_vera_dtypes on XPU with bf16 (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Aug 20, 2024
1 parent f71e89f commit eb5eb6e
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/test_vera.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn

from peft import PeftModel, VeraConfig, get_peft_model
from peft.utils import infer_device


class MLP(nn.Module):
Expand Down Expand Up @@ -284,9 +285,12 @@ def test_vera_different_shapes(self, mlp):

@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_vera_dtypes(self, dtype):
# 1872
if (dtype == torch.bfloat16) and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
pytest.skip("bfloat16 not supported on this system, skipping the test")
if dtype == torch.bfloat16:
# skip if bf16 is not supported on hardware, see #1872
is_xpu = infer_device() == "xpu"
is_cuda_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
if not (is_xpu or is_cuda_bf16):
pytest.skip("bfloat16 not supported on this system, skipping the test")

model = MLP().to(dtype)
config = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False)
Expand Down

0 comments on commit eb5eb6e

Please sign in to comment.