diff --git a/tests/test_alibi.py b/tests/test_alibi.py index f01811ec..2b15106b 100644 --- a/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -26,7 +26,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index 4d2d67c6..834d8ef3 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index f9ceadee..11ba55f2 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_prefill_attention_func_args( diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index 8672dbb0..682a4ada 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -26,7 +26,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_logits_cap.py b/tests/test_logits_cap.py index c42278aa..9bcf882a 100644 --- a/tests/test_logits_cap.py +++ b/tests/test_logits_cap.py @@ -26,7 +26,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py index 22db5f87..f83449dc 100644 --- a/tests/test_non_contiguous_decode.py +++ b/tests/test_non_contiguous_decode.py @@ -8,7 +8,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py index a45c09ad..601d1caa 100644 --- a/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_prefill_attention_func_args( diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index 5a8bbf2c..77338840 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_sliding_window.py b/tests/test_sliding_window.py index c552f73b..0b4f6fda 100644 --- a/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_tensor_cores_decode.py b/tests/test_tensor_cores_decode.py index bf312fb8..66309f45 100644 --- a/tests/test_tensor_cores_decode.py +++ b/tests/test_tensor_cores_decode.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args(