-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: torch custom_op support: norm (#552)
Add torch custom_op (aka, torch library, torch.compile) support for `norm.py`. It should be a no-op for PyTorch < 2.4. Testing is done by `torch.compile` -- as we expect the custom_op marks can isolate out our kernels during torch.compile. To avoid changes to tests, I introduced some magic that replaces the kernels with a `torch.compile`-ed version. For example, to run with/without torch.compile: ```bash # With torch.compile FLASHINFER_TEST_TORCH_COMPILE=1 pytest -svx tests/test_norm.py # Without torch.compile pytest -svx tests/test_norm.py ``` If this PR looks good, I'll add it to more kernels.
- Loading branch information
1 parent
47583b3
commit f6e0010
Showing
3 changed files
with
143 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
import types | ||
|
||
import flashinfer | ||
import pytest | ||
import torch | ||
from torch.torch_version import TorchVersion | ||
from torch.torch_version import __version__ as torch_version | ||
|
||
TORCH_COMPILE_FNS = [ | ||
flashinfer.norm.rmsnorm, | ||
flashinfer.norm.fused_add_rmsnorm, | ||
flashinfer.norm.gemma_rmsnorm, | ||
flashinfer.norm.gemma_fused_add_rmsnorm, | ||
] | ||
|
||
|
||
def _monkeypatch_add_torch_compile(func): | ||
""" | ||
Replace the given function with its torch.compile version. | ||
""" | ||
|
||
from torch._library.custom_ops import CustomOpDef | ||
|
||
if type(func) is types.FunctionType: | ||
fn = func | ||
elif isinstance(func, CustomOpDef): | ||
fn = func._init_fn | ||
else: | ||
raise ValueError(f"Unsupported fn type {type(func)}") | ||
|
||
components = fn.__module__.split(".") | ||
assert components[0] == "flashinfer" | ||
module = flashinfer | ||
for component in components[1:]: | ||
module = getattr(module, component) | ||
|
||
setattr( | ||
module, | ||
fn.__name__, | ||
torch.compile( | ||
func, | ||
fullgraph=True, | ||
backend="inductor", | ||
mode="max-autotune-no-cudagraphs", | ||
), | ||
) | ||
print("Applied torch.compile to", f"{fn.__module__}.{fn.__name__}") | ||
|
||
|
||
def pytest_configure(config): | ||
if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1": | ||
if torch_version < TorchVersion("2.4"): | ||
pytest.skip("torch.compile requires torch >= 2.4") | ||
for fn in TORCH_COMPILE_FNS: | ||
_monkeypatch_add_torch_compile(fn) |