Skip to content

Commit

Permalink
int8 dynamic quant + bsr support (#821)
Browse files Browse the repository at this point in the history
This PR, adds in int8 dynamicquant + bsr support.

Changes:
* Use i8i8 -> bf16 matmul to maintain accuracy
* Added a block sparse layout type to AffineQuantizedTensor + check/impl.  
* Cleaned up benchmark.py script and add a single line `benchmark.sh` file for acceleration numbers
* Updated eval.py and added a single line `evaluate.sh` file for accuracy numbers
* Lots of lint formatting and README updates
* torch.compile now working and is correct
  • Loading branch information
jcaip authored Sep 26, 2024
1 parent da0bbe3 commit 4b5b5ee
Show file tree
Hide file tree
Showing 16 changed files with 1,442 additions and 922 deletions.
139 changes: 121 additions & 18 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,24 @@

import torch
from torch import nn

from torchao.sparsity import (
apply_fake_sparsity,
sparsify_,
semi_sparse_weight,
)
from torch.testing._internal import common_utils
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
from torchao.quantization.quant_api import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
int4_weight_only,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

class TestSemiStructuredSparse(TestCase):

class TestSemiStructuredSparse(common_utils.TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand All @@ -37,6 +34,7 @@ def test_sparse(self):
)
.half()
.cuda()
.eval()
)

apply_fake_sparsity(model)
Expand All @@ -45,13 +43,17 @@ def test_sparse(self):
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
class TestQuantSemiSparse(common_utils.TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quant_semi_sparse(self):
@common_utils.parametrize("compile", [True, False])
def test_quant_semi_sparse(self, compile):
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

input = torch.rand((128, 128)).half().cuda()
model = (
nn.Sequential(
Expand All @@ -60,19 +62,27 @@ def test_quant_semi_sparse(self):
)
.half()
.cuda()
.eval()
)
apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
quantize_(model_copy, int8_dynamic_activation_int8_weight())
dense_result = model_copy(input)

quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
quantize_(
model,
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
)
if compile:
model = torch.compile(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse_marlin(self):
@common_utils.parametrize("compile", [True, False])
def test_sparse_marlin(self, compile):
input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
Expand All @@ -81,6 +91,7 @@ def test_sparse_marlin(self):
)
.half()
.cuda()
.eval()
)

apply_fake_sparsity(model)
Expand All @@ -92,9 +103,101 @@ def test_sparse_marlin(self):

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if compile:
model = torch.compile(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)


class TestBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_sparse(self, compile):
input = torch.rand((1024, 1024)).half().cuda()
model = (
nn.Sequential(
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
)
.half()
.cuda()
.eval()
)

from torchao.sparsity.utils import create_block_sparse_tensor

M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import (
block_sparse_weight,
)

sparsify_(model, block_sparse_weight(blocksize=64))
# if compile:
# model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "pytorch 2.6+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_sparse(self, compile):
input = torch.rand((256, 128)).to(torch.bfloat16).cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.to(torch.bfloat16)
.cuda()
.eval()
)
from torchao.sparsity.prototype.superblock.blocksparse import (
blocksparse_int_addmm,
)
from torchao.sparsity.utils import create_block_sparse_tensor

M, N = model[0].weight.shape
model[0].weight.data = (
create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
* torch.rand(M, N, dtype=torch.bfloat16).cuda()
)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)

model_copy = copy.deepcopy(model)

quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType

quantize_(
model,
int8_dynamic_activation_int8_weight(
layout_type=BlockSparseLayoutType(blocksize=64)
),
)
if compile:
model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)


common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)
common_utils.instantiate_parametrized_tests(TestBlockSparseWeight)
common_utils.instantiate_parametrized_tests(TestQuantBlockSparseWeight)

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 4b5b5ee

Please sign in to comment.