-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partial GPU offload broken for certain number of offloaded layers #5137
Comments
The problem appears to be CUDA-specific. Repeating the same experiment on a Mac using Metal, I get a very similar PPL for full offload and for 30 layers offloaded to the GPU:
|
The problem is related to the ggml-backend integration. If I check out 584d674, the last commit before PR #4766 was merged, I get a meaningful result with 30 layers offloaded to the GPU.
main: build = 1842 (584d674b)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed = 1706261433
ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4080, compute capability 8.9, VMM: yes
llama_model_loader: loaded meta data with 25 key-value pairs and 995 tensors from ../cuda/junk.bin (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = llama
llama_model_loader: - kv 1: general.name str = hf
llama_model_loader: - kv 2: llama.context_length u32 = 32768
llama_model_loader: - kv 3: llama.embedding_length u32 = 4096
llama_model_loader: - kv 4: llama.block_count u32 = 32
llama_model_loader: - kv 5: llama.feed_forward_length u32 = 14336
llama_model_loader: - kv 6: llama.rope.dimension_count u32 = 128
llama_model_loader: - kv 7: llama.attention.head_count u32 = 32
llama_model_loader: - kv 8: llama.attention.head_count_kv u32 = 8
llama_model_loader: - kv 9: llama.expert_count u32 = 8
llama_model_loader: - kv 10: llama.expert_used_count u32 = 2
llama_model_loader: - kv 11: llama.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 12: llama.rope.freq_base f32 = 1000000.000000
llama_model_loader: - kv 13: general.file_type u32 = 19
llama_model_loader: - kv 14: tokenizer.ggml.model str = llama
llama_model_loader: - kv 15: tokenizer.ggml.tokens arr[str,32000] = ["", "
|
Have you seen this issue in any model other than mixtral? |
As a workaround, increasing the alignment to 4096 in |
No, I haven't seen this on another model. Yes, changing Example: Mixtral-8x7B,
Are there any downsides from having this set to 4096? I did see quite significant changes in PPL for Mixtral-8x7B after PR #4766, see my comments there. |
My reasoning when testing increasing the alignment was that if there is a buffer overflow somewhere, adding a gap between the tensors may mask the issue by preventing it from corrupting the data of other tensors. Increasing the alignment effectively does that. I still don't know what is the source the issue, but at least there are less possibilities now. The downside of increasing the alignment is a slight increase in memory usage, but ultimately this is not a solution, it is just hiding the real issue. |
Thanks. I didn't see the reported VRAM increase. Yes, I agree with you that there is an issue somewhere that leads to overriding buffers. It all started with me being curious what happens if the number of experts in a MoE model is changed from the default. I did a
|
So this was caused by an underestimation of the allocation size of non-contiguous tensors. Normally, the only non-contiguous tensors are views, and these don't have to be allocated because they share the memory of their parent tensor. However, when copying data between backends by |
Steps to reproduce
IQ2_XXS
andIQ2_XS
perplexity
calculation with the model fully offloaded to the GPU. A few chunks is enough.-ngl 30
, and observe how PPL is 2-3 times higher than in step 2Here are some example runs
All layers on GPU
main: build = 1971 (1182cf4d) ... llm_load_print_meta: model ftype = IQ2_XSS - 2.0625 bpw llm_load_print_meta: model params = 46.70 B llm_load_print_meta: model size = 11.44 GiB (2.10 BPW) llm_load_print_meta: general.name = hf llm_load_print_meta: BOS token = 1 '
' llm_load_print_meta: EOS token = 2 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_tensors: ggml ctx size = 0.76 MiB llm_load_tensors: offloading 32 repeating layers to GPU llm_load_tensors: offloaded 32/33 layers to GPU llm_load_tensors: CPU buffer size = 11712.97 MiB llm_load_tensors: CUDA0 buffer size = 11586.00 MiB .................................................................................................... llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: freq_base = 1000000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA0 KV buffer size = 64.00 MiB llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB llama_new_context_with_model: CUDA_Host input buffer size = 9.01 MiB llama_new_context_with_model: CUDA0 compute buffer size = 109.03 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 70.50 MiB llama_new_context_with_model: graph splits (measure): 4system_info: n_threads = 32 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 567.194 ms
perplexity: calculating perplexity over 642 chunks, batch_size=512
perplexity: 1.32 seconds per pass - ETA 14.08 minutes
[1]4.0990,[2]4.9914,[3]5.6483,[4]6.3020,[5]6.2826,[6]6.2130,[7]6.4030,[8]6.4265,[9]6.5435,[10]6.8596,[11]7.0488,[12]7.0107,[13]7.0517,[14]7.0914
30 layers offloaded to GPU
main: build = 1971 (1182cf4d) ... llm_load_print_meta: model ftype = IQ2_XSS - 2.0625 bpw llm_load_print_meta: model params = 46.70 B llm_load_print_meta: model size = 11.44 GiB (2.10 BPW) llm_load_print_meta: general.name = hf llm_load_print_meta: BOS token = 1 '
' llm_load_print_meta: EOS token = 2 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_tensors: ggml ctx size = 0.76 MiB llm_load_tensors: offloading 30 repeating layers to GPU llm_load_tensors: offloaded 30/33 layers to GPU llm_load_tensors: CPU buffer size = 11712.97 MiB llm_load_tensors: CUDA0 buffer size = 10806.75 MiB .................................................................................................... llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: freq_base = 1000000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA_Host KV buffer size = 4.00 MiB llama_kv_cache_init: CUDA0 KV buffer size = 60.00 MiB llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB llama_new_context_with_model: CUDA_Host input buffer size = 9.01 MiB llama_new_context_with_model: CUDA0 compute buffer size = 108.03 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 108.03 MiB llama_new_context_with_model: graph splits (measure): 5system_info: n_threads = 32 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 564.023 ms
perplexity: calculating perplexity over 642 chunks, batch_size=512
perplexity: 1.44 seconds per pass - ETA 15.45 minutes
[1]9.7855,[2]10.1005,[3]12.9574,[4]13.0298,[5]12.7318,[6]11.8905,[7]11.7408,[8]11.9335,[9]11.8980,[10]12.4120,[11]12.8212,[12]13.9232,[13]13.9312,[14]14.1171
All on CPU
main: build = 1971 (1182cf4)... llm_load_print_meta: model type = 7B llm_load_print_meta: model ftype = IQ2_XSS - 2.0625 bpw llm_load_print_meta: model params = 46.70 B llm_load_print_meta: model size = 11.44 GiB (2.10 BPW) llm_load_print_meta: general.name = hf llm_load_print_meta: BOS token = 1 '
' llm_load_print_meta: EOS token = 2 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_tensors: ggml ctx size = 0.38 MiB llm_load_tensors: offloading 0 repeating layers to GPU llm_load_tensors: offloaded 0/33 layers to GPU llm_load_tensors: CPU buffer size = 11712.97 MiB .................................................................................................... llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: freq_base = 1000000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CPU KV buffer size = 64.00 MiB llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB llama_new_context_with_model: CPU input buffer size = 9.01 MiB llama_new_context_with_model: CPU compute buffer size = 114.53 MiB llama_new_context_with_model: graph splits (measure): 1system_info: n_threads = 32 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 566.337 ms
perplexity: calculating perplexity over 642 chunks, batch_size=512
perplexity: 72.40 seconds per pass - ETA 12 hours 54.68 minutes
[1]4.1341,[2]5.0092,[3]5.6687,[4]6.3300,[5]6.3044,[6]6.2292,[7]6.4185,[8]6.4343,[9]6.5516,[10]6.8710,[11]7.0630,[12]7.0260,[13]7.0671,
29 layers on GPU
main: build = 1971 (1182cf4d) ... llm_load_print_meta: model type = 7B llm_load_print_meta: model ftype = IQ2_XSS - 2.0625 bpw llm_load_print_meta: model params = 46.70 B llm_load_print_meta: model size = 11.44 GiB (2.10 BPW) llm_load_print_meta: general.name = hf llm_load_print_meta: BOS token = 1 '
' llm_load_print_meta: EOS token = 2 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_tensors: ggml ctx size = 0.76 MiB llm_load_tensors: offloading 29 repeating layers to GPU llm_load_tensors: offloaded 29/33 layers to GPU llm_load_tensors: CPU buffer size = 11712.97 MiB llm_load_tensors: CUDA0 buffer size = 10417.12 MiB .................................................................................................... llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: freq_base = 1000000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA_Host KV buffer size = 6.00 MiB llama_kv_cache_init: CUDA0 KV buffer size = 58.00 MiB llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB llama_new_context_with_model: CUDA_Host input buffer size = 9.01 MiB llama_new_context_with_model: CUDA0 compute buffer size = 108.03 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 108.03 MiB llama_new_context_with_model: graph splits (measure): 5system_info: n_threads = 32 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 566.749 ms
perplexity: calculating perplexity over 642 chunks, batch_size=512
perplexity: 1.52 seconds per pass - ETA 16.28 minutes
[1]4.0521,[2]4.9624,[3]5.5985,[4]6.2678,[5]6.2614,[6]6.2038,[7]6.4056,[8]6.4241,[9]6.5409,[10]6.8630,[11]7.0622,[12]7.0257,[13]7.0661,[14]7.1006,
31 layers on GPU
main: build = 1971 (1182cf4d) ... llm_load_print_meta: model type = 7B llm_load_print_meta: model ftype = IQ2_XSS - 2.0625 bpw llm_load_print_meta: model params = 46.70 B llm_load_print_meta: model size = 11.44 GiB (2.10 BPW) llm_load_print_meta: general.name = hf llm_load_print_meta: BOS token = 1 '
' llm_load_print_meta: EOS token = 2 '' llm_load_print_meta: UNK token = 0 '' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_tensors: ggml ctx size = 0.76 MiB llm_load_tensors: offloading 31 repeating layers to GPU llm_load_tensors: offloaded 31/33 layers to GPU llm_load_tensors: CPU buffer size = 11712.97 MiB llm_load_tensors: CUDA0 buffer size = 11196.38 MiB .................................................................................................... llama_new_context_with_model: n_ctx = 512 llama_new_context_with_model: freq_base = 1000000.0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: CUDA_Host KV buffer size = 2.00 MiB llama_kv_cache_init: CUDA0 KV buffer size = 62.00 MiB llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB llama_new_context_with_model: CUDA_Host input buffer size = 9.01 MiB llama_new_context_with_model: CUDA0 compute buffer size = 108.03 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 108.03 MiB llama_new_context_with_model: graph splits (measure): 5system_info: n_threads = 32 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 548.178 ms
perplexity: calculating perplexity over 642 chunks, batch_size=512
perplexity: 1.39 seconds per pass - ETA 14.82 minutes
[1]4.8836,[2]6.0415,[3]6.4471,[4]7.0981,[5]6.9666,[6]6.8581,[7]7.1009,[8]7.0858,[9]7.2431,[10]7.5545,[11]7.7723,[12]7.6741,[13]7.7159,[14]7.7463,
The text was updated successfully, but these errors were encountered: