Skip to content

Commit

Permalink
FIX: tl.program_id() does indeed not have a cast method in triton2.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
wizyoung committed Sep 26, 2024
1 parent 5c38d09 commit cd15004
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/liger_kernel/ops/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def _geglu_tanh_forward_kernel(
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0).cast(tl.int64)
program_id = tl.program_id(0).to(tl.int64)

# locate start index
a += program_id * stride
Expand All @@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
def _geglu_tanh_backward_kernel(
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0).cast(tl.int64)
program_id = tl.program_id(0).to(tl.int64)

# locate start index
dc += program_id * stride
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/ops/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def silu(x):
def _swiglu_forward_kernel(
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0).cast(tl.int64)
program_id = tl.program_id(0).to(tl.int64)

# locate start index
a_ptr += program_id * stride
Expand All @@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
def _swiglu_backward_kernel(
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0).cast(tl.int64)
program_id = tl.program_id(0).to(tl.int64)

# locate start index
dc_ptr += program_id * stride
Expand Down

0 comments on commit cd15004

Please sign in to comment.