From e2f99deb9e27594a9a79504c9d78e8cf405a2cf4 Mon Sep 17 00:00:00 2001 From: Trevor Gale Date: Tue, 19 Sep 2023 19:49:00 +0000 Subject: [PATCH] Adding more configs to try for DSD. --- stk/backend/triton_kernels.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/stk/backend/triton_kernels.py b/stk/backend/triton_kernels.py index e208358..93abb54 100644 --- a/stk/backend/triton_kernels.py +++ b/stk/backend/triton_kernels.py @@ -67,9 +67,14 @@ def _sdd_kernel(A, B, C, M, N, K, @triton.autotune( configs=[ # Configs for A100. - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=5, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=6, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'BLOCK_SIZE': 128}, num_stages=7, num_warps=4), # Configs for H100. - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_SIZE': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_SIZE': 128}, num_stages=5, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_SIZE': 128}, num_stages=6, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_SIZE': 128}, num_stages=7, num_warps=4), ], key=['M', 'N', 'K'],