diff --git a/python/flashinfer/norm.py b/python/flashinfer/norm.py index 742e44e5..df9a77c7 100644 --- a/python/flashinfer/norm.py +++ b/python/flashinfer/norm.py @@ -14,9 +14,12 @@ limitations under the License. """ +from typing import Optional + import torch -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .utils import register_custom_op, register_fake_op _norm_module = None @@ -43,7 +46,7 @@ def rmsnorm( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - out: torch.Tensor = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Root mean square normalization. @@ -65,13 +68,28 @@ def rmsnorm( """ if out is None: out = torch.empty_like(input) - get_norm_module().rmsnorm(out, input, weight, eps) + _rmsnorm(out, input, weight, eps) return out +@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) +def _rmsnorm( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + get_norm_module().rmsnorm(out, input, weight, eps) + + +@register_fake_op("flashinfer::rmsnorm") +def _rmsnorm_fake( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + pass + + +@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual")) def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 -): +) -> None: r"""Fused add root mean square normalization. Parameters @@ -88,12 +106,19 @@ def fused_add_rmsnorm( get_norm_module().fused_add_rmsnorm(input, residual, weight, eps) +@register_fake_op("flashinfer::fused_add_rmsnorm") +def _fused_add_rmsnorm_fake( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + pass + + def gemma_rmsnorm( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - out: torch.Tensor = None, -): + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: r"""Gemma Root mean square normalization. Parameters @@ -114,13 +139,30 @@ def gemma_rmsnorm( """ if out is None: out = torch.empty_like(input) - get_norm_module().gemma_rmsnorm(out, input, weight, eps) + _gemma_rmsnorm(out, input, weight, eps) return out +@register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",)) +def _gemma_rmsnorm( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + get_norm_module().gemma_rmsnorm(out, input, weight, eps) + + +@register_fake_op("flashinfer::gemma_rmsnorm") +def _gemma_rmsnorm_fake( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + pass + + +@register_custom_op( + "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual") +) def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 -): +) -> None: r"""Gemma Fused add root mean square normalization. Parameters @@ -135,3 +177,10 @@ def gemma_fused_add_rmsnorm( Epsilon for numerical stability. """ get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps) + + +@register_fake_op("flashinfer::gemma_fused_add_rmsnorm") +def _gemma_fused_add_rmsnorm_fake( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + pass diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 0affeb5d..40ad522f 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -14,10 +14,13 @@ limitations under the License. """ -import torch import math from enum import Enum -from typing import Optional, Tuple, Union, Dict +from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union + +import torch +from torch.torch_version import TorchVersion +from torch.torch_version import __version__ as torch_version class PosEncodingMode(Enum): @@ -197,3 +200,28 @@ def _check_cached_qkv_data_type( raise ValueError( f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function." ) + + +def register_custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: Optional[Union[str, Sequence[str]]] = None, + schema: Optional[str] = None, +) -> Callable: + if TorchVersion(torch_version) < TorchVersion("2.4"): + return fn + return torch.library.custom_op( + name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema + ) + + +def register_fake_op( + name: str, + fn: Optional[Callable] = None, +) -> Callable: + if TorchVersion(torch_version) < TorchVersion("2.4"): + return fn + return torch.library.register_fake(name, fn) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..e8948378 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,56 @@ +import os +import types + +import flashinfer +import pytest +import torch +from torch.torch_version import TorchVersion +from torch.torch_version import __version__ as torch_version + +TORCH_COMPILE_FNS = [ + flashinfer.norm.rmsnorm, + flashinfer.norm.fused_add_rmsnorm, + flashinfer.norm.gemma_rmsnorm, + flashinfer.norm.gemma_fused_add_rmsnorm, +] + + +def _monkeypatch_add_torch_compile(func): + """ + Replace the given function with its torch.compile version. + """ + + from torch._library.custom_ops import CustomOpDef + + if type(func) is types.FunctionType: + fn = func + elif isinstance(func, CustomOpDef): + fn = func._init_fn + else: + raise ValueError(f"Unsupported fn type {type(func)}") + + components = fn.__module__.split(".") + assert components[0] == "flashinfer" + module = flashinfer + for component in components[1:]: + module = getattr(module, component) + + setattr( + module, + fn.__name__, + torch.compile( + func, + fullgraph=True, + backend="inductor", + mode="max-autotune-no-cudagraphs", + ), + ) + print("Applied torch.compile to", f"{fn.__module__}.{fn.__name__}") + + +def pytest_configure(config): + if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1": + if torch_version < TorchVersion("2.4"): + pytest.skip("torch.compile requires torch >= 2.4") + for fn in TORCH_COMPILE_FNS: + _monkeypatch_add_torch_compile(fn)