Skip to content

Commit

Permalink
Expose hqq through int4_weight_only API
Browse files Browse the repository at this point in the history
Summary:
att, this is a follow up for pytorch#605 to make hqq available in quantize_ API

`quantize_(model, int4_weight_only(group_size, use_hqq=True)`

Test Plan:

python generate.py --compile --quantization int4wo-hqq-64 --precision bfloat16
Average tokens/sec: 195.24
Average Bandwidth: 729.40 GB/s
Peak Memory Usage: 5.09 GB
Model Size: 3.74 GB

python eval.py --compile --quantization int4wo-hqq-64 --precision bfloat16

wikitext: {'word_perplexity,none': 12.823631773497512, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.611400903914048, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6883154699192412, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Sep 4, 2024
1 parent ba2d3b1 commit fa78b5b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
5 changes: 3 additions & 2 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def run_evaluation(
if "int4wo" in quantization and not "gptq" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model.to(device), int4_weight_only(group_size=groupsize))
use_hqq = "hqq" in quantization
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "int4wo" in quantization and "gptq" in quantization:
Expand Down Expand Up @@ -120,7 +121,7 @@ def run_evaluation(
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq")
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-hqq-<groupsize>")
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand Down
7 changes: 4 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def generate(
# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)

seq = torch.cat((seq[:T+1], *generated_tokens))

return seq
Expand Down Expand Up @@ -218,7 +218,8 @@ def main(
if "int4wo" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
use_hqq = "hqq" in quantization
quantize_(model, int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "autoquant" == quantization:
Expand Down Expand Up @@ -387,7 +388,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-hqq-<groupsize>, autoquant')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand Down
7 changes: 6 additions & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,12 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
```python
# for torch 2.4+
from torchao.quantization import quantize_, int4_weight_only
quantize_(model, int4_weight_only())
group_size = 32

# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
# use_hqq flag for `int4_weight_only` quantization
use_hqq = False
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
30 changes: 27 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner
size is more fine grained, choices are [256, 128, 64, 32]
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
"""
def apply_int4_weight_only_quant(weight, use_hqq=False):
def apply_int4_weight_only_quant(weight):
if weight.shape[-1] % group_size != 0:
return weight

Expand Down Expand Up @@ -498,7 +498,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
"""
Applies float8 weight-only symmetric per-channel quantization to linear layers.
Args:
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
Expand Down Expand Up @@ -617,9 +617,33 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
return weight

layout_type = FpxTensorCoreLayoutType(ebits, mbits)
print("layout type:", layout_type)
return to_affine_quantized_fpx(weight, layout_type)
return _get_linear_subclass_inserter(apply_quant_llm)

def hqq_weight_only(dtype, group_size):

def apply_hqq(weight: torch.Tensor) -> torch.Tensor:
if dtype != torch.uint4:
layout_type = UintxLayoutType(dtype=dtype, pack_dim=-1)
return to_affine_quantized_intx(
input_float=weight,
mapping_type=MappingType.ASYMMETRIC,
block_size=(1, group_size),
target_dtype=dtype,
zero_point_domain=ZeroPointDomain.FLOAT,
layout_type=layout_type,
preserve_zero=False,
use_hqq=True,
)
else:
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain=ZeroPointDomain.FLOAT, layout_type=layout_type, use_hqq=True)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])

0 comments on commit fa78b5b

Please sign in to comment.