-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce QLoRA-like technique #19356
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19356 +/- ##
==========================================
+ Coverage 75.86% 75.91% +0.04%
==========================================
Files 366 366
Lines 40479 40532 +53
Branches 7869 7884 +15
==========================================
+ Hits 30711 30768 +57
+ Misses 8068 8066 -2
+ Partials 1700 1698 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
We should wait for #19302 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The custom gradient PR has been merged. Please fix merge conflicts.
Is there a way we could try this out with an LLM? E.g. adapt this guide -> https://ai.google.dev/gemma/docs/lora_tuning IIUC, this should probably allow us to fine tune a gemma 7b model checkpoint on less than 16gb GPU RAM, because lora will essentially zero out the size of optimizer variables relative to model weights, and quantizing our weights to int8 should bring us to a little over 8gb of space. An end to end test with a massive model might validate a lot. |
How do we want to handle embeddings and quantization? Embeddings are usually the biggest individual memory hogs for models that might want quantization. We might want to add some quant support to our layer (though does not need to be this PR!). https://github.com/google/gemma_pytorch/blob/cf8658c186255379194ba5b62612321eacde1b6b/gemma/model.py#L132-L154 |
Since |
Yeah I think it should be pretty simple. Then there's the question of lora + quantization + an embedding layer. I don't think practically doing lora + quantization will be that important, as most people probably use lora with any embeddings frozen, but it might be worth adding for consistency (since we have enable_lora on the layer). |
ops.cast(kernel, dtype=self.compute_dtype), | ||
kernel_scale, | ||
) | ||
# From https://stackoverflow.com/a/47609896 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow. Brilliant! Never thought of it this way!! A very beautiful way to exploit Einsteinian Tensor Summation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
I have encountered an issue to try this PR with KerasNLP: This issue will cause the following to fail: preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en", sequence_length=128
)
lora_model = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en", preprocessor=preprocessor
)
lora_model.quantize("int8")
lora_model.save("model_int8.keras")
reloaded_model = keras.saving.load_model("model_int8.keras") # <- this line The above is neccesary to accurately record the peak GPU memory usage because tensorflow doesn't release GPU memory after using
I can add |
I just realized that I need to manually call |
Thanks for the update and for the GPT-2 numbers!
Is this something we should do in the framework code, then? Have you been able to try Gemma 2B? |
This should be a harmless addition. I have updated the code and verified its effectiveness.
Unfortunately, I failed to fit Gemma 2B with my rig. (12GB 4070...) BTW, the training speed is improved after #19368
Finetuning a quantized model might be faster in float32 and is competitive in bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- thank you for the great contribution. Let's merge!
Highlights
This PR enables training with frozen int8 weights for
Dense
andEinsumDense
.Overall, we will have a similar training speed and lower memory usage (about 68~50%) compared to floating-point LoRA.
Notes
Similar to QLoRA, but this PR lacks the following:
The training speed with torch backend is slower due to the lack of hardware-accelerated matmul/einsum.
Results
compute_dtype
Dense
Dense
Dense
Dense
Einsum
Einsum
Einsum
Einsum
Standalone benchmark script:
benchmark.py
GPT2 & Gemma Finetuning
Try this PR with GPT2/Gemma and LoRA
compute_dtype
qlora.py