diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index ce9f0a57..3a512915 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -136,8 +136,20 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor: """ Packs a tensor of quantized weights stored in int8 into int32s with padding + Pseudocode: + 1. Shift wrt num_bits to convert to unsigned. num_bits=8 + [1,2] -> [129, 130] + 2. Pad to fill in 32 bits + [129, 130] -> [129, 130, 0, 0] + 3. convert to binary align in order + [129, 130, 0, 0] -> 00000000 00000000 10000010 10000001 + 4. convert aligned binary to number + 00000000000000001000001010000001 -> 33409 + 5. covert back to uint32 + 33409 -> 33409 + :param value: tensor to pack - :param num_bits: number of bits used to store underlying data + :param num_bits: number of bits used to store underlying data, must be at least 1 :returns: packed int32 tensor """ if value.dtype is not torch.int8: @@ -146,19 +158,22 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor: if num_bits > 8: raise ValueError("Packing is only supported for less than 8 bits") + if num_bits < 1: + raise ValueError(f"num_bits must be at least 1, got {num_bits}") + # convert to unsigned for packing - offset = pow(2, num_bits) // 2 + offset = 1 << (num_bits - 1) value = (value + offset).to(torch.uint8) value = value.cpu().numpy().astype(np.uint32) pack_factor = 32 // num_bits # pad input tensor and initialize packed output packed_size = math.ceil(value.shape[1] / pack_factor) - packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32) - padding = packed.shape[1] * pack_factor - value.shape[1] + padding = packed_size * pack_factor - value.shape[1] value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0) # pack values + packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32) for i in range(pack_factor): packed |= value[:, i::pack_factor] << num_bits * i @@ -172,7 +187,9 @@ def unpack_from_int32( ) -> torch.Tensor: """ Unpacks a tensor of packed int32 weights into individual int8s, maintaining the - original their bit range + original bit range. + + Return tensors in int8 :param value: tensor to upack :param num_bits: number of bits to unpack each data point into @@ -190,7 +207,7 @@ def unpack_from_int32( pack_factor = 32 // num_bits # unpack - mask = pow(2, num_bits) - 1 + mask = (1 << num_bits) - 1 unpacked = torch.zeros( (value.shape[0], value.shape[1] * pack_factor), device=value.device, diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index fde57c4b..3b0da609 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -250,3 +250,165 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration): assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"]) shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "num_bits,values,expected_values", + [ + ( + 4, + torch.tensor([[1]]), + torch.tensor([[9]], dtype=torch.int32), + ), + ( + 8, + torch.tensor([[1]]), + torch.tensor([[129]], dtype=torch.int32), + ), + # 0000 0000 0000 0000 1100 1011 1010 1001 + (4, torch.tensor([[1, 2, 3, 4]]), torch.tensor([[52137]], dtype=torch.int32)), + # 0111 0110 0101 0100 0011 0010 0001 0000 + ( + 4, + torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]]), + torch.tensor([[1985229328]], dtype=torch.int32), + ), + # 10000100 10000011 10000010 10000001 + ( + 8, + torch.tensor([[1, 2, 3, 4]]), + torch.tensor([[-2071756159]], dtype=torch.int32), + ), + # 00000011 00000010 00000001 00000000 + ( + 8, + torch.tensor([[-128, -127, -126, -125]]), + torch.tensor([[50462976]], dtype=torch.int32), + ), + ( + 4, + torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]]), + torch.tensor([[1985229328, 52137]], dtype=torch.int32), + ), + ( + 4, + torch.tensor( + [ + [-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8], + [1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1], + ] + ), + torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32), + ), + ( + 8, + torch.tensor( + [ + [1, 2, 3, 4], + [-128, -127, -126, -125], + ] + ), + torch.tensor([[-2071756159], [50462976]], dtype=torch.int32), + ), + ( + 8, + torch.tensor( + [ + [1, 2, 3, 4, -128, -127, -126, -125], + [-128, -127, -126, -125, 1, 2, 3, 4], + ] + ), + torch.tensor( + [[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32 + ), + ), + ], +) +def test_pack_to_int32(num_bits, values, expected_values): + values = values.to(torch.int8) + packed_values = pack_to_int32(values, num_bits) + assert torch.equal(packed_values, expected_values) + assert packed_values.dtype == expected_values.dtype + + +@pytest.mark.parametrize( + "num_bits,values,expected_tensor", + [ + ( + 4, + torch.tensor([[9]], dtype=torch.int32), + torch.tensor([[1]], dtype=torch.int8), + ), + ( + 8, + torch.tensor([[129]], dtype=torch.int32), + torch.tensor([[1]], dtype=torch.int8), + ), + ( + 4, + torch.tensor([[52137]], dtype=torch.int32), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int8), + ), + ( + 4, + torch.tensor([[1985229328]], dtype=torch.int32), + torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]], dtype=torch.int8), + ), + ( + 8, + torch.tensor([[-2071756159]], dtype=torch.int32), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int8), + ), + ( + 8, + torch.tensor([[50462976]], dtype=torch.int32), + torch.tensor([[-128, -127, -126, -125]], dtype=torch.int8), + ), + ( + 4, + torch.tensor([[1985229328, 52137]], dtype=torch.int32), + torch.tensor( + [[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]], dtype=torch.int8 + ), + ), + ( + 4, + torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32), + torch.tensor( + [ + [-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8], + [1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1], + ], + dtype=torch.int8, + ), + ), + ( + 8, + torch.tensor([[-2071756159], [50462976]], dtype=torch.int32), + torch.tensor( + [ + [1, 2, 3, 4], + [-128, -127, -126, -125], + ], + dtype=torch.int8, + ), + ), + ( + 8, + torch.tensor( + [[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32 + ), + torch.tensor( + [ + [1, 2, 3, 4, -128, -127, -126, -125], + [-128, -127, -126, -125, 1, 2, 3, 4], + ], + dtype=torch.int8, + ), + ), + ], +) +def test_unpack_from_int32(num_bits, values, expected_tensor): + unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape) + assert torch.equal(unpacked_tensor, unpacked_tensor) + assert unpacked_tensor.dtype == unpacked_tensor.dtype