Skip to content

Commit

Permalink
Update quantization to use tensor subclasses
Browse files Browse the repository at this point in the history
**Summary:** In torchao, we are migrating our quantization flows
from module swap to tensor subclasses. The existing
`Int8DynActInt4WeightQuantizer` will be deprecated in the near
future in favor of quantizing using the `quantize_` API,
so we should do the same in torchtune. This quantizer is
currently only used by QAT, which also recently migrated to
a tensor subclass implementation.

This also changes the eval script slightly since models
quantized through the torchao tensor subclasses are expected
to be loaded with `assign=True`: https://github.com/pytorch/ao/blob/9a56e80cb6070599701b8f5f587bd8187c8dccb4/test/quantization/test_quant_api.py#L610.
We should load the model similarly in torchtune.

**Test Plan:**

Quantized and evaluated the base Llama3-8B model on 1 A100 GPU:

```
CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/my_quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.checkpoint_files=[consolidated.00.pth] \
    checkpointer.model_type=LLAMA3

CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \
    checkpointer.checkpoint_files=[consolidated-8da4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
```

Reviewers: ebsmothers, kartikayk, RdoubleA

Subscribers: ebsmothers, kartikayk, RdoubleA

Subscribers:
  • Loading branch information
andrewor14 committed Aug 28, 2024
1 parent f95a9f8 commit d020da3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
8 changes: 6 additions & 2 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,15 @@ def _setup_model(
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

if self._quantization_mode is not None:
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)

model.load_state_dict(model_state_dict)
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(self._device)
model.load_state_dict(model_state_dict, assign=True)
else:
model.load_state_dict(model_state_dict)

# Put model in eval mode.
# Note: This will not disable the dropout applied in SDPA,
Expand Down
54 changes: 33 additions & 21 deletions torchtune/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,51 @@

from typing import Callable, Optional

# importing TORCH_VERSION_AFTER_2_3 because `Int8DynActInt4WeightQuantizer`
# is only available after 2.3 so we have to guard the pytorch versions to decide
# the list of supported quantizers
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_

from torchao.quantization.prototype.qat import (
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATQuantizer,
)

__all__ = [
"get_quantizer_mode",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightQATQuantizer",
]


_quantizer_to_mode = {}
_quantizer_mode_to_disable_fake_quant = {}
_quantizer_mode_to_enable_fake_quant = {}
class Int8DynActInt4WeightQuantizer:
"""
Quantizer for applying int8 per token dynamic activation + int4
per group weight quantization to linear layers in the model.
"""

def __init__(self, groupsize: int = 256):
self.groupsize = groupsize

def quantize(self, model):
quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize)
quantize_(model, quantize_fn)
return model


if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
_quantizer_to_mode = {
Int8DynActInt4WeightQuantizer: "8da4w",
Int8DynActInt4WeightQATQuantizer: "8da4w-qat",
}

__all__.append("Int8DynActInt4WeightQuantizer")
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"

# TODO: remove these
_quantizer_mode_to_disable_fake_quant = {
"8da4w-qat": disable_8da4w_fake_quant,
}

if TORCH_VERSION_AFTER_2_4:
from torchao.quantization.prototype.qat import (
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATQuantizer,
)

__all__.append("Int8DynActInt4WeightQATQuantizer")
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat"
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant
_quantizer_mode_to_enable_fake_quant = {
"8da4w-qat": enable_8da4w_fake_quant,
}


def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
Expand Down

0 comments on commit d020da3

Please sign in to comment.