Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda : add rope f16, restore performance with parallel decoding #3272

Merged
merged 4 commits into from
Sep 20, 2023

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Sep 19, 2023

With this change, all the offloaded leafs are assumed to be inputs and copied to VRAM. This allows offloading KQ_mask and simplifies the logic for offloading KQ_pos. Replacing a conditional in the RoPE kernel with a template parameter as suggested by @JohannesGaessler also improves performance.

Performance compared to master:

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6

model size backend test master t/s PR t/s speedup
llama 7B mostly Q2_K 2.63 GiB CUDA pp 512 1805.63 ± 10.64 1828.37 ± 5.60 1.01259
llama 7B mostly Q3_K - Small 2.75 GiB CUDA pp 512 1907.34 ± 5.34 1932.17 ± 5.03 1.01301
llama 7B mostly Q3_K - Medium 3.07 GiB CUDA pp 512 1997.51 ± 2.89 2017.52 ± 3.51 1.01001
llama 7B mostly Q3_K - Large 3.35 GiB CUDA pp 512 1928.15 ± 4.06 1951.13 ± 4.13 1.01191
llama 7B mostly Q4_0 3.56 GiB CUDA pp 512 2404.71 ± 3.79 2439.23 ± 5.24 1.01435
llama 7B mostly Q4_K - Small 3.59 GiB CUDA pp 512 2202.34 ± 3.30 2226.35 ± 5.70 1.01090
llama 7B mostly Q4_K - Medium 3.80 GiB CUDA pp 512 2207.27 ± 3.36 2229.99 ± 11.03 1.01029
llama 7B mostly Q4_1 3.95 GiB CUDA pp 512 2102.01 ± 3.47 2123.36 ± 3.03 1.01015
llama 7B mostly Q5_0 4.33 GiB CUDA pp 512 2172.19 ± 4.07 2202.96 ± 5.92 1.01416
llama 7B mostly Q5_K - Small 4.33 GiB CUDA pp 512 2035.79 ± 4.42 2059.83 ± 5.22 1.01180
llama 7B mostly Q5_K - Medium 4.45 GiB CUDA pp 512 2057.20 ± 4.03 2065.73 ± 24.31 1.00414
llama 7B mostly Q5_1 4.72 GiB CUDA pp 512 1931.67 ± 2.74 1948.27 ± 4.01 1.00859
llama 7B mostly Q6_K 5.15 GiB CUDA pp 512 2107.53 ± 1.66 2133.80 ± 7.40 1.01246
llama 7B mostly Q8_0 6.67 GiB CUDA pp 512 2353.88 ± 3.43 2386.22 ± 9.66 1.01373
llama 7B mostly F16 12.55 GiB CUDA pp 512 1659.42 ± 3.10 1679.42 ± 2.03 1.01205
llama 7B mostly Q2_K 2.63 GiB CUDA tg 128 105.91 ± 0.33 105.08 ± 1.66 0.99216
llama 7B mostly Q3_K - Small 2.75 GiB CUDA tg 128 101.85 ± 0.14 100.75 ± 1.97 0.98919
llama 7B mostly Q3_K - Medium 3.07 GiB CUDA tg 128 108.07 ± 0.30 107.95 ± 0.90 0.99888
llama 7B mostly Q3_K - Large 3.35 GiB CUDA tg 128 105.34 ± 0.25 105.65 ± 0.36 1.00294
llama 7B mostly Q4_0 3.56 GiB CUDA tg 128 131.00 ± 0.30 130.90 ± 0.85 0.99923
llama 7B mostly Q4_K - Small 3.59 GiB CUDA tg 128 119.67 ± 0.23 120.00 ± 0.42 1.00275
llama 7B mostly Q4_K - Medium 3.80 GiB CUDA tg 128 114.40 ± 0.14 114.15 ± 0.54 0.99781
llama 7B mostly Q4_1 3.95 GiB CUDA tg 128 124.94 ± 0.04 124.71 ± 0.67 0.99815
llama 7B mostly Q5_0 4.33 GiB CUDA tg 128 112.05 ± 0.43 111.67 ± 0.34 0.99660
llama 7B mostly Q5_K - Small 4.33 GiB CUDA tg 128 110.85 ± 0.11 111.08 ± 0.69 1.00207
llama 7B mostly Q5_K - Medium 4.45 GiB CUDA tg 128 107.38 ± 0.06 106.77 ± 0.57 0.99431
llama 7B mostly Q5_1 4.72 GiB CUDA tg 128 108.43 ± 0.23 108.22 ± 0.78 0.99806
llama 7B mostly Q6_K 5.15 GiB CUDA tg 128 92.17 ± 0.07 91.27 ± 0.35 0.99023
llama 7B mostly Q8_0 6.67 GiB CUDA tg 128 88.62 ± 0.04 87.79 ± 0.48 0.99063
llama 7B mostly F16 12.55 GiB CUDA tg 128 56.14 ± 0.07 55.90 ± 0.16 0.99572

@@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl(
}

// make a view of the destination
struct ggml_tensor * result = ggml_view_tensor(ctx, b);
struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to avoid creating a new leaf with ggml_cpy(.., ggml_new_tensor(..)).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this change, the lora test is failing:

make -j && ./bin/perplexity --model ../tmp/mnt/models/open-llama/3B-v2/ggml-model-q8_0.gguf -f ../tmp/mnt/models/shakespeare/shakespeare.txt --lora ../tmp/mnt/models/open-llama/3B-v2/lora/ggml-adapter-model.bin  --lora-base ../tmp/mnt/models/open-llama/3B-v2/ggml-model-f16.gguf -c 128 -b 128 --chunks 2
Segmentation fault: 11

It crashes in ggml_build_forward_expand() for the Vcur cpy:

ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));

Copy link
Collaborator Author

@slaren slaren Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this change to ggml_cpy is not going to work. Maybe we could add an ggml_cont_reshape or just ggml_cont_4d that does the same thing?

Such as:

struct ggml_tensor * ggml_cont_4d(
        struct ggml_context * ctx,
        struct ggml_tensor  * a,
        int64_t               ne0,
        int64_t               ne1,
        int64_t               ne2,
        int64_t               ne3) {
    GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));

    bool is_node = false;

    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
    ggml_format_name(result, "%s (cont)", a->name);

    result->op   = GGML_OP_CONT;
    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
    result->src[0] = a;

    return result;
}

Then replace:

cur = ggml_cpy(ctx0,
        KQV_merged,
        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));

with

cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);

llama.cpp Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner

I'm testing this branch with the following command to fill-up the context and the results become incoherent after the first rope shift:

make -j && ./bin/main -m ../models/llama-7b/ggml-model-q8_0.gguf -p "I believe the meaning of life is" --ignore-eos -c 256 -n -1 -t 8 -ngl 35

It works OK with -ngl 34.
The following patch helps to see when the K-cache shifts occur, for easier debugging:

diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 1ed543c..2bf7628 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -502,6 +502,8 @@ int main(int argc, char ** argv) {
                 const int n_left    = n_past - params.n_keep - 1;
                 const int n_discard = n_left/2;
 
+                printf("\n\n XXXXXXXXXXXXXXXXXXXXXX\n\n");
+
                 LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                     n_past, n_left, n_ctx, params.n_keep, n_discard);
llm_load_tensors: ggml ctx size =    0,09 MB
llm_load_tensors: using CUDA for GPU acceleration
llm_load_tensors: mem required  =  132,91 MB (+  128,00 MB per state)
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloading v cache to GPU
llm_load_tensors: offloading k cache to GPU
llm_load_tensors: offloaded 35/35 layers to GPU
llm_load_tensors: VRAM used: 6824 MB
...................................................................................................
llama_new_context_with_model: kv self size  =  128,00 MB
llama_new_context_with_model: compute buffer total size =   36,72 MB
llama_new_context_with_model: VRAM scratch buffer: 35,25 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 256, n_batch = 512, n_predict = -1, n_keep = 0


 I believe the meaning of life is to love and serve others. The ultimate goal is becoming a servant-leader, and the best way to do that is through service and action.
As a young child in Vietnam, I had many friends who were not as fortunate as I was because of our war situation. In my teenage years, this realization made me feel guilty for being one of the lucky ones. My family fled to the United States when I was 14 years old and I’ve been living in the USA ever since. As a refugee myself, I understand the struggles of other refugees better than others because of my personal experience.
Teaching English is my passion, and it’s what keeps me inspired to serve. I started volunteering for a local nonprofit organization teaching English as a second language in 2015, and it has been the most rewarding and fulfilling part of my life since then. It helps me feel happy, proud, and productive!
My experience with refugees is not limited to students from Vietnam only; I have taught students from different countries including: Afghanistan, Burma (Myanmar), Bhutan, Laos,

 XXXXXXXXXXXXXXXXXXXXXX

 I love teaching English as a volunteer for nearly three years in the Portland Learning Center. As an ESL teacher at the same center since 2015; it has enriched my life and career-wise, even more than ever! Every time when I see them succeeding and improving their lives is very important to me teaching English as a volunteering for over five years ago – some places such as Sudanese, but also Syria, Iraq, Somalia, Congo, Afghanistan, Iran, Pakistan, Nepal, Vietnam, Eritia (Laos, Mynamia, Russia

 XXXXXXXXXXXXXXXXXXXXXX

, have learn about culture and have helped students. I’s now. The center in Seattle with a volunteer. My job is very same time as 16. When i, given me with so much I am able to my students are improving their lives. Thankfully!
and progressing, and also; because they improve my lives. Now.
I, I have as well. Thailand, Somalia, Kenya, Burma, Sudan, Sri Laba, and Sudh, Syria, Bhutan, Nepali, and Iraq and China, India, Tibet, Y

 XXXXXXXXXXXXXXXXXXXXXX

v I, their 1 have also been working as a year, and have a year-I’ve. I as.
Bes. I have also am I the, am happy for me in. Thaning.
Kathing and with a my English is much a. Also, and, I am volunteered at, soy, a have worked at afr.
Sudan, and Rwia, Eritisnaa and I have worked for. The work for 10, Bhutan.
Viety. e. and Vietnand I have and

 XXXXXXXXXXXXXXXXXXXXXX

 5, I the the last 6 to as an the same that they have work in, a to be a and as a
BH and the 3 and my, an as the
My havee work have tto be.
I am I, I
The are work. asI and work,
I have a I am with the ean and India, am
I am and in India. have I’ve worked an me and
have I
I am my have a, I have 1 I, am. am, I have.
I be the
am have I,

 XXXXXXXXXXXXXXXXXXXXXX

 in
inH, and work, asI work
The as. I have
be the last year,
I I, I h1, last
5,
I. I the work. The same. an, be the. I am a have the
My 6,be the,
to 7.

@slaren
Copy link
Collaborator Author

slaren commented Sep 20, 2023

Should be fixed now. Unfortunately, the fix also makes the copies of K_pos, K_mask and K_shift synchronous, because they may be overwritten by other operations once the computation start (probably the get_rows, which still runs on the CPU). It shouldn't affect performance too much, but it should be fixed in the future.

@ggerganov ggerganov merged commit e04dc51 into custom-attention-mask Sep 20, 2023
@slaren slaren deleted the cam-cuda-2 branch September 20, 2023 11:20
ggerganov added a commit that referenced this pull request Sep 28, 2023
…3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close #3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
yusiwen pushed a commit to yusiwen/llama.cpp that referenced this pull request Oct 7, 2023
…gerganov#3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (ggerganov#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (ggerganov#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (ggerganov#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close ggerganov#3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants