Skip to content

Commit

Permalink
Merge pull request #14 from sashaDoubov/add_k_assert
Browse files Browse the repository at this point in the history
Add assert that K % BLOCK_K == 0
  • Loading branch information
tgale96 authored Aug 9, 2024
2 parents a1ddf98 + 84d48d0 commit 7b8f0df
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions stk/backend/triton_kernels.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import torch
import triton
import triton.language as tl

from dataclasses import dataclass

@dataclass
class TritonConfig:
BLOCK_M: int = 128
BLOCK_N: int = 128
BLOCK_K: int = 32
BLOCK_SIZE: int = 128
NUM_STAGES: int = 4
NUM_WARPS: int = 4

def _validate_matmul_dims(M: int, K: int, N: int):
error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)

@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': TritonConfig.BLOCK_M,
'BLOCK_N': TritonConfig.BLOCK_N,
'BLOCK_K': TritonConfig.BLOCK_K,
'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
],
key=['M', 'N', 'K'],
)
Expand Down Expand Up @@ -50,7 +70,12 @@ def _sdd_kernel(A, B, C, M, N, K,
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': TritonConfig.BLOCK_M,
'BLOCK_N': TritonConfig.BLOCK_N,
'BLOCK_K': TritonConfig.BLOCK_K,
'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
],
key=['M', 'N', 'K'],
)
Expand Down Expand Up @@ -122,7 +147,12 @@ def _dsd_kernel(A, B, C, M, N, K,
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': TritonConfig.BLOCK_M,
'BLOCK_N': TritonConfig.BLOCK_N,
'BLOCK_K': TritonConfig.BLOCK_K,
'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
],
key=['M', 'N', 'K'],
)
Expand Down Expand Up @@ -214,6 +244,8 @@ def dsd(shape,
M, K = shape
_, N = rhs.shape

_validate_matmul_dims(M, K, N)

# accumulator types
ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32

Expand Down Expand Up @@ -268,6 +300,8 @@ def dds(lhs,
M, K = lhs.shape
_, N = shape

_validate_matmul_dims(M, K, N)

# accumulator types
ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32

Expand Down Expand Up @@ -318,6 +352,8 @@ def sdd(lhs,
M, K = lhs.shape
_, N = rhs.shape

_validate_matmul_dims(M, K, N)

# accumulator types
ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32

Expand Down

0 comments on commit 7b8f0df

Please sign in to comment.