diff --git a/llama-box/patches/llama.cpp/ggml-metal.patch b/llama-box/patches/llama.cpp/ggml-metal.patch index f3cd413..3f26124 100644 --- a/llama-box/patches/llama.cpp/ggml-metal.patch +++ b/llama-box/patches/llama.cpp/ggml-metal.patch @@ -12,7 +12,7 @@ index 669c1f84..f5d1892c 100644 } #endif diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m -index 34fe5778..d6d4a91a 100644 +index 34fe5778..13103351 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2062,7 +2062,21 @@ static void ggml_metal_encode_node( @@ -38,18 +38,29 @@ index 34fe5778..d6d4a91a 100644 // first try to use small-batch mat-mv kernels // these should be efficient for BS [2, ~8] -@@ -2071,8 +2085,8 @@ static void ggml_metal_encode_node( +@@ -2071,20 +2085,8 @@ static void ggml_metal_encode_node( ( ( src0t == GGML_TYPE_F16 || // TODO: helper function - src0t == GGML_TYPE_Q4_0 || - src0t == GGML_TYPE_Q4_1 || -+ (src0t == GGML_TYPE_Q4_0 && ne11 >=4 ) || -+ (src0t == GGML_TYPE_Q4_1 && ne11 >=4 ) || - src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || - src0t == GGML_TYPE_Q8_0 || -@@ -2096,7 +2110,7 @@ static void ggml_metal_encode_node( +- src0t == GGML_TYPE_Q5_0 || +- src0t == GGML_TYPE_Q5_1 || +- src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || +- false) && (ne11 >= 2 && ne11 <= 8) +- ) || +- ( +- ( +- src0t == GGML_TYPE_Q4_K || +- src0t == GGML_TYPE_Q5_K || +- src0t == GGML_TYPE_Q6_K || +- false) && (ne11 >= 4 && ne11 <= 8) ++ false) && (ne11 >= 2 && ne11 <= 5) + ) + ) + ) { +@@ -2096,7 +2098,7 @@ static void ggml_metal_encode_node( // my current hypothesis is that the work grid is not evenly divisible for different nsg // values and there can be some tail effects when nsg is high. need to confirm this // @@ -58,7 +69,7 @@ index 34fe5778..d6d4a91a 100644 const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup -@@ -2921,7 +2935,7 @@ static void ggml_metal_encode_node( +@@ -2921,7 +2923,7 @@ static void ggml_metal_encode_node( } break; case GGML_OP_RMS_NORM: { @@ -67,7 +78,7 @@ index 34fe5778..d6d4a91a 100644 GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; -@@ -4869,4 +4883,8 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { +@@ -4869,4 +4871,8 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { return &g_ggml_backend_metal_reg; }