-
Notifications
You must be signed in to change notification settings - Fork 448
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding quantization support in torchtune
Summary: Allows user to specify quantization_mode in generating model in full_finetune_single_device.py and inference with the quantized model in generate.py Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
32d66df
commit b285a16
Showing
11 changed files
with
234 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,3 +28,6 @@ seed: 217 | |
tasks: ["truthfulqa_mc2"] | ||
limit: null | ||
max_seq_length: 4096 | ||
|
||
# Quantization | ||
quantization_mode: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ checkpointer: | |
|
||
device: cuda | ||
dtype: bf16 | ||
quantization_mode: f16a4w | ||
|
||
seed: 1234 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
|
||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelTorchTuneCheckpointer | ||
checkpoint_dir: /tmp/llama2/ | ||
checkpoint_files: [meta_model_0.f16a4w.pt] | ||
output_dir: /tmp/llama2/ | ||
model_type: LLAMA2 | ||
|
||
device: cpu | ||
dtype: bf16 | ||
seed: 1234 | ||
|
||
# Quantization Arguments | ||
quantization_mode: f16a4w | ||
|
||
# Tokenizer arguments | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/llama2/tokenizer.model | ||
|
||
# Generation arguments; defaults taken from gpt-fast | ||
prompt: "Hello, my name is" | ||
max_new_tokens: 300 | ||
temperature: 0.8 | ||
top_k: 300 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
|
||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelMetaCheckpointer | ||
checkpoint_dir: /tmp/llama2/ | ||
checkpoint_files: [meta_model_0.pt] | ||
output_dir: /tmp/llama2/ | ||
model_type: LLAMA2 | ||
|
||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
|
||
# Quantization Arguments | ||
quantization_mode: f16a4w |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import sys | ||
import time | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from omegaconf import DictConfig | ||
|
||
from torch import nn | ||
|
||
from torchtune import config, utils | ||
|
||
logger = utils.get_logger("DEBUG") | ||
|
||
|
||
class QuantizationRecipe: | ||
""" | ||
Recipe for generating tokens from a dense Transformer-based LLM. | ||
Currently this recipe support single-GPU generation only. Speculative | ||
decoding is not supported. | ||
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = utils.get_dtype(dtype=cfg.dtype) | ||
self._quantization_mode = cfg.quantization_mode | ||
utils.set_seed(seed=cfg.seed) | ||
|
||
def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: | ||
self._checkpointer = config.instantiate(checkpointer_cfg) | ||
checkpoint_dict = self._checkpointer.load_checkpoint() | ||
return checkpoint_dict | ||
|
||
def setup(self, cfg: DictConfig) -> None: | ||
ckpt_dict = self.load_checkpoint(cfg.checkpointer) | ||
self._model = self._setup_model( | ||
model_cfg=cfg.model, | ||
model_state_dict=ckpt_dict[utils.MODEL_KEY], | ||
) | ||
|
||
def _setup_model( | ||
self, | ||
model_cfg: DictConfig, | ||
model_state_dict: Dict[str, Any], | ||
) -> nn.Module: | ||
with utils.set_default_dtype(self._dtype), self._device: | ||
model = config.instantiate(model_cfg) | ||
|
||
model.load_state_dict(model_state_dict, assign=True) | ||
|
||
# Validate model was loaded in with the expected dtype. | ||
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) | ||
logger.info(f"Model is initialized with precision {self._dtype}.") | ||
|
||
# Ensure the cache is setup on the right device | ||
with self._device: | ||
model.setup_caches(max_batch_size=1, dtype=self._dtype) | ||
|
||
return model | ||
|
||
@torch.no_grad() | ||
def quantize(self, cfg: DictConfig): | ||
quantizer = utils.get_quantizer(self._quantization_mode) | ||
t0 = time.perf_counter() | ||
self._model = quantizer.quantize(self._model) | ||
t = time.perf_counter() - t0 | ||
logger.info( | ||
f"Time for quantization: {t:.02f} sec" | ||
) | ||
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") | ||
|
||
@torch.no_grad() | ||
def save(self, cfg: DictConfig): | ||
ckpt_dict = self._model.state_dict() | ||
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0] | ||
torch.save(ckpt_dict, cfg.checkpointer.output_dir + file_name + "." + self._quantization_mode + ".pt") | ||
|
||
|
||
@config.parse | ||
def main(cfg: DictConfig) -> None: | ||
recipe = QuantizationRecipe(cfg=cfg) | ||
recipe.setup(cfg=cfg) | ||
recipe.quantize(cfg=cfg) | ||
recipe.save(cfg=cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,4 @@ tqdm | |
omegaconf | ||
|
||
# Quantization | ||
torchao-nightly==2024.3.29 | ||
torchao-nightly==2024.4.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Any | ||
import torch | ||
from torchao.quantization.quant_api import ( | ||
change_linear_weights_to_int4_woqtensors, | ||
change_linear_weights_to_int8_woqtensors, | ||
Quantizer, | ||
) | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 | ||
|
||
class FP16ActInt4WeightQuantizer(Quantizer): | ||
def quantize( | ||
self, model: torch.nn.Module, *args: Any, **kwargs: Any | ||
) -> torch.nn.Module: | ||
change_linear_weights_to_int4_woqtensors(model) | ||
return model | ||
|
||
|
||
class FP16ActInt8WeightQuantizer(Quantizer): | ||
def quantize( | ||
self, model: torch.nn.Module, *args: Any, **kwargs: Any | ||
) -> torch.nn.Module: | ||
change_linear_weights_to_int8_woqtensors(model) | ||
return model | ||
|
||
|
||
def get_quantizer(quantization_mode, *args, **kwargs): | ||
qmode_to_quantizer = { | ||
"f16a4w": FP16ActInt4WeightQuantizer, | ||
"f16a8w": FP16ActInt8WeightQuantizer, | ||
} | ||
if TORCH_VERSION_AFTER_2_3: | ||
from torchao.quantization.quant_api import ( | ||
Int8DynActInt4WeightQuantizer, | ||
Int8DynActInt4WeightGPTQQuantizer, | ||
Int4WeightGPTQQuantizer, | ||
) | ||
|
||
qmode_to_quantizer |= { | ||
"8da4w": Int8DynActInt4WeightQuantizer, | ||
# TODO: merge into 8da4w | ||
"8da4w-gptq": Int8DynActInt4WeightGPTQQuantizer, | ||
# merge into f16a4w | ||
"f16a4w-gptq": Int4WeightGPTQQuantizer, | ||
} | ||
if quantization_mode not in qmode_to_quantizer: | ||
raise ValueError(f"Unsupported quantization mode: {quantization_mode}, supported modes are: {_QMODE_TO_QUANTIZER.keys()}") | ||
return qmode_to_quantizer[quantization_mode](*args, **kwargs) |