Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Float8 Weight Only and FP8 weight + dynamic activation #740

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def all_linear(mod, name):
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp6", "None"], help='Which quantization technique to apply')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
Expand Down
71 changes: 46 additions & 25 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,33 @@
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

import torch
import unittest
import tempfile

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())

if is_cuda_8_9:
base_functions.append(float8_weight_only())

return base_functions


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand All @@ -38,36 +59,36 @@ def test_tensor_core_layout_transpose(self):
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
def test_weights_only(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_device(self):
from torchao.quantization import quantize_
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()

common_utils.instantiate_parametrized_tests(TestAffineQuantized)

if __name__ == "__main__":
run_tests()
103 changes: 103 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)
import pytest

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from numpy import full
from torch.testing._internal.common_utils import (
run_tests,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.quantization import (
quantize_,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.float8.float8_utils import compute_error
import torch
import unittest
import pytest
import tempfile
import copy
import random

from unittest.mock import patch


random.seed(0)
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
((256,), 512, 256),
((64,), 128, 64),
((32, 128), 64, 256),
((64, 256), 512, 128),
],
)
def test_fp8_linear_variants(
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
):
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

mode_map = {
"dynamic": float8_dynamic_activation_float8_weight,
"weight-only": float8_weight_only,
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert (
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

if __name__ == "__main__":
pytest.main([__file__])
4 changes: 4 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
PlainLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
Float8LayoutType,
Float8AQTLayout,
)

__all__ = [
Expand All @@ -27,4 +29,6 @@
"PlainLayoutType",
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Float8AQTLayout",
]
Loading
Loading