Skip to content

Commit

Permalink
Fix uninitialized variable in quantized compressors
Browse files Browse the repository at this point in the history
Both compressors have a can_quantize() check, which if ever doesn't
succeed would trigger:

> UnboundLocalError: cannot access local variable 'quantized_weight' where it is not associated with a value

Add the obvious fix for this and highly artificial test cases
that would trigger it.
  • Loading branch information
markmc committed Nov 13, 2024
1 parent ff121cc commit ce81191
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ def compress_weight(
args=quantization_args,
dtype=quantization_args.pytorch_dtype(),
)
else:
quantized_weight = weight

if device is not None:
quantized_weight = quantized_weight.to(device)
if device is not None:
quantized_weight = quantized_weight.to(device)

return {"weight": quantized_weight}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def compress_weight(
args=quantization_args,
dtype=torch.int8,
)
else:
quantized_weight = weight

packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
weight_shape = torch.tensor(weight.shape)
Expand Down
34 changes: 28 additions & 6 deletions tests/test_compressors/quantized_compressors/test_int_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,54 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):


@pytest.mark.parametrize(
"strategy,group_size,sc,zp",
"strategy,group_size,sc,zp,int8_weights",
[
[QuantizationStrategy.TENSOR, None, 0.01, 0],
[QuantizationStrategy.TENSOR, None, 0.01, 0, False],
[QuantizationStrategy.TENSOR, None, 1, 0, True],
[
QuantizationStrategy.GROUP,
128,
torch.rand((300, 8)) * 0.01,
torch.zeros((300, 8), dtype=torch.int8),
False,
],
[
QuantizationStrategy.CHANNEL,
None,
torch.rand((300, 1)) * 0.01,
torch.zeros((300, 1), dtype=torch.int8),
False,
],
],
)
def test_reload_match(strategy, group_size, sc, zp, tmp_path):
def test_reload_match(strategy, group_size, sc, zp, int8_weights, tmp_path):
dense_state_dict = {
"dummy.weight": torch.rand((300, 1024)),
"dummy.weight_scale": torch.tensor(sc, dtype=torch.float32),
"dummy.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
"dummy2.weight": torch.rand((300, 1024)),
"dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32),
"dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
}
if not int8_weights:
dense_state_dict.update(
{
"dummy2.weight": torch.rand((300, 1024)),
"dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32),
"dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
}
)
else:
dense_state_dict.update(
{
"dummy2.weight": torch.randint(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max,
(511, 350),
dtype=torch.int8,
),
"dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32),
"dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
}
)

quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)

compressor = IntQuantizationCompressor(config=quant_config)
Expand Down
28 changes: 23 additions & 5 deletions tests/test_compressors/quantized_compressors/test_pack_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,34 @@ def test_repack_8bit(value):
assert torch.equal(value, unpacked)


@pytest.mark.parametrize("num_bits", [4, 8])
def test_reload_match(tmp_path, num_bits):
@pytest.mark.parametrize("num_bits,int8_weights", [(4, False), (8, False), (8, True)])
def test_reload_match(tmp_path, num_bits, int8_weights):
dense_state_dict = {
"dummy.weight": torch.rand((511, 350)),
"dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32),
"dummy.weight_zero_point": torch.tensor(0, dtype=torch.int8),
"dummy2.weight": torch.rand((128, 280)),
"dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32),
"dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8),
}
if not int8_weights:
dense_state_dict.update(
{
"dummy2.weight": torch.rand((128, 280)),
"dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32),
"dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8),
}
)
else:
dense_state_dict.update(
{
"dummy2.weight": torch.randint(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max,
(511, 350),
dtype=torch.int8,
),
"dummy2.weight_scale": torch.tensor(1, dtype=torch.float32),
"dummy2.weight_zero_point": torch.tensor(0, dtype=torch.int8),
}
)

names_to_scheme = {
"dummy": QuantizationArgs(num_bits=num_bits),
Expand Down

0 comments on commit ce81191

Please sign in to comment.