Skip to content

Commit

Permalink
Fix FP6-LLM API and add .to(device) op (#595)
Browse files Browse the repository at this point in the history
* fix

* add some ops for convenience
  • Loading branch information
gau-nernst authored and jainapurva committed Aug 7, 2024
1 parent bb52d6c commit c151e6b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
24 changes: 24 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
fp6_llm_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
Expand Down Expand Up @@ -65,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
x = torch.randn(256, 64)
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
assert fpx.device.type == "cuda"
fpx = fpx.cpu()
assert fpx.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
Expand Down Expand Up @@ -98,6 +108,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_fp6_llm_quantize(self):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, device=device)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fp6_llm_weight_only())

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestQuantLlmLinearWeight)

Expand Down
21 changes: 19 additions & 2 deletions torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchao.quantization.quant_api import _get_linear_subclass_inserter


aten = torch.ops.aten
_ONES_TABLE = [_n_ones(i) for i in range(8)]


Expand Down Expand Up @@ -430,11 +431,27 @@ def _(func, types, args, kwargs):
return out.view(*act.shape[:-1], out_dim).to(act.dtype)


@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default)
@QuantLlmLinearWeight.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))


@QuantLlmLinearWeight.implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone))


@QuantLlmLinearWeight.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
# only support device kwargs, ignore the rest
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
)


def quant_llm_fpx_weight_only(ebits: int, mbits: int):
def apply_quant_llm(weight: Tensor) -> Tensor:
out_dim, in_dim = weight.shape
Expand All @@ -445,4 +462,4 @@ def apply_quant_llm(weight: Tensor) -> Tensor:


def fp6_llm_weight_only():
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
return quant_llm_fpx_weight_only(3, 2)

0 comments on commit c151e6b

Please sign in to comment.