diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index cb6694b3036e9..2c8d007d8719f 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -86,6 +86,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ + f(in_T, out_T, W_T, narrow, 49408) \ f(in_T, out_T, W_T, narrow, 60544) \ f(in_T, out_T, W_T, narrow, 60672) \ f(in_T, out_T, W_T, narrow, 64000) \ @@ -182,6 +183,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 49408, narrow) \ f(in_T, out_T, W_T, 60544, narrow) \ f(in_T, out_T, W_T, 60672, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 110c9b243507d..dbeb16cb21ad3 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -111,6 +111,7 @@ def _lora_ref_impl( 36864, 43264, 49152, + 49408, 60544, 60672, 64000,