Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FSDP Test in CI #207

Merged
merged 4 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
pytest
# Test utilities
pytest==7.4.0
expecttest
unittest-xml-reporting
parameterized
packaging
transformers

# For prototype features and benchmarks
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
matplotlib # needed for triton benchmarking
pandas # also for triton benchmarking
transformers #for galore testing
matplotlib
pandas

# Custom CUDA Extensions
ninja
29 changes: 11 additions & 18 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# Skip entire test if triton is not available, otherwise CI failure
import pytest
try:
import triton
import hqq
if int(triton.__version__.split(".")[0]) < 3:
pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True)
except ImportError:
pytest.skip("triton and hqq required to run this test", allow_module_level=True)

import itertools
import torch

from hqq.core.quantize import HQQLinear, BaseQuantizeConfig

triton = pytest.importorskip("triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test")
hqq = pytest.importorskip("hqq", reason="hqq required to run this test")
HQQLinear = pytest.importorskip("hqq.core.quantize.HQQLinear", reason="HQQLinear required to run this test")
BaseQuantizeConfig = pytest.importorskip("hqq.core.quantize.BaseQuantizeConfig", reason="HQQLinear required to run this test")

from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4


Expand Down Expand Up @@ -61,7 +55,7 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
Expand All @@ -81,24 +75,23 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant
scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)

if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq
hqq_out = x @ W_dq

#Pack uint8 W_q, then run fused dequant matmul
#Pack uint8 W_q, then run fused dequant matmul
packed_w = pack_2xint4(W_q)
tt_out = triton_mixed_mm(
x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)
else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T
hqq_out = x @ W_dq.T

packed_w = pack_2xint4(W_q.T)
tt_out = triton_mixed_mm(
x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)

assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3)

Loading