-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding quantization support in torchtune #632
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/632
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4abcdb9 with merge base 45031b3 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
f780b1b
to
47bd46d
Compare
434fa6d
to
b285a16
Compare
recipes/quantize.py
Outdated
# Ensure the cache is setup on the right device | ||
with self._device: | ||
model.setup_caches(max_batch_size=1, dtype=self._dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to set up caches for quantization?
torchtune/utils/quantization.py
Outdated
@@ -0,0 +1,47 @@ | |||
from typing import Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some tests for this file? I'm nervous about adding utilities without any associated tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add some tests to make sure the model is as expected after quantization. This will help catch any breakages in APIs/behaviors if torchao changes.
@rohan-varma any ideas around testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kartikayk I changed these to just use classes directly and they will be tested in torchao, is that OK?
recipes/configs/quant_generate.yaml
Outdated
@@ -0,0 +1,29 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we need a separate config here. We should make it clear in our documentation and tutorial on how to load a quantized model. Its just a checkpointer change. So Id remove this
recipes/eleuther_eval.py
Outdated
@@ -124,20 +124,25 @@ class EleutherEvalRecipe(EvalRecipeInterface): | |||
def __init__(self, cfg: DictConfig) -> None: | |||
self._cfg = cfg | |||
|
|||
def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: | |||
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings can I get a review for this file? Trying to think about the best way to integrate the quantization changes for eval
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: why would we want to quantize the model in eval?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly to load a quantized model for evaluation. The flow here is something like this:
- Finetune + Eval
- Quantize
- Eval with quantize to make sure quantized model is still performant
- Run inference with quantized model to make sure its not doing something crazy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any updates on this? @kartikayk @joecummings
recipes/quantize.py
Outdated
def __init__(self, cfg: DictConfig) -> None: | ||
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = utils.get_dtype(dtype=cfg.dtype) | ||
self._quantization_mode = cfg.quantization_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to a couple of Kartikay's comments, should be very explicit about what the supported quantization modes are here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some docs in quantize.yaml file, please let me know if that's a good place to host it
52b7a4c
to
f794e13
Compare
# Quantization specific args | ||
quantizer: | ||
_component_: torchtune.utils.Int4WeightOnlyQuantizer | ||
groupsize: 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! Do you mind just adding some comments on explaining what these are here. Ok to point to the doc string for more info, but add a line or two around what this block is doing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added these in receipes/configs/quantize.yaml
recipes/eleuther_eval.py
Outdated
@@ -33,6 +33,16 @@ | |||
sys.exit(1) | |||
|
|||
|
|||
def _parent_name(target): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this used? I didn't find this in the file below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh will remove, this was used for workaround issues in tensor subclass
return model | ||
|
||
@torch.no_grad() | ||
def quantize(self, cfg: DictConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So when will the failure be surfaced? At the time the config is parsed?
7137b6d
to
b4182ac
Compare
3be5b97
to
1907236
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks awesome! A few questions and suggestions.
Mind adding some screenshots for:
a) Eval with quantized model on any one task
b) Screenshot of inference with the prompt including memory consumption and tokens/sec.
Both of these will be helpful in the future and act as a reference.
recipes/README.md
Outdated
|
||
`receipes/configs/quantize.yaml` | ||
|
||
We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument menas: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument menas: | |
We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument means: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, seems like this is missing some information i.e. I assumed there will be a description of what each argument is :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry I moved these to quantized.yaml, I'll reword this
recipes/README.md
Outdated
`recipes/eleuther_eval.py` | ||
|
||
``` | ||
# to skip running through GPTQ, change model = quantizer.quantize(model) to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does "running through GPTQ" mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I'll change this a bit, also will add a bit description for GPTQ
# Args: | ||
# `groupsize` (int): a parameter of int4 weight only quantization, | ||
# it refers to the size of quantization groups which get independent quantization parameters | ||
# e.g. 32, 64, 128, 256, smaller numbers means more fine grained and higher accuracy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing smaller numbers also means more memory? Is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah smaller numbers will cost more memory, but it's probably not too significant overall
# multiple of groupsize. | ||
# `percdamp`: GPTQ stablization hyperparameter, recommended to be .01 | ||
# | ||
# future note: blocksize and percdamp should not have to be 'known' by users by default. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since most users don't need it, should we just remove this from the config and add these as defaults to the instantiate call? Or maybe not even expose these at all? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, maybe we can remove these in a future release? we have done branch cut today. cc @HDCharles
recipes/eleuther_eval.py
Outdated
@@ -150,10 +153,15 @@ def _setup_model( | |||
model_cfg: DictConfig, | |||
model_state_dict: Dict[str, Any], | |||
) -> nn.Module: | |||
with utils.set_default_dtype(self._dtype), self._device: | |||
if self._quantization_mode is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that we'll init the model in fp32 instead of say bf16. Is that by design? I only ask because this will double the model memory at init time. Is quantization from bf16 not supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh it is supported, we could init with bf16. I can change this back to the init under self._dtype and device
recipes/eleuther_eval.py
Outdated
model.load_state_dict(model_state_dict) | ||
else: | ||
with utils.set_default_dtype(self._dtype), self._device: | ||
model = config.instantiate(model_cfg) | ||
model.load_state_dict(model_state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.load_state_dict(model_state_dict) | |
else: | |
with utils.set_default_dtype(self._dtype), self._device: | |
model = config.instantiate(model_cfg) | |
model.load_state_dict(model_state_dict) | |
else: | |
with utils.set_default_dtype(self._dtype), self._device: | |
model = config.instantiate(model_cfg) | |
model.load_state_dict(model_state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.load_state_dict(model_state_dict)
can be just moved outside the if-else block. If we can init the model in bf16 for quantizaton then this if-else block can be further simplified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
if self._quantization_mode is not None: | ||
model = self._quantizer.quantize(model) | ||
model = model.to(device=self._device, dtype=self._dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So seems like init the model in bf16 is fine? Can we do the same for eval too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, done
@@ -9,4 +9,4 @@ tqdm | |||
omegaconf | |||
|
|||
# Quantization | |||
torchao-nightly==2024.3.29 | |||
torchao==0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!!
Summary: Allows user to specify quantization_mode in generating model in full_finetune_single_device.py and inference with the quantized model in generate.py Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this functionality and for patiently addressing all of the comments!
Summary:
Allows user to specify quantization_mode in generating model in full_finetune_single_device.py and inference with the quantized model in generate.py
Test Plan:
tested locally
may need changes to corresponding yaml files, see README.md for more info
Results of generate for int4 weight only quantized model:
Results of eval for int4 weight only quantized model:
Reviewers:
Subscribers:
Tasks:
Tags: