Skip to content

Commit

Permalink
Add AffineQuantizedTensor based workflow doc and examples (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored May 29, 2024
1 parent 6dd63b8 commit cbc74ee
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,117 @@ model = torch.compile(model, mode='max-autotune')
model(input)
```

## Affine Quantization
Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data.

### Quantization Primitives
We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass.

### Quantized Tensor Subclass
We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel)

### Quantization Flow
What we need to do afterwards is roughly the following

```
from torchao.dtypes.aqt import to_aq
def apply_int8wo_quant(weight):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
# optional filtering for module name, shape etc.
m.weight = nn.Parameter(apply_int8wo_quant(m.weight))
# note: quantization for activation need to be applied after the weight quantization
# quantization activation (needed by dynamic quantization)
# input_quant_func = apply_int8wo_quant # specify how input activation is quantized
# m.weight = nn.Parameter(to_laq(m.weight, input_quant_func))
```
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
`torch.export.export` and `torch.aot_compile` with the following workaround:
```
from torchao.quantization.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)
# export
m = torch.export.export(m_unwrapped, example_inputs).module()
# aot_compile
torch._export.aot_compile(m_unwrapped, example_inputs)
```

But we expect this will be integrated into the export path by default in the future.


### Example
Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul
as an example:
```python
import torch
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.dtypes import to_aq
from torch._inductor.runtime.runtime_utils import do_bench_gpu
import copy
from torchao.quantization.quant_api import (
quantize,
get_apply_int4wo_quant,
)

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

dtype = torch.bfloat16
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")

m_bf16 = torch.compile(m_bf16, mode='max-autotune')
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
groupsize = 32
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# temporary workaround for tensor subclass + torch.compile
from torchao.quantization.utils import unwrap_tensor_subclass
m = unwrap_tensor_subclass(m)
# compile the model to improve performance
m = torch.compile(m, mode='max-autotune')

# benchmark to see the speedup
from torchao.utils import benchmark_model

num_runs = 100
torch._dynamo.reset()
bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0])
print(f"bf16 mean time: {bf16_time}")
int4_time = benchmark_model(m, num_runs, example_inputs[0])
print(f"int4 weight only quantized mean time: {int4_time}")
print(f"speedup: {bf16_time / int4_time}")

# output (1xA100 GPU machine)
bf16 mean time: 71.457685546875
int4 weight only quantized mean time: 31.4580908203125
speedup: 2.2715200981216173
```

## Notes

Expand Down

0 comments on commit cbc74ee

Please sign in to comment.