Skip to content

Commit

Permalink
Adding quantization support in torchtune
Browse files Browse the repository at this point in the history
Summary:
Allows user to specify quantization_mode in generating model in full_finetune_single_device.py
and inference with the quantized model in generate.py

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 3, 2024
1 parent 32d66df commit b285a16
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 11 deletions.
3 changes: 3 additions & 0 deletions recipes/configs/eleuther_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ seed: 217
tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096

# Quantization
quantization_mode: null
1 change: 1 addition & 0 deletions recipes/configs/generate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ checkpointer:

device: cuda
dtype: bf16
quantization_mode: f16a4w

seed: 1234

Expand Down
29 changes: 29 additions & 0 deletions recipes/configs/quant_generate.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [meta_model_0.f16a4w.pt]
output_dir: /tmp/llama2/
model_type: LLAMA2

device: cpu
dtype: bf16
seed: 1234

# Quantization Arguments
quantization_mode: f16a4w

# Tokenizer arguments
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Hello, my name is"
max_new_tokens: 300
temperature: 0.8
top_k: 300
18 changes: 18 additions & 0 deletions recipes/configs/quantize.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [meta_model_0.pt]
output_dir: /tmp/llama2/
model_type: LLAMA2

device: cuda
dtype: bf16
seed: 1234

# Quantization Arguments
quantization_mode: f16a4w
21 changes: 16 additions & 5 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,25 @@ class EleutherEvalRecipe(EvalRecipeInterface):
def __init__(self, cfg: DictConfig) -> None:
self._cfg = cfg

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]:
checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = checkpointer.load_checkpoint()
checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only)
return checkpoint_dict

def setup(self) -> None:
self._device = utils.get_device(device=self._cfg.device)
self._dtype = utils.get_dtype(dtype=self._cfg.dtype)
self._limit = self._cfg.limit
self._tasks = list(self._cfg.tasks)
self._quantization_mode = self._cfg.quantization_mode

utils.set_seed(seed=self._cfg.seed)

ckpt_dict = self.load_checkpoint(self._cfg.checkpointer)
weights_only = True
if self._quantization_mode is not None:
weights_only = False

ckpt_dict = self.load_checkpoint(self._cfg.checkpointer, weights_only=weights_only)
self._model = self._setup_model(
model_cfg=self._cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
Expand All @@ -153,10 +158,16 @@ def _setup_model(
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict)
if self._quantization_mode is not None:
quantizer = utils.get_quantizer(self._quantization_mode)
model = quantizer.quantize(model)

model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
# TODO: enable dtype checking for quantization
if self._quantization_mode is None:
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")
return model

Expand Down
20 changes: 15 additions & 5 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,20 @@ class InferenceRecipe:
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantization_mode = cfg.quantization_mode

utils.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]:
checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = checkpointer.load_checkpoint()
checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only)
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
weights_only = True
if self._quantization_mode is not None:
weights_only = False
ckpt_dict = self.load_checkpoint(cfg.checkpointer, weights_only=weights_only)
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
Expand All @@ -52,10 +56,16 @@ def _setup_model(
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict)
if self._quantization_mode is not None:
quantizer = utils.get_quantizer(self._quantization_mode)
model = quantizer.quantize(model)

model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
# TODO: enable this for quantization as well
if self._quantization_mode is None:
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")

# Ensure the cache is setup on the right device
Expand Down
93 changes: 93 additions & 0 deletions recipes/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
from typing import Any, Dict

import torch
from omegaconf import DictConfig

from torch import nn

from torchtune import config, utils

logger = utils.get_logger("DEBUG")


class QuantizationRecipe:
"""
Recipe for generating tokens from a dense Transformer-based LLM.
Currently this recipe support single-GPU generation only. Speculative
decoding is not supported.
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantization_mode = cfg.quantization_mode
utils.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
self._checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
)

def _setup_model(
self,
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")

# Ensure the cache is setup on the right device
with self._device:
model.setup_caches(max_batch_size=1, dtype=self._dtype)

return model

@torch.no_grad()
def quantize(self, cfg: DictConfig):
quantizer = utils.get_quantizer(self._quantization_mode)
t0 = time.perf_counter()
self._model = quantizer.quantize(self._model)
t = time.perf_counter() - t0
logger.info(
f"Time for quantization: {t:.02f} sec"
)
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

@torch.no_grad()
def save(self, cfg: DictConfig):
ckpt_dict = self._model.state_dict()
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0]
torch.save(ckpt_dict, cfg.checkpointer.output_dir + file_name + "." + self._quantization_mode + ".pt")


@config.parse
def main(cfg: DictConfig) -> None:
recipe = QuantizationRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.quantize(cfg=cfg)
recipe.save(cfg=cfg)


if __name__ == "__main__":
sys.exit(main())
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ tqdm
omegaconf

# Quantization
torchao-nightly==2024.3.29
torchao-nightly==2024.4.2
9 changes: 9 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Recipe:
file_path="generate.py",
configs=[
Config(name="generate", file_path="generate.yaml"),
Config(name="quant_generate", file_path="quant_generate.yaml"),
],
supports_distributed=False,
),
Expand All @@ -97,6 +98,14 @@ class Recipe:
],
supports_distributed=False,
),
Recipe(
name="quantize",
file_path="quantize.py",
configs=[
Config(name="quantize", file_path="quantize.yaml"),
],
supports_distributed=False,
),
]


Expand Down
2 changes: 2 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
validate_expected_param_dtype,
)
from .seed import set_seed
from .quantization import get_quantizer

__all__ = [
"save_checkpoint",
Expand Down Expand Up @@ -80,4 +81,5 @@
"OptimizerInBackwardWrapper",
"create_optim_in_bwd_wrapper",
"register_optim_in_bwd_hooks",
"get_quantizer",
]
47 changes: 47 additions & 0 deletions torchtune/utils/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any
import torch
from torchao.quantization.quant_api import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
Quantizer,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3

class FP16ActInt4WeightQuantizer(Quantizer):
def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
change_linear_weights_to_int4_woqtensors(model)
return model


class FP16ActInt8WeightQuantizer(Quantizer):
def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
change_linear_weights_to_int8_woqtensors(model)
return model


def get_quantizer(quantization_mode, *args, **kwargs):
qmode_to_quantizer = {
"f16a4w": FP16ActInt4WeightQuantizer,
"f16a8w": FP16ActInt8WeightQuantizer,
}
if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.quant_api import (
Int8DynActInt4WeightQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int4WeightGPTQQuantizer,
)

qmode_to_quantizer |= {
"8da4w": Int8DynActInt4WeightQuantizer,
# TODO: merge into 8da4w
"8da4w-gptq": Int8DynActInt4WeightGPTQQuantizer,
# merge into f16a4w
"f16a4w-gptq": Int4WeightGPTQQuantizer,
}
if quantization_mode not in qmode_to_quantizer:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}, supported modes are: {_QMODE_TO_QUANTIZER.keys()}")
return qmode_to_quantizer[quantization_mode](*args, **kwargs)

0 comments on commit b285a16

Please sign in to comment.