diff --git a/README.md b/README.md index 83a83f50de..cad352324b 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ To learn more try out our APIs, you can check out API examples in 4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees - [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning - [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads + - [FP6-LLM](torchao/prototype/fp6_llm) mixed matmul FP16 x FP6 kernel for io bound workloads ## Our Goals diff --git a/benchmarks/benchmark_fp6_llm.py b/benchmarks/benchmark_fp6_llm.py index f0d3da72bc..b6fdca6437 100644 --- a/benchmarks/benchmark_fp6_llm.py +++ b/benchmarks/benchmark_fp6_llm.py @@ -1,6 +1,6 @@ import torch from torch import nn -from torchao.quantization.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2 +from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2 from torch.utils.benchmark import Timer import pandas as pd from tqdm import tqdm diff --git a/test/quantization/test_fp6_llm.py b/test/prototype/test_fp6_llm.py similarity index 98% rename from test/quantization/test_fp6_llm.py rename to test/prototype/test_fp6_llm.py index 906154d331..6eddc522ab 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -7,7 +7,7 @@ parametrize, run_tests, ) -from torchao.quantization.fp6_llm import ( +from torchao.prototype.fp6_llm.fp6_llm import ( to_tc_float6_e3m2, from_tc_float6_e3m2, _to_tc_float6_e3m2_ref, diff --git a/test/test_ops.py b/test/test_ops.py index 016ac24fed..920b32c5f2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,7 @@ from torch.testing._internal.common_utils import TestCase, IS_FBCODE from torch.testing._internal.optests import opcheck import torchao -from torchao.quantization.fp6_llm import from_tc_float6_e3m2 +from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 import unittest from parameterized import parameterized import pytest @@ -26,7 +26,7 @@ def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp16act_fp6weight_linear(self): + def test_fp6_llm_linear(self): BS = 2 OC = 256 IC = 256 @@ -34,19 +34,19 @@ def test_fp16act_fp6weight_linear(self): fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") # smoke test - torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) + opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): + def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") - results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] results_fp16 = fp16_activation @ fp16_weight.T diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 30b0978a1a..8db5d44303 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -178,7 +178,7 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda); + m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda); } } // namespace torchao diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp index 9972655466..bd787385c0 100644 --- a/torchao/csrc/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm.cpp @@ -4,5 +4,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); + m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index d943a6490c..25cbfb5656 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -12,7 +12,7 @@ def decorator(func): return decorator -def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: +def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: """ FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. @@ -25,10 +25,10 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso Returns output of linear layer """ - return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) + return torch.ops.torchao.fp6_llm_linear.default(_in_feats, _weights, _scales, splitK) -@register_custom_op("torchao::fp16act_fp6weight_linear") +@register_custom_op("torchao::fp6_llm_linear") def _(_in_feats, _weights, _scales, splitK = 1): torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 02ee2dd3be..633099368a 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -9,6 +9,7 @@ - `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm - `galore/docs` - implementation notes and discussion of issues faced in kernel design. +- [`fp6_llm`](fp6_llm) - FP16 x FP6 mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) #### Roadmap diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md new file mode 100644 index 0000000000..767785275b --- /dev/null +++ b/torchao/prototype/fp6_llm/README.md @@ -0,0 +1,44 @@ +# FP6-LLM + +This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32 weights to FP6 and facility to convert existing models to FP6. + +## Usage + +```python +from torchao.prototype.fp6_llm import convert_fp6_llm + +model = ... +convert_fp6_llm(model) # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear + +# fully compatible with torch.compile() +model.compile(mode="max-autotune", fullgraph=True) +``` + +It's also possible to pre-process the weight and call the kernel directly. + +```python +import torch +from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2 +from torchao.ops import fp6_llm_linear + +fp32_weight = torch.randn(1024, 512).cuda() + +# pre-process the weight. this will quantize the weight to FP6 and pack it in a special +# layout for tensor cores. refer to paper for more details. +fp6_weight, scales = to_scaled_tc_float6_e3m2(fp32_weight) + +fp16_act = torch.randn(1, 512).cuda().half() +outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024) +``` + +## TODO + +- [ ] Compile CUDA kernel for Windows +- [ ] Merge FP5 from upstream + +## Credits + +Credits to FP6-LLM authors + +- Paper: https://arxiv.org/abs/2401.14112 +- Code: https://github.com/usyd-fsalab/fp6_llm diff --git a/torchao/prototype/fp6_llm/__init__.py b/torchao/prototype/fp6_llm/__init__.py new file mode 100644 index 0000000000..d1a46339bd --- /dev/null +++ b/torchao/prototype/fp6_llm/__init__.py @@ -0,0 +1 @@ +from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2 diff --git a/torchao/quantization/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py similarity index 98% rename from torchao/quantization/fp6_llm.py rename to torchao/prototype/fp6_llm/fp6_llm.py index 446d1ad937..570ea13546 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked, f6_e3m2_unpacked_to_f32 from torchao.prototype.mx_formats.constants import F6_E3M2_MAX -from torchao.ops import fp16act_fp6weight_linear +from torchao.ops import fp6_llm_linear def _pack_2bit(x: Tensor) -> Tensor: @@ -273,7 +273,7 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None def forward(self, x: Tensor) -> Tensor: splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) - out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK) + out = fp6_llm_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK) if self.bias is not None: out = out + self.bias return out.view(*x.shape[:-1], self.out_features).to(x.dtype)