Skip to content

Commit

Permalink
Update torchao api reference and add contributor guide (#1255)
Browse files Browse the repository at this point in the history
* Update torchao api reference and add contributor guide

Summary:
1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py
and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py
2. added #391 to torchao docs

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* format

* typo

* renaming

* comma

* format

* comments
  • Loading branch information
jerryzh168 authored Nov 13, 2024
1 parent 333bde6 commit 39f16f4
Show file tree
Hide file tree
Showing 16 changed files with 731 additions and 57 deletions.
4 changes: 3 additions & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ torchao.dtypes

to_nf4
to_affine_quantized_intx
to_affine_quantized_floatx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
AffineQuantizedTensor

..
Expand Down
9 changes: 3 additions & 6 deletions docs/source/api_ref_intro.rst
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
``torchao`` API Reference
=========================

This section introduces the torchao API reference.
Dive into the details of how torchao integrates with PyTorch to
optimize your machine learning models.
This section introduces the torchao API reference. Dive into the details of how torchao integrates with PyTorch to optimize your machine learning models.

.. toctree::
:glob:
:maxdepth: 1
:caption: Python API Reference

api_ref_sparsity
api_ref_quantization
api_ref_dtypes
api_ref_kernel
api_ref_quantization
api_ref_sparsity
38 changes: 31 additions & 7 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,39 @@ torchao.quantization
.. autosummary::
:toctree: generated/
:nosignatures:

SmoothFakeDynQuantMixin
SmoothFakeDynamicallyQuantizedLinear
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference
Int4WeightOnlyGPTQQuantizer
Int4WeightOnlyQuantizer
autoquant

quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
int4_weight_only
int8_weight_only
float8_weight_only
float8_dynamic_activation_float8_weight
float8_static_activation_float8_weight
uintx_weight_only
fpx_weight_only

to_linear_activation_quantized

swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
quantize_affine
quantize_affine_floatx
dequantize_affine
dequantize_affine_floatx
choose_qparams_and_quantize_affine_hqq
fake_quantize_affine
fake_quantize_affine_cachemask

safe_int_mm
int_scaled_matmul

MappingType
ZeroPointDomain
TorchAODType

604 changes: 604 additions & 0 deletions docs/source/contributor_guide.rst

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
Welcome to the torchao Documentation
=======================================

**torchao** is an open-source library that provides the functionality
to quantize and prune your models using native PyTorch. Our documentation is under development
with more content coming soon.
`**torchao** <https://github.com/pytorch/ao>`__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials.

..
.. grid:: 3
Expand Down Expand Up @@ -81,13 +79,19 @@ with more content coming soon.
:maxdepth: 1
:caption: API Reference

api_ref_sparsity
api_ref_intro
api_ref_quantization
api_ref_dtypes
api_ref_quantization
api_ref_sparsity
..
api_ref_kernel
.. toctree::
:glob:
:maxdepth: 1
:caption: Contributor Guide

contributor_guide

.. toctree::
:glob:
:maxdepth: 1
Expand Down
4 changes: 3 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
)
from torchao.quantization.quant_primitives import (
from torchao.quantization import (
safe_int_mm,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
dequantize_affine,
Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
addmm_float8_unwrapped_inference,
preprocess_data,
)
from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.quant_primitives import (
FP8_TYPES,
MappingType,
Expand All @@ -31,7 +34,6 @@
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
int_scaled_matmul,
quantize_affine,
quantize_affine_floatx,
)
Expand Down
7 changes: 7 additions & 0 deletions torchao/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm

__all__ = [
"safe_int_mm",
"int_scaled_matmul",
]
80 changes: 55 additions & 25 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from torchao.kernel import (
int_scaled_matmul,
safe_int_mm,
)

from .autoquant import (
DEFAULT_AUTOQUANT_CLASS_LIST,
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
Expand Down Expand Up @@ -51,10 +56,18 @@
)
from .quant_primitives import (
MappingType,
TorchAODType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_floatx,
choose_qparams_affine_with_min_max,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
quantize_affine_floatx,
)
from .smoothquant import (
SmoothFakeDynamicallyQuantizedLinear,
Expand All @@ -72,50 +85,67 @@
from .weight_only import WeightOnlyInt8QuantLinear

__all__ = [
"swap_conv2d_1x1_to_linear",
# top level API - auto
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"compute_error",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize_affine",
"dequantize_affine",
"choose_qparams_affine",
# top level API - manual
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"uintx_weight_only",
"fpx_weight_only",
"LinearActivationQuantizedTensor",
# smooth quant - subject to change
"swap_conv2d_1x1_to_linear",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"compute_error",
# building blocks
"to_linear_activation_quantized",
"to_weight_tensor_with_linear_activation_scale_metadata",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"Int8DynActInt4WeightGPTQQuantizer",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightLinear",
"WeightOnlyInt8QuantLinear",
"TwoStepQuantizer",
"Quantizer",
"ZeroPointDomain",
"MappingType",
"AffineQuantizedMinMaxObserver",
"AffineQuantizedObserverBase",
# quant primitive ops
"choose_qparams_affine",
"choose_qparams_affine_with_min_max",
"choose_qparams_affine_floatx",
"quantize_affine",
"quantize_affine_floatx",
"dequantize_affine",
"dequantize_affine_floatx",
"choose_qparams_and_quantize_affine_hqq",
"fake_quantize_affine",
"fake_quantize_affine_cachemask",
# operators/kernels
"safe_int_mm",
"int_scaled_matmul",
# dataclasses and types
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"PerTensor",
"PerAxis",
"PerGroup",
"PerRow",
"PerToken",
"LinearActivationQuantizedTensor",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"Int8DynActInt4WeightGPTQQuantizer",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightLinear",
"WeightOnlyInt8QuantLinear",
"TwoStepQuantizer",
"Quantizer",
]
2 changes: 1 addition & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TensorCoreTiledLayout,
)
from torchao.float8.inference import Float8MMConfig
from torchao.kernel import safe_int_mm
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
Expand All @@ -24,7 +25,6 @@
PerRow,
PerTensor,
)
from .quant_primitives import safe_int_mm
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
"""
Applies activation quantization for linear operator, this is used to support
dynamic quantization or static quantization, user can pass in a `input_quant_func`
dynamic quantization, user can pass in a `input_quant_func`
that is used to quantize the activation
Args:
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
self.quant_kwargs = quant_kwargs

def __repr__(self):
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_quant_func, self.quant_kwargs]
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from torchao.dtypes.uintx.uintx import UintxLayout
from torchao.float8.inference import Float8MMConfig
from torchao.quantization.linear_activation_weight_observer import (
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
Expand Down
6 changes: 3 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm
from torchao.prototype.custom_fp_utils import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
Expand All @@ -24,8 +23,6 @@
)

__all__ = [
"safe_int_mm",
"int_scaled_matmul",
"choose_qparams_affine",
"choose_qparams_affine_with_min_max",
"choose_qparams_affine_floatx",
Expand All @@ -36,6 +33,9 @@
"fake_quantize_affine",
"fake_quantize_affine_cachemask",
"choose_qparams_and_quantize_affine_hqq",
"MappingType",
"ZeroPointDomain",
"TorchAODType",
]


Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import torch
from torch.utils._python_dispatch import TorchDispatchMode

from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
dequantize_affine,
int_scaled_matmul,
quantize_affine,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.quant_kwargs = quant_kwargs

def __repr__(self):
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"

def __tensor_flatten__(self):
tensor_data = ["original_weight_tensor", "scale"]
Expand Down

0 comments on commit 39f16f4

Please sign in to comment.