Skip to content

Commit

Permalink
feat: warmup for jit kernel tests (#629)
Browse files Browse the repository at this point in the history
Currently unittests are slow when using flashinfer jit because we only
compile kernels the first time we run it, it's blocking and didn't
compile multiple ops in parallel. This PR add a warmup pre-hook to
kernel unittests, so that we compile all necessary kernels before
running the unittests in JIT mode, which greatly accelerate the
unittests.

This PR also fixes the several issues with #628 :
1. using thread-safe `make_dirs(..., exist_ok=True)` instead of relying
on `os.path.exists`
2. change the signature of `parallel_load_modules` to lists of
`(jit_module_creation_func, args)` instead of lambda function, because
lambda function captures variable by ref instead of value, which may
cause some unexpected errors.
  • Loading branch information
yzh119 authored Nov 24, 2024
1 parent 92ac440 commit 8f5f349
Show file tree
Hide file tree
Showing 16 changed files with 493 additions and 48 deletions.
3 changes: 1 addition & 2 deletions python/flashinfer/jit/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str:

def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> None:
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
os.makedirs(gen_directory, exist_ok=True)
sources = [gen_directory / f"{act_func_name}_and_mul.cu"]
write_if_different(
sources[0],
Expand Down
12 changes: 0 additions & 12 deletions python/flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ def get_batch_decode_uri(

def gen_batch_decode_module(*args):
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
uri = get_batch_decode_uri(*args)
sources = get_batch_decode_sources(*args)
source_paths = []
Expand Down Expand Up @@ -214,8 +212,6 @@ def get_batch_decode_mla_uri(

def gen_batch_decode_mla_module(*args):
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
uri = get_batch_decode_mla_uri(*args)
sources = get_batch_decode_mla_sources(*args)
source_paths = []
Expand Down Expand Up @@ -275,8 +271,6 @@ def get_single_prefill_uri(

def gen_single_prefill_module(*args):
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
uri = get_single_prefill_uri(*args)
sources = get_single_prefill_sources(*args)
source_paths = []
Expand Down Expand Up @@ -341,8 +335,6 @@ def get_batch_prefill_uri(

def gen_batch_prefill_module(*args):
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
uri = get_batch_prefill_uri(*args)
sources = get_batch_prefill_sources(*args)
source_paths = []
Expand Down Expand Up @@ -518,8 +510,6 @@ def get_customize_single_prefill_sources(

def gen_customize_single_decode_module(module_name, *args):
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
sources = get_customize_single_decode_sources(*args)
source_paths = []
for suffix, source in zip(single_decode_suffix, sources):
Expand All @@ -532,8 +522,6 @@ def gen_customize_single_decode_module(module_name, *args):

def gen_customize_single_prefill_module(module_name, *args):
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
sources = get_customize_single_prefill_sources(*args)
source_paths = []
for suffix, source in zip(single_prefill_suffix, sources):
Expand Down
7 changes: 3 additions & 4 deletions python/flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from .env import FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR
from .env import FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR

if not os.path.exists(FLASHINFER_WORKSPACE_DIR):
os.makedirs(FLASHINFER_WORKSPACE_DIR)
os.makedirs(FLASHINFER_WORKSPACE_DIR, exist_ok=True)
os.makedirs(FLASHINFER_CSRC_DIR, exist_ok=True)


class FlashInferJITLogger(logging.Logger):
Expand Down Expand Up @@ -99,8 +99,7 @@ def load_cuda_ops(
logger.info(f"Loading JIT ops: {name}")
check_cuda_arch()
build_directory = FLASHINFER_JIT_DIR / name
if not os.path.exists(build_directory):
os.makedirs(build_directory, exist_ok=True)
os.makedirs(build_directory, exist_ok=True)
if extra_include_paths is None:
extra_include_paths = [
FLASHINFER_INCLUDE_DIR,
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pathlib
import threading
from typing import Callable, List
from typing import Any, Callable, List, Tuple

import torch

Expand All @@ -35,19 +35,19 @@ def write_if_different(path: pathlib.Path, content: str) -> None:


def parallel_load_modules(
load_module_funcs: List[Callable],
load_module_func_args: List[Tuple[Callable, List[Any]]],
):
threads = []
exceptions = []

def wrapper(func):
def wrapper(func, args):
try:
func()
func(*args)
except Exception as e:
exceptions.append((func, e))

for func in load_module_funcs:
thread = threading.Thread(target=wrapper, args=(func,))
for func, args in load_module_func_args:
thread = threading.Thread(target=wrapper, args=(func, args))
thread.start()
threads.append(thread)

Expand Down
149 changes: 149 additions & 0 deletions tests/jit_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import itertools

import torch

import flashinfer


def jit_decode_attention_func_args(
q_dtypes,
kv_dtypes,
head_dims,
pos_encoding_modes,
use_sliding_window_options,
use_logits_soft_cap_options,
):
load_module_func_args = []

for (
q_dtype,
kv_dtype,
head_dim,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
) in itertools.product(
q_dtypes,
kv_dtypes,
head_dims,
pos_encoding_modes,
use_sliding_window_options,
use_logits_soft_cap_options,
):
load_module_func_args.append(
(
flashinfer.decode.get_single_decode_module,
(
q_dtype,
kv_dtype,
q_dtype,
head_dim,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
),
)
)
load_module_func_args.append(
(
flashinfer.decode.get_batch_decode_module,
(
q_dtype,
kv_dtype,
q_dtype,
torch.int32,
head_dim,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
),
)
)

return load_module_func_args


def jit_prefill_attention_func_args(
q_dtypes,
kv_dtypes,
head_dims,
pos_encoding_modes,
use_sliding_window_options,
use_logits_soft_cap_options,
allow_fp16_qk_reduction_options,
):
load_module_func_args = []

for (
q_dtype,
kv_dtype,
head_dim,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
allow_fp16_qk_reduction,
) in itertools.product(
q_dtypes,
kv_dtypes,
head_dims,
pos_encoding_modes,
use_sliding_window_options,
use_logits_soft_cap_options,
allow_fp16_qk_reduction_options,
):
load_module_func_args.append(
(
flashinfer.prefill.gen_single_prefill_module,
(
q_dtype,
kv_dtype,
q_dtype,
head_dim,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
allow_fp16_qk_reduction,
),
)
)
load_module_func_args.append(
(
flashinfer.prefill.gen_batch_prefill_module,
(
q_dtype,
kv_dtype,
q_dtype,
torch.int32,
head_dim,
pos_encoding_mode,
use_sliding_window,
use_logits_soft_cap,
allow_fp16_qk_reduction,
),
)
)

load_module_func_args.append(
(
flashinfer.quantization.get_quantization_module,
[],
) # required for attention with custom mask
)

return load_module_func_args
32 changes: 32 additions & 0 deletions tests/test_alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,42 @@
import pytest
import torch
from alibi_reference import alibi_attention
from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args

import flashinfer


@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("seq_len", [1, 9, 81, 729])
@pytest.mark.parametrize("num_heads", [4, 8, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
Expand Down
32 changes: 32 additions & 0 deletions tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,42 @@

import pytest
import torch
from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args

import flashinfer


@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
)
+ jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("kv_len", [54, 97, 512])
@pytest.mark.parametrize("page_size", [1, 8, 16])
Expand Down
24 changes: 24 additions & 0 deletions tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,34 @@

import pytest
import torch
from jit_utils import jit_prefill_attention_func_args

import flashinfer


@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
[torch.float16], # q_dtypes
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
[128, 256], # head_dims
[0, 1, 2], # pos_encoding_modes
[False], # use_sliding_windows
[False, True], # use_logits_soft_caps
[False], # allow_fp16_qk_reductions
)
)
except Exception as e:
# abort the test session if warmup fails
pytest.exit(str(e))
finally:
yield


@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
Expand Down
Loading

0 comments on commit 8f5f349

Please sign in to comment.