From ba4994443afc6a8249ed726c5ebd09b2c57a3b00 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 28 Jun 2024 20:48:25 -0600 Subject: [PATCH] [Kernel] Add punica dimensions for Granite 3b and 8b (#5930) Signed-off-by: Joe Runde --- csrc/punica/bgmv/bgmv_config.h | 2 ++ tests/lora/test_punica.py | 1 + 2 files changed, 3 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index cb6694b3036e9..2c8d007d8719f 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -86,6 +86,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ + f(in_T, out_T, W_T, narrow, 49408) \ f(in_T, out_T, W_T, narrow, 60544) \ f(in_T, out_T, W_T, narrow, 60672) \ f(in_T, out_T, W_T, narrow, 64000) \ @@ -182,6 +183,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 49408, narrow) \ f(in_T, out_T, W_T, 60544, narrow) \ f(in_T, out_T, W_T, 60672, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 110c9b243507d..dbeb16cb21ad3 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -111,6 +111,7 @@ def _lora_ref_impl( 36864, 43264, 49152, + 49408, 60544, 60672, 64000,