Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add grok-1 support #6204

Merged
merged 6 commits into from
Mar 23, 2024
Merged

Add grok-1 support #6204

merged 6 commits into from
Mar 23, 2024

Conversation

arki05
Copy link
Contributor

@arki05 arki05 commented Mar 21, 2024

This pull request adds grok-1 support to llama.cpp (#6120).

I've added a separate MODEL_ARCH_GROK as to not clutter the LLAMA arch too much.
The convert-hf-to-gguf.py can convert from keyfan/grok-1-hf to GGUF now. I've started uploading Quants in Split-GGUF format to Arki05/Grok-1-GGUF, might take a while due to the size.

For now, the Graph includes a few hardcoded values like attn_output_multiplyer that were included in the original implementation. Maybe we should move those to a separate parameter, but I'm not sure what the policy / guidelines on those are.

Would a Script to convert from the base JAX weights to gguf be helpful? If so, i can work on that next.
PS: Please be gentle it's my first Pull request on here.

@ggerganov
Copy link
Owner

Good job (I noticed the fork earlier 😉 )

Would a Script to convert from the base JAX weights to gguf be helpful?

I think it is important to have such script because using the F16 weights to re-quantize to Q8_0 would lead to some precision loss compared to straight up converting the JAX 8-bit data to Q8_0

But this is already a great start and we can add the script in another PR

@ggerganov
Copy link
Owner

Hm, we definitely need the JAX -> GGUF script - converting with convert-hf-to-gguf.py requires to load the entire model in memory. The process gets killed on my Mac Studio:

$ python3 convert-hf-to-gguf.py ~/Data/huggingface/grok-1-hf/ --outfile ~/Data/huggingface/grok-1-hf/ggml-model-f16.gguf --outtype f16

...

blk.18.ffn_gate.3.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_up.4.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_down.4.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_gate.4.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_up.5.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_down.5.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_gate.5.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_up.6.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_down.6.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_gate.6.weight, n_dims = 2, torch.bfloat16 --> float16
blk.18.ffn_up.7.weight, n_dims = 2, torch.bfloat16 --> float16
Killed: 9

@arki05
Copy link
Contributor Author

arki05 commented Mar 22, 2024

Setting use_temp_file=True in the gguf_writer should solve this (it's what I used as well). I just didn't want to change the default values/behavior for people with way more RAM.

But yes I'm already working on a JAX to GGUF script.

@ggerganov
Copy link
Owner

IQ3_S does not look very coherent. Not sure if it is the low-quantization without imatrix, or something else is wrong:

./main -m ./models/grok-1/ggml-model-iq3_s.gguf -p "The answer to life the universe and everything is of course" -s 1 -n 64 -ngl 99
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M2 Ultra
ggml_metal_init: picking default device: Apple M2 Ultra
ggml_metal_init: default.metallib not found, loading from source
ggml_metal_init: GGML_METAL_PATH_RESOURCES = nil
ggml_metal_init: loading '/Users/ggerganov/development/github/llama.cpp/ggml-metal.metal'
ggml_metal_init: GPU name:   Apple M2 Ultra
ggml_metal_init: GPU family: MTLGPUFamilyApple8  (1008)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 154618.82 MB
ggml_backend_metal_buffer_type_alloc_buffer: allocated buffer, size =   128.00 MiB, (131511.92 / 147456.00)
llama_kv_cache_init:      Metal KV buffer size =   128.00 MiB
llama_new_context_with_model: KV self size  =  128.00 MiB, K (f16):   64.00 MiB, V (f16):   64.00 MiB
llama_new_context_with_model:        CPU  output buffer size =   256.00 MiB
ggml_backend_metal_buffer_type_alloc_buffer: allocated buffer, size =   356.03 MiB, (131867.95 / 147456.00)
llama_new_context_with_model:      Metal compute buffer size =   356.03 MiB
llama_new_context_with_model:        CPU compute buffer size =    13.00 MiB
llama_new_context_with_model: graph nodes  = 3782
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 16 / 24 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 64, n_keep = 0


 The answer to life the universe and everything is of course 42 and the question is 42.

What is the meaning of life, the universe and everything.

But why 42, and what is the question.

The answer is 42, the question, what is the answer to life the universe and everything.

The question is what
llama_print_timings:        load time =    6511.70 ms
llama_print_timings:      sample time =       2.79 ms /    64 runs   (    0.04 ms per token, 22955.52 tokens per second)
llama_print_timings: prompt eval time =    2219.04 ms /    11 tokens (  201.73 ms per token,     4.96 tokens per second)
llama_print_timings:        eval time =    6752.31 ms /    63 runs   (  107.18 ms per token,     9.33 tokens per second)
llama_print_timings:       total time =    8992.90 ms /    74 tokens

@arki05
Copy link
Contributor Author

arki05 commented Mar 22, 2024

With Q8_0 it seems a bit better:

./main -m ../grok-1/checkpoints/ggml-model-Q8_0.gguf -p "The answer to life the universe and everything is of course" -s 1 -n 64
llama_kv_cache_init:        CPU KV buffer size =   128.00 MiB
llama_new_context_with_model: KV self size  =  128.00 MiB, K (f16):   64.00 MiB, V (f16):   64.00 MiB
llama_new_context_with_model:        CPU  output buffer size =   256.00 MiB
llama_new_context_with_model:        CPU compute buffer size =   356.03 MiB
llama_new_context_with_model: graph nodes  = 3782
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 64, n_keep = 0


 The answer to life the universe and everything is of course 42!
The problem with the question "What is the meaning of life, the universe and everything?"

In fact 42 is not a question but an answer.
What is the meaning of life, the universe and everything?
What is the meaning of life, the universe and everything?

llama_print_timings:        load time =  697926.23 ms
llama_print_timings:      sample time =       4.41 ms /    64 runs   (    0.07 ms per token, 14519.06 tokens per second)
llama_print_timings: prompt eval time =  169129.22 ms /    11 tokens (15375.38 ms per token,     0.07 tokens per second)
llama_print_timings:        eval time = 1933799.63 ms /    63 runs   (30695.23 ms per token,     0.03 tokens per second)
llama_print_timings:       total time = 2103010.48 ms /    74 tokens

(Don't have enough ram for Q8_0 so this was ~3.7TiB of IO-Read xD)

Looking at other Inference runs (on the official GitHub xai-org/grok-1) the quality seems comparable imo.

Can't find any issues with the implementation, but it can't really rule it out either.

@ggerganov
Copy link
Owner

Something is definitely wrong because the perplexity is through the roof:

./perplexity -m ./models/grok-1/ggml-model-iq3_s.gguf -f build/wikitext-2-raw/wiki.test.raw -ngl 99
system_info: n_threads = 16 / 24 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 625.506 ms
perplexity: calculating perplexity over 569 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 44.28 seconds per pass - ETA 1 hours 44.97 minutes
[1]38.0053,[2]75.8298,[3]67.8807,[4]66.3816,[5]61.0398,[6]63.8503,[7]65.4388,[8]69.9127,^C

For comparison, Mixtral 8x7B IQ3_S:

[1]3.7628,[2]4.5510,[3]5.2835,[4]5.6738,[5]5.6631,[6]5.5793,[7]5.7355,[8]5.8053,[9]5.9741,[10]6.2084,[11]6.4799,[12]6.5136,^C

I'll find some time in the next days to investigate since I can run much more efficiently the model. But we should fix this before merging

@ggerganov
Copy link
Owner

ggerganov commented Mar 22, 2024

The rope type was incorrect. It works now:

The answer to life the universe and everything is of course, 42.

The answer to how to make a great website?

Well that is a lot more complicated and not quite as easy as a one word answer.

This is a topic that I have been asked about a lot recently and I will say that there is no magic formula to creating the perfect website
[1]2.8530,[2]3.8739,[3]2.9291,[4]3.0394,[5]3.1812,[6]3.1148,[7]3.2956,[8]3.1129,
grok-1.mp4

@trholding
Copy link

Can you share your hardware / ram usage and other specs? @ggerganov and @arki05

@gelim
Copy link
Contributor

gelim commented Mar 23, 2024

Can you share your hardware / ram usage and other specs? @ggerganov and @arki05

(answering for Georgi based on the bits here and there)
Mac Studio with an M2 Ultra and 192GB of unified ram.

@arki05
Copy link
Contributor Author

arki05 commented Mar 23, 2024

I'm working on a Threadripper 3955WX with 256GB RAM. As long as i use a Version that fits into 256GB i'm getting a reasonable 0.5 tokens per second.

Log ``` llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = grok llm_load_print_meta: vocab type = SPM llm_load_print_meta: n_vocab = 131072 llm_load_print_meta: n_merges = 0 llm_load_print_meta: n_ctx_train = 8192 llm_load_print_meta: n_embd = 6144 llm_load_print_meta: n_head = 48 llm_load_print_meta: n_head_kv = 8 llm_load_print_meta: n_layer = 64 llm_load_print_meta: n_rot = 128 llm_load_print_meta: n_embd_head_k = 128 llm_load_print_meta: n_embd_head_v = 128 llm_load_print_meta: n_gqa = 6 llm_load_print_meta: n_embd_k_gqa = 1024 llm_load_print_meta: n_embd_v_gqa = 1024 llm_load_print_meta: f_norm_eps = 0.0e+00 llm_load_print_meta: f_norm_rms_eps = 1.0e-05 llm_load_print_meta: f_clamp_kqv = 0.0e+00 llm_load_print_meta: f_max_alibi_bias = 0.0e+00 llm_load_print_meta: f_logit_scale = 0.0e+00 llm_load_print_meta: n_ff = 32768 llm_load_print_meta: n_expert = 8 llm_load_print_meta: n_expert_used = 2 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = 2 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000.0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_yarn_orig_ctx = 8192 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 0 llm_load_print_meta: ssm_d_inner = 0 llm_load_print_meta: ssm_d_state = 0 llm_load_print_meta: ssm_dt_rank = 0 llm_load_print_meta: model type = 314B llm_load_print_meta: model ftype = IQ3_XS - 3.3 bpw llm_load_print_meta: model params = 316.49 B llm_load_print_meta: model size = 120.73 GiB (3.28 BPW) llm_load_print_meta: general.name = Grok llm_load_print_meta: BOS token = 1 '[BOS]' llm_load_print_meta: EOS token = 2 '[EOS]' llm_load_print_meta: UNK token = 0 '[PAD]' llm_load_print_meta: PAD token = 0 '[PAD]' llm_load_print_meta: LF token = 79 '<0x0A>' llm_load_tensors: ggml ctx size = 0.81 MiB llm_load_tensors: CPU buffer size = 16716.66 MiB llm_load_tensors: CPU buffer size = 14592.75 MiB llm_load_tensors: CPU buffer size = 14484.75 MiB llm_load_tensors: CPU buffer size = 14901.35 MiB llm_load_tensors: CPU buffer size = 14714.18 MiB llm_load_tensors: CPU buffer size = 14493.75 MiB llm_load_tensors: CPU buffer size = 14484.75 MiB llm_load_tensors: CPU buffer size = 15250.88 MiB llm_load_tensors: CPU buffer size = 3990.96 MiB .................................................................................................... llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: n_batch = 512 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: freq_base = 10000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CPU KV buffer size = 128.00 MiB llama_new_context_with_model: KV self size = 128.00 MiB, K (f16): 64.00 MiB, V (f16): 64.00 MiB llama_new_context_with_model: CPU output buffer size = 256.00 MiB llama_new_context_with_model: CPU compute buffer size = 356.03 MiB llama_new_context_with_model: graph nodes = 3782 llama_new_context_with_model: graph splits = 1

system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 |
sampling:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 512, n_batch = 2048, n_predict = 128, n_keep = 0

User: In the context of LLMs, what is a sparse tensor?\nAssistant: It's a type of tensor that stores data in a non-contiguous way, allowing for efficient storage and computation of large data sets.\n\nUser: In the context of LLMs, what is a dense tensor?\nAssistant: It's a type of tensor that stores data in a contiguous way, allowing for efficient computation of small data sets.\n\nUser: In the context of LLMs, what is a data pipeline?\nAssistant: It's a series of steps that are used to process and analyze large amounts of data, including data cleaning, feature extraction, and model training.\n\nUser
llama_print_timings: load time = 13056.75 ms
llama_print_timings: sample time = 8.90 ms / 128 runs ( 0.07 ms per token, 14385.26 tokens per second)
llama_print_timings: prompt eval time = 18305.56 ms / 18 tokens ( 1016.98 ms per token, 0.98 tokens per second)
llama_print_timings: eval time = 240625.12 ms / 127 runs ( 1894.69 ms per token, 0.53 tokens per second)

</details>

@liuq4360
Copy link

it's a great job, tomorrow i will try q2 quantity of grok-1 on my m3 max.

@arki05
Copy link
Contributor Author

arki05 commented Mar 23, 2024

FYI: for anyone testing using the quants in Arki05/Grok-1-GGUF:
I've made those right before this naming issue, but with the new split/shard GGUFs.
There's already a fix in #6192.

Edit: Merged now, no need for additional branches.

@ggerganov ggerganov merged commit 476b025 into ggerganov:master Mar 23, 2024
37 checks passed
@ggerganov
Copy link
Owner

@arki05 Thank you once again - great work!

@phymbert
Copy link
Collaborator

For convenience, I've created a branch with just this fix and a rebase (in case someone wants to just test grok). I don't think we need to change anything here, a few PRs down the line this will all be resolved.

server \
    --hf-repo Arki05/Grok-1-GGUF \
    --hf-file grok-1-IQ3_XS-split-00001-of-00009.gguf \
    --model models/grok-1-IQ3_XS-split-00001-of-00009.gguf \
    -ngl 999

@ggerganov
Copy link
Owner

ggerganov commented Mar 23, 2024

Here is the full ppl for IQ3_S:

./perplexity -m ./models/grok-1/ggml-model-iq3_s.gguf -f build/wikitext-2-raw/wiki.test.raw -ngl 99
system_info: n_threads = 16 / 24 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 614.268 ms
perplexity: calculating perplexity over 569 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 44.04 seconds per pass - ETA 1 hours 44.40 minutes
[1]2.8530,[2]3.8739,[3]2.9291,[4]3.0394,[5]3.1812,[6]3.1148,[7]3.2956,[8]3.1129,[9]3.0586,[10]3.0574,[11]2.9916,[12]3.1702,[13]3.4848,[14]3.4859,[15]3.6010,[16]3.7078,[17]3.6744,[18]3.8611,[19]3.8382,[20]3.7163,[21]3.6687,[22]3.6875,[23]3.5952,[24]3.5157,[25]3.4820,[26]3.3930,[27]3.3222,[28]3.2753,[29]3.2679,[30]3.3174,[31]3.3493,[32]3.3926,[33]3.4342,[34]3.4781,[35]3.5157,[36]3.5952,[37]3.6541,[38]3.6646,[39]3.7330,[40]3.7672,[41]3.7673,[42]3.8155,[43]3.8271,[44]3.8282,[45]3.8378,[46]3.8890,[47]3.9467,[48]3.9786,[49]3.8862,[50]3.8125,[51]3.7772,[52]3.7450,[53]3.7667,[54]3.7822,[55]3.8030,[56]3.7956,[57]3.8184,[58]3.8385,[59]3.8529,[60]3.8764,[61]3.9103,[62]3.9553,[63]3.9758,[64]4.0049,[65]4.0264,[66]4.0698,[67]4.0615,[68]4.0384,[69]4.0194,[70]4.0101,[71]3.9862,[72]3.9756,[73]3.9367,[74]3.9216,[75]3.8863,[76]3.8777,[77]3.8417,[78]3.7991,[79]3.7728,[80]3.7565,[81]3.7591,[82]3.7877,[83]3.7835,[84]3.7902,[85]3.8036,[86]3.7710,[87]3.7849,[88]3.7934,[89]3.8370,[90]3.8369,[91]3.8419,[92]3.8422,[93]3.8496,[94]3.8588,[95]3.8543,[96]3.8606,[97]3.8741,[98]3.9146,[99]3.9524,[100]3.9736,[101]3.9881,[102]3.9938,[103]3.9811,[104]3.9997,[105]4.0361,[106]4.0826,[107]4.0751,[108]4.1130,[109]4.1516,[110]4.1673,[111]4.2100,[112]4.2566,[113]4.2786,[114]4.2655,[115]4.2725,[116]4.2821,[117]4.2789,[118]4.3008,[119]4.3094,[120]4.3085,[121]4.3024,[122]4.3038,[123]4.3078,[124]4.2834,[125]4.2821,[126]4.2775,[127]4.2701,[128]4.2674,[129]4.2695,[130]4.2771,[131]4.2808,[132]4.3010,[133]4.3057,[134]4.3000,[135]4.3064,[136]4.3145,[137]4.3285,[138]4.3469,[139]4.3578,[140]4.3627,[141]4.3772,[142]4.3630,[143]4.3500,[144]4.3317,[145]4.3144,[146]4.2956,[147]4.2841,[148]4.2550,[149]4.2249,[150]4.1971,[151]4.1857,[152]4.1662,[153]4.1470,[154]4.1258,[155]4.1076,[156]4.0861,[157]4.0761,[158]4.0719,[159]4.0535,[160]4.0472,[161]4.0332,[162]4.0162,[163]4.0098,[164]4.0145,[165]4.0260,[166]4.0261,[167]4.0461,[168]4.0665,[169]4.0893,[170]4.1025,[171]4.1301,[172]4.1608,[173]4.1889,[174]4.2129,[175]4.1896,[176]4.1637,[177]4.1583,[178]4.1514,[179]4.1418,[180]4.1349,[181]4.1234,[182]4.1083,[183]4.1149,[184]4.1301,[185]4.1477,[186]4.1690,[187]4.1825,[188]4.1911,[189]4.2082,[190]4.2261,[191]4.2389,[192]4.2463,[193]4.2429,[194]4.2462,[195]4.2478,[196]4.2497,[197]4.2599,[198]4.2687,[199]4.2866,[200]4.2966,[201]4.2959,[202]4.3017,[203]4.2952,[204]4.3143,[205]4.3025,[206]4.3096,[207]4.3147,[208]4.3221,[209]4.3224,[210]4.3293,[211]4.3244,[212]4.3224,[213]4.3180,[214]4.3146,[215]4.3110,[216]4.3102,[217]4.3125,[218]4.3131,[219]4.3117,[220]4.3112,[221]4.2977,[222]4.2914,[223]4.2892,[224]4.2848,[225]4.2832,[226]4.2801,[227]4.2791,[228]4.2848,[229]4.2861,[230]4.2620,[231]4.2686,[232]4.2723,[233]4.2736,[234]4.2807,[235]4.2877,[236]4.2926,[237]4.2926,[238]4.3087,[239]4.3140,[240]4.3263,[241]4.3457,[242]4.3628,[243]4.3767,[244]4.3900,[245]4.4062,[246]4.4218,[247]4.4372,[248]4.4533,[249]4.4770,[250]4.4930,[251]4.4937,[252]4.4832,[253]4.4602,[254]4.4379,[255]4.4192,[256]4.4063,[257]4.3981,[258]4.3843,[259]4.3677,[260]4.3500,[261]4.3339,[262]4.3172,[263]4.2974,[264]4.2822,[265]4.2729,[266]4.2624,[267]4.2541,[268]4.2378,[269]4.2201,[270]4.2025,[271]4.1917,[272]4.1839,[273]4.1741,[274]4.1644,[275]4.1493,[276]4.1354,[277]4.1384,[278]4.1325,[279]4.1287,[280]4.1233,[281]4.1231,[282]4.1289,[283]4.1374,[284]4.1398,[285]4.1446,[286]4.1491,[287]4.1586,[288]4.1575,[289]4.1659,[290]4.1598,[291]4.1613,[292]4.1613,[293]4.1630,[294]4.1573,[295]4.1562,[296]4.1637,[297]4.1689,[298]4.1704,[299]4.1750,[300]4.1783,[301]4.1798,[302]4.1862,[303]4.1939,[304]4.1972,[305]4.2039,[306]4.2058,[307]4.2077,[308]4.2072,[309]4.2163,[310]4.2158,[311]4.2268,[312]4.2168,[313]4.2296,[314]4.2384,[315]4.2529,[316]4.2692,[317]4.2736,[318]4.2709,[319]4.2766,[320]4.2748,[321]4.2759,[322]4.2763,[323]4.2822,[324]4.2830,[325]4.2851,[326]4.2817,[327]4.2797,[328]4.2834,[329]4.2939,[330]4.2901,[331]4.2889,[332]4.2810,[333]4.2761,[334]4.2704,[335]4.2700,[336]4.2671,[337]4.2603,[338]4.2542,[339]4.2478,[340]4.2399,[341]4.2353,[342]4.2324,[343]4.2326,[344]4.2296,[345]4.2294,[346]4.2328,[347]4.2391,[348]4.2366,[349]4.2340,[350]4.2382,[351]4.2460,[352]4.2501,[353]4.2347,[354]4.2284,[355]4.2242,[356]4.2122,[357]4.1993,[358]4.1859,[359]4.1739,[360]4.1620,[361]4.1509,[362]4.1381,[363]4.1238,[364]4.1123,[365]4.1006,[366]4.0879,[367]4.0755,[368]4.0649,[369]4.0552,[370]4.0437,[371]4.0357,[372]4.0236,[373]4.0137,[374]4.0019,[375]3.9933,[376]3.9844,[377]3.9741,[378]3.9659,[379]3.9562,[380]3.9494,[381]3.9419,[382]3.9316,[383]3.9224,[384]3.9192,[385]3.9212,[386]3.9245,[387]3.9287,[388]3.9338,[389]3.9364,[390]3.9410,[391]3.9474,[392]3.9482,[393]3.9365,[394]3.9324,[395]3.9247,[396]3.9219,[397]3.9194,[398]3.9147,[399]3.9145,[400]3.9110,[401]3.9060,[402]3.8950,[403]3.8866,[404]3.8843,[405]3.8729,[406]3.8635,[407]3.8551,[408]3.8446,[409]3.8352,[410]3.8251,[411]3.8172,[412]3.8105,[413]3.8034,[414]3.7980,[415]3.7992,[416]3.7932,[417]3.7869,[418]3.7824,[419]3.7755,[420]3.7669,[421]3.7705,[422]3.7610,[423]3.7578,[424]3.7563,[425]3.7498,[426]3.7425,[427]3.7377,[428]3.7300,[429]3.7247,[430]3.7183,[431]3.7127,[432]3.7025,[433]3.7034,[434]3.7000,[435]3.6903,[436]3.6839,[437]3.6784,[438]3.6687,[439]3.6672,[440]3.6690,[441]3.6716,[442]3.6741,[443]3.6779,[444]3.6818,[445]3.6866,[446]3.6954,[447]3.6937,[448]3.6900,[449]3.6816,[450]3.6732,[451]3.6652,[452]3.6573,[453]3.6486,[454]3.6412,[455]3.6378,[456]3.6301,[457]3.6217,[458]3.6153,[459]3.6137,[460]3.6165,[461]3.6175,[462]3.6235,[463]3.6293,[464]3.6315,[465]3.6328,[466]3.6286,[467]3.6282,[468]3.6400,[469]3.6438,[470]3.6432,[471]3.6447,[472]3.6483,[473]3.6512,[474]3.6512,[475]3.6510,[476]3.6538,[477]3.6580,[478]3.6608,[479]3.6614,[480]3.6634,[481]3.6645,[482]3.6636,[483]3.6636,[484]3.6641,[485]3.6634,[486]3.6659,[487]3.6637,[488]3.6665,[489]3.6643,[490]3.6729,[491]3.6764,[492]3.6833,[493]3.6820,[494]3.6836,[495]3.6875,[496]3.6906,[497]3.6925,[498]3.6966,[499]3.6950,[500]3.6951,[501]3.6970,[502]3.6980,[503]3.6998,[504]3.7003,[505]3.6999,[506]3.7031,[507]3.7067,[508]3.7097,[509]3.7101,[510]3.7121,[511]3.7155,[512]3.7213,[513]3.7229,[514]3.7252,[515]3.7219,[516]3.7160,[517]3.7139,[518]3.7148,[519]3.7125,[520]3.7117,[521]3.7117,[522]3.7117,[523]3.7087,[524]3.7088,[525]3.7045,[526]3.7082,[527]3.7122,[528]3.7102,[529]3.7111,[530]3.7168,[531]3.7133,[532]3.7108,[533]3.7110,[534]3.7081,[535]3.7090,[536]3.7072,[537]3.7058,[538]3.7032,[539]3.6976,[540]3.6962,[541]3.6996,[542]3.7033,[543]3.7062,[544]3.7120,[545]3.7169,[546]3.7211,[547]3.7253,[548]3.7308,[549]3.7355,[550]3.7377,[551]3.7381,[552]3.7427,[553]3.7381,[554]3.7399,[555]3.7367,[556]3.7356,[557]3.7375,[558]3.7387,[559]3.7390,[560]3.7407,[561]3.7453,[562]3.7392,[563]3.7357,[564]3.7358,[565]3.7299,[566]3.7241,[567]3.7206,[568]3.7154,[569]3.7100,
Final estimate: PPL = 3.7100 +/- 0.02069

llama_print_timings:        load time =    7435.73 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time = 6284344.27 ms / 291328 tokens (   21.57 ms per token,    46.36 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time = 6288945.69 ms / 291329 tokens

Hellaswag@400 is 80.25

 ./perplexity --hellaswag -f build/hellaswag_val_full.txt -m models/grok-1/ggml-model-iq3_s.gguf --hellaswag-tasks 400

We will later compare this with the properly converted models that do not go through the F16 dequantization

@ggerganov
Copy link
Owner

Any ideas how to compute an imatrix for this model?

If somebody has a 700GB RAM machine, the following should do the job:

./imatrix -m ./models/grok-1/ggml-model-fp16.gguf -f ./wikitext-2-raw/wiki.train.raw -ngl 0 -b 512

Though it might take a while :)

@RichardErkhov
Copy link

RichardErkhov commented Mar 24, 2024

@ggerganov if you give me guidelines I can do anything for you (1.2TB ram, 4x E7-8890 v4, ubuntu 2204)
just give me step by step

@ggerganov
Copy link
Owner

@RichardErkhov Here are the commands:

git clone https://huggingface.co/keyfan/grok-1-hf
git clone https://github.com/ggerganov/llama.cpp

cd llama.cpp
pip install -r requirements.txt

python convert-hf-to-gguf.py ../grok-1-hf/ --outfile models/grok-1-f16.gguf --outtype f16

./scripts/get-wikitext-2.sh
unzip wikitext-2-raw-v1.zip

make -j imatrix
./imatrix -m ./models/grok-1-fp16.gguf -f ./wikitext-2-raw/wiki.train.raw -b 512

The last command will run for a few days. When done, upload the generated .imatrix file

@RichardErkhov
Copy link

ok, see you in few days haha. I first need to download the hf model, my internet is just 80 mbps. Hopefully no electricity shutdown

@miron
Copy link

miron commented Mar 24, 2024

ok, see you in few days haha. I first need to download the hf model, my internet is just 80 mbps. Hopefully no electricity shutdown

I gladly trade my 500 mbps line for your 1.2TB RAM :)

@ggerganov
Copy link
Owner

ggerganov commented Mar 25, 2024

In the meantime, I generated an imatrix using the IQ3_S model:

https://huggingface.co/ggml-org/imatrix/blob/main/grok-1-iq3_s.imatrix

It seems to help - here is a summary of zero-shot Hellaswag scores at 400 tasks for Grok-1 and Mixtral:

Grok-1  IQ3_S         no imatrix: 80.25
Grok-1  IQ3_S with IQ3_S imatrix: 83.00

Mixtral IQ3_S         no imatrix: 77.00
Mixtral IQ3_S with IQ3_S imatrix: 80.75
Mixtral F16:                      81.00

PPL:

Grok-1  IQ3_S         no imatrix: 3.7100
Grok-1  IQ3_S with IQ3_S imatrix: 3.4487

Mixtral IQ3_S with F16   imatrix: 4.3577
Mixtral IQ3_S with IQ3_S imatrix: 4.3498
Mixtral F16:                      4.1009

@RichardErkhov
Copy link

@ggerganov It's running walking. ETA just 400 hours. Anything else you want to run later?

@ggerganov
Copy link
Owner

No worries - seems it would need a lot of time, so feel free to stop it. Moreover the IQ3_S imatrix seems to be good enough, so probably not worth computing an F16 one

@RichardErkhov
Copy link

No worries - seems it would need a lot of time, so feel free to stop it. Moreover the IQ3_S imatrix seems to be good enough, so probably not worth computing an F16 one

@ggerganov I can keep it running and just publish it when it finishes. If you need anything else just text me, Im always open for help =)

@RichardErkhov
Copy link

RichardErkhov commented Mar 27, 2024

update from "that crazy guy with 1.2TB of ram that will run some random stuff for fun"
image
Apparently if you run the grok-1 imatrix calculation, 2 days later it will reduce the CPU usage for some reason? It's not throttling 100%, it has quite a big boy cooler. What can cause that ? @ggerganov

@RichardErkhov
Copy link

ah, it decided to disappear, how cool xD nevermind, I guess it's ubuntu is having some fun with long-run task

@foldl
Copy link
Contributor

foldl commented Mar 28, 2024

Hi there. I made an implementation in foldl/chatllm.cpp@912bacc .

I don't have enough compute resource, so, maybe we can only export a subset of experts. Test with the first 4 experts shows some meaningful but not expressive results, while with the first 2 experts, it is worse.

Could someone like to have a test with all 8 experts? Doc.

@phymbert
Copy link
Collaborator

Could someone like to have a test with all 8 experts? Doc.

But the model is not instructed ? How can you chat ?

@foldl
Copy link
Contributor

foldl commented Mar 28, 2024

@phymbert ChatLLM.cpp can work in completion mode.

@RichardErkhov
Copy link

Hi there. I made an implementation in foldl/chatllm.cpp@912bacc .

I don't have enough compute resource, so, maybe we can only export a subset of experts. Test with the first 4 experts shows some meaningful but not expressive results, while with the first 2 experts, it is worse.

Could someone like to have a test with all 8 experts? Doc.

@foldl give me step by step what to execute and ask and I can run it for you

@foldl
Copy link
Contributor

foldl commented Mar 28, 2024

@RichardErkhov Thank you!

Here is the step by step:

https://github.com/foldl/chatllm.cpp/blob/master/docs/grok.md

-i /path/to/model/ckpt-0 gives the directory containing those tensor_...._.... files, and --vocab_dir /path/to/repository specifies the directory containing tokenizer.model.

@RichardErkhov
Copy link

@RichardErkhov Thank you!

Here is the step by step:

https://github.com/foldl/chatllm.cpp/blob/master/docs/grok.md

-i /path/to/model/ckpt-0 gives the directory containing those tensor_...._.... files, and --vocab_dir /path/to/repository specifies the directory containing tokenizer.model.

ok, if everything works you will get results tomorrow, as I need to download the repo. Give me what to ask the model

@foldl
Copy link
Contributor

foldl commented Mar 28, 2024

I have no idea. Maybe ask for the answer for everything?

@RichardErkhov
Copy link

@foldl idk, we will see haha. It's going 6.5mb/s, 300gb download, which is 13 hours haha, so I guess tomorrow morning I will convert and run it

@RichardErkhov
Copy link

image
@foldl it's quite slow haha, around 0.3 tokens per second, which is expectable lol. But generation is messed up for sure. Anything else you want me to run? Just text me
image

@foldl
Copy link
Contributor

foldl commented Mar 30, 2024

@RichardErkhov THANK YOU! I am satisfied with the output, and no more experiments are needed.
I own owe you so much.

Later, I will a full layer and compare the output against JAX implementation, as a double check.

The ggml community is awesome.

@RichardErkhov
Copy link

I own you so much.

Lol, this is a very funny typo xD @foldl

@foldl
Copy link
Contributor

foldl commented Mar 30, 2024

@RichardErkhov #@@#@!@ a funny typo. Why the AI-powered Edge had not corrected it for me? Lol.

@RichardErkhov
Copy link

@foldl want anything else to run? I can help with projects. You can contact me in discord if you want. ganza2309

@foldl
Copy link
Contributor

foldl commented Mar 30, 2024

@RichardErkhov No, thanks. It's time to free up your disk space, :).

@RichardErkhov
Copy link

@RichardErkhov No, thanks. It's time to free up your disk space, :).

Yeah, electricity went down and it cleaned itself lol

hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* Add support for Grok model architecture

* Revert convert-hf-to-gguf to default options

* Fixed f_norm_rms_eps bug

* Fix whitespaces

* llama : fix grok rope type

* llama : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
* Add support for Grok model architecture

* Revert convert-hf-to-gguf to default options

* Fixed f_norm_rms_eps bug

* Fix whitespaces

* llama : fix grok rope type

* llama : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@phymbert phymbert mentioned this pull request Apr 6, 2024
13 tasks
tybalex pushed a commit to rubra-ai/tools.cpp that referenced this pull request Apr 17, 2024
* Add support for Grok model architecture

* Revert convert-hf-to-gguf to default options

* Fixed f_norm_rms_eps bug

* Fix whitespaces

* llama : fix grok rope type

* llama : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants