Skip to content

Commit

Permalink
Fix compatibility issue on triton=2.3.1 (#219)
Browse files Browse the repository at this point in the history
## Summary

Fix #215 

num_warps can't be passed to the kernel function if `num_warps` is
declared at the kernel function.
tl.constexpr can't be automatically casted to int when comparing to an
integer number.


## Testing Done

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
```bash
❯ pip show triton
Name: triton
Version: 2.3.1
Summary: A language and compiler for custom Deep Learning operations
Home-page: https://github.com/openai/triton/
Author: Philippe Tillet
Author-email: phil@openai.com
License:
Location: /home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages
Requires: filelock
Required-by: liger-kernel, torch
❯ python3 -m pytest test/transformers/test_layer_norm.py
=============================================================== test session starts ================================================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
configfile: pyproject.toml
collected 16 items

test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-2-8-64] PASSED                                                [  6%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-2-8-128] PASSED                                               [ 12%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-2-8-256] PASSED                                               [ 18%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-2-8-512] PASSED                                               [ 25%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-4-16-64] PASSED                                               [ 31%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-4-16-128] PASSED                                              [ 37%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-4-16-256] PASSED                                              [ 43%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-4-16-512] PASSED                                              [ 50%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-8-32-64] PASSED                                               [ 56%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-8-32-128] PASSED                                              [ 62%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-8-32-256] PASSED                                              [ 68%]
test/transformers/test_layer_norm.py::test_liger_layer_norm[dtype0-1e-05-1e-05-8-32-512] PASSED                                              [ 75%]
test/transformers/test_layer_norm.py::test_liger_layer_norm_functional[dtype0-1e-05-1e-05-2-2-8] PASSED                                      [ 81%]
test/transformers/test_layer_norm.py::test_liger_layer_norm_functional[dtype0-1e-05-1e-05-2-2-41] PASSED                                     [ 87%]
test/transformers/test_layer_norm.py::test_liger_layer_norm_functional[dtype0-1e-05-1e-05-9-7-8] PASSED                                      [ 93%]
test/transformers/test_layer_norm.py::test_liger_layer_norm_functional[dtype0-1e-05-1e-05-9-7-41] PASSED                                     [100%]

================================================================= warnings summary =================================================================
.venv/lib/python3.10/site-packages/_pytest/config/__init__.py:1437
  /home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/_pytest/config/__init__.py:1437: PytestConfigWarning: Unknown config option: asyncio_mode

    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================== 16 passed, 1 warning in 4.05s ===========================================================
```
  • Loading branch information
Tcc0403 authored Sep 6, 2024
1 parent 7382a87 commit 5151712
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
10 changes: 4 additions & 6 deletions src/liger_kernel/ops/kl_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def _kldiv_kernel_forward(
loss_stride, # int, output stride
n_cols, # int, number of columns in the input tensor
BLOCK_SIZE: tl.constexpr,
num_warps: tl.constexpr,
log_target: tl.constexpr = False,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
Expand Down Expand Up @@ -86,7 +85,6 @@ def _kldiv_kernel_backward(
target_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
num_warps: tl.constexpr,
log_target: tl.constexpr = False,
):
pid = tl.program_id(0).to(tl.int64)
Expand Down Expand Up @@ -120,7 +118,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
grid = (B,)
reduction = _str_to_reduction_mode[reduction]

out_size = (B, S) if reduction == _REDUCTION_MODE_NONE else (B,)
out_size = (B, S) if reduction == _REDUCTION_MODE_NONE.value else (B,)
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)

_kldiv_kernel_forward[grid](
Expand All @@ -140,11 +138,11 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
# calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
if reduction == _REDUCTION_MODE_BATCHMEAN:
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
return output_tensor.sum() / B
elif reduction == _REDUCTION_MODE_SUM:
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0)
elif reduction == _REDUCTION_MODE_MEAN:
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.mean(dim=0)
else:
return output_tensor
Expand Down
4 changes: 0 additions & 4 deletions src/liger_kernel/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _layer_norm_forward_kernel(
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
num_warps: tl.constexpr,
):
"""
References:
Expand Down Expand Up @@ -90,7 +89,6 @@ def _layer_norm_backward_kernel(
n_cols,
rows_per_program: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
num_warps: tl.constexpr,
dtype: tl.constexpr,
):
"""
Expand Down Expand Up @@ -151,7 +149,6 @@ def layer_norm_forward(X, W, B, eps):
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)

assert (
X.shape[1] == W.shape[0]
), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
Expand Down Expand Up @@ -213,7 +210,6 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
n_cols,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
dtype=triton_dtype,
)

Expand Down

0 comments on commit 5151712

Please sign in to comment.