Skip to content

Commit

Permalink
Add FP6-LLM doc and move FP6-LLM to prototype (pytorch#358)
Browse files Browse the repository at this point in the history
* add doc. move fp6_llm under prototype

* doc update

* update doc. rename functions
  • Loading branch information
gau-nernst authored Jun 13, 2024
1 parent 2d27ccf commit b352fc1
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ To learn more try out our APIs, you can check out API examples in
4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees
- [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning
- [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads
- [FP6-LLM](torchao/prototype/fp6_llm) mixed matmul FP16 x FP6 kernel for io bound workloads

## Our Goals

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from torchao.quantization.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
from torch.utils.benchmark import Timer
import pandas as pd
from tqdm import tqdm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
parametrize,
run_tests,
)
from torchao.quantization.fp6_llm import (
from torchao.prototype.fp6_llm.fp6_llm import (
to_tc_float6_e3m2,
from_tc_float6_e3m2,
_to_tc_float6_e3m2_ref,
Expand Down
12 changes: 6 additions & 6 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
import torchao
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2
import unittest
from parameterized import parameterized
import pytest
Expand All @@ -26,27 +26,27 @@ def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
def test_fp6_llm_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

# smoke test
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T
Expand Down
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats,
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda);
m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda);
}

} // namespace torchao
2 changes: 1 addition & 1 deletion torchao/csrc/fp6_llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.impl_abstract_pystub("torchao.ops");
m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor");
m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor");
}
6 changes: 3 additions & 3 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def decorator(func):
return decorator


def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor:
def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor:
"""
FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.
Expand All @@ -25,10 +25,10 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso
Returns
output of linear layer
"""
return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK)
return torch.ops.torchao.fp6_llm_linear.default(_in_feats, _weights, _scales, splitK)


@register_custom_op("torchao::fp16act_fp6weight_linear")
@register_custom_op("torchao::fp6_llm_linear")
def _(_in_feats, _weights, _scales, splitK = 1):
torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507)
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`fp6_llm`](fp6_llm) - FP16 x FP6 mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)

#### Roadmap

Expand Down
44 changes: 44 additions & 0 deletions torchao/prototype/fp6_llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# FP6-LLM

This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32 weights to FP6 and facility to convert existing models to FP6.

## Usage

```python
from torchao.prototype.fp6_llm import convert_fp6_llm

model = ...
convert_fp6_llm(model) # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear

# fully compatible with torch.compile()
model.compile(mode="max-autotune", fullgraph=True)
```

It's also possible to pre-process the weight and call the kernel directly.

```python
import torch
from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2
from torchao.ops import fp6_llm_linear

fp32_weight = torch.randn(1024, 512).cuda()

# pre-process the weight. this will quantize the weight to FP6 and pack it in a special
# layout for tensor cores. refer to paper for more details.
fp6_weight, scales = to_scaled_tc_float6_e3m2(fp32_weight)

fp16_act = torch.randn(1, 512).cuda().half()
outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024)
```

## TODO

- [ ] Compile CUDA kernel for Windows
- [ ] Merge FP5 from upstream

## Credits

Credits to FP6-LLM authors

- Paper: https://arxiv.org/abs/2401.14112
- Code: https://github.com/usyd-fsalab/fp6_llm
1 change: 1 addition & 0 deletions torchao/prototype/fp6_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn, Tensor
from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked, f6_e3m2_unpacked_to_f32
from torchao.prototype.mx_formats.constants import F6_E3M2_MAX
from torchao.ops import fp16act_fp6weight_linear
from torchao.ops import fp6_llm_linear


def _pack_2bit(x: Tensor) -> Tensor:
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None

def forward(self, x: Tensor) -> Tensor:
splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features)
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK)
out = fp6_llm_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK)
if self.bias is not None:
out = out + self.bias
return out.view(*x.shape[:-1], self.out_features).to(x.dtype)
Expand Down

0 comments on commit b352fc1

Please sign in to comment.