forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
from torchao.prototype.common.bitpacking import pack, unpack | ||
import pytest | ||
from torch.utils._triton import has_triton | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 | ||
|
||
if not TORCH_VERSION_AFTER_2_4: | ||
pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
||
def test_uint4_to_uint8_CPU(): | ||
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) | ||
packed = pack(test_tensor, 8, 4, device='cpu') | ||
unpacked = unpack(packed, 4, device='cpu') | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
def test_uint3_to_int16_col_wise_cpu(): | ||
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) | ||
packed = pack(test_tensor,16, 3, False, device='cpu') | ||
unpacked = unpack(packed, 3, False, device='cpu') | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_uint4_to_uint8(): | ||
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() | ||
packed = pack(test_tensor, 8, 4) | ||
unpacked = unpack(packed, 4) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
def test_uint4_to_uint8_compile(): | ||
torch._dynamo.config.specialize_int = True | ||
pack_compiled = torch.compile(pack, fullgraph=True) | ||
unpack_compiled = torch.compile(unpack, fullgraph=True) | ||
test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() | ||
packed = pack_compiled(test_tensor, 8, 4) | ||
unpacked = unpack_compiled(packed, 4) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_uint3_to_int16(): | ||
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() | ||
packed = pack(test_tensor,16, 3) | ||
unpacked = unpack(packed, 3) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
def test_uint2_to_uint8_col_wise_compile(): | ||
torch._dynamo.config.specialize_int = True | ||
pack_compiled = torch.compile(pack, fullgraph=True) | ||
unpack_compiled = torch.compile(unpack, fullgraph=True) | ||
test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() | ||
packed = pack_compiled(test_tensor, 8, 2, False) | ||
unpacked = unpack_compiled(packed,2, False) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_uint3_to_int16_col_wise(): | ||
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() | ||
packed = pack(test_tensor,16, 3, False) | ||
unpacked = unpack(packed, 3, False) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
from functools import reduce | ||
|
||
|
||
|
||
def unpack(data, data_size, by_rows = True, device="cuda"): | ||
""" | ||
Unpacks small dtype elements from a larger dtype. | ||
Inputs: | ||
data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. | ||
data_size: int - the size of the small dtype in bits. | ||
optional: | ||
by_rows: bool - specifies whether to unpack... | ||
by rows: tensor(n,m) -> tensor(n*scale, m) | ||
or by columns: tensor(n,m) -> tensor(n,m*scale) | ||
defaults to rows because quantization is typically done by rows | ||
but choose the version which matches how you quantize as this improves memory accesses/performance | ||
Returns: torch.Tensor - a tensor of the unpacked elements. | ||
""" | ||
if by_rows: | ||
return _unpack_by_rows(data, data_size, device) | ||
else: | ||
return _unpack_by_cols(data, data_size) | ||
|
||
def pack(data, container_size, data_size, by_rows = True, device="cuda"): | ||
""" | ||
Packs small dtype elements into a larger dtype. | ||
Pads rows to be divisible by the scale. | ||
Inputs: | ||
data: torch.Tensor - a tensor of unpacked elements of a small dtype. | ||
container_size: int - the size of the large dtype in bits. | ||
data_size: int - the size of the small dtype in bits. | ||
optional: | ||
by_rows: bool - specifies whether to pack values... | ||
by rows: tensor(n,m) -> tensor(n//scale, m) | ||
or by columns: tensor(n,m) -> tensor(n,m//scale) | ||
defaults to rows because quantization is typically done by rows | ||
but choose the version which matches how you quantize as this improves memory accesses/performance | ||
Returns: torch.Tensor - a tensor of packed elements. | ||
""" | ||
if by_rows: | ||
return _pack_by_rows(data, container_size, data_size, device) | ||
else: | ||
return _pack_by_cols(data, container_size, data_size, device) | ||
|
||
def _unpack_by_rows(data, data_size, device) -> torch.Tensor: | ||
shape = data.shape | ||
scale = data.element_size() * 8 // data_size | ||
|
||
unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).to(device) | ||
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits | ||
for i in range(scale): | ||
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint | ||
unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) | ||
return unpacked_data | ||
|
||
def _unpack_by_cols(data, data_size) -> torch.Tensor: | ||
shape = data.shape | ||
scale = data.element_size() * 8 // data_size | ||
unpacked_data = [] | ||
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits | ||
for i in range(scale): | ||
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint | ||
unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) | ||
return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape | ||
|
||
def _pack_by_rows(data, container_size, data_size, device) -> torch.Tensor: | ||
|
||
scale = container_size // data_size | ||
assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" | ||
assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" | ||
# pad the data to be divisible by scale | ||
if data.shape[0] % scale != 0: | ||
padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).to(device) | ||
data = torch.cat([data, padding], dim=0).cuda() | ||
|
||
shape = data.shape | ||
ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]) | ||
return ret.view(shape[0] // scale, *shape[1:]).to(device) | ||
|
||
def _pack_by_cols(data, container_size, data_size, device) -> torch.Tensor: | ||
scale = container_size // data_size | ||
assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" | ||
# pad the data to be divisible by scale | ||
if data.shape[-1] % scale != 0: | ||
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).to(device) | ||
data = torch.cat([data, padding], dim=-1).cuda() | ||
|
||
shape = data.shape | ||
data = data.contiguous().view(-1) | ||
#shift the data to the different indexes within the larger dtype and then union them together | ||
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]) | ||
return ret.view(*shape[:-1],shape[-1] // scale).to(device) |