Highlights
We are excited to announce the 0.6.1 release of torchao! This release adds support for Auto-Round support, Float8 Axiswise scaled training, a BitNet training recipe, an implementation of AWQ and much more!
Auto-Round Support (#581)
Auto-Round is a new weight-only quantization algorithm, it has as achieved superior accuracy compared to GPTQ, AWQ, and OmniQuant across 11 tasks, particularly excelling in low-bit quantization (e.g., 2-bits and 3-bits). Auto-Round supports quantization from 2 to 8 bits, involves low tuning costs, and imposes no additional overhead during inference. Key results are summarized below, with detailed information available in our paper, GitHub repository, and Hugging Face low-bit quantization leaderboard.
from torchao.prototype.autoround.core import prepare_model_for_applying_auto_round_
from torchao.prototype.autoround.core import apply_auto_round
prepare_model_for_applying_auto_round_(
model,
is_target_module=is_target_module,
bits=4,
group_size=128,
iters=200,
device=device,
)
input_ids_lst = []
for data in dataloader:
input_ids_lst.append(data["input_ids"].to(model_device))
multi_t_input_ids = MultiTensor(input_ids_lst)
out = model(multi_t_input_ids)
quantize_(model, apply_auto_round(), is_target_module)
Added float8 training axiswise scaling support with per-gemm-argument configuration (#940)
We added experimental support for rowwise scaled float8 gemm to torchao.float8
, with per-gemm-input configurability to enable exploration of various recipes. Here is how a user can configure all-axiswise scaling
# all-axiswise scaling
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m = torchao.float8.convert_to_float8_training(config)
# or, a custom recipe by @lw where grad_weight is left in bfloat16
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
m = torchao.float8.convert_to_float8_training(config)
Early performance benchmarks show all-axiswise scaling achieve a 1.13x speedup vs bf16 on torchtitan / LLaMa 3 8B / 8 H100 GPUs (compared to 1.17x from all-tensorwise scaling in the same setup), and loss curves which match to bf16 and all-tensorwise scaling. Further performance and accuracy benchmarks will follow in future releases.
Introduced BitNet b1.58 training recipe (#930)
Adds recipe for doing BitNet b1.58](https://arxiv.org/abs/2402.17764) ternary weights clamping.
from torchao.prototype.quantized_training import bitnet_training
from torchao import quantize_
model = ...
quantize_(model, bitnet_training())
Notably: Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases.
[Prototype] Implemented Activation Aware Weight Quantization AWQ (#743)
Perplexity and performance measured on A100 GPU:
Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) |
---|---|---|---|---|---|
Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 |
awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 | |
awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 | |
int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 | |
int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 |
Usage:
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model=model.to(device)
insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
with torch.no_grad():
for batch in calibration_data:
model(batch.to(device))
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)
New Features
- [Prototype] Added Float8 support for AQT tensor parallel (#1003)
- Added composable QAT quantizer (#938)
- Introduced torchchat quantizer (#897)
- Added INT8 mixed-precision training (#748)
- Implemented sparse marlin AQT layout (#621)
- Added a PerTensor static quant api (#787)
- Introduced uintx quant to generate and eval (#811)
- Added Float8 Weight Only and FP8 weight + dynamic activation (#740)
- Implemented Auto-Round support (#581)
- Added 2, 3, 4, 5 bit custom ops (#828)
- Introduced symmetric quantization with no clipping error in the tensor subclass based API (#845)
- Added int4 weight-only embedding QAT (#947)
- Added support for 1-bit and 6-bit quantization for Llama in torchchat (#910, #1007)
- Added a linear_observer class for doing static activation calibration (#807)
- Exposed hqq through uintx_weight_only API (#786)
- Added RowWise scaling option for Float8 dynamic activation quantization (#819)
- Added Float8 weight only to autoquant api (#866)
Improvements
- Enhanced Auto-Round functionality (#870)
- Improved FSDP support for low-bit optimizers (#538)
- Added support for using AffineQuantizedTensor with
weights_only=True
for torch.load (#630) - Optimized 3-bit packing (#1029)
- Added more evaluation metrics to llama/eval.sh (#934)
- Improved eager numerics for dynamic scales in float8 (#904)
Bug fixes
- Fixed inference_mode issues (#885)
- Fixed failing FP6 benchmark (#931)
- Resolved various issues with float8 support (#918, #923)
- Fixed load state dict when device is different for low-bit optim (#1021)
Performance
- Added SM75 (Turing) support for FP6 kernel (#942)
- Implemented int8 dynamic quant + bsr support (#821)
- Added workaround to recover the perf for quantized vit in torch.compile (#926)
INT8 Mixed-Precision Training
On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision.
from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig
from torchao.quantization import quantize_
model = ...
# apply INT8 matmul to all 3 matmuls
quantize_(model, int8_mixed_precision_training())
# customize which matmul is left in original precision.
config = Int8MixedPrecisionTrainingConfig(
output=True,
grad_input=True,
grad_weight=False,
)
quantize_(model, int8_mixed_precision_training(config))
End2end speed benchmark using benchmarks/quantized_training/pretrain_llama2.py
Model & GPU | bs x seq_len | Config | Tok/s | Peak mem (GB) |
---|---|---|---|---|
Llama2-7B, A100 | 8 x 2048 | BF16 (baseline) | ~4400 | 59.69 |
Llama2-7B, A100 | 8 x 2048 | INT8 mixed-precision | ~6100 (+39%) | 58.28 |
Llama2-1B, 4090 | 16 x 2048 | BF16 (baseline) | ~17,900 | 18.23 |
Llama2-1B, 4090 | 16 x 2048 | INT8 mixed-precision | ~30,700 (+72%) | 18.34 |
Docs
- Updated README with more current float8 speedup information (#816)
- Added tutorial for trainable tensor subclass (#908)
- Improved documentation for float8 unification and inference (#895, #896)
Devs
- Added compile tests to test suite (#906)
- Improved CI setup and build processes (#887)
- Added M1 wheel support (#822)
- Added more benchmarking and profiling tools (#1017)
- Renamed
fpx
tofloatx
(#877) - Removed torchao_nightly package (#661)
- Added more lint fixes (#827)
- Added better subclass testing support (#839)
- Added CI to catch syntax errors (#861)
- Added tutorial on composing quantized subclass w/ Dtensor based TP (#785)
Security
No significant security updates in this release.
Untopiced
- Added basic SAM2 AutomaticMaskGeneration example server (#1039)
New Contributors
New Contributors
- @iseeyuan made their first contribution in #805
- @YihengBrianWu made their first contribution in #860
- @kshitij12345 made their first contribution in #863
- @ZainRizvi made their first contribution in #887
- @alexsamardzic made their first contribution in #899
- @vaishnavi17 made their first contribution in #911
- @tobiasvanderwerff made their first contribution in #931
- @kwen2501 made their first contribution in #937
- @y-sq made their first contribution in #912
- @jimexist made their first contribution in #969
- @danielpatrickhug made their first contribution in #914
- @ramreddymounica made their first contribution in #1007
- @yushangdi made their first contribution in #1006
- @ringohoffman made their first contribution in #1023
Full Changelog: v0.5.0...v0.6.1