Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llamafile : use 64-bit integers in sgemm #6928

Merged
merged 1 commit into from
Apr 26, 2024
Merged

Conversation

jart
Copy link
Contributor

@jart jart commented Apr 26, 2024

This PR fixes a regression with Command-R-Plus support, as reported by @fairydreaming. I'm still waiting to confirm CRP to download and quantize so I can confirm it myself too. However this should be a safe merge. See #6796 for the original discussion.

cc: @ggerganov who wants this fast-tracked.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested Command-R Plus and it no longer crashes. Thanks

@ggerganov ggerganov merged commit 4b1c3c9 into ggerganov:master Apr 26, 2024
44 of 55 checks passed
@b4rtaz
Copy link

b4rtaz commented Apr 27, 2024

Hello @jart, sorry for offtopic. I'm trying to use the sgemm.cpp file in Distributed Llama project. I'm able to observe some acceleration on my mac (F32 * F32) but unfortunetly I cannot observe any sacceleration on Raspberry PI 5 (Q80 * Q40). In fact, I'm seeing a slowdown.

My setup: Raspberry Pi 5 8GB, quantized Llama 2 7B to Q40, matmul inputs are quantized to Q80.

llamafile-sgemm branch with the sgemm.cpp implementation, tested on Raspberry Pi 5 8GB:

b4rtaz@raspberrypi3:~/distributed-llama $ sudo nice -n -20 ./main inference --model ../dllama_llama-2-7b_q40.bin --tokenizer ../dllama-llama2-tokenizer.t --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4
...
🔶 G  646 ms I  641 ms T    5 ms S      0 kB R      0 kB  while
Generated tokens:    16
Avg generation time: 656.06 ms

main branch, matmul is basically the same as in llama.cpp:

...
🔶 G  452 ms I  439 ms T   13 ms S      0 kB R      0 kB  one
Generated tokens:    16
Avg generation time: 448.56 ms

Were you able to observe any acceleration on your side?

@jart
Copy link
Contributor Author

jart commented Apr 27, 2024

@b4rtaz I haven't done anything to speed up token prediction. My code only makes evaluation faster. This is actually what you want. Prompt processing is where the real ring of power is at in AI. Try passing a prompt that's longer than two words and watch what happens. Here's what I get with a 221 token prompt on RPI5.

image

Now compare that to llama.cpp back in March.

image

As we can see, you can now analyze information 47% faster using Q4_0 than you could before, thanks to the llamafile_sgemm() function. If you were to use F16 weights, then I've seen tokens per second on RPI5 go as high as 80 per second with TinyLLaMA, which is almost too good to be true.

@b4rtaz
Copy link

b4rtaz commented Apr 27, 2024

Ahhh, by some reason I assumed that the acceleration is observed also for the matrix multiplication with one input row (output = weights * 1 input row). It seems llamafile_sgemm works faster ONLY for 2 and 3 input rows.

I created a simple benchmark F32xF32 (tested on Raspberry Pi 5 8GB) to double check it.

(1024,1) * (1024,1024) = (1024,1) -- 0.000168ms / single output number // 0.000168ms takes calculation a single number from the output
(1024,2) * (1024,1024) = (1024,2) -- 0.000097ms / single output number // 🚀 FAST
(1024,3) * (1024,1024) = (1024,3) -- 0.000076ms / single output number // 🚀 FASTEST
(1024,4) * (1024,1024) = (1024,4) -- 0.000181ms / single output number
(1024,5) * (1024,1024) = (1024,5) -- 0.000180ms / single output number
(1024,6) * (1024,1024) = (1024,6) -- 0.000178ms / single output number
(1024,7) * (1024,1024) = (1024,7) -- 0.000156ms / single output number
(1024,8) * (1024,1024) = (1024,8) -- 0.000141ms / single output number
(1024,9) * (1024,1024) = (1024,9) -- 0.000180ms / single output number
(1024,10) * (1024,1024) = (1024,10) -- 0.000180ms / single output number
(1024,11) * (1024,1024) = (1024,11) -- 0.000179ms / single output number
(1024,12) * (1024,1024) = (1024,12) -- 0.000166ms / single output number

(2024,1) * (2024,2024) = (2024,1) -- 0.000331ms / single output number
(2024,2) * (2024,2024) = (2024,2) -- 0.000192ms / single output number // 🚀 FAST
(2024,3) * (2024,2024) = (2024,3) -- 0.000149ms / single output number // 🚀 FASTEST
(2024,4) * (2024,2024) = (2024,4) -- 0.000352ms / single output number
(2024,5) * (2024,2024) = (2024,5) -- 0.000348ms / single output number
(2024,6) * (2024,2024) = (2024,6) -- 0.000345ms / single output number
(2024,7) * (2024,2024) = (2024,7) -- 0.000303ms / single output number
(2024,8) * (2024,2024) = (2024,8) -- 0.000274ms / single output number
(2024,9) * (2024,2024) = (2024,9) -- 0.000358ms / single output number
Source code of the benchmark
#include "funcs.hpp"
#include "utils.hpp"
#include "llamafile-sgemm.hpp"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <cassert>

void test(int n, int d, int p) {
    float* x = new float[n * p];
    float* w = new float[n * d];
    float* out = new float[d * p];

    unsigned long t0 = timeMs();

    const int attempts = 1024;
    for (int attempt = 0; attempt < attempts; attempt++)
        assert(llamafile_sgemm(
            p, d , n,
            x, n,
            w, n,
            out, p,
            0, 1, 0,
            F32, F32, F32
        ) == true);

    unsigned long t1 = timeMs();
    printf("(%d,%d) * (%d,%d) = (%d,%d) -- %fms / single output number\n", n, p, d, n, d, p, (t1 - t0) / (float)(d * p * attempts));

    delete[] x;
    delete[] w;
    delete[] out;
}

int main() {
    initQuants();

    for (int p = 1; p < 32; p++)
        test(1024, 1024, p);
    return EXIT_SUCCESS;
}

Thanks for your answer!


Edit: what is interesting this behaviour is not visible for Q80xQ80 multiplication.

(2048,1) * (2048,2048) = (2048,1) -- 0.000552ms / single output number
(2048,2) * (2048,2048) = (2048,2) -- 0.000529ms / single output number
(2048,3) * (2048,2048) = (2048,3) -- 0.000499ms / single output number
(2048,4) * (2048,2048) = (2048,4) -- 0.000512ms / single output number
(2048,5) * (2048,2048) = (2048,5) -- 0.000510ms / single output number

@b4rtaz
Copy link

b4rtaz commented Apr 27, 2024

@jart BTW: if llamafile_sgemm is fastest for (p,3) input vectors, should be this part chunked every 3 output rows? 🤔

        for (int64_t i13 = 0; i13 < ne13; i13++)
            for (int64_t i12 = 0; i12 < ne12; i12++)
                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),

Edit: probably this effect occurs only for F32xF32 multiplication.

@jart
Copy link
Contributor Author

jart commented Apr 28, 2024

I'm afraid I don't understand. You're doing matrix-vector multiplications. The llamafile_sgemm() function is intended for matrix-matrix multiplications. Please see https://justine.lol/matmul/ You should also measure measure speed in terms of flops, i.e. m*n*k. The more computations it does per second, the faster it goes.

@b4rtaz
Copy link

b4rtaz commented Apr 28, 2024

Sorry, too much noise from my side. I've just confirmed a speedup on RasPi 5 8GB by using the latest llama.cpp version and Llama 2 7B Q4_0 model. Great job! 👌

Llamafile enabled

llama_print_timings: prompt eval time =   13458.83 ms /   109 tokens (  123.48 ms per token,     8.10 tokens per second)
logs
b4rtaz@raspberrypi3:~/llama.cpp $ make main
...
b4rtaz@raspberrypi3:~/llama.cpp $ ./main -m ../ggml-model-f32_q40_1.gguf -p "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur." -n 128 --threads 4
Log start
main: build = 2751 (4dba7e81)
main: built with cc (Debian 12.2.0-14) 12.2.0 for aarch64-linux-gnu
main: seed  = 1714293910
llama_model_loader: loaded meta data with 17 key-value pairs and 291 tensors from ../ggml-model-f32_q40_1.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                           llama.vocab_size u32              = 32000
llama_model_loader: - kv   3:                       llama.context_length u32              = 4096
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                          llama.block_count u32              = 32
llama_model_loader: - kv   6:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   7:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 32
llama_model_loader: - kv  10:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  11:                          general.file_type u32              = 2
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_0:  225 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 11008
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_0
llm_load_print_meta: model params     = 6.74 B
llm_load_print_meta: model size       = 3.56 GiB (4.54 BPW) 
llm_load_print_meta: general.name     = LLaMA v2
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.15 MiB
llm_load_tensors:        CPU buffer size =  3647.87 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =   256.00 MiB
llama_new_context_with_model: KV self size  =  256.00 MiB, K (f16):  128.00 MiB, V (f16):  128.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.12 MiB
llama_new_context_with_model:        CPU compute buffer size =    70.50 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 4 / 4 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LAMMAFILE = 1 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 128, n_keep = 1


<s> Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
The Best of the Web's Free Web Design Resources
The Best of the Web's Free Web Design Resources Here are some of the best free web design resources available on the web today. They're all free to download and use on your own websites. If you are just starting out in web design or are an experienced web designer looking for new ideas to inspire you, then you need to download these. The Internet is a big place and it
llama_print_timings:        load time =   10343.57 ms
llama_print_timings:      sample time =       5.95 ms /   128 runs   (    0.05 ms per token, 21508.99 tokens per second)
llama_print_timings: prompt eval time =   13458.83 ms /   109 tokens (  123.48 ms per token,     8.10 tokens per second)
llama_print_timings:        eval time =   53722.72 ms /   127 runs   (  423.01 ms per token,     2.36 tokens per second)
llama_print_timings:       total time =   67220.32 ms /   236 tokens
Log end

Llamafile disabled

llama_print_timings: prompt eval time =   20388.32 ms /   109 tokens (  187.05 ms per token,     5.35 tokens per second)
logs
b4rtaz@raspberrypi3:~/llama.cpp $ LLAMA_NO_LLAMAFILE=1 make main
...
b4rtaz@raspberrypi3:~/llama.cpp $ ./main -m ../ggml-model-f32_q40_1.gguf -p "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur." -n 128 --threads 4
Log start
main: build = 2751 (4dba7e81)
main: built with cc (Debian 12.2.0-14) 12.2.0 for aarch64-linux-gnu
main: seed  = 1714294134
llama_model_loader: loaded meta data with 17 key-value pairs and 291 tensors from ../ggml-model-f32_q40_1.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                           llama.vocab_size u32              = 32000
llama_model_loader: - kv   3:                       llama.context_length u32              = 4096
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                          llama.block_count u32              = 32
llama_model_loader: - kv   6:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   7:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 32
llama_model_loader: - kv  10:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  11:                          general.file_type u32              = 2
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_0:  225 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 11008
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_0
llm_load_print_meta: model params     = 6.74 B
llm_load_print_meta: model size       = 3.56 GiB (4.54 BPW) 
llm_load_print_meta: general.name     = LLaMA v2
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.15 MiB
llm_load_tensors:        CPU buffer size =  3647.87 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =   256.00 MiB
llama_new_context_with_model: KV self size  =  256.00 MiB, K (f16):  128.00 MiB, V (f16):  128.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.12 MiB
llama_new_context_with_model:        CPU compute buffer size =    70.50 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 4 / 4 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LAMMAFILE = 0 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 128, n_keep = 1


<s> Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit
llama_print_timings:        load time =     733.14 ms
llama_print_timings:      sample time =       6.49 ms /   128 runs   (    0.05 ms per token, 19719.61 tokens per second)
llama_print_timings: prompt eval time =   20388.32 ms /   109 tokens (  187.05 ms per token,     5.35 tokens per second)
llama_print_timings:        eval time =   50432.82 ms /   127 runs   (  397.11 ms per token,     2.52 tokens per second)
llama_print_timings:       total time =   70857.12 ms /   236 tokens
Log end

nopperl pushed a commit to nopperl/llama.cpp that referenced this pull request May 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants