Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FP6-LLM API and add .to(device) op #595

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
fp6_llm_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
Expand Down Expand Up @@ -65,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
x = torch.randn(256, 64)
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
assert fpx.device.type == "cuda"
fpx = fpx.cpu()
assert fpx.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
Expand Down Expand Up @@ -98,6 +108,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_fp6_llm_quantize(self):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, device=device)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fp6_llm_weight_only())

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestQuantLlmLinearWeight)

Expand Down
21 changes: 19 additions & 2 deletions torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchao.quantization.quant_api import _get_linear_subclass_inserter


aten = torch.ops.aten
_ONES_TABLE = [_n_ones(i) for i in range(8)]


Expand Down Expand Up @@ -430,11 +431,27 @@ def _(func, types, args, kwargs):
return out.view(*act.shape[:-1], out_dim).to(act.dtype)


@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default)
@QuantLlmLinearWeight.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))


@QuantLlmLinearWeight.implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone))


@QuantLlmLinearWeight.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
# only support device kwargs, ignore the rest
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
)


def quant_llm_fpx_weight_only(ebits: int, mbits: int):
def apply_quant_llm(weight: Tensor) -> Tensor:
out_dim, in_dim = weight.shape
Expand All @@ -445,4 +462,4 @@ def apply_quant_llm(weight: Tensor) -> Tensor:


def fp6_llm_weight_only():
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
return quant_llm_fpx_weight_only(3, 2)
Loading