diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 66a5e51cff..7c55308f15 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -419,7 +419,7 @@ def int8_dynamic_activation_int4_weight(group_size=32): return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size) -def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)): +def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -453,7 +453,7 @@ def apply_int4_weight_only_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant)