From 2efcd87b1280ccf79c5851888e681bdd936c222b Mon Sep 17 00:00:00 2001 From: S Date: Thu, 4 Apr 2024 18:23:23 +0100 Subject: [PATCH 01/14] Add Command R Plus GGUF --- convert-hf-to-gguf.py | 3 +++ gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/tensor_mapping.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 7e601170e925a..6b7ea1dfda5d8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -149,6 +149,7 @@ def write_tensors(self): # map tensor names new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) if new_name is None: + print(tensor_map) print(f"Can not map tensor {name!r}") sys.exit() @@ -2344,6 +2345,8 @@ def __init__(self, *args, **kwargs): # max_position_embeddings = 8192 in config.json but model was actually # trained on 128k context length + if "model_max_length" not in self.hparams: + self.hparams["model_max_length"] = 131072 self.hparams["max_position_embeddings"] = self.hparams["model_max_length"] def set_gguf_parameters(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5214764a9ea98..9ed7bbe183e46 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -638,6 +638,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q_NORM, ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 345b1b0c72212..4f02d298e13de 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -285,12 +285,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon + "model.layers.{bid}.self_attn.q_norm", # cohere "transformer.blocks.{bid}.attn.q_ln", # sea-lion ), MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon + "model.layers.{bid}.self_attn.k_norm", # cohere "transformer.blocks.{bid}.attn.k_ln", # sea-lion ), From fbab98497bce793b82c23002219a857864017c87 Mon Sep 17 00:00:00 2001 From: S Date: Thu, 4 Apr 2024 18:27:54 +0100 Subject: [PATCH 02/14] Add Command R Plus GGUF --- convert-hf-to-gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6b7ea1dfda5d8..9ace488890156 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -149,7 +149,6 @@ def write_tensors(self): # map tensor names new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) if new_name is None: - print(tensor_map) print(f"Can not map tensor {name!r}") sys.exit() From e4b2e2d339e68cc1c6a3fe1e71b4b8353ed2b77e Mon Sep 17 00:00:00 2001 From: S Date: Thu, 4 Apr 2024 22:30:17 +0100 Subject: [PATCH 03/14] Loading works up to LayerNorm2D --- convert-hf-to-gguf.py | 3 +-- llama.cpp | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 9ace488890156..aad073292bb7f 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2344,8 +2344,7 @@ def __init__(self, *args, **kwargs): # max_position_embeddings = 8192 in config.json but model was actually # trained on 128k context length - if "model_max_length" not in self.hparams: - self.hparams["model_max_length"] = 131072 + self.hparams["max_position_embeddings"] = self.hparams["model_max_length"] def set_gguf_parameters(self): diff --git a/llama.cpp b/llama.cpp index 9a1c11043b94a..d3234b28525f9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -924,6 +924,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, { @@ -5403,7 +5405,13 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - + + if(n_layer >= 64) + { + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); + } + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); @@ -9452,6 +9460,21 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } + if(model.layers[il].attn_q_norm) + { + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, From c354db751e1aaf014de77c492ebbafb9f48ba402 Mon Sep 17 00:00:00 2001 From: S Date: Fri, 5 Apr 2024 00:10:07 +0100 Subject: [PATCH 04/14] Export new tensors in 1D so they are not quantized. --- convert-hf-to-gguf.py | 47 +++++++++++++++++++++++++++++++++++++++++++ llama.cpp | 10 ++++----- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index aad073292bb7f..e52ebe4b822da 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2352,6 +2352,53 @@ def set_gguf_parameters(self): self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + #Convert Q norm and K norm to 1d so they are exported in float32 and not quantized + if name.endswith((".q_norm.weight")) or name.endswith((".k_norm.weight")): + data_torch = data_torch.flatten() + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + + ###### CONVERSION LOGIC ###### diff --git a/llama.cpp b/llama.cpp index d3234b28525f9..c8ef61aeec3b6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5408,8 +5408,8 @@ static bool llm_load_tensors( if(n_layer >= 64) { - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k * hparams.n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k * hparams.n_head_kv}); } layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); @@ -9462,17 +9462,17 @@ struct llm_build_context { if(model.layers[il].attn_q_norm) { - Qcur = llm_build_norm(ctx0, Qcur, hparams, + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM, cb, il); cb(Qcur, "Qcur", il); - + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM, cb, il); - cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il); } Qcur = ggml_rope_custom( From 553b09ba8fb9ebf8db0b28700693d43f8fa15f71 Mon Sep 17 00:00:00 2001 From: S Date: Fri, 5 Apr 2024 10:35:19 +0100 Subject: [PATCH 05/14] Fix embedding layer based on Noeda's example --- convert-hf-to-gguf.py | 50 +------------------------------------------ llama.cpp | 43 ++++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 59 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e52ebe4b822da..91730e0d8de8b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2351,55 +2351,7 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) - - def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) - tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): - continue - - #Convert Q norm and K norm to 1d so they are exported in float32 and not quantized - if name.endswith((".q_norm.weight")) or name.endswith((".k_norm.weight")): - data_torch = data_torch.flatten() - - old_dtype = data_torch.dtype - - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - - data = data_torch.squeeze().numpy() - - # map tensor names - new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() - - n_dims = len(data.shape) - data_dtype = data.dtype - - # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) - - # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: - data = data.astype(np.float32) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16) - - print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") - - self.gguf_writer.add_tensor(new_name, data) - - - - + ###### CONVERSION LOGIC ###### diff --git a/llama.cpp b/llama.cpp index c8ef61aeec3b6..53cf86dc02ef8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5405,13 +5405,13 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - - if(n_layer >= 64) + + if (n_layer >= 64) { - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k * hparams.n_head}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k * hparams.n_head_kv}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); } - + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); @@ -9460,19 +9460,31 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - if(model.layers[il].attn_q_norm) + if (model.layers[il].attn_q_norm) { - Qcur = llm_build_norm(ctx0, Qcur, hparams, + + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM, cb, il); cb(Qcur, "Qcur", il); - + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM, cb, il); - cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il); } Qcur = ggml_rope_custom( @@ -13085,9 +13097,15 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return std::make_pair(i_layer, n_layer); }; + // Command-R+ has such a large embedding weight tensor it overflows + // 32-bit signed integers. This is band-aid until quants can deal with + // that. + if (name == "token_embd.weight" && arch == LLM_ARCH_COMMAND_R && qs.model.hparams.n_layer >= 64) { + new_type = GGML_TYPE_F16; + } // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings // with the quantization of the output tensor - if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + else if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { @@ -13119,6 +13137,11 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_IQ3_S; } } + + } else if ((arch == LLM_ARCH_COMMAND_R) && + (name.find("q_norm") != std::string::npos || + name.find("k_norm") != std::string::npos)) { + new_type = GGML_TYPE_F32; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { if (name.find("attn_v.weight") != std::string::npos) { From f3532ff80c8dbf3147cc290f0839d28ff7b3fb69 Mon Sep 17 00:00:00 2001 From: S Date: Fri, 5 Apr 2024 10:37:25 +0100 Subject: [PATCH 06/14] Whitespace --- convert-hf-to-gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 91730e0d8de8b..7e601170e925a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2344,14 +2344,14 @@ def __init__(self, *args, **kwargs): # max_position_embeddings = 8192 in config.json but model was actually # trained on 128k context length - self.hparams["max_position_embeddings"] = self.hparams["model_max_length"] def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) - + + ###### CONVERSION LOGIC ###### From ec613b856c91d219b6d6efb7852e286b862fe797 Mon Sep 17 00:00:00 2001 From: S Date: Fri, 5 Apr 2024 12:48:43 +0100 Subject: [PATCH 07/14] Add line --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index c9b0a6a0ef776..ebaf06683ca9e 100644 --- a/ggml.c +++ b/ggml.c @@ -20335,7 +20335,7 @@ size_t ggml_quantize_chunk( int nrows, int n_per_row, const float * imatrix) { - const int n = nrows * n_per_row; + const size_t n = (size_t) nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); From c2658c3ae8c077a70715c40f2ea5c2bfb0c22da6 Mon Sep 17 00:00:00 2001 From: S Date: Sat, 6 Apr 2024 11:23:38 +0100 Subject: [PATCH 08/14] Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda) --- llama.cpp | 350 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 346 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 53cf86dc02ef8..a8c6568c020a0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5734,6 +5734,340 @@ static void llm_build_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); } +static struct ggml_tensor * llama_build_mat_mul_blocked_computation( + /* + * Does (almost) same thing as ggml_mat_mul mathematically speaking, + * but splits the computation into chunks. + * + * Why would you want to do this? As part of Command-R+ coding, we + * discovered that quite a bit of the GPU code is not prepared for + * matrices with more than 2**31-1 elements (~2 billion). + * + * Some context: + * https://github.com/ggerganov/llama.cpp/pull/6491 + * + * This function has a limit (set to 2B) that if any constituent parts + * of it (input, output, result) would go over that limit byte-wise, + * it'll use the splitted computation. This is based on the idea that + * this minimizes the chance that somewhere downstream in GPU code, be + * it MPS or Cuda, has something like: int x = y * z; where the values + * of y and z overflow the multiplication and then silently (or not so + * silently) does something weird. At the time of writing (2024-04-05); + * it seems that CUDA code outright crashes and MPS silently gives bad + * results. + * + * This is a band-aid workaround. The ideal state of the world is that + * this function does nothing but "return ggml_mat_mul(ctx, a, b)". + * + * The last argument (forced_block_size) is for debugging. You can + * force a certain block size to use with the computation. If zero + * (default) then the block size is determined on the fly. Production + * code should always have it zero; and only set it to a non-zero value + * for debugging and testing. + */ + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const llama_model & model, + const llm_build_cb & cb, + int64_t il, + size_t forced_block_size) +{ + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + if (forced_block_size != 0) { + //fprintf(stderr, "warning: llama_build_mat_mul_blocked_computation() forced block size: %zu\n", forced_block_size); + } + + const size_t MAX_BYTES_BEFORE_SPLIT = 2000000000; + + // the actual ggml_mul_mat supports batching. But this one doesn't. + GGML_ASSERT(a->ne[2] == 1 && b->ne[2] == 1); + GGML_ASSERT(a->ne[3] == 1 && b->ne[3] == 1); + + // bail out if if the number of elements would be zero. + // nicer than getting a segfault. + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + GGML_ASSERT(a->ne[i] > 0 && "Matrix multiplication with a 0-side length matrix ('a')."); + GGML_ASSERT(b->ne[i] > 0 && "Matrix multiplication with a 0-side length matrix ('b')."); + } + + // Use the max size of: a, b, result size + const size_t a_rows = a->ne[1]; + const size_t a_cols = a->ne[0]; + + // b is transposed + const size_t b_rows = b->ne[0]; + const size_t b_cols = b->ne[1]; + + const size_t c_rows = a_rows; + const size_t c_cols = b_cols; + + // determine a size of a block that's as big as possible. + // we start with block size of the maximum size, and if that passes, + // then we just use ggml_mat_mul() + // + // the block is square. + size_t cand_block_size = a_rows; + if (a_cols > cand_block_size) { cand_block_size = a_cols; } + if (b_rows > cand_block_size) { cand_block_size = b_rows; } + if (b_cols > cand_block_size) { cand_block_size = b_cols; } + if (c_rows > cand_block_size) { cand_block_size = c_rows; } + if (c_cols > cand_block_size) { cand_block_size = c_cols; } + + size_t block_size = 1; + while (block_size < cand_block_size) { + block_size <<= 1; + } + + if (forced_block_size != 0) { + block_size = forced_block_size; + } else { + // figure out what is largest block_size we can use that will never + // have an intermediate result bigger than + // MAX_BYTES_BEFORE_SPLIT + bool ok = true; + while (block_size > 0) { + ok = true; + + // keep the byte calculations in sync with the blocked code in + // the computation part. + + // Criteria: + // 1. result block size + { + const size_t i_min = 0; + const size_t j_min = 0; + size_t i_max = i_min + block_size; + size_t j_max = j_min + block_size; + if (i_max > a_rows) { i_max = a_rows; } + if (j_max > b_cols) { j_max = b_cols; } + + const size_t bytes_size = sizeof(float) * (i_max - i_min) * (j_max - j_min); + if (bytes_size > MAX_BYTES_BEFORE_SPLIT) { + ok = false; + } + } + // 2. and 3. + // Block size from 'a' and 'b' + { + const size_t i_min = 0; + const size_t j_min = 0; + const size_t k_min = 0; + + size_t i_max = i_min + block_size; + size_t j_max = j_min + block_size; + size_t k_max = k_min + block_size; + + if (i_max > a_rows) { i_max = a_rows; } + if (j_max > b_cols) { j_max = b_cols; } + if (k_max > a_cols) { k_max = a_cols; } + + const size_t bytes_size_a = sizeof(float) * (k_max - k_min) * (i_max - i_min); + const size_t bytes_size_b = sizeof(float) * (k_max - k_min) * (j_max - j_min); + + if (bytes_size_a > MAX_BYTES_BEFORE_SPLIT || bytes_size_b > MAX_BYTES_BEFORE_SPLIT) { + ok = false; + } + } + + if (!ok) { + block_size /= 2; + continue; + } + break; + } + GGML_ASSERT(block_size > 0); + } + + //fprintf(stderr, "block_size=%zu a shape: %d %d b shape: %d %d\n", block_size, a_rows, a_cols, b_rows, b_cols); + + // O(N^3) nested loop, where N is number of blocks on one of the + // constituent parts. + size_t nb_A = (a_rows + block_size - 1) / block_size; + size_t nb_B = (b_cols + block_size - 1) / block_size; + size_t nb_A2 = (a_cols + block_size - 1) / block_size; + + // make placeholder tensors for each block results. + // 2D: (row, col) -> offset is: (x, y) -> x * nb_B + y + struct ggml_tensor ** result_blocks = (struct ggml_tensor **) malloc(nb_A * nb_B * sizeof(struct ggml_tensor *)); + + for (size_t i = 0; i < nb_A; ++i) { + for (size_t j = 0; j < nb_B; ++j) { + const size_t i_min = i * block_size; + const size_t j_min = j * block_size; + size_t i_max = i_min + block_size; + size_t j_max = j_min + block_size; + + if (i_max > a_rows) { i_max = a_rows; } + if (j_max > b_cols) { j_max = b_cols; } + + struct ggml_tensor * result_block = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, i_max - i_min, j_max - j_min); + result_block = ggml_scale(ctx, result_block, 0.0f); + + cb(result_block, "result_block-fresh", il); + result_blocks[i * nb_B + j] = result_block; + } + } + + size_t num_blocks = 0; + for (size_t i = 0; i < nb_A; ++i) { + for (size_t j = 0; j < nb_B; ++j) { + for (size_t k = 0; k < nb_A2; ++k) { + num_blocks++; + + const size_t i_min = i * block_size; + const size_t j_min = j * block_size; + const size_t k_min = k * block_size; + + size_t i_max = i_min + block_size; + size_t j_max = j_min + block_size; + size_t k_max = k_min + block_size; + if (i_max > a_rows) { i_max = a_rows; } + if (j_max > b_cols) { j_max = b_cols; } + if (k_max > a_cols) { k_max = a_cols; } + + const size_t blck_size_a = (const size_t) ggml_blck_size(a->type); + const size_t blck_size_b = (const size_t) ggml_blck_size(b->type); + const size_t type_size_a = ggml_type_size(a->type); + const size_t type_size_b = ggml_type_size(b->type); + + GGML_ASSERT(k_min * type_size_a % blck_size_a == 0); + GGML_ASSERT(k_min * type_size_b % blck_size_b == 0); + + // blck_size=32 + // type_size_a=19 + // + // k_min = 4 + // + // byte_offset = (type_size_a * (k_min/blck_size)) = + // 19 * (4/32) = 2 + + struct ggml_tensor * a_slice = ggml_view_2d( + ctx, a, + k_max - k_min, // k:k_max size + i_max - i_min, // i:i_max size + ggml_row_size(a->type, a->ne[0]), + ggml_row_size(a->type, a->ne[0]) * i_min + k_min * type_size_a / blck_size_a); + + cb(a_slice, "a_slice", il); + + struct ggml_tensor * b_slice = ggml_view_2d( + ctx, b, + k_max - k_min, // k:k_max size + j_max - j_min, // j:j_max size + ggml_row_size(b->type, b->ne[0]), + ggml_row_size(b->type, b->ne[0]) * j_min + k_min * type_size_b / blck_size_b); + + cb(b_slice, "b_slice", il); + + struct ggml_tensor * result_slice = result_blocks[i * nb_B + j]; + + struct ggml_tensor * mm_result = ggml_mul_mat(ctx, a_slice, b_slice); + cb(mm_result, "mm_result", il); + + result_blocks[i * nb_B + j] = ggml_add(ctx, result_slice, mm_result); + cb(result_blocks[i * nb_B + j], "result_slice", il); + } + } + } + + // concate the results into one chonky tensor. + // ggml_concat goes mad if the first two dimensions are not the same. + // + // We use this strategy: find largest power of two that divides the + // size of all the tensors. Power of two to make it friendly to GPU + // code; (TODO: LCD might be better? but not sure it won't break code). + // + // Flatten all the tensors to (X, 1, N, 1). + size_t split_size = 1; + while (1) { + size_t candidate_split_size = split_size << 1; + bool bad = false; + + for (size_t i = 0; i < nb_A * nb_B; ++i) { + size_t rows = result_blocks[i]->ne[0]; + size_t cols = result_blocks[i]->ne[1]; + + if (candidate_split_size > rows * cols) { + bad = true; + break; + } + + if ((rows * cols) % candidate_split_size != 0) { + bad = true; + break; + } + } + + if (bad) { + break; + } + + split_size = candidate_split_size; + } + + struct ggml_tensor * result_final = nullptr; + const ggml_type wanted_final_type = a->type; + + // TODO: looks like concat also wants f32, so everything is casted to + // f32 here.. A datatype-agnostic concat would be nice; or ability to + // do the tensor equivalent of unsafe type cast. + // + // The Command-R+ tensor this code was written for was 6GB. So this is + // going to handle 12GB I guess. Oof. + // + // I believe you could be smarter and combine hierarchially instead of + // one by one. I.e. we are doing a concetenation like this: + // for x in range(100): + // accum = accum + [x] (copies accum every time? maybe. didn't read concat code) + // + // You could instead divide and conquer to make it a bit smarter. + for (size_t i = 0; i < nb_A; ++i) { + for (size_t j = 0; j < nb_B; ++j) { + struct ggml_tensor * src_block = result_blocks[i * nb_B + j]; + + const size_t rows = src_block->ne[0]; + const size_t cols = src_block->ne[1]; + GGML_ASSERT(rows * cols % split_size == 0); + + const size_t nflattened_rows = split_size; + const size_t n3 = (rows * cols) / split_size; + + src_block = ggml_view_3d(ctx, src_block, + nflattened_rows, + 1, + n3, + nflattened_rows * ggml_element_size(src_block), + nflattened_rows * ggml_element_size(src_block), + 0); + + if (result_final == nullptr) { + if (src_block->type != GGML_TYPE_F32) { + result_final = ggml_cast(ctx, src_block, GGML_TYPE_F32); + cb(result_final, "result-upcast", il); + } else { + result_final = src_block; + } + continue; + } + + if (src_block->type != GGML_TYPE_F32) { + src_block = ggml_cast(ctx, src_block, GGML_TYPE_F32); + } + result_final = ggml_concat(ctx, result_final, src_block); + cb(result_final, "result_final-accumulator", il); + } + } + + result_final = ggml_reshape_2d(ctx, result_final, c_rows, c_cols); + cb(result_final, "result_final", il); + + free(result_blocks); + + return result_final; +} + static struct ggml_tensor * llm_build_norm( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -6479,7 +6813,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llama_build_mat_mul_blocked_computation(ctx0, model.output, cur, model, cb, -1, 0); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9462,7 +9796,15 @@ struct llm_build_context { if (model.layers[il].attn_q_norm) { - + struct ggml_tensor * attn_q_norm = model.layers[il].attn_q_norm; + struct ggml_tensor * attn_k_norm = model.layers[il].attn_k_norm; + + // CPU did not like F16, so cast to F32 + attn_q_norm = ggml_cast(ctx0, attn_q_norm, GGML_TYPE_F32); + cb(attn_q_norm, "attn_q_norm_cast_F32", il); + attn_k_norm = ggml_cast(ctx0, attn_k_norm, GGML_TYPE_F32); + cb(attn_k_norm, "attn_k_norm_cast_F32", il); + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur) * n_embd_head, ggml_element_size(Qcur) * n_embd_head * n_head, @@ -9475,13 +9817,13 @@ struct llm_build_context { cb(Kcur, "Kcur", il); Qcur = llm_build_norm(ctx0, Qcur, hparams, - model.layers[il].attn_q_norm, + attn_q_norm, NULL, LLM_NORM, cb, il); cb(Qcur, "Qcur", il); Kcur = llm_build_norm(ctx0, Kcur, hparams, - model.layers[il].attn_k_norm, + attn_k_norm, NULL, LLM_NORM, cb, il); cb(Kcur, "Kcur", il); From 6745ea7a65204a558cfb4f686fe734d8a43379f2 Mon Sep 17 00:00:00 2001 From: S Date: Sat, 6 Apr 2024 14:20:27 +0100 Subject: [PATCH 09/14] dranger003: Fix block index overflow in CUDA dequantizing. --- ggml-cuda/common.cuh | 2 +- ggml-cuda/dequantize.cuh | 10 +++++----- ggml-cuda/dmmv.cu | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b98d7cbd0c1c5..481065b2a3484 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); ////////////////////// diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh index b54400632a5ec..bd3c2d9db9463 100644 --- a/ggml-cuda/dequantize.cuh +++ b/ggml-cuda/dequantize.cuh @@ -1,6 +1,6 @@ #include "common.cuh" -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; const dfloat d = x[ib].d; @@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; const dfloat d = x[ib].d; @@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; const dfloat d = x[ib].d; diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 0b17e3cb961f7..3097a951092ab 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; // automatic half -> float type cast if dfloat == float @@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; - const int ib = (row*ncols + col)/qk; // x block index + const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index From 26e8f23bf3eff0380408a1543bc7d46015e37800 Mon Sep 17 00:00:00 2001 From: S Date: Sat, 6 Apr 2024 20:21:42 +0100 Subject: [PATCH 10/14] Reverted blocked multiplication code as it still has issues and could affect other Llama arches --- llama.cpp | 336 +----------------------------------------------------- 1 file changed, 1 insertion(+), 335 deletions(-) diff --git a/llama.cpp b/llama.cpp index a8c6568c020a0..934bbc25aed1f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5734,340 +5734,6 @@ static void llm_build_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); } -static struct ggml_tensor * llama_build_mat_mul_blocked_computation( - /* - * Does (almost) same thing as ggml_mat_mul mathematically speaking, - * but splits the computation into chunks. - * - * Why would you want to do this? As part of Command-R+ coding, we - * discovered that quite a bit of the GPU code is not prepared for - * matrices with more than 2**31-1 elements (~2 billion). - * - * Some context: - * https://github.com/ggerganov/llama.cpp/pull/6491 - * - * This function has a limit (set to 2B) that if any constituent parts - * of it (input, output, result) would go over that limit byte-wise, - * it'll use the splitted computation. This is based on the idea that - * this minimizes the chance that somewhere downstream in GPU code, be - * it MPS or Cuda, has something like: int x = y * z; where the values - * of y and z overflow the multiplication and then silently (or not so - * silently) does something weird. At the time of writing (2024-04-05); - * it seems that CUDA code outright crashes and MPS silently gives bad - * results. - * - * This is a band-aid workaround. The ideal state of the world is that - * this function does nothing but "return ggml_mat_mul(ctx, a, b)". - * - * The last argument (forced_block_size) is for debugging. You can - * force a certain block size to use with the computation. If zero - * (default) then the block size is determined on the fly. Production - * code should always have it zero; and only set it to a non-zero value - * for debugging and testing. - */ - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const llama_model & model, - const llm_build_cb & cb, - int64_t il, - size_t forced_block_size) -{ - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - if (forced_block_size != 0) { - //fprintf(stderr, "warning: llama_build_mat_mul_blocked_computation() forced block size: %zu\n", forced_block_size); - } - - const size_t MAX_BYTES_BEFORE_SPLIT = 2000000000; - - // the actual ggml_mul_mat supports batching. But this one doesn't. - GGML_ASSERT(a->ne[2] == 1 && b->ne[2] == 1); - GGML_ASSERT(a->ne[3] == 1 && b->ne[3] == 1); - - // bail out if if the number of elements would be zero. - // nicer than getting a segfault. - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - GGML_ASSERT(a->ne[i] > 0 && "Matrix multiplication with a 0-side length matrix ('a')."); - GGML_ASSERT(b->ne[i] > 0 && "Matrix multiplication with a 0-side length matrix ('b')."); - } - - // Use the max size of: a, b, result size - const size_t a_rows = a->ne[1]; - const size_t a_cols = a->ne[0]; - - // b is transposed - const size_t b_rows = b->ne[0]; - const size_t b_cols = b->ne[1]; - - const size_t c_rows = a_rows; - const size_t c_cols = b_cols; - - // determine a size of a block that's as big as possible. - // we start with block size of the maximum size, and if that passes, - // then we just use ggml_mat_mul() - // - // the block is square. - size_t cand_block_size = a_rows; - if (a_cols > cand_block_size) { cand_block_size = a_cols; } - if (b_rows > cand_block_size) { cand_block_size = b_rows; } - if (b_cols > cand_block_size) { cand_block_size = b_cols; } - if (c_rows > cand_block_size) { cand_block_size = c_rows; } - if (c_cols > cand_block_size) { cand_block_size = c_cols; } - - size_t block_size = 1; - while (block_size < cand_block_size) { - block_size <<= 1; - } - - if (forced_block_size != 0) { - block_size = forced_block_size; - } else { - // figure out what is largest block_size we can use that will never - // have an intermediate result bigger than - // MAX_BYTES_BEFORE_SPLIT - bool ok = true; - while (block_size > 0) { - ok = true; - - // keep the byte calculations in sync with the blocked code in - // the computation part. - - // Criteria: - // 1. result block size - { - const size_t i_min = 0; - const size_t j_min = 0; - size_t i_max = i_min + block_size; - size_t j_max = j_min + block_size; - if (i_max > a_rows) { i_max = a_rows; } - if (j_max > b_cols) { j_max = b_cols; } - - const size_t bytes_size = sizeof(float) * (i_max - i_min) * (j_max - j_min); - if (bytes_size > MAX_BYTES_BEFORE_SPLIT) { - ok = false; - } - } - // 2. and 3. - // Block size from 'a' and 'b' - { - const size_t i_min = 0; - const size_t j_min = 0; - const size_t k_min = 0; - - size_t i_max = i_min + block_size; - size_t j_max = j_min + block_size; - size_t k_max = k_min + block_size; - - if (i_max > a_rows) { i_max = a_rows; } - if (j_max > b_cols) { j_max = b_cols; } - if (k_max > a_cols) { k_max = a_cols; } - - const size_t bytes_size_a = sizeof(float) * (k_max - k_min) * (i_max - i_min); - const size_t bytes_size_b = sizeof(float) * (k_max - k_min) * (j_max - j_min); - - if (bytes_size_a > MAX_BYTES_BEFORE_SPLIT || bytes_size_b > MAX_BYTES_BEFORE_SPLIT) { - ok = false; - } - } - - if (!ok) { - block_size /= 2; - continue; - } - break; - } - GGML_ASSERT(block_size > 0); - } - - //fprintf(stderr, "block_size=%zu a shape: %d %d b shape: %d %d\n", block_size, a_rows, a_cols, b_rows, b_cols); - - // O(N^3) nested loop, where N is number of blocks on one of the - // constituent parts. - size_t nb_A = (a_rows + block_size - 1) / block_size; - size_t nb_B = (b_cols + block_size - 1) / block_size; - size_t nb_A2 = (a_cols + block_size - 1) / block_size; - - // make placeholder tensors for each block results. - // 2D: (row, col) -> offset is: (x, y) -> x * nb_B + y - struct ggml_tensor ** result_blocks = (struct ggml_tensor **) malloc(nb_A * nb_B * sizeof(struct ggml_tensor *)); - - for (size_t i = 0; i < nb_A; ++i) { - for (size_t j = 0; j < nb_B; ++j) { - const size_t i_min = i * block_size; - const size_t j_min = j * block_size; - size_t i_max = i_min + block_size; - size_t j_max = j_min + block_size; - - if (i_max > a_rows) { i_max = a_rows; } - if (j_max > b_cols) { j_max = b_cols; } - - struct ggml_tensor * result_block = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, i_max - i_min, j_max - j_min); - result_block = ggml_scale(ctx, result_block, 0.0f); - - cb(result_block, "result_block-fresh", il); - result_blocks[i * nb_B + j] = result_block; - } - } - - size_t num_blocks = 0; - for (size_t i = 0; i < nb_A; ++i) { - for (size_t j = 0; j < nb_B; ++j) { - for (size_t k = 0; k < nb_A2; ++k) { - num_blocks++; - - const size_t i_min = i * block_size; - const size_t j_min = j * block_size; - const size_t k_min = k * block_size; - - size_t i_max = i_min + block_size; - size_t j_max = j_min + block_size; - size_t k_max = k_min + block_size; - if (i_max > a_rows) { i_max = a_rows; } - if (j_max > b_cols) { j_max = b_cols; } - if (k_max > a_cols) { k_max = a_cols; } - - const size_t blck_size_a = (const size_t) ggml_blck_size(a->type); - const size_t blck_size_b = (const size_t) ggml_blck_size(b->type); - const size_t type_size_a = ggml_type_size(a->type); - const size_t type_size_b = ggml_type_size(b->type); - - GGML_ASSERT(k_min * type_size_a % blck_size_a == 0); - GGML_ASSERT(k_min * type_size_b % blck_size_b == 0); - - // blck_size=32 - // type_size_a=19 - // - // k_min = 4 - // - // byte_offset = (type_size_a * (k_min/blck_size)) = - // 19 * (4/32) = 2 - - struct ggml_tensor * a_slice = ggml_view_2d( - ctx, a, - k_max - k_min, // k:k_max size - i_max - i_min, // i:i_max size - ggml_row_size(a->type, a->ne[0]), - ggml_row_size(a->type, a->ne[0]) * i_min + k_min * type_size_a / blck_size_a); - - cb(a_slice, "a_slice", il); - - struct ggml_tensor * b_slice = ggml_view_2d( - ctx, b, - k_max - k_min, // k:k_max size - j_max - j_min, // j:j_max size - ggml_row_size(b->type, b->ne[0]), - ggml_row_size(b->type, b->ne[0]) * j_min + k_min * type_size_b / blck_size_b); - - cb(b_slice, "b_slice", il); - - struct ggml_tensor * result_slice = result_blocks[i * nb_B + j]; - - struct ggml_tensor * mm_result = ggml_mul_mat(ctx, a_slice, b_slice); - cb(mm_result, "mm_result", il); - - result_blocks[i * nb_B + j] = ggml_add(ctx, result_slice, mm_result); - cb(result_blocks[i * nb_B + j], "result_slice", il); - } - } - } - - // concate the results into one chonky tensor. - // ggml_concat goes mad if the first two dimensions are not the same. - // - // We use this strategy: find largest power of two that divides the - // size of all the tensors. Power of two to make it friendly to GPU - // code; (TODO: LCD might be better? but not sure it won't break code). - // - // Flatten all the tensors to (X, 1, N, 1). - size_t split_size = 1; - while (1) { - size_t candidate_split_size = split_size << 1; - bool bad = false; - - for (size_t i = 0; i < nb_A * nb_B; ++i) { - size_t rows = result_blocks[i]->ne[0]; - size_t cols = result_blocks[i]->ne[1]; - - if (candidate_split_size > rows * cols) { - bad = true; - break; - } - - if ((rows * cols) % candidate_split_size != 0) { - bad = true; - break; - } - } - - if (bad) { - break; - } - - split_size = candidate_split_size; - } - - struct ggml_tensor * result_final = nullptr; - const ggml_type wanted_final_type = a->type; - - // TODO: looks like concat also wants f32, so everything is casted to - // f32 here.. A datatype-agnostic concat would be nice; or ability to - // do the tensor equivalent of unsafe type cast. - // - // The Command-R+ tensor this code was written for was 6GB. So this is - // going to handle 12GB I guess. Oof. - // - // I believe you could be smarter and combine hierarchially instead of - // one by one. I.e. we are doing a concetenation like this: - // for x in range(100): - // accum = accum + [x] (copies accum every time? maybe. didn't read concat code) - // - // You could instead divide and conquer to make it a bit smarter. - for (size_t i = 0; i < nb_A; ++i) { - for (size_t j = 0; j < nb_B; ++j) { - struct ggml_tensor * src_block = result_blocks[i * nb_B + j]; - - const size_t rows = src_block->ne[0]; - const size_t cols = src_block->ne[1]; - GGML_ASSERT(rows * cols % split_size == 0); - - const size_t nflattened_rows = split_size; - const size_t n3 = (rows * cols) / split_size; - - src_block = ggml_view_3d(ctx, src_block, - nflattened_rows, - 1, - n3, - nflattened_rows * ggml_element_size(src_block), - nflattened_rows * ggml_element_size(src_block), - 0); - - if (result_final == nullptr) { - if (src_block->type != GGML_TYPE_F32) { - result_final = ggml_cast(ctx, src_block, GGML_TYPE_F32); - cb(result_final, "result-upcast", il); - } else { - result_final = src_block; - } - continue; - } - - if (src_block->type != GGML_TYPE_F32) { - src_block = ggml_cast(ctx, src_block, GGML_TYPE_F32); - } - result_final = ggml_concat(ctx, result_final, src_block); - cb(result_final, "result_final-accumulator", il); - } - } - - result_final = ggml_reshape_2d(ctx, result_final, c_rows, c_cols); - cb(result_final, "result_final", il); - - free(result_blocks); - - return result_final; -} - static struct ggml_tensor * llm_build_norm( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -6813,7 +6479,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = llama_build_mat_mul_blocked_computation(ctx0, model.output, cur, model, cb, -1, 0); + cur = ggml_mul_mat(ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); From ce9413d849f255d81528bb0bdf59f02a5d884bf7 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 6 Apr 2024 22:33:36 +0200 Subject: [PATCH 11/14] export norms as f32 --- convert-hf-to-gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 7e601170e925a..37af6328a1705 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -160,7 +160,7 @@ def write_tensors(self): data = data.astype(np.float32) # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 From 78819c07ff21c5c248a7c37d8c9f1e8186e739e7 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 6 Apr 2024 22:37:08 +0200 Subject: [PATCH 12/14] fix overflow issues during quant and other cleanup --- ggml-quants.c | 310 +++++++++++++++++++------------------- ggml-quants.h | 164 ++++++++++---------- ggml.c | 14 +- ggml.h | 14 +- gguf-py/gguf/constants.py | 2 +- llama.cpp | 66 +++----- 6 files changed, 276 insertions(+), 294 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index f2e6c4bd1a321..32e84434a8c1b 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -544,7 +544,7 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -581,12 +581,12 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } } -void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_0_reference(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -623,11 +623,11 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } } -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_1_reference(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -671,11 +671,11 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } } -void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_0_reference(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -719,12 +719,12 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } } -void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_1_reference(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -749,7 +749,7 @@ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict } } -void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -938,7 +938,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -973,7 +973,7 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict } } -void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -1192,7 +1192,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { #endif } -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -1212,7 +1212,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int } } -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_1; assert(k % qk == 0); @@ -1233,7 +1233,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int } } -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -1259,7 +1259,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int } } -void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_1; assert(k % qk == 0); @@ -1286,7 +1286,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int } } -void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK8_0; assert(k % qk == 0); @@ -1581,7 +1581,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1658,7 +1658,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1704,7 +1704,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int } } -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q2_K_reference(x, vy, k); } @@ -1960,14 +1960,14 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri } } -size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { - quantize_row_q2_K_reference(src, dst, nrow*n_per_row); + quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -1978,7 +1978,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow, //========================= 3-bit (de)-quantization -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2092,7 +2092,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } #if QK_K == 256 -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2142,7 +2142,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } } #else -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); assert(QK_K == 64); const int nb = k / QK_K; @@ -2175,11 +2175,11 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } #endif -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q3_K_reference(x, vy, k); } -static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) { +static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q3_K_reference(x, y, n_per_row); @@ -2268,14 +2268,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri #endif } -size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { - quantize_row_q3_K_reference(src, dst, nrow*n_per_row); + quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -2286,7 +2286,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow, // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2393,7 +2393,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2432,19 +2432,19 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int } } -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; quantize_row_q4_K_reference(x, y, k); } -static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q4_K_reference(x, y, n_per_row); #else assert(n_per_row % QK_K == 0); - const int nb = n_per_row / QK_K; + const int64_t nb = n_per_row / QK_K; uint8_t L[QK_K]; uint8_t Laux[32]; @@ -2516,14 +2516,14 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri #endif } -size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { - quantize_row_q4_K_reference(src, dst, nrow*n_per_row); + quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -2534,9 +2534,9 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow, // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 uint8_t L[QK_K]; @@ -2676,9 +2676,9 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict } } -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2721,19 +2721,19 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int } } -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; quantize_row_q5_K_reference(x, y, k); } -static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q5_K_reference(x, y, n_per_row); #else assert(n_per_row % QK_K == 0); - const int nb = n_per_row / QK_K; + const int64_t nb = n_per_row / QK_K; uint8_t L[QK_K]; uint8_t Laux[32]; @@ -2825,14 +2825,14 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri #endif } -size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { - quantize_row_q5_K_reference(src, dst, nrow*n_per_row); + quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -2843,9 +2843,9 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow, // ====================== 6-bit (de)-quantization -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; int8_t L[QK_K]; float scales[QK_K/16]; @@ -2925,9 +2925,9 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict } } -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2972,19 +2972,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int } } -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; quantize_row_q6_K_reference(x, y, k); } -static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q6_K_reference(x, y, n_per_row); #else assert(n_per_row % QK_K == 0); - const int nb = n_per_row / QK_K; + const int64_t nb = n_per_row / QK_K; int8_t L[QK_K]; float scales[QK_K/16]; @@ -3067,14 +3067,14 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri #endif } -size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { - quantize_row_q6_K_reference(src, dst, nrow*n_per_row); + quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3083,7 +3083,7 @@ size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { @@ -3098,7 +3098,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK4_0; + const int64_t nb = n_per_row/QK4_0; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK4_0 * ib; const float * qw = quant_weights + QK4_0 * ib; @@ -3111,14 +3111,14 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri } } -size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_reference(src, dst, nrow*n_per_row); + quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3126,7 +3126,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK4_1 == 32, "QK4_1 must be 32"); if (!quant_weights) { @@ -3141,7 +3141,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK4_1; + const int64_t nb = n_per_row/QK4_1; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK4_1 * ib; const float * qw = quant_weights + QK4_1 * ib; @@ -3156,14 +3156,14 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri } } -size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_1_reference(src, dst, nrow*n_per_row); + quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3171,7 +3171,7 @@ size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK5_0 == 32, "QK5_0 must be 32"); if (!quant_weights) { @@ -3186,7 +3186,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK5_0; + const int64_t nb = n_per_row/QK5_0; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK5_0 * ib; const float * qw = quant_weights + QK5_0 * ib; @@ -3210,14 +3210,14 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri } } -size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_0_reference(src, dst, nrow*n_per_row); + quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3225,7 +3225,7 @@ size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK5_1 == 32, "QK5_1 must be 32"); if (!quant_weights) { @@ -3240,7 +3240,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK5_1; + const int64_t nb = n_per_row/QK5_1; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK5_1 * ib; const float * qw = quant_weights + QK5_1 * ib; @@ -3263,14 +3263,14 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri } } -size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_1_reference(src, dst, nrow*n_per_row); + quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3278,18 +3278,18 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -size_t quantize_q8_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - quantize_row_q8_0_reference(src, dst, nrow*n_per_row); + quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } // ====================== "True" 2-bit (de)-quantization -void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) { +void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; uint32_t aux32[2]; const uint8_t * aux8 = (const uint8_t *)aux32; @@ -3315,9 +3315,9 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y // ====================== 2.3125 bpw (de)-quantization -void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) { +void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; float db[2]; @@ -3342,9 +3342,9 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, // ====================== 2.5625 bpw (de)-quantization -void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int k) { +void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; float db[2]; @@ -3374,9 +3374,9 @@ void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, in // ====================== 3.0625 bpw (de)-quantization -void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) { +void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; uint32_t aux32; @@ -3406,9 +3406,9 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization -void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { +void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3449,9 +3449,9 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in // ====================== 1.5625 bpw (de)-quantization -void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) { +void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3474,9 +3474,9 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in } } -void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) { +void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; float delta[4]; uint16_t idx[4]; @@ -3535,9 +3535,9 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) { +void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) { assert(k % QK4_NL == 0); - const int nb = k / QK4_NL; + const int64_t nb = k / QK4_NL; for (int i = 0; i < nb; i++) { @@ -3553,12 +3553,12 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, } } -void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) { +void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); #if QK_K == 64 dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k); #else - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3582,9 +3582,9 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, //===================================== Q8_K ============================================== -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3621,9 +3621,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict } } -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { for (int j = 0; j < QK_K; ++j) { @@ -3632,7 +3632,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int } } -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { +void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q8_K_reference(x, y, k); } @@ -10648,7 +10648,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } -static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS); @@ -10664,7 +10664,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict const int kMaxQ = 3; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq2_xxs * y = vy; @@ -10821,7 +10821,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict } } -static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); @@ -10837,7 +10837,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v const int kMaxQ = 3; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq2_xs * y = vy; @@ -11001,11 +11001,11 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v } } -size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq2_xxs); @@ -11013,11 +11013,11 @@ size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nro return nrow * nblock * sizeof(block_iq2_xxs); } -size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq2_xs); @@ -11242,7 +11242,7 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } -static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n, +static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq3_data_index(grid_size); @@ -11259,7 +11259,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v const int kMaxQ = 8; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; ggml_fp16_t * dh; uint8_t * qs; @@ -11455,11 +11455,11 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v } } -size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq3_xxs); @@ -11467,13 +11467,13 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nro return nrow * nblock * sizeof(block_iq3_xxs); } -void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_xxs * restrict y = vy; quantize_row_iq3_xxs_reference(x, y, k); } -void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) { +void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); } @@ -11504,7 +11504,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo const int kMaxQ = 8; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq3_s * y = vy; @@ -11661,9 +11661,9 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } #define IQ3S_BLOCK_SIZE 32 -size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; float scales[QK_K/IQ3S_BLOCK_SIZE]; float weight[IQ3S_BLOCK_SIZE]; float xval[IQ3S_BLOCK_SIZE]; @@ -11674,7 +11674,7 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4]; uint8_t block_signs[IQ3S_BLOCK_SIZE/8]; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights, scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs); src += n_per_row; @@ -11683,13 +11683,13 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, return nrow * nblock * sizeof(block_iq3_s); } -void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_s * restrict y = vy; quantize_row_iq3_s_reference(x, y, k); } -void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) { +void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); } @@ -11822,7 +11822,7 @@ static int iq1_sort_helper(const void * left, const void * right) { #define IQ1S_BLOCK_SIZE 32 #define IQ1M_BLOCK_SIZE 16 -static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights, +static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, float * scales, float * weight, float * sumx, @@ -11846,7 +11846,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy block_iq1_s * y = vy; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; const int block_size = IQ1S_BLOCK_SIZE; @@ -11980,7 +11980,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy } } -size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); float scales[QK_K/IQ1S_BLOCK_SIZE]; float weight[IQ1S_BLOCK_SIZE]; @@ -11990,9 +11990,9 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, float pairs[2*IQ1S_BLOCK_SIZE]; uint16_t index[IQ1S_BLOCK_SIZE/8]; int8_t shifts[QK_K/IQ1S_BLOCK_SIZE]; - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts); src += n_per_row; qrow += nblock*sizeof(block_iq1_s); @@ -12000,7 +12000,7 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, return nrow * nblock * sizeof(block_iq1_s); } -static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights, +static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, float * scales, float * weight, float * pairs, @@ -12022,7 +12022,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy block_iq1_m * y = vy; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; const int block_size = IQ1M_BLOCK_SIZE; @@ -12265,7 +12265,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy } } -size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); float scales[QK_K/IQ1M_BLOCK_SIZE]; float weight[IQ1M_BLOCK_SIZE]; @@ -12273,9 +12273,9 @@ size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, float pairs[2*IQ1M_BLOCK_SIZE]; uint16_t index[IQ1M_BLOCK_SIZE/8]; int8_t shifts[QK_K/IQ1M_BLOCK_SIZE]; - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts); src += n_per_row; qrow += nblock*sizeof(block_iq1_m); @@ -12407,16 +12407,16 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block } } -size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK4_NL == 0); - int nblock = n_per_row/QK4_NL; + int64_t nblock = n_per_row/QK4_NL; char * qrow = (char *)dst; uint8_t L[QK4_NL]; float weight[QK4_NL]; uint16_t unused_h; uint8_t * unused_l = NULL; float scale; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { block_iq4_nl * iq4 = (block_iq4_nl *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; @@ -12429,9 +12429,9 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow return nrow * nblock * sizeof(block_iq4_nl); } -void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) { GGML_ASSERT(k%QK4_NL == 0); - int nblock = k/QK4_NL; + int64_t nblock = k/QK4_NL; uint8_t L[QK4_NL]; float weight[QK4_NL]; uint16_t unused_h; @@ -12444,22 +12444,22 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) { } } -void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) { +void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { assert(k % QK4_NL == 0); quantize_row_iq4_nl(x, y, k); } -size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { #if QK_K == 64 return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights); #else GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; uint8_t L[QK_K]; float weight[32]; float scales[QK_K/32]; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { block_iq4_xs * iq4 = (block_iq4_xs *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; @@ -12473,20 +12473,20 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow #endif } -void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq4_xs * restrict y = vy; quantize_row_iq4_xs_reference(x, y, k); } -void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) { +void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); } // =============================== 2.5625 bpw -static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_S); @@ -12501,7 +12501,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy const int kMaxQ = 3; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq2_s * y = vy; @@ -12654,11 +12654,11 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy } } -size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq2_s); @@ -12666,12 +12666,12 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow, return nrow * nblock * sizeof(block_iq2_s); } -void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int k) { +void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_s(x, y, 1, k, NULL); } -void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq2_s * restrict y = vy; quantize_row_iq2_s_reference(x, y, k); diff --git a/ggml-quants.h b/ggml-quants.h index ac1091c3d3b66..4d436a8f06b3e 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -12,70 +12,70 @@ extern "C" { #endif // Quantization -void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int k); -void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int k); -void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int k); -void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int k); -void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int k); -void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int k); - -void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k); -void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int k); -void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int k); -void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int k); -void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k); -void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k); - -void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k); -void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k); -void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k); -void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k); -void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k); - -void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); - -void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); - -void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); + +void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); - -void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); - -void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -101,26 +101,26 @@ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); - -size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); +size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); void iq2xs_init_impl(enum ggml_type type); void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml.c b/ggml.c index ebaf06683ca9e..912a8ca2e286f 100644 --- a/ggml.c +++ b/ggml.c @@ -338,14 +338,14 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { return GGML_FP32_TO_FP16(x); } -void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) { - for (int i = 0; i < n; i++) { +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { + for (int64_t i = 0; i < n; i++) { y[i] = GGML_FP16_TO_FP32(x[i]); } } -void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) { - int i = 0; +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { + int64_t i = 0; #if defined(__F16C__) for (; i + 7 < n; i += 8) { __m256 x_vec = _mm256_loadu_ps(x + i); @@ -20331,9 +20331,9 @@ size_t ggml_quantize_chunk( enum ggml_type type, const float * src, void * dst, - int start, - int nrows, - int n_per_row, + int64_t start, + int64_t nrows, + int64_t n_per_row, const float * imatrix) { const size_t n = (size_t) nrows * n_per_row; diff --git a/ggml.h b/ggml.h index 5cef45c0ba4ad..abe3767f22418 100644 --- a/ggml.h +++ b/ggml.h @@ -332,8 +332,8 @@ extern "C" { GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); - GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n); - GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n); struct ggml_object; struct ggml_context; @@ -2210,9 +2210,9 @@ extern "C" { enum ggml_type type, const float * src, void * dst, - int start, - int nrows, - int n_per_row, + int64_t start, + int64_t nrows, + int64_t n_per_row, const float * imatrix); // @@ -2377,8 +2377,8 @@ extern "C" { #else #define GGML_RESTRICT restrict #endif - typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); - typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, const void * GGML_RESTRICT y, size_t by, int nrc); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9ed7bbe183e46..43307a8ffb015 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -639,7 +639,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.ATTN_K_NORM, - MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_Q_NORM, ], # TODO } diff --git a/llama.cpp b/llama.cpp index 934bbc25aed1f..280131691e0e6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5406,8 +5406,7 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - if (n_layer >= 64) - { + if (n_layer >= 64){ layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}); layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); } @@ -9460,17 +9459,7 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - if (model.layers[il].attn_q_norm) - { - struct ggml_tensor * attn_q_norm = model.layers[il].attn_q_norm; - struct ggml_tensor * attn_k_norm = model.layers[il].attn_k_norm; - - // CPU did not like F16, so cast to F32 - attn_q_norm = ggml_cast(ctx0, attn_q_norm, GGML_TYPE_F32); - cb(attn_q_norm, "attn_q_norm_cast_F32", il); - attn_k_norm = ggml_cast(ctx0, attn_k_norm, GGML_TYPE_F32); - cb(attn_k_norm, "attn_k_norm_cast_F32", il); - + if (model.layers[il].attn_q_norm) { Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur) * n_embd_head, ggml_element_size(Qcur) * n_embd_head * n_head, @@ -9481,15 +9470,15 @@ struct llm_build_context { ggml_element_size(Kcur) * n_embd_head * n_head_kv, 0); cb(Kcur, "Kcur", il); - + Qcur = llm_build_norm(ctx0, Qcur, hparams, - attn_q_norm, + model.layers[il].attn_q_norm, NULL, LLM_NORM, cb, il); cb(Qcur, "Qcur", il); Kcur = llm_build_norm(ctx0, Kcur, hparams, - attn_k_norm, + model.layers[il].attn_k_norm, NULL, LLM_NORM, cb, il); cb(Kcur, "Kcur", il); @@ -13105,15 +13094,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return std::make_pair(i_layer, n_layer); }; - // Command-R+ has such a large embedding weight tensor it overflows - // 32-bit signed integers. This is band-aid until quants can deal with - // that. - if (name == "token_embd.weight" && arch == LLM_ARCH_COMMAND_R && qs.model.hparams.n_layer >= 64) { - new_type = GGML_TYPE_F16; - } // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings // with the quantization of the output tensor - else if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { @@ -13145,11 +13128,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_IQ3_S; } } - - } else if ((arch == LLM_ARCH_COMMAND_R) && - (name.find("q_norm") != std::string::npos || - name.find("k_norm") != std::string::npos)) { - new_type = GGML_TYPE_F32; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { if (name.find("attn_v.weight") != std::string::npos) { @@ -13372,9 +13350,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return new_type; } -static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, const float * imatrix, std::vector & workers, const int nthread) { +static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { std::mutex mutex; - int counter = 0; + int64_t counter = 0; size_t new_size = 0; if (nthread < 2) { // single-thread @@ -13382,11 +13360,11 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa } auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix]() { - const int nrows_per_chunk = chunk_size / n_per_row; + const int64_t nrows_per_chunk = chunk_size / n_per_row; size_t local_size = 0; while (true) { std::unique_lock lock(mutex); - int first_row = counter; counter += nrows_per_chunk; + int64_t first_row = counter; counter += nrows_per_chunk; if (first_row >= nrows) { if (local_size > 0) { new_size += local_size; @@ -13394,7 +13372,7 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa break; } lock.unlock(); - const int this_nrow = std::min(nrows - first_row, nrows_per_chunk); + const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); } }; @@ -13583,6 +13561,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= !params->only_copy; @@ -13629,7 +13611,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_size = ggml_nbytes(tensor); LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); } else { - const size_t nelements = ggml_nelements(tensor); + const int64_t nelements = ggml_nelements(tensor); const float * imatrix = nullptr; if (imatrix_data) { @@ -13681,20 +13663,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); fflush(stdout); - if (work.size() < nelements * 4) { + if (work.size() < (size_t)nelements * 4) { work.resize(nelements * 4); // upper bound on size } new_data = work.data(); - const int n_per_row = tensor->ne[0]; - const int nrows = tensor->ne[1]; + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; - static const int min_chunk_size = 32 * 512; - const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); - const int nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; // quantize each expert separately since they have different importance matrices new_size = 0; From d2924073ee9bdd600d22ded4e2d5fe30e69783a7 Mon Sep 17 00:00:00 2001 From: Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> Date: Sun, 7 Apr 2024 15:50:24 +0100 Subject: [PATCH 13/14] Type convention Co-authored-by: Georgi Gerganov --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 912a8ca2e286f..793b67f4c7020 100644 --- a/ggml.c +++ b/ggml.c @@ -20335,7 +20335,7 @@ size_t ggml_quantize_chunk( int64_t nrows, int64_t n_per_row, const float * imatrix) { - const size_t n = (size_t) nrows * n_per_row; + const int64_t n = (int64_t) nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); From ea1aeba48b42f5f6ec41412ef271f5d708b0fa2f Mon Sep 17 00:00:00 2001 From: S Date: Mon, 8 Apr 2024 22:47:46 +0100 Subject: [PATCH 14/14] dranger003: Fix more int overflow during quant. --- ggml-cuda.cu | 6 ++-- ggml-cuda/convert.cu | 74 +++++++++++++++++++++--------------------- ggml-cuda/convert.cuh | 2 +- ggml-cuda/dmmv.cu | 2 +- ggml-cuda/quantize.cu | 16 ++++----- ggml-cuda/quantize.cuh | 2 +- 6 files changed, 51 insertions(+), 51 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f51b2042df359..5616e2348aa85 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas( // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = id == ctx.device ? ne0 : row_diff; + int64_t ldc = id == ctx.device ? ne0 : row_diff; const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat( const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 18a31edc34f83..ed4fa2748972b 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -4,14 +4,14 @@ #define CUDA_Q8_0_NE_ALIGN 2048 template -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { - const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { + const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (i >= k) { return; } - const int ib = i/qk; // block index + const int64_t ib = i/qk; // block index const int iqs = (i%qk)/qr; // quant index const int iybs = i - i%qk; // y block start index const int y_offset = qr == 1 ? 1 : qk/2; @@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ } template -static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { +static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) { #if __CUDA_ARCH__ >= CC_PASCAL constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; @@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h template static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -313,14 +313,14 @@ template static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q6_K * x = (const block_q6_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; + const int64_t tid = threadIdx.x; + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; dst_t * y = yy + i*QK_K + 128*ip + il; @@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t #else // assume 32 threads - const int tid = threadIdx.x; - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 + const int64_t tid = threadIdx.x; + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 dst_t * y = yy + i*QK_K + 16*ip + il; @@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst #endif template -static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { +static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); dequantize_block<<>>(vx, y, k); } -static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) { +static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN; if (k % CUDA_Q8_0_NE_ALIGN == 0) { const bool need_check = false; @@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * } template -static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); @@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); @@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb32 = k / 32; const int nb = (k + 255) / 256; dequantize_block_q4_0<<>>(vx, y, nb32); } template -static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb32 = k / 32; const int nb = (k + 255) / 256; dequantize_block_q4_1<<>>(vx, y, nb32); } template -static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } template -static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); @@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); @@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq2_xxs<<>>(vx, y); } template -static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq2_xs<<>>(vx, y); } template -static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq2_s<<>>(vx, y); } template -static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq3_xxs<<>>(vx, y); } template -static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq3_s<<>>(vx, y); } template -static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq1_s<<>>(vx, y); } template -static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq4_nl<<>>(vx, y); } template -static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq1_m<<>>(vx, y); } template -static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; #if QK_K == 64 dequantize_block_iq4_nl<<>>(vx, y); @@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, } template -static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; @@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res } template -static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { +static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; convert_unary<<>>(vx, y, k); } diff --git a/ggml-cuda/convert.cuh b/ggml-cuda/convert.cuh index db34c0be96449..5394be9f161b3 100644 --- a/ggml-cuda/convert.cuh +++ b/ggml-cuda/convert.cuh @@ -3,7 +3,7 @@ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 template -using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); +using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream); typedef to_t_cuda_t to_fp32_cuda_t; typedef to_t_cuda_t to_fp16_cuda_t; diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 3097a951092ab..7313e3e175367 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -577,7 +577,7 @@ template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block - const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index a1fbc9932122f..7578c4b6c7cab 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,20 +1,20 @@ #include "quantize.cuh" -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { - const int ix = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) { + const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (ix >= kx_padded) { return; } - const int iy = blockDim.y*blockIdx.y + threadIdx.y; + const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y; - const int i_padded = iy*kx_padded + ix; + const int64_t i_padded = (int64_t)iy*kx_padded + ix; block_q8_1 * y = (block_q8_1 *) vy; - const int ib = i_padded / QK8_1; // block index - const int iqs = i_padded % QK8_1; // quant index + const int64_t ib = i_padded / QK8_1; // block index + const int64_t iqs = i_padded % QK8_1; // quant index const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; float amax = fabsf(xi); @@ -36,8 +36,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { - const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; +void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) { + const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ky, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); quantize_q8_1<<>>(x, vy, kx, kx_padded); diff --git a/ggml-cuda/quantize.cuh b/ggml-cuda/quantize.cuh index adb89c83aac85..b37a4752f2d24 100644 --- a/ggml-cuda/quantize.cuh +++ b/ggml-cuda/quantize.cuh @@ -2,4 +2,4 @@ #define CUDA_QUANTIZE_BLOCK_SIZE 256 -void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream); +void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);