Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better perplexity for 2- and 3-bit quantization for LLaMA-v2-70B #2807

Merged
merged 2 commits into from
Aug 26, 2023

Conversation

ikawrakow
Copy link
Contributor

In LLaMA-v2-70B eight heads share the same K and V attention tensors, and as a result they are 8X smaller than the attention Q tensor. The attention V tensor is quite important for generation quality, so it is often quantized with more bits when using k_quants. Given this, we can get a nice improvement in perplexity score (as a measure of generation quality) with negligible increase in quantized model size by quantizing the entire attention V tensor with 5 bits when the k_quants logic has decided to quantize it with 3 or 4 bits. The table shows the PPL change for a subset of the k_quants:

Quantization Model size (master) Model size PR PPL (Master) PPL (PR)
Q2_K 27.11 GiB 27.27 GiB 3.8164 3.7339
Q3_K_S 27.70 GiB 27.86 GiB 3.7800 3.7019
Q4_K_S 36.31 GiB 36.39 GiB 3.4923 3.4852

@IgnacioFDM
Copy link
Contributor

I'd assume the same should apply to 34B?

llama.cpp Outdated
@@ -4678,6 +4682,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
++n_feed_forward_w2;
}
}
if (n_attention_wv != n_feed_forward_w2 || (uint32_t)n_attention_wv != model.hparams.n_layer) {
fprintf(stderr, "============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use LLAMA_LOG_WARN with __func__ prefix as all other logs

@ikawrakow ikawrakow merged commit 7592375 into master Aug 26, 2023
@ikawrakow ikawrakow deleted the ik/refine_70B branch August 26, 2023 14:27
mattgauf added a commit to mattgauf/llama.cpp that referenced this pull request Aug 26, 2023
* master: (773 commits)
  server : add `/detokenize` endpoint (ggerganov#2802)
  convert.py : advanced option (ggerganov#2753)
  llama : use Unicode Escape Sequence to replace encoded characters (ggerganov#2814)
  flake.nix : add rocm support and cleanup (ggerganov#2808)
  llama : move #includes out of _GNU_SOURCE conditional (ggerganov#2817)
  main : fix bug (penalize_nl=false doesn't work) + suppress warning on mingw (ggerganov#1528)
  llama : use std::abs in llama_sample_tail_free (ggerganov#2800)
  k-quants : remove unnecessary tensor shape restrictions (ggerganov#2811)
  Better perplexity for 2- and 3-bit quantization for LLaMA-v2-70B (ggerganov#2807)
  Fix HellaSwag (ggerganov#2805)
  flake : build llama.cpp on Intel with nix (ggerganov#2795)
  Handle null rope scaling value (ggerganov#2793)
  Fix spm whitespaces (ggerganov#2806)
  examples : skip unnecessary external lib in server README.md how-to (ggerganov#2804)
  llama : fix struct decl (ggerganov#2790)
  Faster perplexity computation (ggerganov#2786)
  llama : add llama_beam_search() (ggerganov#2267)
  convert.py : Get rope scale from HuggingFace models (ggerganov#2772)
  llama-bench : add model sizes (ggerganov#2771)
  convert.py : export rope freq_base when converting CodeLlama from an HF model (ggerganov#2773)
  ...
akawrykow pushed a commit to akawrykow/llama.cpp that referenced this pull request Aug 29, 2023
…rganov#2807)

* Better perplexity for 2- and 3-bit quantization for the 70B model

* PR comment

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
@Nexesenex
Copy link
Contributor

How long is the context for the perplexity values in the table, @ikawrakow?

@ikawrakow
Copy link
Contributor Author

How long is the context for the perplexity values in the table, @ikawrakow?

512 tokens

@Nexesenex Nexesenex mentioned this pull request Jan 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

5 participants