Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding quantization support in torchtune #632

Merged
merged 1 commit into from
Apr 5, 2024
Merged

Adding quantization support in torchtune #632

merged 1 commit into from
Apr 5, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 1, 2024

Summary:
Allows user to specify quantization_mode in generating model in full_finetune_single_device.py and inference with the quantized model in generate.py

Test Plan:
tested locally

may need changes to corresponding yaml files, see README.md for more info

tune run full_finetune_single_device --config llama2/7B_full_single_device max_steps_per_epoch=4 epochs=1
tune run quantize --config quantize
tune run generate --config generate
tune run eleuther_eval --config eleuther_eval

Results of generate for int4 weight only quantized model:

$ tune run generate --config quant_generate                                                                                                         
2024-04-04:19:23:15,756 INFO     [_parse.py:52] Running main with parameters {'model': {'_component_': 'torchtune.models.llama2.llama2_7b'}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelTorchTuneCheckpointer', 'checkpoint_dir': '/tmp/llama2/', 'checkpoint_files': ['meta_model_0.4w.pt'], 'output_dir': '/tmp/llama2/', 'model_type': 'LLAMA2'}, 'device': 'cuda', 'dtype': 'bf16', 'seed': 1234, 'quantizer': {'_compone
nt_': 'torchtune.utils.quantization.Int4WeightOnlyQuantizer', 'groupsize': 256}, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': '/tmp/llama2/tokenizer.model'}, 'prompt': 'Hello, my
 name is', 'max_new_tokens': 300, 'temperature': 0.8, 'top_k': 300}
2024-04-04:19:23:16,140 DEBUG    [seed.py:59] Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
linear: layers.0.attn.q_proj, in=4096, out=4096
linear: layers.0.attn.k_proj, in=4096, out=4096
linear: layers.0.attn.v_proj, in=4096, out=4096
linear: layers.0.attn.output_proj, in=4096, out=4096
linear: layers.0.mlp.w1, in=4096, out=11008
linear: layers.0.mlp.w2, in=11008, out=4096
…
linear: output, in=4096, out=32000
2024-04-04:19:23:26,511 INFO     [generate.py:68] Model is initialized with precision torch.bfloat16.
2024-04-04:19:23:39,668 INFO     [generate.py:92] Hello, my name is Alicia, I’m a 47-year-old married woman whose husband is older than me. myself. I don’t have children and I’m not expecting any. If you are already 18 years old, and you plan to have a baby with me! I’m an artist and I like to be creative (especially with jewelry and beading). I’m a little clumsy and I don’t know how to drive, but I’m good at doing girly things, lol.
I like to have fun in private, so I’m looking for a guy who likes to have sex and is funny.
I am for private relationships and I understand the importance of loyalty, so I expect the same from you!
I’m here to have fun, so let’s meet and know each other and let’s see if we’ll get along!
British, Europe
Art, Cinema, History, Reading, Shopping, Sleeping
Chillin, Comedy, Conversation, Going out, Intelligence, Movies, Partying
Clothes, Creative, Goal Oriented, Intellectual, Money
2024-04-04:19:23:39,669 INFO     [generate.py:96] Time for inference: 12.82 sec total, 20.44 tokens/sec
2024-04-04:19:23:39,669 INFO     [generate.py:99] Memory used: 17.85 GB

Results of eval for int4 weight only quantized model:

$ tune run eleuther_eval --config eleuther_eval
2024-04-04:19:26:20,675 INFO     [_parse.py:52] Running recipe_main with parameters {'model': {'_component_': 'torchtune.models.llama2.llama2_7b'}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelTorchTun
eCheckpointer', 'checkpoint_dir': '/tmp/llama2/', 'checkpoint_files': ['meta_model_0.4w.pt'], 'output_dir': '/tmp/llama2/', 'model_type': 'LLAMA2'}, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_toke
nizer', 'path': '/tmp/llama2/tokenizer.model'}, 'device': 'cuda', 'dtype': 'bf16', 'seed': 217, 'tasks': ['wikitext'], 'limit': None, 'max_seq_length': 4096, 'quantizer': {'_component_': 'torchtune.utils.quantizati
on.Int4WeightOnlyQuantizer', 'groupsize': 256}}
2024-04-04:19:26:21,029 DEBUG    [seed.py:59] Setting manual seed to local seed 217. Local seed is seed + rank = 217 + 0
linear: layers.0.attn.q_proj, in=4096, out=4096
linear: layers.0.attn.k_proj, in=4096, out=4096
linear: layers.0.attn.v_proj, in=4096, out=4096
linear: layers.0.attn.output_proj, in=4096, out=4096
linear: layers.0.mlp.w1, in=4096, out=11008
linear: layers.0.mlp.w2, in=11008, out=4096
...
linear: output, in=4096, out=32000
2024-04-04:19:26:27,681 INFO     [eleuther_eval.py:167] Model is initialized with precision torch.bfloat16.
2024-04-04:19:26:27,699 INFO     [eleuther_eval.py:151] Tokenizer is initialized from file.
2024-04-04:19:26:28,036 INFO     [huggingface.py:162] Using device 'cuda:0'
q2024-04-04:19:26:35,919 WARNING  [task.py:763] [Task: wikitext] metric word_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-04-04:19:26:35,919 WARNING  [task.py:775] [Task: wikitext] metric word_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-04-04:19:26:35,919 WARNING  [task.py:763] [Task: wikitext] metric byte_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-04-04:19:26:35,919 WARNING  [task.py:775] [Task: wikitext] metric byte_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-04-04:19:26:35,919 WARNING  [task.py:763] [Task: wikitext] metric bits_per_byte is defined, but aggregation is not. using default aggregation=bits_per_byte
2024-04-04:19:26:35,919 WARNING  [task.py:775] [Task: wikitext] metric bits_per_byte is defined, but higher_is_better is not. using default higher_is_better=False
2024-04-04:19:26:37,607 INFO     [eleuther_eval.py:188] Running evaluation on ['wikitext'] tasks.
2024-04-04:19:26:37,610 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 487.84it/s]
2024-04-04:19:26:37,743 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [04:53<00:00,  4.73s/it]
2024-04-04:19:31:31,458 INFO     [eleuther_eval.py:195] Eval completed in 303.43 seconds.
2024-04-04:19:31:31,458 INFO     [eleuther_eval.py:197] wikitext: {'word_perplexity,none': 9.615681846101303, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.5269407647819768, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.610644096180032, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Apr 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/632

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4abcdb9 with merge base 45031b3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2024
@jerryzh168 jerryzh168 force-pushed the add-quant branch 3 times, most recently from f780b1b to 47bd46d Compare April 3, 2024 02:01
@jerryzh168 jerryzh168 force-pushed the add-quant branch 2 times, most recently from 434fa6d to b285a16 Compare April 3, 2024 04:14
recipes/eleuther_eval.py Outdated Show resolved Hide resolved
recipes/quantize.py Outdated Show resolved Hide resolved
Comment on lines 60 to 62
# Ensure the cache is setup on the right device
with self._device:
model.setup_caches(max_batch_size=1, dtype=self._dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to set up caches for quantization?

recipes/quantize.py Outdated Show resolved Hide resolved
recipes/quantize.py Outdated Show resolved Hide resolved
@@ -0,0 +1,47 @@
from typing import Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some tests for this file? I'm nervous about adding utilities without any associated tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add some tests to make sure the model is as expected after quantization. This will help catch any breakages in APIs/behaviors if torchao changes.

@rohan-varma any ideas around testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kartikayk I changed these to just use classes directly and they will be tested in torchao, is that OK?

@@ -0,0 +1,29 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need a separate config here. We should make it clear in our documentation and tutorial on how to load a quantized model. Its just a checkpointer change. So Id remove this

recipes/eleuther_eval.py Outdated Show resolved Hide resolved
recipes/eleuther_eval.py Outdated Show resolved Hide resolved
@@ -124,20 +124,25 @@ class EleutherEvalRecipe(EvalRecipeInterface):
def __init__(self, cfg: DictConfig) -> None:
self._cfg = cfg

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings can I get a review for this file? Trying to think about the best way to integrate the quantization changes for eval

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: why would we want to quantize the model in eval?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to load a quantized model for evaluation. The flow here is something like this:

  • Finetune + Eval
  • Quantize
  • Eval with quantize to make sure quantized model is still performant
  • Run inference with quantized model to make sure its not doing something crazy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any updates on this? @kartikayk @joecummings

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantization_mode = cfg.quantization_mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to a couple of Kartikay's comments, should be very explicit about what the supported quantization modes are here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some docs in quantize.yaml file, please let me know if that's a good place to host it

@jerryzh168 jerryzh168 force-pushed the add-quant branch 3 times, most recently from 52b7a4c to f794e13 Compare April 4, 2024 05:07
recipes/README.md Outdated Show resolved Hide resolved
# Quantization specific args
quantizer:
_component_: torchtune.utils.Int4WeightOnlyQuantizer
groupsize: 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Do you mind just adding some comments on explaining what these are here. Ok to point to the doc string for more info, but add a line or two around what this block is doing.

Copy link
Contributor Author

@jerryzh168 jerryzh168 Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added these in receipes/configs/quantize.yaml

recipes/README.md Outdated Show resolved Hide resolved
recipes/configs/eleuther_eval.yaml Outdated Show resolved Hide resolved
@@ -33,6 +33,16 @@
sys.exit(1)


def _parent_name(target):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used? I didn't find this in the file below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh will remove, this was used for workaround issues in tensor subclass

recipes/generate.py Outdated Show resolved Hide resolved
return model

@torch.no_grad()
def quantize(self, cfg: DictConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when will the failure be surfaced? At the time the config is parsed?

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks awesome! A few questions and suggestions.

Mind adding some screenshots for:
a) Eval with quantized model on any one task
b) Screenshot of inference with the prompt including memory consumption and tokens/sec.

Both of these will be helpful in the future and act as a reference.


`receipes/configs/quantize.yaml`

We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument menas:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument menas:
We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument means:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, seems like this is missing some information i.e. I assumed there will be a description of what each argument is :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I moved these to quantized.yaml, I'll reword this

`recipes/eleuther_eval.py`

```
# to skip running through GPTQ, change model = quantizer.quantize(model) to:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "running through GPTQ" mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll change this a bit, also will add a bit description for GPTQ

recipes/configs/eleuther_eval.yaml Outdated Show resolved Hide resolved
# Args:
# `groupsize` (int): a parameter of int4 weight only quantization,
# it refers to the size of quantization groups which get independent quantization parameters
# e.g. 32, 64, 128, 256, smaller numbers means more fine grained and higher accuracy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing smaller numbers also means more memory? Is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah smaller numbers will cost more memory, but it's probably not too significant overall

# multiple of groupsize.
# `percdamp`: GPTQ stablization hyperparameter, recommended to be .01
#
# future note: blocksize and percdamp should not have to be 'known' by users by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since most users don't need it, should we just remove this from the config and add these as defaults to the instantiate call? Or maybe not even expose these at all? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, maybe we can remove these in a future release? we have done branch cut today. cc @HDCharles

@@ -150,10 +153,15 @@ def _setup_model(
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
if self._quantization_mode is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that we'll init the model in fp32 instead of say bf16. Is that by design? I only ask because this will double the model memory at init time. Is quantization from bf16 not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it is supported, we could init with bf16. I can change this back to the init under self._dtype and device

Comment on lines 160 to 164
model.load_state_dict(model_state_dict)
else:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
model.load_state_dict(model_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model.load_state_dict(model_state_dict)
else:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
model.load_state_dict(model_state_dict)
else:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
model.load_state_dict(model_state_dict)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.load_state_dict(model_state_dict) can be just moved outside the if-else block. If we can init the model in bf16 for quantizaton then this if-else block can be further simplified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

recipes/generate.py Show resolved Hide resolved
Comment on lines +58 to +62
if self._quantization_mode is not None:
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So seems like init the model in bf16 is fine? Can we do the same for eval too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, done

@@ -9,4 +9,4 @@ tqdm
omegaconf

# Quantization
torchao-nightly==2024.3.29
torchao==0.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!!

Summary:
Allows user to specify quantization_mode in generating model in full_finetune_single_device.py
and inference with the quantized model in generate.py

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this functionality and for patiently addressing all of the comments!

@kartikayk kartikayk merged commit 2ab2721 into main Apr 5, 2024
20 checks passed
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
@joecummings joecummings deleted the add-quant branch April 11, 2024 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants