From 2ce9bc473e34478704ca65317047b18a6455e934 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 2 Jul 2024 11:25:53 -0700 Subject: [PATCH] Renaming `quantize` to `quantize_` Summary: Addressing feedback for `quantize` API from https://github.com/pytorch/ao/issues/391#issuecomment-2174713094 this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 18 +++++++++--------- test/prototype/test_quant_llm.py | 4 ++-- test/quantization/test_quant_api.py | 18 +++++++++--------- torchao/__init__.py | 4 ++-- torchao/quantization/README.md | 10 +++++----- torchao/quantization/__init__.py | 2 +- torchao/quantization/quant_api.py | 14 +++++++------- tutorials/quantize_vit/run_vit_b_quant.py | 4 ++-- 8 files changed, 37 insertions(+), 37 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4d5a2c511c..c21f3a38be 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -23,7 +23,7 @@ int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, - quantize, + quantize_, _replace_with_custom_fn_if_matches_filter, ) # APIs to be deprecated (used for torch 2.2.2 and 2.3) @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_weight_only(), set_inductor_config=False) + quantize_(mod, int8_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4_weight_only(), set_inductor_config=False) + quantize_(mod, int4_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) @@ -127,8 +127,8 @@ def _int4wo_api(mod): def undo_recommended_configs(): torch._inductor.config.coordinate_descent_tuning = False torch._inductor.config.coordinate_descent_check_all_directions = False - torch._inductor.config.force_fuse_int_mm_with_mul = False - torch._inductor.config.fx_graph_cache = False + torch._inductor.config.force_fuse_int_mm_with_mul = False + torch._inductor.config.fx_graph_cache = False torch._inductor.config.triton.unique_kernel_names = False torch.set_float32_matmul_precision("highest") @@ -844,7 +844,7 @@ def api(mod): kwargs_copy = kwargs.copy() kwargs_copy["group_size"] = groupsize del kwargs_copy["groupsize"] - quantize(mod, int4_weight_only(**kwargs_copy)) + quantize_(mod, int4_weight_only(**kwargs_copy)) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod, **kwargs) @@ -865,7 +865,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ diff --git a/test/prototype/test_quant_llm.py b/test/prototype/test_quant_llm.py index 77eac6f69d..fab2d972b1 100644 --- a/test/prototype/test_quant_llm.py +++ b/test/prototype/test_quant_llm.py @@ -16,7 +16,7 @@ ) from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias): linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fpx_linear = copy.deepcopy(linear) - quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fpx_linear(x) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e8b9d606d7..b137cd22dc 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -31,7 +31,7 @@ Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) -from torchao import quantize +from torchao import quantize_ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, Quantizer, @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) return model class ToyLinearModel(torch.nn.Module): @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self): ) m = ToyLinearModel().eval().cpu() def api(model): - model = quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) api(m) @@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self): m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") group_size = 32 - m = quantize(m, int4_weight_only(group_size=group_size)) + quantize_(m, int4_weight_only(group_size=group_size)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) diff --git a/torchao/__init__.py b/torchao/__init__.py index 3b5a1b3c0f..104dc5f311 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -30,14 +30,14 @@ from torchao.quantization import ( autoquant, - quantize, + quantize_, ) from . import dtypes __all__ = [ "dtypes", "autoquant", - "quantize", + "quantize_", ] # test-pytorchbot diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 76e7cd9ff2..4765d6a5fc 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -74,7 +74,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.dtypes import to_affine_quantized import copy from torchao.quantization.quant_api import ( - quantize, + quantize_, int4_weight_only, ) @@ -101,7 +101,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -m = quantize(m, int4_weight_only(group_size=group_size)) +quantize_(m, int4_weight_only(group_size=group_size)) # temporary workaround for tensor subclass + torch.compile from torchao.utils import unwrap_tensor_subclass @@ -168,7 +168,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True # for torch 2.4+ from torchao.quantization import quantize, int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -180,7 +180,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize, int8_weight_only -quantize(model, int8_weight_only()) +quantize_(model, int8_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -195,7 +195,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ from torchao.quantization import quantize, int4_weight_only -quantize(model, int4_weight_only()) +quantize_(model, int4_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 115062c8f6..a1cf1bf034 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -29,7 +29,7 @@ "quantize_affine", "dequantize_affine", "choose_qprams_affine", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 31ab71f385..3da530b940 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,7 +54,7 @@ "Int4WeightOnlyQuantizer", "autoquant", "_get_subclass_inserter", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", @@ -259,8 +259,8 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` +def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): + """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: model (torch.nn.Module): input model @@ -273,7 +273,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens import torch import torch.nn as nn - from torchao import quantize + from torchao import quantize_ # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) @@ -286,7 +286,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, int4_weight_only(group_size=32)) + quantize_(m, int4_weight_only(group_size=32)) # 2. write your own new apply_tensor_subclass # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor @@ -305,7 +305,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, apply_weight_quant, filter_fn) + quantize_(m, apply_weight_quant, filter_fn) """ if set_inductor_config: @@ -315,7 +315,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: _get_linear_subclass_inserter(apply_tensor_subclass), _is_linear if filter_fn is None else filter_fn, ) - return model + def int8_dynamic_activation_int4_weight(group_size=32): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 07e0118d20..a082cfe53a 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -19,9 +19,9 @@ # for APIs for earlier torch version and other quantization techniques # for torch 2.4+ -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) ## Quantization code - end ## compilation configs