-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused DoRA kernels #216
Merged
Merged
Fused DoRA kernels #216
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
facebook-github-bot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
May 5, 2024
msaroufim
approved these changes
May 7, 2024
HDCharles
added a commit
that referenced
this pull request
May 8, 2024
* Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * Fused DoRA kernels (#216) * add dora kernels * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Unified AffineQuantizedTensor subclass (#214) Summary: Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine and dequantize_affine ops) only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Mark Saroufim <marksaroufim@meta.com> * add expecttest to requirements.txt (#225) * add expecttest to requirements.txt * update * Install dev-requirements.txt in doc build (#224) Install dev-requirements.txt --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com> * Fix an error in subclass impl (#226) Summary: Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Some follow up fixes for quant primitives (#220) Summary: att Test Plan: python test/quantization/test_quant_primitives.py -k test_raises Reviewers: Subscribers: Tasks: Tags: * Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Mark Saroufim <marksaroufim@meta.com> Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
dbyoung18
pushed a commit
to dbyoung18/ao
that referenced
this pull request
Jul 31, 2024
* add dora kernels
dbyoung18
pushed a commit
to dbyoung18/ao
that referenced
this pull request
Jul 31, 2024
* Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * Fused DoRA kernels (pytorch#216) * add dora kernels * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Unified AffineQuantizedTensor subclass (pytorch#214) Summary: Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine and dequantize_affine ops) only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Mark Saroufim <marksaroufim@meta.com> * add expecttest to requirements.txt (pytorch#225) * add expecttest to requirements.txt * update * Install dev-requirements.txt in doc build (pytorch#224) Install dev-requirements.txt --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com> * Fix an error in subclass impl (pytorch#226) Summary: Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Some follow up fixes for quant primitives (pytorch#220) Summary: att Test Plan: python test/quantization/test_quant_primitives.py -k test_raises Reviewers: Subscribers: Tasks: Tags: * Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Mark Saroufim <marksaroufim@meta.com> Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fused DoRA Kernels
Fused DoRA layer implementation that reduces number of individual kernels from ~10 -> 5.
Contents
Background
DoRA (weight-decomposed low-rank adaptation) is a variant of LoRA that decomposes the LoRA update into magnitude and vector components.
The DoRA layer is roughly as follows:
where:
lora_A
andlora_B
arelinear
layers with weight shapesrank x in_features
andout_features x rank
.base_weight
is the weight of the frozenlinear
layer of shapeout_features x in_features
.magnitude_vector
is initialized as the columnwise2-norm
of the frozen weight (shapeout-features
).x
are the inputs of shapebatch_size x seqlen x in_features
Optimization
After initial profiling, and as outlined above, the
DoRA
update layer requires multiple kernels.In order of compute intensity:
x @ base_weight
lora_B(lora_A(x))
lora_B.weight @ lora_A.weight
2-norm
While
torch.compile
(andCUDA
graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel.Key Contributions
1 - Small K Fused Kernel
Note that the
lora_B.weight @ lora_A.weight
has a specific shape, whereK << {M, N}
. That is,lora_B.weight
isout_features x lora_rank
andlora_A.weight
islora_rank x in_features
.Since
lora_rank
is typically< 64
while{in,out}-features
are typically> 4096
(e.g.,Llama MLP / QKV projections
), thisGEMM
is inefficient, since eachCTA
loads a block, only to perform a fewMAC
iterations given smallK
.Moreover, note that the result of this
GEMM
is not needed -- we only need the2-norm
of this computation.Combining these two observations, we can write a fused kernel where:
CTA
computes an entire row of the output matrix, with the key assumption thatBLOCK_K = K
. That is, eachCTA
does a single MAC iteration to compute aBLOCK_M x BLOCK_N
output, then iterates across dimensionN
.axis=1
into the kernel. In this case, we can directly fold the2-norm
computation into theGEMM
.base_weight
elementwise addition andmagnitude_vector
multiplication into theGEMM
epilogue.Altogether, this allows us to fuse the following computation into a single kernel:
2 - Fused Epilogue GEMM
Additionally, instead of computing the base layer output before the
DoRA / LoRA
updates, we can compute the latter (loRA layer
andmagnitude_scale
) first, and fold these into the epilogue of the base layerGEMM
:Usage
The fused kernels can be used to implement
DoRA
/QDoRA
layers.A reference implementation is provided in
dora.dora_layer.DoRALinear
, which defines a baseQDoRA
linear layer (with a stubdequantize
method) along with correspondingBNBDoRALinear
andHQQDoRALinear
subclasses, which overridedequantize
with their respective methods.Example
See
test/test_dora_layer.py
andbenchmarks/dora_bench.py
for more detailed usage.Also, note that these are reference implementations and are not fully optimized. See Next Steps for follow-up plans.
Tests
See
test/dora/test*
, for correctness checks of the fused kernels and layers.Benchmarks
See
benchmarks/dora_bench.py
.Run with flag
--kernel
set to one of{dora-colnorm,dora-mm-epilogue}
, to benchmark the respective fused kernels against a referencetorch
/torch.compile
implementation, or--kernel=dora-full
to bench against the entireDoRA
computation.Additionally, passing either
--kernel={dora-bnb, dora-hqq}
will bench a referenceQDoRA
layer against their fused implementations.Profiling
The reference
DoRALinear
layer described above also has an instrumented forward pass with annotated regions for each of theDoRA
ops.An example script for running a profiled forward pass is provided in
dora/dora_profile.py
.To run with
torch.profiler
:which outputs chrome trace to default folder
dora_profiles
.To run with
nsys
:where
...
are other desirednsys
options.Note that
--capture_range=cudaProfilerApi
is required.Next Steps
torch.compile
, re-ordering computations, etc.torch.autograd.Function
FSDP-LoRA
)triton
autotunergalore
,hqq
, anddora
can now be refactored into single module. Separate PR?