Skip to content

Commit

Permalink
[FEAT] Add custom CUDA tinygemm unpacker (#415)
Browse files Browse the repository at this point in the history
* add unpack cuda

* add tests

* fix tests

* refactor tinygemm unpacking kernel

* add dequant

* add additional dequant check

* update tinygemm dequantize test

* correct dequant kernel logic

* clean up kernel

* update dequantize kernel tests

* rename kernel ops to tensor_core_tiled_layout

* add renamed kernel source

* add back test_aot_dispatch opcheck

* rename innerKTiles to inner_k_tiles

* add unpack and dequant test

* additional numerical checks for unpack then dequant

* rebase test_ops on main

* remove commented out code

* skip dynamic opcheck unless torch>=2.5
  • Loading branch information
jeromeku authored Jul 4, 2024
1 parent 6fa2d96 commit 74846da
Show file tree
Hide file tree
Showing 4 changed files with 652 additions and 3 deletions.
227 changes: 224 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import itertools

import torchao

import torch
from torch.testing._internal.common_utils import (
TestCase,
Expand All @@ -6,7 +10,7 @@
run_tests,
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode
from torchao.utils import is_fbcode, TORCH_VERSION_AFTER_2_5
from torchao.prototype.quant_llm import from_scaled_tc_fpx
import pytest

Expand All @@ -18,6 +22,14 @@
except RuntimeError:
pytest.skip("torchao.ops not available")

from torchao.quantization.utils import (
get_groupwise_affine_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
unpack_tinygemm_scales_and_zeros,
)


class TestOps(TestCase):
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
Expand Down Expand Up @@ -61,9 +73,218 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
relative_error = error / gt
assert relative_error < 1e-3


instantiate_parametrized_tests(TestOps)


## Tests for `tensor_core_layout`
kTileSizeN = 8
kTileSizeK = 16

SHAPES = [
(4096, 4096),
# Llama 2 GEMM shapes
(4096, 11008),
(11008, 4096),
# Llama 3 GEMM shapes
(4096, 14336),
(14336, 4096),
]
INNERKTILES = [2, 4, 8]
QGROUP_SIZES = [32, 64, 128, 256]
TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES))
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
N, K = shape
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
assert torch.equal(t, unpacked)

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]

# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AFTER_2_5:
test_utils.append("test_aot_dispatch_dynamic")

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)

opcheck(
torch.ops.torchao.unpack_tensor_core_tiled_layout,
(packed_w, inner_k_tiles),
test_utils=test_utils,
)

def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
n, k = q.shape
assert q.dtype == torch.int

n_groups = k // group_size
assert scales.shape[0] == n and scales.shape[1] == n_groups
assert scales.shape == zeros.shape

midpoint = 2 ** (nbits - 1)

#Convert fron u4 -> s4 and upcast to bfloat16
q = q.sub(midpoint).to(dtype)

# Dequantize
q = q.reshape(-1, group_size)
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)

return dq.reshape(n, k)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16

device = "cuda"

t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)

# Quantize
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

# Pack to tensor core layout
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
q_groups = k // group_size
assert scales_and_zeros.shape == torch.Size([q_groups, n, 2])

# Dequantize 'ao' ref
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
q, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
a_eye,
packed,
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

assert diff_op_ao < 1e-1

# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16
device = "cuda"

# Quantize and pack
t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

# Unpack and dequantize
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
unpacked, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
a_eye,
packed,
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

assert diff_op_ao < 1e-1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
n, k = shape
device = "cuda"

q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles)
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]
# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AFTER_2_5:
test_utils.append("test_aot_dispatch_dynamic")
opcheck(
torch.ops.torchao.dequantize_tensor_core_tiled_layout,
(packed_w, scales_and_zeros, group_size, inner_k_tiles),
test_utils=test_utils,
)

if __name__ == "__main__":
run_tests()
run_tests()
Loading

0 comments on commit 74846da

Please sign in to comment.