Skip to content

Commit

Permalink
Add more docs for int4_weight_only API that targets tinygemm
Browse files Browse the repository at this point in the history
Summary:
att, per request in pytorch#415 (comment)

Test Plan:
doc changes

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jul 2, 2024
1 parent 5d22ad2 commit c1dc6e3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ Affine quantization refers to the type of quantization that maps from floating p
### Quantization Primitives
We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass.

Note: these primitive ops supports two "types" of quantization, distinguished by whether `zero_point` is in floating point domain or integer domain. See docstrings for `choose_qparams` for more details.

### Quantized Tensor Subclass
We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel)

#### Layouts
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.

### Quantization Flow Example
Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul
as an example:
Expand Down
8 changes: 8 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,14 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
"tensor_core_tiled" layout for speedup with tinygemm kernel
Note:
This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference
of quantization algorithm compared to the more traditional type of integer quantization is the following:
1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`)
2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`)
please follow the relevant code in `choose_qparams_affine`, `quantize_affine` and `dequantize_affine`
to learn about how the quantization parameters are chosen and how the Tensor is quantized/dequantized for tinygemm
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def dequantize_affine(
dequant *= scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
# TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this)
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
Expand Down

0 comments on commit c1dc6e3

Please sign in to comment.