diff --git a/test/prototype/test_quant_llm.py b/test/prototype/test_quant_llm.py index fab2d972b..610979674 100644 --- a/test/prototype/test_quant_llm.py +++ b/test/prototype/test_quant_llm.py @@ -11,6 +11,7 @@ from torchao.prototype.quant_llm import ( QuantLlmLinearWeight, quant_llm_fpx_weight_only, + fp6_llm_weight_only, to_scaled_tc_fpx, from_scaled_tc_fpx, ) @@ -65,6 +66,15 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device): actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale) torch.testing.assert_close(actual, expected) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", _FPx_DTYPES) + def test_to_copy_device(self, ebits, mbits): + x = torch.randn(256, 64) + fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda() + assert fpx.device.type == "cuda" + fpx = fpx.cpu() + assert fpx.device.type == "cpu" + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("leading_dims", [(4,), (2, 4)]) @@ -98,6 +108,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias): actual = torch.compile(fpx_linear, fullgraph=True)(x) torch.testing.assert_close(actual, expected) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_fp6_llm_quantize(self): + N, OC, IC = 4, 256, 64 + device = "cuda" + + linear = torch.nn.Linear(IC, OC, device=device) + fpx_linear = copy.deepcopy(linear) + quantize_(fpx_linear, fp6_llm_weight_only()) + + x = torch.randn(N, IC, device=device, dtype=torch.half) + expected = fpx_linear(x) + actual = torch.compile(fpx_linear, fullgraph=True)(x) + torch.testing.assert_close(actual, expected) + instantiate_parametrized_tests(TestQuantLlmLinearWeight) diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index bbcc978e7..f41bac9b2 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -10,6 +10,7 @@ from torchao.quantization.quant_api import _get_linear_subclass_inserter +aten = torch.ops.aten _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -430,11 +431,27 @@ def _(func, types, args, kwargs): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default) +@QuantLlmLinearWeight.implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) +@QuantLlmLinearWeight.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)) + + +@QuantLlmLinearWeight.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + # only support device kwargs, ignore the rest + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + ) + + def quant_llm_fpx_weight_only(ebits: int, mbits: int): def apply_quant_llm(weight: Tensor) -> Tensor: out_dim, in_dim = weight.shape @@ -445,4 +462,4 @@ def apply_quant_llm(weight: Tensor) -> Tensor: def fp6_llm_weight_only(): - return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2)) + return quant_llm_fpx_weight_only(3, 2)