diff --git a/test/transformers/test_qwen2vl_mrope.py b/test/transformers/test_qwen2vl_mrope.py index f8bcfd2a2..fb3f4b80e 100644 --- a/test/transformers/test_qwen2vl_mrope.py +++ b/test/transformers/test_qwen2vl_mrope.py @@ -2,16 +2,25 @@ import pytest import torch -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLRotaryEmbedding, - apply_multimodal_rotary_pos_emb, -) + +try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLRotaryEmbedding, + apply_multimodal_rotary_pos_emb, + ) + + IS_QWEN_AVAILABLE = True +except Exception: + IS_QWEN_AVAILABLE = False from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction from liger_kernel.transformers.functional import liger_qwen2vl_mrope from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +@pytest.mark.skipif( + not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." +) @pytest.mark.parametrize("bsz", [1, 2]) @pytest.mark.parametrize("seq_len", [128, 131]) @pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)]) @@ -87,6 +96,9 @@ def test_correctness( torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol) +@pytest.mark.skipif( + not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." +) @pytest.mark.parametrize( "bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section", [