Skip to content

Commit

Permalink
refactor: metal patch
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <thxcode0824@gmail.com>
  • Loading branch information
thxCode committed Dec 13, 2024
1 parent 5999fef commit 621ef0f
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions llama-box/patches/llama.cpp/ggml-metal.patch
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ index 669c1f84..f5d1892c 100644
}
#endif
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 34fe5778..d6d4a91a 100644
index 34fe5778..13103351 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -2062,7 +2062,21 @@ static void ggml_metal_encode_node(
Expand All @@ -38,18 +38,29 @@ index 34fe5778..d6d4a91a 100644

// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
@@ -2071,8 +2085,8 @@ static void ggml_metal_encode_node(
@@ -2071,20 +2085,8 @@ static void ggml_metal_encode_node(
(
(
src0t == GGML_TYPE_F16 || // TODO: helper function
- src0t == GGML_TYPE_Q4_0 ||
- src0t == GGML_TYPE_Q4_1 ||
+ (src0t == GGML_TYPE_Q4_0 && ne11 >=4 ) ||
+ (src0t == GGML_TYPE_Q4_1 && ne11 >=4 ) ||
src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 ||
src0t == GGML_TYPE_Q8_0 ||
@@ -2096,7 +2110,7 @@ static void ggml_metal_encode_node(
- src0t == GGML_TYPE_Q5_0 ||
- src0t == GGML_TYPE_Q5_1 ||
- src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_IQ4_NL ||
- false) && (ne11 >= 2 && ne11 <= 8)
- ) ||
- (
- (
- src0t == GGML_TYPE_Q4_K ||
- src0t == GGML_TYPE_Q5_K ||
- src0t == GGML_TYPE_Q6_K ||
- false) && (ne11 >= 4 && ne11 <= 8)
+ false) && (ne11 >= 2 && ne11 <= 5)
)
)
) {
@@ -2096,7 +2098,7 @@ static void ggml_metal_encode_node(
// my current hypothesis is that the work grid is not evenly divisible for different nsg
// values and there can be some tail effects when nsg is high. need to confirm this
//
Expand All @@ -58,7 +69,7 @@ index 34fe5778..d6d4a91a 100644
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
@@ -2921,7 +2935,7 @@ static void ggml_metal_encode_node(
@@ -2921,7 +2923,7 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_RMS_NORM:
{
Expand All @@ -67,7 +78,7 @@ index 34fe5778..d6d4a91a 100644
GGML_ASSERT(ggml_is_contiguous_1(src0));

float eps;
@@ -4869,4 +4883,8 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) {
@@ -4869,4 +4871,8 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) {
return &g_ggml_backend_metal_reg;
}

Expand Down

0 comments on commit 621ef0f

Please sign in to comment.