Skip to content

Commit

Permalink
add some ops for convenience
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 5, 2024
1 parent c6adfcb commit 53c7330
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
9 changes: 9 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
x = torch.randn(256, 64)
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
assert fpx.device.type == "cuda"
fpx = fpx.cpu()
assert fpx.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
Expand Down
19 changes: 18 additions & 1 deletion torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchao.quantization.quant_api import _get_linear_subclass_inserter


aten = torch.ops.aten
_ONES_TABLE = [_n_ones(i) for i in range(8)]


Expand Down Expand Up @@ -430,11 +431,27 @@ def _(func, types, args, kwargs):
return out.view(*act.shape[:-1], out_dim).to(act.dtype)


@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default)
@QuantLlmLinearWeight.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))


@QuantLlmLinearWeight.implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone))


@QuantLlmLinearWeight.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
# only support device kwargs, ignore the rest
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
)


def quant_llm_fpx_weight_only(ebits: int, mbits: int):
def apply_quant_llm(weight: Tensor) -> Tensor:
out_dim, in_dim = weight.shape
Expand Down

0 comments on commit 53c7330

Please sign in to comment.