Skip to content

Commit

Permalink
Renaming quantize to quantize_
Browse files Browse the repository at this point in the history
Summary:
Addressing feedback for `quantize` API from pytorch#391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jul 2, 2024
1 parent 5d22ad2 commit 934fa21
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 35 deletions.
18 changes: 9 additions & 9 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
Expand Down Expand Up @@ -98,21 +98,21 @@

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_weight_only(), set_inductor_config=False)
quantize_(mod, int8_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4_weight_only(), set_inductor_config=False)
quantize_(mod, int4_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)
Expand All @@ -127,8 +127,8 @@ def _int4wo_api(mod):
def undo_recommended_configs():
torch._inductor.config.coordinate_descent_tuning = False
torch._inductor.config.coordinate_descent_check_all_directions = False
torch._inductor.config.force_fuse_int_mm_with_mul = False
torch._inductor.config.fx_graph_cache = False
torch._inductor.config.force_fuse_int_mm_with_mul = False
torch._inductor.config.fx_graph_cache = False
torch._inductor.config.triton.unique_kernel_names = False
torch.set_float32_matmul_precision("highest")

Expand Down Expand Up @@ -844,7 +844,7 @@ def api(mod):
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize(mod, int4_weight_only(**kwargs_copy))
quantize_(mod, int4_weight_only(**kwargs_copy))
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
Expand All @@ -865,7 +865,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
quantize(m, int8_dynamic_activation_int8_weight())
quantize_(m, int8_dynamic_activation_int8_weight())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand Down Expand Up @@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype):
out3 = mod(example_input)
sqnr2 = SQNR(out, out3)
self.assertTrue(sqnr2 >= 30)


@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
Expand Down
18 changes: 9 additions & 9 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao import quantize
from torchao import quantize_
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Quantizer,
Expand Down Expand Up @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:

class TorchCompileDynamicQuantizer(Quantizer):
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())
return model

class ToyLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantize(m, int8_dynamic_activation_int8_weight())
quantize_(m, int8_dynamic_activation_int8_weight())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
)
m = ToyLinearModel().eval().cpu()
def api(model):
model = quantize(model, int8_weight_only())
quantize_(model, int8_weight_only())
unwrap_tensor_subclass(model)

api(m)
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self):
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down Expand Up @@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self):
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))
quantize_(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

m = quantize(m, int8_weight_only())
quantize_(m, int8_weight_only())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
m = quantize(m, int8_dynamic_activation_int8_weight())
quantize_(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down Expand Up @@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self):
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, int8_weight_only())
quantize_(m, int8_weight_only())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
Expand Down
4 changes: 2 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@

from torchao.quantization import (
autoquant,
quantize,
quantize_,
)
from . import dtypes

__all__ = [
"dtypes",
"autoquant",
"quantize",
"quantize_",
]

# test-pytorchbot
10 changes: 5 additions & 5 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.dtypes import to_affine_quantized
import copy
from torchao.quantization.quant_api import (
quantize,
quantize_,
int4_weight_only,
)

Expand All @@ -101,7 +101,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
group_size = 32
# only works for torch 2.4+
m = quantize(m, int4_weight_only(group_size=group_size))
quantize_(m, int4_weight_only(group_size=group_size))

# temporary workaround for tensor subclass + torch.compile
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -168,7 +168,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True

# for torch 2.4+
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand All @@ -180,7 +180,7 @@ change_linear_weights_to_int8_dqtensors(model)
```python
# for torch 2.4+
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())
quantize_(model, int8_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
Expand All @@ -195,7 +195,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
```python
# for torch 2.4+
from torchao.quantization import quantize, int4_weight_only
quantize(model, int4_weight_only())
quantize_(model, int4_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"quantize_affine",
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int4_weight_only",
Expand Down
14 changes: 7 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"Int4WeightOnlyQuantizer",
"autoquant",
"_get_subclass_inserter",
"quantize",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int4_weight_only",
Expand Down Expand Up @@ -259,8 +259,8 @@ def insert_subclass(lin):

return insert_subclass

def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
Args:
model (torch.nn.Module): input model
Expand All @@ -273,7 +273,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
import torch
import torch.nn as nn
from torchao import quantize
from torchao import quantize_
# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
Expand All @@ -286,7 +286,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
from torchao.quantization.quant_api import int4_weight_only
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
m = quantize(m, int4_weight_only(group_size=32))
quantize_(m, int4_weight_only(group_size=32))
# 2. write your own new apply_tensor_subclass
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
Expand All @@ -305,7 +305,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
m = quantize(m, apply_weight_quant, filter_fn)
quantize_(m, apply_weight_quant, filter_fn)
"""
if set_inductor_config:
Expand All @@ -315,7 +315,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
_get_linear_subclass_inserter(apply_tensor_subclass),
_is_linear if filter_fn is None else filter_fn,
)
return model


def int8_dynamic_activation_int4_weight(group_size=32):
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
Expand Down
4 changes: 2 additions & 2 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
# for APIs for earlier torch version and other quantization techniques

# for torch 2.4+
from torchao.quantization.quant_api import quantize
from torchao.quantization.quant_api import quantize_
from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())
## Quantization code - end

## compilation configs
Expand Down

0 comments on commit 934fa21

Please sign in to comment.