Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Plans for LLM QAT #86

Open
andrewor14 opened this issue Mar 26, 2024 · 0 comments
Open

[RFC] Plans for LLM QAT #86

andrewor14 opened this issue Mar 26, 2024 · 0 comments
Assignees
Labels

Comments

@andrewor14
Copy link
Contributor

Following the recent success of the LLM-QAT paper, our high-level goal is to provide a PyTorch native workflow for LLM quantization-aware training (QAT) leveraging the quantization primitives and kernels provided by torchao, which is planned to become the de facto OSS library for AO techniques and kernels in PyTorch across different platforms (#47). We also hope to eventually integrate with TorchTune, a recently open-sourced library for fine-tuning and experimenting with LLMs, to provide an end-to-end flow that supports both finetuning and QAT.

Workstream 1: Edge Devices

Executorch provides a mechanism for quantizing Llama2 using post-training quantization (PTQ) techniques such as GPTQ, and lowering it to backends like XNNPACK. The main goal of this workstream is to provide a QAT drop-in replacement for GPTQ but with superior accuracy, starting with Llama2 7b using the following quantization/training configurations:

  • Linear weights: 4-bit per channel grouped symmetric static quantization
  • Linear activations: 8-bit per token symmetric dynamic quantization
  • Not “data-free”: Use original dataset unlike in the LLM-QAT paper

We plan to adopt the same eager mode quantization implementation used by the PTQ flow. In the future, if we decide to experiment with static quantization for activations for example, then we can explore using the PT2 Export QAT flow.

Workstream 2: Explore new quantization methods

This workstream is largely backend agnostic; our goal is to motivate the backends (mobile or server CPU/GPU) to build the relevant kernels once we have demonstrated the initial success of a particular quantization configuration. There is a large design space we can experiment with summarized below. The suggested quantization and training techniques are primarily motivated by the LLM-QAT paper, but also by ongoing developments across the industry.

We can start with the following dimensions:

  • KV-cache quantization: 4- or 8-bit KV-cache quantization can alleviate throughput bottlenecks with long sequences, and this has been shown (in the LLM-QAT paper) to hurt QAT a lot less than PTQ in terms of accuracy.
  • Custom dtypes: The latest Hopper and upcoming Blackwell GPU generations no longer support int4 tensor cores, and so int4 kernels may not be as performant as other 4-bit dtypes. For example, both nf4 and MX4 promise higher fidelity than any a priori fixed quantization like int4. Experimenting with newer dtypes in QAT may lead to further accuracy gains.
  • Lower bit-widths: 2- or 3-bit weight quantization can help further lower memory footprint and speed up inference. There have been PTQ attempts at such bit-widths (e.g. Quip#, AQLM), but QAT has the potential to further mitigate the accuracy degradation.

Workstream 3: Server GPU Inference

This is an extension of the recent gpt-fast efforts to quantize Llama but for QAT. An important goal here is to reuse the same quantization primitives as Workstream 1 to unify the two flows as much as possible. We can start with the following quantization configurations:

  • Int4 weight-only quantization. This was the focus last half for Llama2, which targeted batch size 1 local chat agent use cases. This particular workload is memory bound, not compute bound, when run on GPUs, and so quantizing the activations here may not be particularly beneficial. For QAT, we can perform the same weight-only quantization for better accuracy.
  • Int4 weight quantization + int8 activation dynamic quantization, similar to Workstream 1. One advantage here is we will have numerical baselines from the ExecuTorch workstream to compare against. However, as explained above, it may not make sense for Llama2 batch size 1 use cases for GPUs, so this configuration may be more suitable for larger batch sizes or other more compute bound models. The plan here is to be able to leverage ongoing efforts to provide mixed 4-bit / 8-bit GEMM in cutlass: Add support for mixed 4-bit/8-bit data types GEMM NVIDIA/cutlass#1413.
  • MX4 weight + activation quantization. Please see the previous section under Custom dtypes for more details.
@andrewor14 andrewor14 self-assigned this Mar 26, 2024
andrewor14 added a commit that referenced this issue Apr 15, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 16, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 18, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
andrewor14 added a commit that referenced this issue Apr 18, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
@msaroufim msaroufim added the rfc label May 7, 2024
dbyoung18 pushed a commit to dbyoung18/ao that referenced this issue Jul 31, 2024
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: pytorch#86
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants