Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval bug: DeepScaleR-1.5B-Preview produces random tokens on AArch64 with fp16fml #11920

Open
icecream95 opened this issue Feb 17, 2025 · 1 comment

Comments

@icecream95
Copy link
Contributor

icecream95 commented Feb 17, 2025

Name and Version

version: 4733 (faaa9b93)
built with cc (Ubuntu 14.2.0-4ubuntu2) 14.2.0 for aarch64-linux-gnu

Operating systems

Linux

GGML backends

CPU

Hardware

Tested on Snapdragon X Elite and Cortex-A76.

Models

https://huggingface.co/bartowski/agentica-org_DeepScaleR-1.5B-Preview-GGUF/blob/main/agentica-org_DeepScaleR-1.5B-Preview-IQ4_NL.gguf

Other formats are broken as well, for example Q4_0 and F16.

Problem description & steps to reproduce

When llama.cpp is compiled with the fp16fml CPU flag (-DGGML_NATIVE=0 -DGGML_CPU_ARM_ARCH=armv8.2-a+fp16fml), the "DeepScaleR" model can start outputting random tokens, for example:

Certainly! Here's an organized and elegant presentation of various delicious ways to enjoy cake, categorized for clarity:
1;@0#H1D*()"H2-+G<#0--/8<=.5(F&:=$

I've found that ggml_vec_dot_f16 can return values as high as 321000 for this model. That's okay without fp16fml, since single-precision accumulators are used.

But with fp16fml, the intermediate values are summed in fp16, where the maximum normalized number is 65504. So the values are about five times larger than what can be represented. Some of the time it still works, because accumulation is done into the sum vector, and so while the returned sumf would overflow a half-precision variable, the value is spread out over multiple elements of the vector. But other times, a single one of the 32 accumulator elements overflows, and so ggml_vec_dot_f16 returns inf.

(The code in sgemm.cpp exhibits similar issues, but for this testing I've completely disabled the FP16 case there.)

What are possible solutions?

I guess what would make the most sense is to use a scale factor somewhere, for example dividing by eight or sixteen for this model.

There are a few other possibilities I can think of:

  • Set the "Alternate half-precision control" bit (which will effectively saturate instead of returning infinite values)
  • Armv8.4 FEAT_FHM which accumulates to single precision (but might be twice as slow)
  • Armv8.6 BFloat16
  • Use isfinite and handle overflow when it happens somehow
  • Increase GGML_F16_STEP and make GGML_F16x8_REDUCE do everything in single-precision

First Bad Commit

No response

Relevant log output

$ ./bin/llama-cli -t 2 -m agentica-org_DeepScaleR-1.5B-Preview-IQ4_NL.gguf -no-cnv -p Hello -n 100
build: 4731 (0f2bbe65) with cc (Ubuntu 14.2.0-4ubuntu2) 14.2.0 for aarch64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 51 key-value pairs and 339 tensors from agentica-org_DeepScaleR-1.5B-Preview-IQ4_NL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = DeepScaleR 1.5B Preview
llama_model_loader: - kv   3:                       general.organization str              = Agentica Org
llama_model_loader: - kv   4:                           general.finetune str              = Preview
llama_model_loader: - kv   5:                           general.basename str              = DeepScaleR
llama_model_loader: - kv   6:                         general.size_label str              = 1.5B
llama_model_loader: - kv   7:                            general.license str              = mit
llama_model_loader: - kv   8:                   general.base_model.count u32              = 1
llama_model_loader: - kv   9:                  general.base_model.0.name str              = DeepSeek R1 Distill Qwen 1.5B
llama_model_loader: - kv  10:          general.base_model.0.organization str              = Deepseek Ai
llama_model_loader: - kv  11:              general.base_model.0.repo_url str              = https://huggingface.co/deepseek-ai/De...
llama_model_loader: - kv  12:                      general.dataset.count u32              = 4
llama_model_loader: - kv  13:                     general.dataset.0.name str              = NuminaMath CoT
llama_model_loader: - kv  14:             general.dataset.0.organization str              = AI MO
llama_model_loader: - kv  15:                 general.dataset.0.repo_url str              = https://huggingface.co/AI-MO/NuminaMa...
llama_model_loader: - kv  16:                     general.dataset.1.name str              = Omni MATH
llama_model_loader: - kv  17:             general.dataset.1.organization str              = KbsdJames
llama_model_loader: - kv  18:                 general.dataset.1.repo_url str              = https://huggingface.co/KbsdJames/Omni...
llama_model_loader: - kv  19:                     general.dataset.2.name str              = STILL 3 Preview RL Data
llama_model_loader: - kv  20:             general.dataset.2.organization str              = RUC AIBOX
llama_model_loader: - kv  21:                 general.dataset.2.repo_url str              = https://huggingface.co/RUC-AIBOX/STIL...
llama_model_loader: - kv  22:                     general.dataset.3.name str              = Competition_Math
llama_model_loader: - kv  23:             general.dataset.3.organization str              = Hendrycks
llama_model_loader: - kv  24:                 general.dataset.3.repo_url str              = https://huggingface.co/hendrycks/comp...
llama_model_loader: - kv  25:                          general.languages arr[str,1]       = ["en"]
llama_model_loader: - kv  26:                          qwen2.block_count u32              = 28
llama_model_loader: - kv  27:                       qwen2.context_length u32              = 131072
llama_model_loader: - kv  28:                     qwen2.embedding_length u32              = 1536
llama_model_loader: - kv  29:                  qwen2.feed_forward_length u32              = 8960
llama_model_loader: - kv  30:                 qwen2.attention.head_count u32              = 12
llama_model_loader: - kv  31:              qwen2.attention.head_count_kv u32              = 2
llama_model_loader: - kv  32:                       qwen2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  33:     qwen2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  34:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  35:                         tokenizer.ggml.pre str              = deepseek-r1-qwen
llama_model_loader: - kv  36:                      tokenizer.ggml.tokens arr[str,151936]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  37:                  tokenizer.ggml.token_type arr[i32,151936]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  38:                      tokenizer.ggml.merges arr[str,151387]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  39:                tokenizer.ggml.bos_token_id u32              = 151646
llama_model_loader: - kv  40:                tokenizer.ggml.eos_token_id u32              = 151643
llama_model_loader: - kv  41:            tokenizer.ggml.padding_token_id u32              = 151643
llama_model_loader: - kv  42:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  43:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  44:                    tokenizer.chat_template str              = {% if not add_generation_prompt is de...
llama_model_loader: - kv  45:               general.quantization_version u32              = 2
llama_model_loader: - kv  46:                          general.file_type u32              = 25
llama_model_loader: - kv  47:                      quantize.imatrix.file str              = /models_out/DeepScaleR-1.5B-Preview-G...
llama_model_loader: - kv  48:                   quantize.imatrix.dataset str              = /training_dir/calibration_datav3.txt
llama_model_loader: - kv  49:             quantize.imatrix.entries_count i32              = 196
llama_model_loader: - kv  50:              quantize.imatrix.chunks_count i32              = 128
llama_model_loader: - type  f32:  141 tensors
llama_model_loader: - type q5_K:   28 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type iq4_nl:  169 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = IQ4_NL - 4.5 bpw
print_info: file size   = 1012.47 MiB (4.78 BPW) 
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 22
load: token to piece cache size = 0.9310 MB
print_info: arch             = qwen2
print_info: vocab_only       = 0
print_info: n_ctx_train      = 131072
print_info: n_embd           = 1536
print_info: n_layer          = 28
print_info: n_head           = 12
print_info: n_head_kv        = 2
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 6
print_info: n_embd_k_gqa     = 256
print_info: n_embd_v_gqa     = 256
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: n_ff             = 8960
print_info: n_expert         = 0
print_info: n_expert_used    = 0
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 2
print_info: rope scaling     = linear
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 131072
print_info: rope_finetuned   = unknown
print_info: ssm_d_conv       = 0
print_info: ssm_d_inner      = 0
print_info: ssm_d_state      = 0
print_info: ssm_dt_rank      = 0
print_info: ssm_dt_b_c_rms   = 0
print_info: model type       = 1.5B
print_info: model params     = 1.78 B
print_info: general.name     = DeepScaleR 1.5B Preview
print_info: vocab type       = BPE
print_info: n_vocab          = 151936
print_info: n_merges         = 151387
print_info: BOS token        = 151646 '<|begin▁of▁sentence|>'
print_info: EOS token        = 151643 '<|end▁of▁sentence|>'
print_info: EOT token        = 151643 '<|end▁of▁sentence|>'
print_info: PAD token        = 151643 '<|end▁of▁sentence|>'
print_info: LF token         = 198 'Ċ'
print_info: FIM PRE token    = 151659 '<|fim_prefix|>'
print_info: FIM SUF token    = 151661 '<|fim_suffix|>'
print_info: FIM MID token    = 151660 '<|fim_middle|>'
print_info: FIM PAD token    = 151662 '<|fim_pad|>'
print_info: FIM REP token    = 151663 '<|repo_name|>'
print_info: FIM SEP token    = 151664 '<|file_sep|>'
print_info: EOG token        = 151643 '<|end▁of▁sentence|>'
print_info: EOG token        = 151662 '<|fim_pad|>'
print_info: EOG token        = 151663 '<|repo_name|>'
print_info: EOG token        = 151664 '<|file_sep|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:   CPU_Mapped model buffer size =  1012.47 MiB
........................................................................
llama_init_from_model: n_seq_max     = 1
llama_init_from_model: n_ctx         = 4096
llama_init_from_model: n_ctx_per_seq = 4096
llama_init_from_model: n_batch       = 2048
llama_init_from_model: n_ubatch      = 512
llama_init_from_model: flash_attn    = 0
llama_init_from_model: freq_base     = 10000.0
llama_init_from_model: freq_scale    = 1
llama_init_from_model: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 28, can_shift = 1
llama_kv_cache_init:        CPU KV buffer size =   112.00 MiB
llama_init_from_model: KV self size  =  112.00 MiB, K (f16):   56.00 MiB, V (f16):   56.00 MiB
llama_init_from_model:        CPU  output buffer size =     0.58 MiB
llama_init_from_model:        CPU compute buffer size =   299.75 MiB
llama_init_from_model: graph nodes  = 986
llama_init_from_model: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 2

system_info: n_threads = 2 (n_threads_batch = 2) / 12 | CPU : NEON = 1 | ARM_FMA = 1 | FP16_VA = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 

sampler seed: 1846766335
sampler params: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = 100, n_keep = 1

Hello,!4CC".)E(-797C;H@CF!=<9C#!-)"FDB>@%4$5-15F5-7$E>&:@7=%':ED'5E*>"";D!HE*#@6:C9->=20B-8"49-(5GH=1H,G/

llama_perf_sampler_print:    sampling time =       9.37 ms /   102 runs   (    0.09 ms per token, 10881.16 tokens per second)
llama_perf_context_print:        load time =     563.06 ms
llama_perf_context_print: prompt eval time =     136.93 ms /     2 tokens (   68.47 ms per token,    14.61 tokens per second)
llama_perf_context_print:        eval time =    7192.07 ms /    99 runs   (   72.65 ms per token,    13.77 tokens per second)
llama_perf_context_print:       total time =    7367.57 ms /   101 tokens
@icecream95
Copy link
Contributor Author

This patch shows that scaling values can fix the problem, but ideally the scale would be applied at load time rather than during computation.

diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index dbef5df2..326ebcf2 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -1469,9 +1469,13 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
     GGML_F16_VEC ax[GGML_F16_ARR];
     GGML_F16_VEC ay[GGML_F16_ARR];
 
+    ggml_float scaleup = 8.0;
+
+    GGML_F16_VEC scaledown = GGML_F16_VEC_SET1(1.0 / scaleup);
+
     for (int i = 0; i < np; i += GGML_F16_STEP) {
         for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ax[j] = GGML_F16_VEC_MUL(GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j), scaledown);
             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
 
             sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
@@ -1481,6 +1485,8 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
     // reduce sum0..sum3 to sum0
     GGML_F16_VEC_REDUCE(sumf, sum);
 
+    sumf *= scaleup;
+
     // leftovers
     for (int i = np; i < n; ++i) {
         sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index e0482c59..6b7c8785 100644
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -2446,6 +2446,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
         return false;
     }
     case GGML_TYPE_F16: {
+        return false;
 #if defined(__AVX512F__)
         if (Btype == GGML_TYPE_F16) {
             tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant