-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update quantization to use tensor subclasses #1403
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1403
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d29833d with merge base 7c51100 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/utils/quantization.py
Outdated
@@ -22,7 +22,19 @@ | |||
|
|||
|
|||
if TORCH_VERSION_AFTER_2_3: | |||
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer | |||
from torchao.quantization.quant_api import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers Can you remind me what our principles are around version guards? now that 2.4 is launched do we just claim we work with stable and remove such guards? or what's the downside of doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we can remove, our assumption is that users are at least on the latest stable version of PyTorch (so at this moment 2.4)
33a3ac9
to
d020da3
Compare
torchtune/utils/quantization.py
Outdated
# importing TORCH_VERSION_AFTER_2_3 because `Int8DynActInt4WeightQuantizer` | ||
# is only available after 2.3 so we have to guard the pytorch versions to decide | ||
# the list of supported quantizers | ||
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @msaroufim @ebsmothers Just wanted to confirm this is OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to ask for the 100th time, but TORCH_VERSION_AFTER_2_4 returns True if and only if the PyTorch version is >= 2.4, right? If so we can remove both of these, since we assume everyone is on at least latest stable PyTorch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark can confirm but yes I think TORCH_VERSION_AFTER_2_4
means >= 2.4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it is true only in torchao nightlies. This was fixed. Before it was >2.4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the changes look reasonable to me. My main question is whether there is any logic that's BC breaking (i.e. can I run this as is on ao's 0.3.1, latest nightly, and anything in between?)
d020da3
to
5a500bc
Compare
Good question. I haven't tested an older version of ao, but this change is needed to run the full end-to-end flow after the QAT subclass refactor. If there are concerns about breaking BC, maybe we should wait until we upgrade to torchao 0.5.0? (expected maybe mid-September) |
5a500bc
to
02fce45
Compare
**Summary:** In torchao, we are migrating our quantization flows from module swap to tensor subclasses. The existing `Int8DynActInt4WeightQuantizer` will be deprecated in the near future in favor of quantizing using the `quantize_` API, so we should do the same in torchtune. This quantizer is currently only used by QAT, which also recently migrated to a tensor subclass implementation. This also changes the eval script slightly since models quantized through the torchao tensor subclasses are expected to be loaded with `assign=True`: https://github.com/pytorch/ao/blob/9a56e80cb6070599701b8f5f587bd8187c8dccb4/test/quantization/test_quant_api.py#L610. We should load the model similarly in torchtune. **Test Plan:** Quantized and evaluated the base Llama3-8B model on 1 A100 GPU: ``` CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/my_quantization.yaml \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.checkpoint_files=[consolidated.00.pth] \ checkpointer.model_type=LLAMA3 CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.checkpoint_files=[consolidated-8da4w.pt] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer ``` Reviewers: ebsmothers, kartikayk, RdoubleA Subscribers: ebsmothers, kartikayk, RdoubleA Subscribers:
02fce45
to
d29833d
Compare
Hi @ebsmothers, I think this is ready from my side. I saw that the torchao version is actually removed from pyproject.toml. Do we expect torchtune to only work with the latest version? Either way I tested the quantization and eval with both the module swap QAT (old torchao version) and the tensor subclass QAT (latest torchao version) and both work. Please take another look. Thanks. |
@andrewor14 thanks for the ping. Yeah currently we will install whatever the latest stable version of torchao is. This is what we do in our CI and what we recommend our users to do as well (ofc we also support nightlies). So yes, the latest version of torchao should be sufficient here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, thanks! Once CI is green I think we're good to merge
Just curious, does this work with FSDP2 recipes? Trying out INT8 mixed-precision with torchtune and I saw that in distributed recipes, load state dict uses torchtune/torchtune/training/_distributed.py Line 346 in ee343e6
This is because we call torchao's One fix for this is to add DTensor support for Update: I tried manually swapping tensor subclass after |
Summary: In torchao, we are migrating our quantization flows from module swap to tensor subclasses. The existing
Int8DynActInt4WeightQuantizer
will be deprecated in the near future in favor of quantizing using thequantize_
API, so we should do the same in torchtune. This quantizer is currently only used by QAT, which also recently migrated to a tensor subclass implementation.This also changes the eval script slightly since models quantized through the torchao tensor subclasses are expected to be loaded with
assign=True
(see this test). We should load the model similarly in torchtune.Test Plan:
Quantized and evaluated the base Llama3-8B model on 1 A100 GPU:
Reviewers: ebsmothers, kartikayk, RdoubleA
Subscribers: ebsmothers, kartikayk, RdoubleA