Skip to content

Commit

Permalink
Adding quantization support in torchtune
Browse files Browse the repository at this point in the history
Summary:
Allows user to specify quantization_mode in generating model in full_finetune_single_device.py
and inference with the quantized model in generate.py

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 3, 2024
1 parent 32d66df commit 52b7a4c
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 19 deletions.
13 changes: 8 additions & 5 deletions recipes/configs/eleuther_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@ model:

checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_dir: /tmp/llama/
checkpoint_files: [finetuned_model.pt]
output_dir: /tmp/llama/
checkpoint_dir: /tmp/llama2/
checkpoint_files: [meta_model_0.4w.pt]
output_dir: /tmp/llama2/
model_type: LLAMA2

# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama/tokenizer.model
path: /tmp/llama2/tokenizer.model

# Environment
device: cuda
dtype: bf16
seed: 217

# EleutherAI specific eval args
tasks: ["truthfulqa_mc2"]
tasks: ["truthfulqa_mc2", "hellaswag"]
limit: null
max_seq_length: 4096

# Quantization
quantization_mode: 4w
1 change: 1 addition & 0 deletions recipes/configs/generate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ checkpointer:

device: cuda
dtype: bf16
quantization_mode: null

seed: 1234

Expand Down
29 changes: 29 additions & 0 deletions recipes/configs/quant_generate.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [meta_model_0.f16a4w.pt]
output_dir: /tmp/llama2/
model_type: LLAMA2

device: cpu
dtype: bf16
seed: 1234

# Quantization Arguments
quantization_mode: f16a4w

# Tokenizer arguments
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Hello, my name is"
max_new_tokens: 300
temperature: 0.8
top_k: 300
18 changes: 18 additions & 0 deletions recipes/configs/quantize.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [meta_model_0.pt]
output_dir: /tmp/llama2/
model_type: LLAMA2

device: cuda
dtype: bf16
seed: 1234

# Quantization Arguments
quantization_mode: 4w
30 changes: 22 additions & 8 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,23 @@ class EleutherEvalRecipe(EvalRecipeInterface):
def __init__(self, cfg: DictConfig) -> None:
self._cfg = cfg

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]:
checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = checkpointer.load_checkpoint()
checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only)
return checkpoint_dict

def setup(self) -> None:
self._device = utils.get_device(device=self._cfg.device)
self._dtype = utils.get_dtype(dtype=self._cfg.dtype)
self._limit = self._cfg.limit
self._tasks = list(self._cfg.tasks)
self._quantization_mode = self._cfg.quantization_mode

utils.set_seed(seed=self._cfg.seed)

ckpt_dict = self.load_checkpoint(self._cfg.checkpointer)
# weights_only needs to be False when loading a quantized model
weights_only = (self._quantization_mode is None)
ckpt_dict = self.load_checkpoint(self._cfg.checkpointer, weights_only=weights_only)
self._model = self._setup_model(
model_cfg=self._cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
Expand All @@ -150,13 +153,24 @@ def _setup_model(
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict)
if self._quantization_mode is not None:
with torch.device("meta"):
model = config.instantiate(model_cfg)
quantizer = utils.get_quantizer(self._quantization_mode)
model = quantizer.quantize(model)
model.load_state_dict(model_state_dict, assign=True)
utils.reset_parameters(model)
model = model.to(device=self._device, dtype=self._dtype)
breakpoint()
else:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
# TODO: enable dtype checking for quantization
if self._quantization_mode is None:
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")
return model

Expand Down
19 changes: 14 additions & 5 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class InferenceRecipe:
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantization_mode = cfg.quantization_mode

utils.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]:
checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = checkpointer.load_checkpoint()
checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only)
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
# weights_only needs to be False when loading a quantized model
weights_only = (self._quantization_mode is None)
ckpt_dict = self.load_checkpoint(cfg.checkpointer, weights_only=weights_only)
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
Expand All @@ -52,10 +55,16 @@ def _setup_model(
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict)
if self._quantization_mode is not None:
quantizer = utils.get_quantizer(self._quantization_mode)
model = quantizer.quantize(model)

model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
# TODO: enable this for quantization as well
if self._quantization_mode is None:
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")

# Ensure the cache is setup on the right device
Expand Down
97 changes: 97 additions & 0 deletions recipes/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
from typing import Any, Dict

import torch
from omegaconf import DictConfig

from torch import nn

from torchtune import config, utils

logger = utils.get_logger("DEBUG")


class QuantizationRecipe:
"""
Recipe for quantizing a Transformer-based LLM.
Supported quantization modes are:
8w: int8 weight only per axis group quantization
4w: int4 weight only per axis group quantization
after torch 2.3.0:
8da4w: int8 dynamic activation quantization and int4 weight per axis group quantization
8da4w-gptq: int8 dynamic activation quantization and int4 weight per axis group quantization with GPTQ
4w-gptq: int4 weight only per axis group quantization with GPTQ
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantization_mode = cfg.quantization_mode
utils.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
self._checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
)

def _setup_model(
self,
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")

# Ensure the cache is setup on the right device
# with self._device:
# model.setup_caches(max_batch_size=1, dtype=self._dtype)

return model

@torch.no_grad()
def quantize(self, cfg: DictConfig):
quantizer = utils.get_quantizer(self._quantization_mode)
t0 = time.perf_counter()
self._model = quantizer.quantize(self._model)
t = time.perf_counter() - t0
logger.info(
f"Time for quantization: {t:.02f} sec"
)
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

def save_checkpoint(self, cfg: DictConfig):
ckpt_dict = self._model.state_dict()
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0]
torch.save(ckpt_dict, cfg.checkpointer.output_dir + file_name + "." + self._quantization_mode + ".pt")


@config.parse
def main(cfg: DictConfig) -> None:
recipe = QuantizationRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.quantize(cfg=cfg)
recipe.save_checkpoint(cfg=cfg)


if __name__ == "__main__":
sys.exit(main())
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ tqdm
omegaconf

# Quantization
torchao-nightly==2024.3.29
torchao-nightly==2024.4.2
9 changes: 9 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Recipe:
file_path="generate.py",
configs=[
Config(name="generate", file_path="generate.yaml"),
Config(name="quant_generate", file_path="quant_generate.yaml"),
],
supports_distributed=False,
),
Expand All @@ -97,6 +98,14 @@ class Recipe:
],
supports_distributed=False,
),
Recipe(
name="quantize",
file_path="quantize.py",
configs=[
Config(name="quantize", file_path="quantize.yaml"),
],
supports_distributed=False,
),
]


Expand Down
5 changes: 5 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
validate_expected_param_dtype,
)
from .seed import set_seed
from .quantization import (
get_quantizer,
reset_parameters,
)

__all__ = [
"save_checkpoint",
Expand Down Expand Up @@ -80,4 +84,5 @@
"OptimizerInBackwardWrapper",
"create_optim_in_bwd_wrapper",
"register_optim_in_bwd_hooks",
"get_quantizer",
]
52 changes: 52 additions & 0 deletions torchtune/utils/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any
import torch
from torchao.quantization.quant_api import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
Quantizer,
Int4WeightOnlyGPTQQuantizer,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3

class Int4WeightOnlyQuantizer(Quantizer):
def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
change_linear_weights_to_int4_woqtensors(model)
return model


class Int8WeightOnlyQuantizer(Quantizer):
def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
change_linear_weights_to_int8_woqtensors(model)
return model


def get_quantizer(quantization_mode, *args, **kwargs):
qmode_to_quantizer = {
# TODO: change to 4w before land
"4w": Int4WeightOnlyQuantizer,
"8w": Int8WeightOnlyQuantizer,
"4w-gptq": Int4WeightOnlyGPTQQuantizer,
}
if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.quant_api import (
Int8DynActInt4WeightQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
)

qmode_to_quantizer |= {
"8da4w": Int8DynActInt4WeightQuantizer,
# TODO: merge into 8da4w
"8da4w-gptq": Int8DynActInt4WeightGPTQQuantizer,
}
if quantization_mode not in qmode_to_quantizer:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}, supported modes are: {qmode_to_quantizer.keys()}")
return qmode_to_quantizer[quantization_mode](*args, **kwargs)

def reset_parameters(model: torch.nn.Module):
for name, module in model.named_modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()

0 comments on commit 52b7a4c

Please sign in to comment.