Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchAOHfQuantizer #32306

Merged
merged 13 commits into from
Aug 14, 2024
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
title: FBGEMM_FP8
- local: quantization/optimum
title: Optimum
- local: quantization/torchao
title: TorchAO
- local: quantization/contribute
title: Contribute new quantization method
title: Quantization Methods
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] FbgemmFp8Config

## TorchAOConfig
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

[[autodoc]] TorchAoConfig

2 changes: 1 addition & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ Use the table below to help you decide which quantization method to use.
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |

| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
40 changes: 40 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# TorchAO

[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training.

Before you begin, make sure the following libraries are installed with their latest version:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe mention it supports compile ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added above


```bash
pip install --upgrade torch torchao
```


```py
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But one can serialize the state dict with quantized_model.state_dict() and then use it later right? I'd prefer to have this information included as a workaround.

Also, can we provide a small table noting the expected memory and latency savings here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah state_dict works locally, but I'm not sure how to get this work with the load_pretrained save_pretrained API. we are adding serialization support for both safetensor and non-safetensor though

yeah we have some numbers here: https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks, I can add it

2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
"TorchAoConfig",
],
}

Expand Down Expand Up @@ -5673,6 +5674,7 @@
GPTQConfig,
HqqConfig,
QuantoConfig,
TorchAoConfig,
)

try:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
Expand All @@ -36,6 +37,7 @@
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer


AUTO_QUANTIZER_MAPPING = {
Expand All @@ -48,6 +50,7 @@
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -60,6 +63,7 @@
"quanto": QuantoConfig,
"hqq": HqqConfig,
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
}


Expand Down
133 changes: 133 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from .base import HfQuantizer
from .quantizers_utils import get_module_from_name


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from typing import Any, Dict, List

from ..utils import is_torch_available, is_torchao_available, logging


if is_torch_available():
import torch

if is_torchao_available():
from torchao.quantization import (
quantize_,
)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.get_logger(__name__)


# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent


class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
"""

requires_parameters_quantization = True
requires_calibration = False
required_packages = ["torchao"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def validate_environment(self, device_map, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")

def update_torch_dtype(self, torch_dtype):
if self.quantization_config.quant_type == "int4_weight_only":
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning_once(
f"Setting torch_dtype to {torch_dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Please set the torch_dtype to bfloat16."
)
if torch_dtype is None:
logger.warning_once(
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
)
torch_dtype = torch.bfloat16
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
return torch_dtype

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert

self.modules_to_not_convert = get_keys_to_not_convert(model)

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)

return

def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
# check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
return False
else:
# we only quantize the weight of nn.Linear
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing that the type of the layer under torchao does not change the linear layer? As in any other quantization methods we check against the quantized liner type!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, torchao only modifies the weight tensor and not the nn.Linear module. So we don't need to check against the quantized linear type.


def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
"""
Each nn.Linear layer that needs to be quantized is processsed here.
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
"""
module, tensor_name = get_module_from_name(model, param_name)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())

def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model"""
return

@property
def is_serializable(self):
return False

@property
def is_trainable(self):
# torchao does not have official support for QAT (Quantization Aware Training)
# but torchao support nf4/PEFT, but it is not integrated yet
# TODO: if this is supported in the future, do a version check here.
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
return False
6 changes: 6 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
is_torch_tf32_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdynamo_available,
is_torchvision_available,
Expand Down Expand Up @@ -902,6 +903,11 @@ def require_torchdynamo(test_case):
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)


def require_torchao(test_case):
"""Decorator marking a test that requires torchao"""
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)


def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
Expand Down Expand Up @@ -1045,6 +1046,10 @@ def is_torchaudio_available():
return _torchaudio_available


def is_torchao_available():
return _torchao_available


def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available
Expand Down
Loading
Loading