-
Notifications
You must be signed in to change notification settings - Fork 10.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce bfloat16 support #6412
Conversation
436956a
to
e52d5e5
Compare
IEEE 754 half precision floats can store values in the range
I think this is not due to any change in the weights but rather due to a difference in rounding error in the accumulator. I expect this improvement to not be consistent across models/text corpuses and I also expect there to be no statistically significant improvement at all for a large enough sample size. |
There are some different between quant from BF16-FP32 to BF16-FP16. |
@JohannesGaessler Only 13% of bf16 numbers can be represented accurately by a bf16 -> fp16 conversion. https://justine.lol/tmp/bf16-to-fp16.txt Yes, the vast majority of weights cluster within that 13%. By my calculation, only 0.29101% of Mistral 7b's numbers are broken. I want those numbers. I also don't want to accept limits on what's possible based on what's normal. Someone might find those broken intervals useful. But if that doesn't persuade you, consider this. I recently bought a Threadripper and it offers hardware acceleration for bf16 but not fp16. So this change is not just good for accuracy, it can be good for performance too. |
Broken in what sense? Numbers being flushed to zero is not an issue because the difference between 0 and almost 0 is negligible for matrix multiplication.
The performance point is valid. In terms of numerical precision, this is the bottom line for me: I very much expect the difference between IEEE 754 half precision and bfloat to be completely negligible. I'm not telling you this out of malice but because I want contributors to spend their time in a way that is useful. If it turns out I'm wrong I will happily accept it. |
You might find the differences negligible, but it's important to me. I want llamafile to be able to deliver, to the best of its ability, whatever number of bits are claimed, even if those extra bits are only good for audiophiles. In my day-to-day work as a developer, I feel more comfortable being able to compare my tradeoffs with the master copies. Furthermore, I need this data type in order to be able to exploit the full capabilities of my hardware. Am I correct in understanding you won't merge this? That surprises me. This project recently accepted nine novel "IQ" quantization formats, which I know very little about. So I was under the impression there was a certain level of inclusiveness. Why would you not support the data type that companies like Mistral and Google widely use? |
The ultimate decision of what gets merged is not up to me. And I am not at all opposed to adding bfloat support. I only want to stress that I do not expect the gains from this feature to be in any way proportional to the amount of effort it will take. As such I personally will not invest time into bfloat support by e.g. modifying the CUDA code. If other devs want to do it that is their decision. |
This comment was marked as off-topic.
This comment was marked as off-topic.
I don't hold any demands on your time. In terms of resources, Mozilla is sponsoring me to help llama.cpp so you've got a lot more resources than before. At the moment, I only need this to work on CPU however I'll likely get personal enjoyment at some point in getting this to work on CUDA and Metal too. Particularly Metal, since I've been looking for a good reason to learn it for some time. |
I would imagine older cuda hardware wouldn't support it due to bf16 unsupport on Pascal. What's solution to that? |
Here's the decoding process for bfloat16: typedef struct {
uint16_t x;
} ggml_bf16_t;
/**
* Converts brain16 to float32.
*/
static inline float ggml_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.x << 16;
return u.f;
} So the only thing old CUDA needs to do, is left shift the bf16 number by 16 bits, and then it becomes a float. |
I think bf16 support is nice to have in GGUF, if only because it makes quantizing a lot of models much less I/O intensive. Consider changing |
Relevant for discussion: Mozilla-Ocho/llamafile@ef0307e It seems there seem to be at least some values above the maximum value representable by IEEE 754 half precision floats. @jart do you know in which specific matrices these weights show up? Depending on where they are relative to softmax this could be an issue. |
Is there anything special needed to see performance gains? I cloned/built/tested this PR branch and am seeing no change in performance on CPU (CUDA support flags disabled at compile time) |
For CPU, I think you need something that support bf16 acceleration like AVX512VNNI? |
system_info: n_threads = 55 / 128 | AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | Hardware-wise I think I have what's needed. |
https://justine.lol/matmul/ |
What should be expected in llama.cpp from this patch specifically? I'm seeing about 6% speed increase on prompt processing and inference and I've pulled and built the master, avx512vnni, sgemm and bf16 branches. Each of them perform almost identically on a Q8 70b. |
@Artefact2 I've updated |
Nice. I'll keep an eye out for them. Is there a relevant branch on your llama.cpp fork I can test prior to a PR, or do you still need to merge changes already in llamafile? |
@cpumaxx Could you download https://huggingface.co/jartine/Mistral-7B-Instruct-v0.2-llamafile/blob/main/mistral-7b-instruct-v0.2.BF16.gguf and then build the code in the branch I just created https://github.com/jart/llama.cpp/tree/unified which unifies #6412 and #6414? Thanks! |
Here's an example of what you should expect to see with that branch.
EPYC is for servers so I've heard they generally run at much lower clock rates than Threadripper Pro. So if you get a lower number than 530 tok/sec then try comparing it to llama.cpp at HEAD using the Mistral 7b f16 weights. |
My system is a dual 64 core 9334 running with a 3.9ghz boost clock
vs
This was with identical build flags and after dropping all caches for a level playing field. Anything else I should be trying in order to see the speedup? |
Could you pass the flag |
browser? webgpu? webassembly? mesh networking w/ rtcdatachannel? |
@jart, Further we had tried to run the prompt speedup code from https://github.com/jart/llama.cpp/tree/unified . With the current code in the fork, the code was going through second input(operand) as GGML_TYPE_F32 for mulmat functions. We tried to modify the code such that the second input is in GGML_TYPE_BF16 for mulmat kernels and removes the GGML_TYPE_F32 case, which enables the input of second operand (Btype) to get quantized to BF16 format and hence uses BF16 intrinsics in turn for dot product operation. Significant speedup was observed while comparing the code with original version in the fork where the second operand of mulmat operation is in FP32 format.
The code was tested in AMD Raphael 7600X machine which has AVX512_BF16 support in Linux platform. The original unquantized model is taken from https://huggingface.co/TheBloke/wizardLM-7B-HF . Please find the updated code in PR 2 of your fork of llama.cpp - jart#2. Changes in jart#1 (PR 1) was included while testing the same Could you please share your thoughts here? Is prompt speedup for BF16 models planned to be included in future commits of prompt speedup changes/ BF16 model PR? Thanks |
@Srihari-mcw this change doesn't modify So BF16 optimizations are blocked on review. As for your pull request, the canonical location of the code you're modifying is here:
I've done a lot of work in the past month identifying other performance opportunities. |
Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───┐ 0b0000000000000000 brain16 This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───────────────────┐ 0b00000000000000000000000000000000 IEEE binary32 The issue is that converting bf16 to fp16 can result in information loss. Only 13% of bf16 numbers can be precisely represented in fp16 which in practice ends up being 99.71% of Mistral 7b v0.2's weights however there is currently no way other than fp32 to get the others ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌─┴─┐┌─┴──────┐ 0b0000000000000000 IEEE binary16 This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16
So happy to see this land! Will convert.py and convert-hf-to-gguf.py need to be updated? |
I'm wondering the same thing |
The Python scripts do need to be updated. I was only able to add the IDs. I wasn't able to successfully figure out how to get the raw bfloat16 data from Torch because Numpy doesn't support it. Someone who knows more than me will need to figure that out. So happy to see this merged @ggerganov! Thank you! |
By the way, the workaround I'm currently using is to:
|
We'll need to use a custom wrapper to implement. I tried doing this last year with pure python and it was a no go. Probably |
Note my implementation of
Important EDIT: I've made a proper implementation in #7158 which does properly handle subnormals, and rounding. |
Bring `GGMLQuantizationType` up to date; adds `I8`, `I16`, `I32`, `I64`, `F64`, `IQ1_M` and `BF16`. Added in: * ggerganov/llama.cpp#6045 * ggerganov/llama.cpp#6062 * ggerganov/llama.cpp#6302 * ggerganov/llama.cpp#6412
Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format.
This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left.
The issue is that converting weights from bf16 to fp16 will cause 3 bits of knowledge to be lost. There is currently no way to evaluate models like Mistral at full fidelity, without f32, using llama.cpp.
This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512F, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16