From 37e4d2d89540c65282b55c0df8072f03f1230ac8 Mon Sep 17 00:00:00 2001 From: JayakumarPawan <120random.things@gmail.com> Date: Thu, 23 May 2024 19:58:52 -0400 Subject: [PATCH 01/19] init additions --- bitnet-test.py | 102 +++++++++++++++++ test/dtypes/test_trinary.py | 43 +++++++ torchao/dtypes/trinary.py | 216 ++++++++++++++++++++++++++++++++++++ 3 files changed, 361 insertions(+) create mode 100644 bitnet-test.py create mode 100644 test/dtypes/test_trinary.py create mode 100644 torchao/dtypes/trinary.py diff --git a/bitnet-test.py b/bitnet-test.py new file mode 100644 index 000000000..1a655c8d6 --- /dev/null +++ b/bitnet-test.py @@ -0,0 +1,102 @@ +import torch +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch.library import impl, Library +import lovely_tensors as lt +lt.monkey_patch() + +def down_size(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4) + +def up_size(size): + return (*size[:-1], size[-1] * 4) + +def unpack_uint2(uint8_data) -> torch.Tensor: + """Get the original weight from the normalized float weight format""" + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + unpacked_data = torch.empty((*shape, 4), dtype=torch.uint8) + + unpacked_data[..., 0] = (uint8_data >> 6) & 0b11 + unpacked_data[..., 1] = (uint8_data >> 4) & 0b11 + unpacked_data[..., 2] = (uint8_data >> 2) & 0b11 + unpacked_data[..., 3] = uint8_data & 0b11 + return unpacked_data.view(up_size(shape)) + +def pack_uint2(uint8_data) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data + +def roundclip(x, a, b): + return torch.max(torch.tensor(a), torch.min(torch.tensor(b), torch.round(x))) + +def quantize_per_tensor_uint2_trinary(weights): + # Compute the average absolute value of the weight tensor + gamma = torch.mean(torch.abs(weights)) + + # Scale the weight tensor by the average absolute value + scaled_weights = weights / (gamma + 1e-8) + + # Round each scaled weight to the nearest integer in {-1, 0, +1} + quantized_weights = roundclip(scaled_weights, -1, 1) + + #Shift the distribution over by 1 so we can pack into a uint and not deal with signs + return quantized_weights.to(torch.int8) + +test_tensor = torch.randint(0, 3, (1024, 16, 8), dtype=torch.uint8) +print(test_tensor) +packed = pack_uint2(test_tensor) +unpacked = unpack_uint2(packed) +print(unpacked.allclose(test_tensor)) +assert(unpacked.allclose(test_tensor)) + +test_layer = torch.rand(1024, 16, 8) * 500.0 - 250.0 + +#Quantize our fake layer with bitnet method. +original_fake_layer = quantize_per_tensor_uint2_trinary(test_layer) +print(original_fake_layer) + +#Shift distribution from -1, 1 -> 0, 2 to we can use unsigned storage. +shifted_fake_layer = (original_fake_layer + 1.0).to(torch.uint8) +print("original: ") +print(shifted_fake_layer) + + + +def unpack_uint8_to_trinary(uint8_data) -> torch.Tensor: + """Get the original weight from the normalized float weight format""" + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + unpacked_data = torch.empty((*shape, 4), dtype=torch.int8) + + unpacked_data[..., 0] = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1.0 + unpacked_data[..., 1] = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1.0 + unpacked_data[..., 2] = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1.0 + unpacked_data[..., 3] = (uint8_data & 0b11).to(torch.int8) - 1.0 + return unpacked_data.view(up_size(shape)) + +def pack_uint2(uint8_data) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data + + +packed = pack_uint2(shifted_fake_layer) +print("after packing: ") +print(packed) +unpacked = unpack_uint8_to_trinary(packed) +print("after unpacking: ") +print(unpacked) +print(unpacked.dtype) +print(unpacked.allclose(original_fake_layer)) +assert(unpacked.allclose(original_fake_layer)) + +unpack_empty = torch.compile(unpack_uint8_to_trinary, mode="reduce-overhead") \ No newline at end of file diff --git a/test/dtypes/test_trinary.py b/test/dtypes/test_trinary.py new file mode 100644 index 000000000..688413104 --- /dev/null +++ b/test/dtypes/test_trinary.py @@ -0,0 +1,43 @@ +import torch +from torchao.dtypes.trinary import ( + TrinaryTensor, + quantize_per_tensor_trinary, +) +import unittest +from unittest import TestCase, main +from torch._export import capture_pre_autograd_graph +from torch._export import dynamic_dim +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, +) +from torchao.quantization.utils import ( + compute_error, +) +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, +) +from torch.ao.quantization.observer import ObserverBase +from torch import nn +from torch.fx import ( + Node, + GraphModule, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, +) +import copy + +def _quantize_linear_weights_only(model): + def fn(mod): + mod.weight = torch.nn.Parameter(quantize_per_tensor_trinary(mod.weight), requires_grad=False) + return mod + + _replace_with_custom_fn_if_matches_filter( + model, + lambda mod: fn(mod), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + +class TestTrinary(QuantizationTestCase): + \ No newline at end of file diff --git a/torchao/dtypes/trinary.py b/torchao/dtypes/trinary.py new file mode 100644 index 000000000..4a7c4cdf4 --- /dev/null +++ b/torchao/dtypes/trinary.py @@ -0,0 +1,216 @@ +import torch +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch.library import impl, Library +import lovely_tensors as lt +lt.monkey_patch() + +def down_size(size): + assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" + return (*size[:-1], size[-1] // 4) + +def up_size(size): + return (*size[:-1], size[-1] * 4) + +def roundclip(x, a, b): + return torch.max(torch.tensor(a), torch.min(torch.tensor(b), torch.round(x))) + +def quantize_per_tensor_trinary(weights): + # Compute the average absolute value of the weight tensor + gamma = torch.mean(torch.abs(weights)) + + # Scale the weight tensor by the average absolute value + scaled_weights = weights / (gamma + 1e-8) + + # Round each scaled weight to the nearest integer in {-1, 0, +1} and shift to {0, 1, 2} + quantized_weights = roundclip(scaled_weights, -1, 1) + 1 + + return quantized_weights.to(torch.uint8) + +def unpack_trinary(uint8_data) -> torch.Tensor: + """Get the original weight from the normalized float weight format""" + # since we are using uint8 we will decode 4 entries per byte + shape = uint8_data.shape + unpacked_data = torch.empty((*shape, 4), dtype=torch.int8) + + #shift back to {-1, 0, 1} while unpacking + unpacked_data[..., 0] = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1.0 + unpacked_data[..., 1] = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1.0 + unpacked_data[..., 2] = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1.0 + unpacked_data[..., 3] = (uint8_data & 0b11).to(torch.int8) - 1.0 + return unpacked_data.view(up_size(shape)) + +def pack_trinary(uint8_data) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 4 == 0 + uint8_data = uint8_data.contiguous().view(-1) + packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) + return packed_data + +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + +class TrinaryTensor(torch.Tensor): + def __new__(cls, data, *args, **kwargs): + assert elem.dtype is torch.uint8 + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + + return torch.Tensor._make_wrapper_subclass( + cls, up_size(elem.shape), dtype=torch.trinary, **kwargs + ) + + def __init__(self, elem, **kwargs): + self.elem = elem + + @classmethod + def from_unpacked(cls, unpacked): + return TrinaryTensor(pack_trinary(unpacked)) + + def tolist(self): + return self.to(torch.uint8).tolist() + + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + assert meta is None + elem = flattened["elem"] + return TrinaryTensor(elem) + + def __hash__(self): + return hash(self.elem) + + def __eq__(self, other): + return torch.equal(self.elem, other.elem) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func is torch.ops.aten.view.default: + self, size = args + size = utils.infer_size(size, self.numel()) + assert not kwargs + # WARNING: views not preserved + return TrinaryTensor(self.elem.reshape(down_size(size))) + elif func is torch.ops.aten.view.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_trinary(self.elem).view(torch.uint8) + return NotImplementedError(f"view {args}") + elif func is torch.ops.aten.to.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_trinary(self.elem).view(torch.uint8) + return NotImplementedError(f"to {args}") + elif func is torch.ops.aten.eq.Tensor: + args = pytree.tree_map_only( + TrinaryTensor, lambda x: x.elem.view(torch.uint8), args + ) + kwargs = pytree.tree_map_only( + TrinaryTensor, lambda x: x.elem.view(torch.uint8), kwargs + ) + return torch.ops.aten.eq.Tensor(*args, **kwargs) + elif func is torch.ops.aten._to_copy.default: + (self,) = args + if kwargs == {"dtype": torch.uint8}: + return unpack_trinary(self.elem).view(self.shape) # no wrap + else: + raise NotImplementedError(f"_to_copy {kwargs}") + elif func is torch.ops.aten.unbind.int: + # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to + # create four tensors containing one element each. But we can't + # do this with uint4 because such a tensor's size is not divisible + # by bytes. What I am going to do instead is promote to uint8 + # when this happens + self, dim = fill_defaults(args, 2, [0]) + if dim != self.dim() - 1: + raise NotImplementedError(f"unbind dim={dim}") + else: + # We're unbinding the last dimension, need to promote + return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind( + dim + ) + elif func is torch.ops.aten.select.int: + self, dim, index = args + if dim != self.dim() - 1: + return TrinaryTensor(torch.ops.aten.select.int(self.elem, dim, index)) + else: + raise NotImplementedError(f"select dim={dim}") + elif func is torch.ops.aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == self.dim() - 1: + # hard case + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % 2 == 0, start + assert end >= self.shape[dim] or end % 2 == 0, end + return TrinaryTensor( + # Not sure about this one + torch.ops.aten.slice.Tensor(self.elem, dim, start // 4, end // 4, 1) + ) + else: + # easy case + return TrinaryTensor( + torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step) + ) + elif func is torch.ops.aten.t.default: + # assert False, "transpose is not properly implemented currently" + (self,) = args + unpacked = unpack_trinary(self.elem) + transposed = torch.ops.aten.t.default(unpacked) + transposed_and_packed = pack_trinary(transposed) + return TrinaryTensor(transposed_and_packed) + elif func is torch.ops.aten.transpose_copy.int: + self, dim0, dim1 = args + unpacked = unpack_trinary(self.elem).view(self.shape) + transposed = torch.ops.aten.transpose_copy.int(unpacked, dim0, dim1) + transposed_and_packed = pack_trinary(transposed) + return TrinaryTensor(transposed_and_packed) + + elif func is torch.ops.aten.as_strided.default: + # size, stride, storage_offset are referring to tensor elements, not physical bytes + self, size, stride, storage_offset = args + size = down_size(size) + + new_stride = [] + for s in stride: + if s != 1: + # since four trinary values equals to 1 uint8 + new_stride.append(s // 4) + else: + new_stride.append(s) + stride = new_stride + + storage_offset //= 4 + return TrinaryTensor( + torch.ops.aten.as_strided.default( + self.elem, size, stride, storage_offset + ) + ) + + raise NotImplementedError(f"{func}") + + __torch_function__ = torch._C._disabled_torch_function_impl \ No newline at end of file From b37b52919f1302005567b82f987c04d78d4053bb Mon Sep 17 00:00:00 2001 From: JayakumarPawan <120random.things@gmail.com> Date: Sat, 25 May 2024 16:37:08 -0400 Subject: [PATCH 02/19] extended pack/unpack --- test-uint4.ipynb | 117 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 test-uint4.ipynb diff --git a/test-uint4.ipynb b/test-uint4.ipynb new file mode 100644 index 000000000..c0f2392b4 --- /dev/null +++ b/test-uint4.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch._prims_common as utils\n", + "import torch.utils._pytree as pytree\n", + "from torch.library import impl, Library\n", + "from functools import reduce\n", + "import lovely_tensors as lt\n", + "lt.monkey_patch()" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "def unpack_uint4(uint_data) -> torch.Tensor:\n", + " \"\"\"Get the original weight from the normalized float weight format\"\"\"\n", + " # since we are using uint8 we will decode 2 entries per byte\n", + " # Shift elements down 4 and select out the bottom 4 bits\n", + " shape = uint_data.shape\n", + " scale = uint_data.element_size() * 8 // 4 # how many uint4s can fit in a dtype_size\n", + " unpacked_data = torch.empty((*shape, scale), dtype=uint_data.dtype)\n", + " for i in range(scale):\n", + " unpacked_data[..., i] = (uint_data >> int(uint_data.element_size()*8- 4*(i+1))) & 0b1111\n", + " return unpacked_data.view(up_size(shape, scale))\n", + "\n", + "\n", + "def pack_uint4(uint_data, dtype_size=8) -> torch.Tensor:\n", + " # converting to uint8 for operations\n", + " scale = dtype_size // 4 # how many uint4s can fit in a dtype_size\n", + " padding = torch.zeros((*uint_data.shape[:-1], (scale - uint_data.shape[-1] % scale)%scale), dtype=uint_data.dtype)\n", + " uint_data = torch.cat([uint_data, padding], dim=-1)\n", + " shape = uint_data.shape\n", + " uint_data = uint_data.contiguous().view(-1)\n", + " return reduce(lambda x,y: x|y,[uint_data[i::scale] << dtype_size-4*(i+1) for i in range(scale)]).view(down_size(shape, scale))\n", + "\n", + "def down_size(size, amt):\n", + " assert size[-1] % amt == 0, f\"{size} last dim not divisible by {amt}\"\n", + " return (*size[:-1], size[-1] // amt)\n", + "\n", + "\n", + "def up_size(size, amt):\n", + " return (*size[:-1], size[-1] * amt)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8)\n", + "# print('og', test_tensor)\n", + "packed = pack_uint4(test_tensor)\n", + "# print('packed', packed)\n", + "unpacked = unpack_uint4(packed)\n", + "# print('unpacked', unpacked)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))\n", + "\n", + "test_tensor = torch.randint(0, 3, (5,1, 4), dtype=torch.int16)\n", + "# print('og', test_tensor)\n", + "packed = pack_uint4(test_tensor,16)\n", + "# print('packed', packed)\n", + "unpacked = unpack_uint4(packed)\n", + "# print('unpacked', unpacked)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))\n", + "\n", + "test_tensor = torch.randint(0, 3, (3,1, 9), dtype=torch.int32)\n", + "# print('og', test_tensor)\n", + "packed = pack_uint4(test_tensor,32)\n", + "# print('packed', packed)\n", + "unpacked = unpack_uint4(packed)\n", + "# print('unpacked', unpacked)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 33845feedf5b47ae04c9a83c6bc9fef94d59bf80 Mon Sep 17 00:00:00 2001 From: JayakumarPawan <120random.things@gmail.com> Date: Sun, 26 May 2024 18:13:23 -0400 Subject: [PATCH 03/19] pack/unpack from n to m dtypes --- test-uint4.ipynb | 117 ------------------------------------- test_pack-unpack.ipynb | 130 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 117 deletions(-) delete mode 100644 test-uint4.ipynb create mode 100644 test_pack-unpack.ipynb diff --git a/test-uint4.ipynb b/test-uint4.ipynb deleted file mode 100644 index c0f2392b4..000000000 --- a/test-uint4.ipynb +++ /dev/null @@ -1,117 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch._prims_common as utils\n", - "import torch.utils._pytree as pytree\n", - "from torch.library import impl, Library\n", - "from functools import reduce\n", - "import lovely_tensors as lt\n", - "lt.monkey_patch()" - ] - }, - { - "cell_type": "code", - "execution_count": 74, - "metadata": {}, - "outputs": [], - "source": [ - "def unpack_uint4(uint_data) -> torch.Tensor:\n", - " \"\"\"Get the original weight from the normalized float weight format\"\"\"\n", - " # since we are using uint8 we will decode 2 entries per byte\n", - " # Shift elements down 4 and select out the bottom 4 bits\n", - " shape = uint_data.shape\n", - " scale = uint_data.element_size() * 8 // 4 # how many uint4s can fit in a dtype_size\n", - " unpacked_data = torch.empty((*shape, scale), dtype=uint_data.dtype)\n", - " for i in range(scale):\n", - " unpacked_data[..., i] = (uint_data >> int(uint_data.element_size()*8- 4*(i+1))) & 0b1111\n", - " return unpacked_data.view(up_size(shape, scale))\n", - "\n", - "\n", - "def pack_uint4(uint_data, dtype_size=8) -> torch.Tensor:\n", - " # converting to uint8 for operations\n", - " scale = dtype_size // 4 # how many uint4s can fit in a dtype_size\n", - " padding = torch.zeros((*uint_data.shape[:-1], (scale - uint_data.shape[-1] % scale)%scale), dtype=uint_data.dtype)\n", - " uint_data = torch.cat([uint_data, padding], dim=-1)\n", - " shape = uint_data.shape\n", - " uint_data = uint_data.contiguous().view(-1)\n", - " return reduce(lambda x,y: x|y,[uint_data[i::scale] << dtype_size-4*(i+1) for i in range(scale)]).view(down_size(shape, scale))\n", - "\n", - "def down_size(size, amt):\n", - " assert size[-1] % amt == 0, f\"{size} last dim not divisible by {amt}\"\n", - " return (*size[:-1], size[-1] // amt)\n", - "\n", - "\n", - "def up_size(size, amt):\n", - " return (*size[:-1], size[-1] * amt)" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": {}, - "outputs": [], - "source": [ - "test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8)\n", - "# print('og', test_tensor)\n", - "packed = pack_uint4(test_tensor)\n", - "# print('packed', packed)\n", - "unpacked = unpack_uint4(packed)\n", - "# print('unpacked', unpacked)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))\n", - "\n", - "test_tensor = torch.randint(0, 3, (5,1, 4), dtype=torch.int16)\n", - "# print('og', test_tensor)\n", - "packed = pack_uint4(test_tensor,16)\n", - "# print('packed', packed)\n", - "unpacked = unpack_uint4(packed)\n", - "# print('unpacked', unpacked)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))\n", - "\n", - "test_tensor = torch.randint(0, 3, (3,1, 9), dtype=torch.int32)\n", - "# print('og', test_tensor)\n", - "packed = pack_uint4(test_tensor,32)\n", - "# print('packed', packed)\n", - "unpacked = unpack_uint4(packed)\n", - "# print('unpacked', unpacked)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/test_pack-unpack.ipynb b/test_pack-unpack.ipynb new file mode 100644 index 000000000..b0358a223 --- /dev/null +++ b/test_pack-unpack.ipynb @@ -0,0 +1,130 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from functools import reduce" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def unpack(data, data_size) -> torch.Tensor:\n", + " \"\"\"\n", + " Unpacks small dtype elements from a larger dtype.\n", + " \n", + " Inputs:\n", + " data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype.\n", + " data_size: int - the size of the small dtype in bits.\n", + " \n", + " Returns: torch.Tensor - a tensor of the unpacked elements.\n", + " \"\"\"\n", + " shape = data.shape\n", + " scale = data.element_size() * 8 // data_size\n", + " unpacked_data = []\n", + " for i in range(scale):\n", + " shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint\n", + " nbits = (1 << data_size) - 1 # mask for the last dtype_size bits\n", + " unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype))\n", + " return torch.stack(unpacked_data,dim=-1).view(up_size(shape, scale)) # stack the unpacked data and reshape to the original shape\n", + "\n", + "\n", + "def pack(data, container_size, data_size) -> torch.Tensor:\n", + " \"\"\"\n", + " Packs small dtype elements into a larger dtype.\n", + " \n", + " Inputs:\n", + " data: torch.Tensor - a tensor of unpacked elements of a small dtype.\n", + " container_size: int - the size of the large dtype in bits.\n", + " data_size: int - the size of the small dtype in bits.\n", + " \n", + " Returns: torch.Tensor - a tensor of the packed elements.\n", + " \"\"\"\n", + " scale = container_size // data_size\n", + " assert scale > 1, f\"container_size ({container_size}) not double the capacity ofdata_size ({data_size})\"\n", + " # pad the data to be divisible by scale\n", + " padding = torch.zeros((*data.shape[:-1], (scale - data.shape[-1] % scale)%scale), dtype=data.dtype)\n", + " packed = torch.cat([data, padding], dim=-1)\n", + " \n", + " shape = packed.shape\n", + " packed = packed.contiguous().view(-1)\n", + " #shift the data to the different indexes within the larger dtype and then union them together\n", + " return reduce(lambda x,y: x|y,[packed[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).view(down_size(shape, scale))\n", + "\n", + "def down_size(size, amt):\n", + " assert size[-1] % amt == 0, f\"{size} last dim not divisible by {amt}\"\n", + " return (*size[:-1], size[-1] // amt)\n", + "\n", + "\n", + "def up_size(size, amt):\n", + " return (*size[:-1], size[-1] * amt)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8)\n", + "packed = pack(test_tensor, 8, 4)\n", + "unpacked = unpack(packed, 4)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))\n", + "\n", + "test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16)\n", + "packed = pack(test_tensor,16, 3)\n", + "unpacked = unpack(packed, 3)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))\n", + "\n", + "test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32)\n", + "packed = pack(test_tensor,32, 16)\n", + "unpacked = unpack(packed,16)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))\n", + "\n", + "test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8)\n", + "packed = pack(test_tensor, 8, 2)\n", + "unpacked = unpack(packed,2)\n", + "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", + "assert(unpadded.allclose(test_tensor))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 3e7ca9b91c64cb27dd2fb99ae9ace9e8427d15e8 Mon Sep 17 00:00:00 2001 From: JayakumarPawan <120random.things@gmail.com> Date: Sun, 26 May 2024 18:53:04 -0400 Subject: [PATCH 04/19] works with torch.compile, but not optimized --- test.py | 84 ++++++++++++++++++++++++++++++++++++++++++ test_pack-unpack.ipynb | 27 +++++++++----- 2 files changed, 101 insertions(+), 10 deletions(-) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 000000000..fe6589be6 --- /dev/null +++ b/test.py @@ -0,0 +1,84 @@ +import torch +from functools import reduce +import os + +@torch.compile +def unpack(data, data_size) -> torch.Tensor: + """ + Unpacks small dtype elements from a larger dtype. + + Inputs: + data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. + data_size: int - the size of the small dtype in bits. + + Returns: torch.Tensor - a tensor of the unpacked elements. + """ + shape = data.shape + scale = data.element_size() * 8 // data_size + unpacked_data = [] + for i in range(scale): + shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint + nbits = (1 << data_size) - 1 # mask for the last dtype_size bits + unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) + return torch.stack(unpacked_data,dim=-1).view(up_size(shape, scale)) # stack the unpacked data and reshape to the original shape + +@torch.compile +def pack(data, container_size, data_size) -> torch.Tensor: + """ + Packs small dtype elements into a larger dtype. + + Inputs: + data: torch.Tensor - a tensor of unpacked elements of a small dtype. + container_size: int - the size of the large dtype in bits. + data_size: int - the size of the small dtype in bits. + + Returns: torch.Tensor - a tensor of the packed elements. + """ + scale = container_size // data_size + assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" + # pad the data to be divisible by scale + if data.shape[-1] % scale != 0: + padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype) + data = torch.cat([data, padding], dim=-1) + + shape = data.shape + data = data.contiguous().view(-1) + #shift the data to the different indexes within the larger dtype and then union them together + ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]) + newshape = down_size(shape, scale) + return ret.view(newshape) + +def down_size(size, amt): + assert size[-1] % amt == 0, f"{size} last dim not divisible by {amt}" + return (*size[:-1], size[-1] // amt) + + +def up_size(size, amt): + return (*size[:-1], size[-1] * amt) + + +torch._dynamo.config.specialize_int = True +os.environ["TORCH_LOGS"] = "output_code" +test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8) +packed = pack(test_tensor, 8, 4) +unpacked = unpack(packed, 4) +unpadded = unpacked[..., :test_tensor.shape[-1]] +assert(unpadded.allclose(test_tensor)) + +test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16) +packed = pack(test_tensor,16, 3) +unpacked = unpack(packed, 3) +unpadded = unpacked[..., :test_tensor.shape[-1]] +assert(unpadded.allclose(test_tensor)) + +test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32) +packed = pack(test_tensor,32, 16) +unpacked = unpack(packed,16) +unpadded = unpacked[..., :test_tensor.shape[-1]] +assert(unpadded.allclose(test_tensor)) + +test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8) +packed = pack(test_tensor, 8, 2) +unpacked = unpack(packed,2) +unpadded = unpacked[..., :test_tensor.shape[-1]] +assert(unpadded.allclose(test_tensor)) diff --git a/test_pack-unpack.ipynb b/test_pack-unpack.ipynb index b0358a223..a4616de9f 100644 --- a/test_pack-unpack.ipynb +++ b/test_pack-unpack.ipynb @@ -2,20 +2,22 @@ "cells": [ { "cell_type": "code", - "execution_count": 31, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", - "from functools import reduce" + "from functools import reduce\n", + "import os" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ + "@torch.compile\n", "def unpack(data, data_size) -> torch.Tensor:\n", " \"\"\"\n", " Unpacks small dtype elements from a larger dtype.\n", @@ -35,7 +37,7 @@ " unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype))\n", " return torch.stack(unpacked_data,dim=-1).view(up_size(shape, scale)) # stack the unpacked data and reshape to the original shape\n", "\n", - "\n", + "@torch.compile\n", "def pack(data, container_size, data_size) -> torch.Tensor:\n", " \"\"\"\n", " Packs small dtype elements into a larger dtype.\n", @@ -50,13 +52,16 @@ " scale = container_size // data_size\n", " assert scale > 1, f\"container_size ({container_size}) not double the capacity ofdata_size ({data_size})\"\n", " # pad the data to be divisible by scale\n", - " padding = torch.zeros((*data.shape[:-1], (scale - data.shape[-1] % scale)%scale), dtype=data.dtype)\n", - " packed = torch.cat([data, padding], dim=-1)\n", + " if data.shape[-1] % scale != 0:\n", + " padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype)\n", + " data = torch.cat([data, padding], dim=-1)\n", " \n", - " shape = packed.shape\n", - " packed = packed.contiguous().view(-1)\n", + " shape = data.shape\n", + " data = data.contiguous().view(-1)\n", " #shift the data to the different indexes within the larger dtype and then union them together\n", - " return reduce(lambda x,y: x|y,[packed[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).view(down_size(shape, scale))\n", + " ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)])\n", + " newshape = down_size(shape, scale)\n", + " return ret.view(newshape)\n", "\n", "def down_size(size, amt):\n", " assert size[-1] % amt == 0, f\"{size} last dim not divisible by {amt}\"\n", @@ -69,10 +74,12 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ + "torch._dynamo.config.specialize_int = True\n", + "os.environ[\"TORCH_LOGS\"] = \"output_code\"\n", "test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8)\n", "packed = pack(test_tensor, 8, 4)\n", "unpacked = unpack(packed, 4)\n", From 88fe113e9727e2ceb069c647961342ebf2817c7a Mon Sep 17 00:00:00 2001 From: JayakumarPawan <120random.things@gmail.com> Date: Sun, 26 May 2024 19:00:02 -0400 Subject: [PATCH 05/19] works on gpu --- test.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test.py b/test.py index fe6589be6..8807fbb76 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,5 @@ import torch from functools import reduce -import os @torch.compile def unpack(data, data_size) -> torch.Tensor: @@ -38,13 +37,13 @@ def pack(data, container_size, data_size) -> torch.Tensor: assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" # pad the data to be divisible by scale if data.shape[-1] % scale != 0: - padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype) - data = torch.cat([data, padding], dim=-1) + padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda() + data = torch.cat([data, padding], dim=-1).cuda() shape = data.shape data = data.contiguous().view(-1) #shift the data to the different indexes within the larger dtype and then union them together - ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]) + ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda() newshape = down_size(shape, scale) return ret.view(newshape) @@ -58,26 +57,26 @@ def up_size(size, amt): torch._dynamo.config.specialize_int = True -os.environ["TORCH_LOGS"] = "output_code" -test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8) + +test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8).cuda() packed = pack(test_tensor, 8, 4) unpacked = unpack(packed, 4) unpadded = unpacked[..., :test_tensor.shape[-1]] assert(unpadded.allclose(test_tensor)) -test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16) +test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16).cuda() packed = pack(test_tensor,16, 3) unpacked = unpack(packed, 3) unpadded = unpacked[..., :test_tensor.shape[-1]] assert(unpadded.allclose(test_tensor)) -test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32) +test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32).cuda() packed = pack(test_tensor,32, 16) unpacked = unpack(packed,16) unpadded = unpacked[..., :test_tensor.shape[-1]] assert(unpadded.allclose(test_tensor)) -test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8) +test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8).cuda() packed = pack(test_tensor, 8, 2) unpacked = unpack(packed,2) unpadded = unpacked[..., :test_tensor.shape[-1]] From 80b9a41223554c4862c411a4df1532a08697ffbd Mon Sep 17 00:00:00 2001 From: JayakumarPawan <120random.things@gmail.com> Date: Tue, 28 May 2024 14:24:04 -0400 Subject: [PATCH 06/19] added row-wise bitpack --- bitnet-test.py | 102 ------------- bitpacking.py | 318 +++++++++++++++++++++++++++++++++++++++++ test_pack-unpack.ipynb | 137 ------------------ 3 files changed, 318 insertions(+), 239 deletions(-) delete mode 100644 bitnet-test.py create mode 100644 bitpacking.py delete mode 100644 test_pack-unpack.ipynb diff --git a/bitnet-test.py b/bitnet-test.py deleted file mode 100644 index 1a655c8d6..000000000 --- a/bitnet-test.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import torch._prims_common as utils -import torch.utils._pytree as pytree -from torch.library import impl, Library -import lovely_tensors as lt -lt.monkey_patch() - -def down_size(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4) - -def up_size(size): - return (*size[:-1], size[-1] * 4) - -def unpack_uint2(uint8_data) -> torch.Tensor: - """Get the original weight from the normalized float weight format""" - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - unpacked_data = torch.empty((*shape, 4), dtype=torch.uint8) - - unpacked_data[..., 0] = (uint8_data >> 6) & 0b11 - unpacked_data[..., 1] = (uint8_data >> 4) & 0b11 - unpacked_data[..., 2] = (uint8_data >> 2) & 0b11 - unpacked_data[..., 3] = uint8_data & 0b11 - return unpacked_data.view(up_size(shape)) - -def pack_uint2(uint8_data) -> torch.Tensor: - # converting to uint8 for operations - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) - return packed_data - -def roundclip(x, a, b): - return torch.max(torch.tensor(a), torch.min(torch.tensor(b), torch.round(x))) - -def quantize_per_tensor_uint2_trinary(weights): - # Compute the average absolute value of the weight tensor - gamma = torch.mean(torch.abs(weights)) - - # Scale the weight tensor by the average absolute value - scaled_weights = weights / (gamma + 1e-8) - - # Round each scaled weight to the nearest integer in {-1, 0, +1} - quantized_weights = roundclip(scaled_weights, -1, 1) - - #Shift the distribution over by 1 so we can pack into a uint and not deal with signs - return quantized_weights.to(torch.int8) - -test_tensor = torch.randint(0, 3, (1024, 16, 8), dtype=torch.uint8) -print(test_tensor) -packed = pack_uint2(test_tensor) -unpacked = unpack_uint2(packed) -print(unpacked.allclose(test_tensor)) -assert(unpacked.allclose(test_tensor)) - -test_layer = torch.rand(1024, 16, 8) * 500.0 - 250.0 - -#Quantize our fake layer with bitnet method. -original_fake_layer = quantize_per_tensor_uint2_trinary(test_layer) -print(original_fake_layer) - -#Shift distribution from -1, 1 -> 0, 2 to we can use unsigned storage. -shifted_fake_layer = (original_fake_layer + 1.0).to(torch.uint8) -print("original: ") -print(shifted_fake_layer) - - - -def unpack_uint8_to_trinary(uint8_data) -> torch.Tensor: - """Get the original weight from the normalized float weight format""" - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - unpacked_data = torch.empty((*shape, 4), dtype=torch.int8) - - unpacked_data[..., 0] = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1.0 - unpacked_data[..., 1] = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1.0 - unpacked_data[..., 2] = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1.0 - unpacked_data[..., 3] = (uint8_data & 0b11).to(torch.int8) - 1.0 - return unpacked_data.view(up_size(shape)) - -def pack_uint2(uint8_data) -> torch.Tensor: - # converting to uint8 for operations - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) - return packed_data - - -packed = pack_uint2(shifted_fake_layer) -print("after packing: ") -print(packed) -unpacked = unpack_uint8_to_trinary(packed) -print("after unpacking: ") -print(unpacked) -print(unpacked.dtype) -print(unpacked.allclose(original_fake_layer)) -assert(unpacked.allclose(original_fake_layer)) - -unpack_empty = torch.compile(unpack_uint8_to_trinary, mode="reduce-overhead") \ No newline at end of file diff --git a/bitpacking.py b/bitpacking.py new file mode 100644 index 000000000..9a4f31152 --- /dev/null +++ b/bitpacking.py @@ -0,0 +1,318 @@ +import torch +from functools import reduce +from typing import List, Optional, Tuple +from enum import Enum +#debug +import lovely_tensors +lovely_tensors.monkey_patch() + +class ZeroPointDomain(Enum): + INT = 0 + FLOAT = 1 + +_DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), + torch.uint1: (0, 2**1-1), + torch.uint2: (0, 2**2-1), + torch.uint3: (0, 2**3-1), + torch.uint4: (0, 2**4-1), + torch.uint5: (0, 2**5-1), + torch.uint6: (0, 2**6-1), + torch.uint7: (0, 2**7-1), +} +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + assert quant_min >= quant_min_lower_bound, \ + "quant_min out of bound for dtype, " \ + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + + assert quant_max <= quant_max_upper_bound, \ + "quant_max out of bound for dtype, " \ + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + return quant_min, quant_max + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + assert len(block_size) == len(input_size) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + assert input_size[i] % block_size[i] == 0, f"Expecting input size at {i} dimension: {input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + +@torch.compile +def quantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +): + """ + Args: + input (torch.Tensor): original float32, float16 or bfloat16 Tensor + block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" + quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT: + quant = torch.clamp( + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max + ).to(output_dtype) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT + mid_point = (quant_max + quant_min + 1) / 2 + min_val = zero_point - scale * mid_point + quant = ( + torch.clamp( + torch.round((input - min_val) / scale), + quant_min, quant_max) + ).to(output_dtype) + quant = quant.view(original_shape) + + return quant + +@torch.compile +def dequantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, +): + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for input Tensor + quant_max (Optional[int]): maximum quantized value for input Tensor + output_dtype (torch.dtype): dtype for output Tensor, default is fp32 + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + assert input.dtype == input_dtype + assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" + quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) + + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT: + dequant = input.to(torch.int32) + if zero_point is not None: + dequant -= zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant *= scale + else: + assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + mid_point = (quant_max + quant_min + 1) / 2 + dequant = input - mid_point + dequant = dequant.to(output_dtype) + dequant *= scale + if zero_point is not None: + dequant += zero_point + + return dequant.view(original_shape).to(output_dtype) + +@torch.compile +def unpack(data, data_size) -> torch.Tensor: + """ + Unpacks small dtype elements from a larger dtype. + + Inputs: + data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. + data_size: int - the size of the small dtype in bits. + + Returns: torch.Tensor - a tensor of the unpacked elements. + """ + shape = data.shape + scale = data.element_size() * 8 // data_size + + unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).cuda() + nbits = (1 << data_size) - 1 # mask for the last dtype_size bits + for i in range(scale): + shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint + unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) + # print(unpacked_data) + return unpacked_data + +@torch.compile +def pack(data, container_size, data_size) -> torch.Tensor: + """ + Packs small dtype elements into a larger dtype. + Pads rows to be divisible by the scale. + + Inputs: + data: torch.Tensor - a tensor of unpacked elements of a small dtype. + container_size: int - the size of the large dtype in bits. + data_size: int - the size of the small dtype in bits. + + Returns: torch.Tensor - a tensor of the packed elements. + """ + scale = container_size // data_size + assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" + assert data.shape[0] > scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" + # pad the data to be divisible by scale + if data.shape[0] % scale != 0: + padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).cuda() + data = torch.cat([data, padding], dim=0).cuda() + + shape = data.shape + # data = data.contiguous().view(-1) + #shift the data to the different indexes within the larger dtype and then union them together + # for i in range(scale): + # print(data[i::scale, ...]) + ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]).cuda() + return ret.view(shape[0] // scale, *shape[1:]) + + + +torch._dynamo.config.specialize_int = True + +test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() +# print("original", test_tensor) +packed = pack(test_tensor, 8, 4) +# print("packed", packed) +unpacked = unpack(packed, 4) +# print("unpacked", unpacked) +unpadded = unpacked[:test_tensor.shape[0], ...] +# print("unpadded", unpadded) +assert(unpadded.allclose(test_tensor)) +print("test passed\n") +test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() +# print("original", test_tensor) +packed = pack(test_tensor,16, 3) +# print("packed", packed) +unpacked = unpack(packed, 3) +# print("unpacked", unpacked) +unpadded = unpacked[:test_tensor.shape[0], ...] +# print('unpadded: ', unpadded) +assert(unpadded.allclose(test_tensor)) +print("test passed\n") +test_tensor = torch.randint(0, 15, (1, 9), dtype=torch.int32).cuda() +packed = pack(test_tensor,32, 16) +print("packed", packed) +unpacked = unpack(packed,16) +# print("unpacked", unpacked) +unpadded = unpacked[:test_tensor.shape[0], ...] +assert(unpadded.allclose(test_tensor)) +print("test passed\n") +test_tensor = torch.randint(0, 3, (8, 7), dtype=torch.uint8).cuda() +packed = pack(test_tensor, 8, 2) +unpacked = unpack(packed,2) +unpadded = unpacked[:test_tensor.shape[0], ...] +assert(unpadded.allclose(test_tensor)) +print("test passed\n") +test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.float32).cuda() +quantize_affine(test_tensor, (4, 4), torch.tensor(0.5), torch.tensor(0), torch.uint8) diff --git a/test_pack-unpack.ipynb b/test_pack-unpack.ipynb deleted file mode 100644 index a4616de9f..000000000 --- a/test_pack-unpack.ipynb +++ /dev/null @@ -1,137 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from functools import reduce\n", - "import os" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "@torch.compile\n", - "def unpack(data, data_size) -> torch.Tensor:\n", - " \"\"\"\n", - " Unpacks small dtype elements from a larger dtype.\n", - " \n", - " Inputs:\n", - " data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype.\n", - " data_size: int - the size of the small dtype in bits.\n", - " \n", - " Returns: torch.Tensor - a tensor of the unpacked elements.\n", - " \"\"\"\n", - " shape = data.shape\n", - " scale = data.element_size() * 8 // data_size\n", - " unpacked_data = []\n", - " for i in range(scale):\n", - " shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint\n", - " nbits = (1 << data_size) - 1 # mask for the last dtype_size bits\n", - " unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype))\n", - " return torch.stack(unpacked_data,dim=-1).view(up_size(shape, scale)) # stack the unpacked data and reshape to the original shape\n", - "\n", - "@torch.compile\n", - "def pack(data, container_size, data_size) -> torch.Tensor:\n", - " \"\"\"\n", - " Packs small dtype elements into a larger dtype.\n", - " \n", - " Inputs:\n", - " data: torch.Tensor - a tensor of unpacked elements of a small dtype.\n", - " container_size: int - the size of the large dtype in bits.\n", - " data_size: int - the size of the small dtype in bits.\n", - " \n", - " Returns: torch.Tensor - a tensor of the packed elements.\n", - " \"\"\"\n", - " scale = container_size // data_size\n", - " assert scale > 1, f\"container_size ({container_size}) not double the capacity ofdata_size ({data_size})\"\n", - " # pad the data to be divisible by scale\n", - " if data.shape[-1] % scale != 0:\n", - " padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype)\n", - " data = torch.cat([data, padding], dim=-1)\n", - " \n", - " shape = data.shape\n", - " data = data.contiguous().view(-1)\n", - " #shift the data to the different indexes within the larger dtype and then union them together\n", - " ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)])\n", - " newshape = down_size(shape, scale)\n", - " return ret.view(newshape)\n", - "\n", - "def down_size(size, amt):\n", - " assert size[-1] % amt == 0, f\"{size} last dim not divisible by {amt}\"\n", - " return (*size[:-1], size[-1] // amt)\n", - "\n", - "\n", - "def up_size(size, amt):\n", - " return (*size[:-1], size[-1] * amt)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "torch._dynamo.config.specialize_int = True\n", - "os.environ[\"TORCH_LOGS\"] = \"output_code\"\n", - "test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8)\n", - "packed = pack(test_tensor, 8, 4)\n", - "unpacked = unpack(packed, 4)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))\n", - "\n", - "test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16)\n", - "packed = pack(test_tensor,16, 3)\n", - "unpacked = unpack(packed, 3)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))\n", - "\n", - "test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32)\n", - "packed = pack(test_tensor,32, 16)\n", - "unpacked = unpack(packed,16)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))\n", - "\n", - "test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8)\n", - "packed = pack(test_tensor, 8, 2)\n", - "unpacked = unpack(packed,2)\n", - "unpadded = unpacked[..., :test_tensor.shape[-1]]\n", - "assert(unpadded.allclose(test_tensor))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 47d9c92444d3b986d9cf4891e69c6c8bc9f44c33 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 12:57:17 -0400 Subject: [PATCH 07/19] restructured into prototype/ --- bitpacking.py | 318 ------------------------- log.txt | 0 setup.py | 1 - test.py | 83 ------- torchao/__init__.py | 4 - torchao/prototype/common/bitpacking.py | 152 ++++++++++++ 6 files changed, 152 insertions(+), 406 deletions(-) delete mode 100644 bitpacking.py create mode 100644 log.txt delete mode 100644 test.py create mode 100644 torchao/prototype/common/bitpacking.py diff --git a/bitpacking.py b/bitpacking.py deleted file mode 100644 index 9a4f31152..000000000 --- a/bitpacking.py +++ /dev/null @@ -1,318 +0,0 @@ -import torch -from functools import reduce -from typing import List, Optional, Tuple -from enum import Enum -#debug -import lovely_tensors -lovely_tensors.monkey_patch() - -class ZeroPointDomain(Enum): - INT = 0 - FLOAT = 1 - -_DTYPE_TO_QVALUE_BOUNDS = { - torch.uint8: (0, 255), - torch.int8: (-128, 127), - torch.int16: (-(2**15), 2**15 - 1), - torch.int32: (-(2**31), 2**31 - 1), - torch.uint1: (0, 2**1-1), - torch.uint2: (0, 2**2-1), - torch.uint3: (0, 2**3-1), - torch.uint4: (0, 2**4-1), - torch.uint5: (0, 2**5-1), - torch.uint6: (0, 2**6-1), - torch.uint7: (0, 2**7-1), -} -def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): - """Get quant_min and quant_max args based on dtype and also - verify that they are within the range of possible quant_min/quant_max - for dtype - """ - if dtype not in _DTYPE_TO_QVALUE_BOUNDS: - raise ValueError(f"Unsupported dtype: {dtype}") - quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] - if quant_min is None: - quant_min = quant_min_lower_bound - if quant_max is None: - quant_max = quant_max_upper_bound - - assert quant_min >= quant_min_lower_bound, \ - "quant_min out of bound for dtype, " \ - f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" - - assert quant_max <= quant_max_upper_bound, \ - "quant_max out of bound for dtype, " \ - f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" - return quant_min, quant_max - -def _get_reduction_params(block_size, input_size): - """Given block_size and input size find the parameters for reduction: - - Output: - shape_for_reduction: the shape we use to `view` input to prepare it for reduction - reduction_dims: the dims we'll do reduction over - - Example:: - Input: - block_size: (3, 3, 2, 10) - input_size: (3, 3, 10, 10) - - Output: - shape_for_reduction: (3, 3, 5, 2, 10) - reduction_dim: [0, 1, 3, 4] - """ - assert len(block_size) == len(input_size) - shape_for_reduction = [] - reduction_dims = [] - cur_dim = 0 - for i in range(len(block_size)): - if block_size[i] != input_size[i] and block_size[i] > 1: - assert input_size[i] % block_size[i] == 0, f"Expecting input size at {i} dimension: {input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" - shape_for_reduction.append(input_size[i] // block_size[i]) - shape_for_reduction.append(block_size[i]) - # reduce over the block_size[i] dim - reduction_dims.append(cur_dim + 1) - cur_dim += 2 - else: - # block_size[i] == input_size[i] or block_size[i] == 1 - shape_for_reduction.append(input_size[i]) - # we only need to reduce over the dimension if block_size is greater than 1 - # otherwise it's already the same as reduced dimension - if block_size[i] != 1: - reduction_dims.append(cur_dim) - cur_dim += 1 - return shape_for_reduction, reduction_dims - -@torch.compile -def quantize_affine( - input: torch.Tensor, - block_size: Tuple[int, ...], - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - output_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, -): - """ - Args: - input (torch.Tensor): original float32, float16 or bfloat16 Tensor - block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - scale (float): quantization parameter for affine quantization - zero_point (int): quantization parameter for affine quantization - output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor - quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype - quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - - Note: - How can block_size represent different granularities? - let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different - granularities: - - granularity type | block_size - per_tensor | (3, 3, 10, 10) - per_axis (axis=0) | (1, 3, 10, 10) - per_axis (axis=1) | (3, 1, 10, 10) - per_group (groupsize=2) | (3, 3, 10, 2) - per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) - - - Output: - quantized tensor with requested dtype - """ - # TODO: validations - # TODO: validate scale/zero_point dimensions are compatible with block_size - assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" - quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) - shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) - original_shape = input.shape - input = input.view(shape_for_reduction) - shape_after_reduction = shape_for_reduction - for i in reduction_dims: - shape_after_reduction[i] = 1 - scale = scale.view(shape_after_reduction) - if zero_point is not None: - zero_point = zero_point.view(shape_after_reduction) - - if zero_point_domain == ZeroPointDomain.INT: - quant = torch.clamp( - torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max - ).to(output_dtype) - else: - assert zero_point_domain == ZeroPointDomain.FLOAT - mid_point = (quant_max + quant_min + 1) / 2 - min_val = zero_point - scale * mid_point - quant = ( - torch.clamp( - torch.round((input - min_val) / scale), - quant_min, quant_max) - ).to(output_dtype) - quant = quant.view(original_shape) - - return quant - -@torch.compile -def dequantize_affine( - input: torch.Tensor, - block_size: Tuple[int, ...], - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - input_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - *, - output_dtype: torch.dtype = torch.float32, -): - """ - Args: - input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument - block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - scale (Tensor): quantization parameter for affine quantization - zero_point (Tensor): quantization parameter for affine quantization - dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor - quant_min (Optional[int]): minimum quantized value for input Tensor - quant_max (Optional[int]): maximum quantized value for input Tensor - output_dtype (torch.dtype): dtype for output Tensor, default is fp32 - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - - Output: - dequantized Tensor, with requested dtype or fp32 - """ - # TODO: validations - # TODO: validate scale/zero_point dimensions are compatible with block_size - assert input.dtype == input_dtype - assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" - quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) - - shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) - original_shape = input.shape - input = input.view(shape_for_reduction) - shape_after_reduction = shape_for_reduction - for i in reduction_dims: - shape_after_reduction[i] = 1 - scale = scale.view(shape_after_reduction) - if zero_point is not None: - zero_point = zero_point.view(shape_after_reduction) - - if zero_point_domain == ZeroPointDomain.INT: - dequant = input.to(torch.int32) - if zero_point is not None: - dequant -= zero_point.to(torch.int32) - dequant = dequant.to(output_dtype) - dequant *= scale - else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" - mid_point = (quant_max + quant_min + 1) / 2 - dequant = input - mid_point - dequant = dequant.to(output_dtype) - dequant *= scale - if zero_point is not None: - dequant += zero_point - - return dequant.view(original_shape).to(output_dtype) - -@torch.compile -def unpack(data, data_size) -> torch.Tensor: - """ - Unpacks small dtype elements from a larger dtype. - - Inputs: - data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. - data_size: int - the size of the small dtype in bits. - - Returns: torch.Tensor - a tensor of the unpacked elements. - """ - shape = data.shape - scale = data.element_size() * 8 // data_size - - unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).cuda() - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits - for i in range(scale): - shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint - unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) - # print(unpacked_data) - return unpacked_data - -@torch.compile -def pack(data, container_size, data_size) -> torch.Tensor: - """ - Packs small dtype elements into a larger dtype. - Pads rows to be divisible by the scale. - - Inputs: - data: torch.Tensor - a tensor of unpacked elements of a small dtype. - container_size: int - the size of the large dtype in bits. - data_size: int - the size of the small dtype in bits. - - Returns: torch.Tensor - a tensor of the packed elements. - """ - scale = container_size // data_size - assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" - assert data.shape[0] > scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" - # pad the data to be divisible by scale - if data.shape[0] % scale != 0: - padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).cuda() - data = torch.cat([data, padding], dim=0).cuda() - - shape = data.shape - # data = data.contiguous().view(-1) - #shift the data to the different indexes within the larger dtype and then union them together - # for i in range(scale): - # print(data[i::scale, ...]) - ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]).cuda() - return ret.view(shape[0] // scale, *shape[1:]) - - - -torch._dynamo.config.specialize_int = True - -test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() -# print("original", test_tensor) -packed = pack(test_tensor, 8, 4) -# print("packed", packed) -unpacked = unpack(packed, 4) -# print("unpacked", unpacked) -unpadded = unpacked[:test_tensor.shape[0], ...] -# print("unpadded", unpadded) -assert(unpadded.allclose(test_tensor)) -print("test passed\n") -test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() -# print("original", test_tensor) -packed = pack(test_tensor,16, 3) -# print("packed", packed) -unpacked = unpack(packed, 3) -# print("unpacked", unpacked) -unpadded = unpacked[:test_tensor.shape[0], ...] -# print('unpadded: ', unpadded) -assert(unpadded.allclose(test_tensor)) -print("test passed\n") -test_tensor = torch.randint(0, 15, (1, 9), dtype=torch.int32).cuda() -packed = pack(test_tensor,32, 16) -print("packed", packed) -unpacked = unpack(packed,16) -# print("unpacked", unpacked) -unpadded = unpacked[:test_tensor.shape[0], ...] -assert(unpadded.allclose(test_tensor)) -print("test passed\n") -test_tensor = torch.randint(0, 3, (8, 7), dtype=torch.uint8).cuda() -packed = pack(test_tensor, 8, 2) -unpacked = unpack(packed,2) -unpadded = unpacked[:test_tensor.shape[0], ...] -assert(unpadded.allclose(test_tensor)) -print("test passed\n") -test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.float32).cuda() -quantize_affine(test_tensor, (4, 4), torch.tensor(0.5), torch.tensor(0), torch.uint8) diff --git a/log.txt b/log.txt new file mode 100644 index 000000000..e69de29bb diff --git a/setup.py b/setup.py index 65ec21e15..fd8668f93 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,6 @@ def get_extensions(): package_data={ "torchao.kernel.configs": ["*.pkl"], }, - ext_modules=get_extensions(), install_requires=read_requirements("requirements.txt"), extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", diff --git a/test.py b/test.py deleted file mode 100644 index 8807fbb76..000000000 --- a/test.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -from functools import reduce - -@torch.compile -def unpack(data, data_size) -> torch.Tensor: - """ - Unpacks small dtype elements from a larger dtype. - - Inputs: - data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. - data_size: int - the size of the small dtype in bits. - - Returns: torch.Tensor - a tensor of the unpacked elements. - """ - shape = data.shape - scale = data.element_size() * 8 // data_size - unpacked_data = [] - for i in range(scale): - shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits - unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) - return torch.stack(unpacked_data,dim=-1).view(up_size(shape, scale)) # stack the unpacked data and reshape to the original shape - -@torch.compile -def pack(data, container_size, data_size) -> torch.Tensor: - """ - Packs small dtype elements into a larger dtype. - - Inputs: - data: torch.Tensor - a tensor of unpacked elements of a small dtype. - container_size: int - the size of the large dtype in bits. - data_size: int - the size of the small dtype in bits. - - Returns: torch.Tensor - a tensor of the packed elements. - """ - scale = container_size // data_size - assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" - # pad the data to be divisible by scale - if data.shape[-1] % scale != 0: - padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda() - data = torch.cat([data, padding], dim=-1).cuda() - - shape = data.shape - data = data.contiguous().view(-1) - #shift the data to the different indexes within the larger dtype and then union them together - ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda() - newshape = down_size(shape, scale) - return ret.view(newshape) - -def down_size(size, amt): - assert size[-1] % amt == 0, f"{size} last dim not divisible by {amt}" - return (*size[:-1], size[-1] // amt) - - -def up_size(size, amt): - return (*size[:-1], size[-1] * amt) - - -torch._dynamo.config.specialize_int = True - -test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8).cuda() -packed = pack(test_tensor, 8, 4) -unpacked = unpack(packed, 4) -unpadded = unpacked[..., :test_tensor.shape[-1]] -assert(unpadded.allclose(test_tensor)) - -test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16).cuda() -packed = pack(test_tensor,16, 3) -unpacked = unpack(packed, 3) -unpadded = unpacked[..., :test_tensor.shape[-1]] -assert(unpadded.allclose(test_tensor)) - -test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32).cuda() -packed = pack(test_tensor,32, 16) -unpacked = unpack(packed,16) -unpadded = unpacked[..., :test_tensor.shape[-1]] -assert(unpadded.allclose(test_tensor)) - -test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8).cuda() -packed = pack(test_tensor, 8, 2) -unpacked = unpack(packed,2) -unpadded = unpacked[..., :test_tensor.shape[-1]] -assert(unpadded.allclose(test_tensor)) diff --git a/torchao/__init__.py b/torchao/__init__.py index c8f04c1d9..4f83cf3e5 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -4,10 +4,6 @@ torch._utils_internal.IS_FBSOURCE ) -if not _IS_FBCODE: - from . import _C - from . import ops - from torchao.quantization import ( apply_weight_only_int8_quant, apply_dynamic_quant, diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py new file mode 100644 index 000000000..f686b5acf --- /dev/null +++ b/torchao/prototype/common/bitpacking.py @@ -0,0 +1,152 @@ +import torch +from functools import reduce + + + +def unpack(data, data_size, by_rows = True): + """ + Unpacks small dtype elements from a larger dtype. + + Inputs: + data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. + data_size: int - the size of the small dtype in bits. + + optional: + by_rows: bool - specifies whether to unpack... + by rows: tensor(n,m) -> tensor(n*scale, m) + or by columns: tensor(n,m) -> tensor(n,m*scale) + + defaults to rows because quantization is typically done by rows + but choose the version which matches how you quantize as this improves memory accesses/performance + + Returns: torch.Tensor - a tensor of the unpacked elements. + """ + if by_rows: + return _unpack_by_rows(data, data_size) + else: + return _unpack_by_cols(data, data_size) + +def pack(data, container_size, data_size, by_rows = True): + """ + Packs small dtype elements into a larger dtype. + Pads rows to be divisible by the scale. + + Inputs: + data: torch.Tensor - a tensor of unpacked elements of a small dtype. + container_size: int - the size of the large dtype in bits. + data_size: int - the size of the small dtype in bits. + + optional: + by_rows: bool - specifies whether to pack values... + by rows: tensor(n,m) -> tensor(n//scale, m) + or by columns: tensor(n,m) -> tensor(n,m//scale) + + defaults to rows because quantization is typically done by rows + but choose the version which matches how you quantize as this improves memory accesses/performance + + Returns: torch.Tensor - a tensor of packed elements. + """ + if by_rows: + return _pack_by_rows(data, container_size, data_size) + else: + return _pack_by_cols(data, container_size, data_size) + +def _unpack_by_rows(data, data_size) -> torch.Tensor: + shape = data.shape + scale = data.element_size() * 8 // data_size + + unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).cuda() + nbits = (1 << data_size) - 1 # mask for the last dtype_size bits + for i in range(scale): + shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint + unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) + # print(unpacked_data) + return unpacked_data + +def _unpack_by_cols(data, data_size) -> torch.Tensor: + shape = data.shape + scale = data.element_size() * 8 // data_size + unpacked_data = [] + nbits = (1 << data_size) - 1 # mask for the last dtype_size bits + for i in range(scale): + shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint + unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) + return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape + +def _pack_by_rows(data, container_size, data_size) -> torch.Tensor: + + scale = container_size // data_size + assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" + assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" + # pad the data to be divisible by scale + if data.shape[0] % scale != 0: + padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).cuda() + data = torch.cat([data, padding], dim=0).cuda() + + shape = data.shape + ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]).cuda() + return ret.view(shape[0] // scale, *shape[1:]) + +def _pack_by_cols(data, container_size, data_size) -> torch.Tensor: + scale = container_size // data_size + assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" + # pad the data to be divisible by scale + if data.shape[-1] % scale != 0: + padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda() + data = torch.cat([data, padding], dim=-1).cuda() + + shape = data.shape + data = data.contiguous().view(-1) + #shift the data to the different indexes within the larger dtype and then union them together + ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda() + return ret.view(*shape[:-1],shape[-1] // scale) + +if __name__ == '__main__': + #debug + # import lovely_tensors + # lovely_tensors.monkey_patch() + + torch._dynamo.config.specialize_int = True + pack = torch.compile(pack) + unpack = torch.compile(unpack) + + test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() + packed = pack(test_tensor, 8, 4) + unpacked = unpack(packed, 4) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + + test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() + packed = pack(test_tensor,16, 3) + unpacked = unpack(packed, 3) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + test_tensor = torch.randint(0, 15, (3, 9), dtype=torch.int32).cuda() + packed = pack(test_tensor,32, 16) + unpacked = unpack(packed,16) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + + test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() + packed = pack(test_tensor, 8, 2, False) + unpacked = unpack(packed,2, False) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() + packed = pack(test_tensor, 8, 4, False) + unpacked = unpack(packed, 4, False) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + + test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() + packed = pack(test_tensor,16, 3, False) + unpacked = unpack(packed, 3, False) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + From 8d1ea34a134ebec9064e7ee5d5fc16481162734d Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 13:10:11 -0400 Subject: [PATCH 08/19] revert nuclear fix --- setup.py | 1 + torchao/__init__.py | 4 ++++ torchao/prototype/common/bitpacking.py | 1 - 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fd8668f93..65ec21e15 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ def get_extensions(): package_data={ "torchao.kernel.configs": ["*.pkl"], }, + ext_modules=get_extensions(), install_requires=read_requirements("requirements.txt"), extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", diff --git a/torchao/__init__.py b/torchao/__init__.py index 4f83cf3e5..c8f04c1d9 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -4,6 +4,10 @@ torch._utils_internal.IS_FBSOURCE ) +if not _IS_FBCODE: + from . import _C + from . import ops + from torchao.quantization import ( apply_weight_only_int8_quant, apply_dynamic_quant, diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index f686b5acf..43e7f40f2 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -60,7 +60,6 @@ def _unpack_by_rows(data, data_size) -> torch.Tensor: for i in range(scale): shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) - # print(unpacked_data) return unpacked_data def _unpack_by_cols(data, data_size) -> torch.Tensor: From 6e1a7d60c7d9cbfaffd92685980a05d37608a0e1 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 13:17:36 -0400 Subject: [PATCH 09/19] removed temp log --- log.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 log.txt diff --git a/log.txt b/log.txt deleted file mode 100644 index e69de29bb..000000000 From 46e39fdce3d50390a1e52006a8f3b5e21033f481 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 13:22:27 -0400 Subject: [PATCH 10/19] removed trinary stuff from this branch --- test/dtypes/test_trinary.py | 43 ------- torchao/dtypes/trinary.py | 216 ------------------------------------ 2 files changed, 259 deletions(-) delete mode 100644 test/dtypes/test_trinary.py delete mode 100644 torchao/dtypes/trinary.py diff --git a/test/dtypes/test_trinary.py b/test/dtypes/test_trinary.py deleted file mode 100644 index 688413104..000000000 --- a/test/dtypes/test_trinary.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -from torchao.dtypes.trinary import ( - TrinaryTensor, - quantize_per_tensor_trinary, -) -import unittest -from unittest import TestCase, main -from torch._export import capture_pre_autograd_graph -from torch._export import dynamic_dim -from torch.testing._internal.common_quantization import ( - NodeSpec as ns, - QuantizationTestCase, -) -from torchao.quantization.utils import ( - compute_error, -) -from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, -) -from torch.ao.quantization.observer import ObserverBase -from torch import nn -from torch.fx import ( - Node, - GraphModule, -) -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, -) -import copy - -def _quantize_linear_weights_only(model): - def fn(mod): - mod.weight = torch.nn.Parameter(quantize_per_tensor_trinary(mod.weight), requires_grad=False) - return mod - - _replace_with_custom_fn_if_matches_filter( - model, - lambda mod: fn(mod), - lambda mod, fqn: isinstance(mod, torch.nn.Linear), - ) - -class TestTrinary(QuantizationTestCase): - \ No newline at end of file diff --git a/torchao/dtypes/trinary.py b/torchao/dtypes/trinary.py deleted file mode 100644 index 4a7c4cdf4..000000000 --- a/torchao/dtypes/trinary.py +++ /dev/null @@ -1,216 +0,0 @@ -import torch -import torch._prims_common as utils -import torch.utils._pytree as pytree -from torch.library import impl, Library -import lovely_tensors as lt -lt.monkey_patch() - -def down_size(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4) - -def up_size(size): - return (*size[:-1], size[-1] * 4) - -def roundclip(x, a, b): - return torch.max(torch.tensor(a), torch.min(torch.tensor(b), torch.round(x))) - -def quantize_per_tensor_trinary(weights): - # Compute the average absolute value of the weight tensor - gamma = torch.mean(torch.abs(weights)) - - # Scale the weight tensor by the average absolute value - scaled_weights = weights / (gamma + 1e-8) - - # Round each scaled weight to the nearest integer in {-1, 0, +1} and shift to {0, 1, 2} - quantized_weights = roundclip(scaled_weights, -1, 1) + 1 - - return quantized_weights.to(torch.uint8) - -def unpack_trinary(uint8_data) -> torch.Tensor: - """Get the original weight from the normalized float weight format""" - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - unpacked_data = torch.empty((*shape, 4), dtype=torch.int8) - - #shift back to {-1, 0, 1} while unpacking - unpacked_data[..., 0] = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1.0 - unpacked_data[..., 1] = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1.0 - unpacked_data[..., 2] = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1.0 - unpacked_data[..., 3] = (uint8_data & 0b11).to(torch.int8) - 1.0 - return unpacked_data.view(up_size(shape)) - -def pack_trinary(uint8_data) -> torch.Tensor: - # converting to uint8 for operations - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape)) - return packed_data - -def fill_defaults(args, n, defaults_tail): - """ - __torch_dispatch__ doesn't guarantee the number of arguments you are - passed (e.g., defaulted arguments are not passed); but usually it is - convenient to pad out the arguments list with defaults. This function - helps you do that. - Args: - args: the list of positional arguments passed to __torch_dispatch__ - n: the number of arguments you are expecting to get - defaults_tail: default values for the arguments, starting from the - end of the list - Example: - >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) - [1, 2, 3, 4, 5] - >>> fill_defaults([1, 2, 3], 5, [None, None, None]) - [1, 2, 3, None, None]] - """ - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r - -class TrinaryTensor(torch.Tensor): - def __new__(cls, data, *args, **kwargs): - assert elem.dtype is torch.uint8 - assert not kwargs.get("requires_grad", False) - kwargs["requires_grad"] = False - - return torch.Tensor._make_wrapper_subclass( - cls, up_size(elem.shape), dtype=torch.trinary, **kwargs - ) - - def __init__(self, elem, **kwargs): - self.elem = elem - - @classmethod - def from_unpacked(cls, unpacked): - return TrinaryTensor(pack_trinary(unpacked)) - - def tolist(self): - return self.to(torch.uint8).tolist() - - def __tensor_flatten__(self): - return ["elem"], None - - @staticmethod - def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): - assert meta is None - elem = flattened["elem"] - return TrinaryTensor(elem) - - def __hash__(self): - return hash(self.elem) - - def __eq__(self, other): - return torch.equal(self.elem, other.elem) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func is torch.ops.aten.view.default: - self, size = args - size = utils.infer_size(size, self.numel()) - assert not kwargs - # WARNING: views not preserved - return TrinaryTensor(self.elem.reshape(down_size(size))) - elif func is torch.ops.aten.view.dtype: - self, dtype = args - if dtype == torch.uint8: - return unpack_trinary(self.elem).view(torch.uint8) - return NotImplementedError(f"view {args}") - elif func is torch.ops.aten.to.dtype: - self, dtype = args - if dtype == torch.uint8: - return unpack_trinary(self.elem).view(torch.uint8) - return NotImplementedError(f"to {args}") - elif func is torch.ops.aten.eq.Tensor: - args = pytree.tree_map_only( - TrinaryTensor, lambda x: x.elem.view(torch.uint8), args - ) - kwargs = pytree.tree_map_only( - TrinaryTensor, lambda x: x.elem.view(torch.uint8), kwargs - ) - return torch.ops.aten.eq.Tensor(*args, **kwargs) - elif func is torch.ops.aten._to_copy.default: - (self,) = args - if kwargs == {"dtype": torch.uint8}: - return unpack_trinary(self.elem).view(self.shape) # no wrap - else: - raise NotImplementedError(f"_to_copy {kwargs}") - elif func is torch.ops.aten.unbind.int: - # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to - # create four tensors containing one element each. But we can't - # do this with uint4 because such a tensor's size is not divisible - # by bytes. What I am going to do instead is promote to uint8 - # when this happens - self, dim = fill_defaults(args, 2, [0]) - if dim != self.dim() - 1: - raise NotImplementedError(f"unbind dim={dim}") - else: - # We're unbinding the last dimension, need to promote - return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind( - dim - ) - elif func is torch.ops.aten.select.int: - self, dim, index = args - if dim != self.dim() - 1: - return TrinaryTensor(torch.ops.aten.select.int(self.elem, dim, index)) - else: - raise NotImplementedError(f"select dim={dim}") - elif func is torch.ops.aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == self.dim() - 1: - # hard case - if step != 1: - raise NotImplementedError(f"slice step={step}") - assert start % 2 == 0, start - assert end >= self.shape[dim] or end % 2 == 0, end - return TrinaryTensor( - # Not sure about this one - torch.ops.aten.slice.Tensor(self.elem, dim, start // 4, end // 4, 1) - ) - else: - # easy case - return TrinaryTensor( - torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step) - ) - elif func is torch.ops.aten.t.default: - # assert False, "transpose is not properly implemented currently" - (self,) = args - unpacked = unpack_trinary(self.elem) - transposed = torch.ops.aten.t.default(unpacked) - transposed_and_packed = pack_trinary(transposed) - return TrinaryTensor(transposed_and_packed) - elif func is torch.ops.aten.transpose_copy.int: - self, dim0, dim1 = args - unpacked = unpack_trinary(self.elem).view(self.shape) - transposed = torch.ops.aten.transpose_copy.int(unpacked, dim0, dim1) - transposed_and_packed = pack_trinary(transposed) - return TrinaryTensor(transposed_and_packed) - - elif func is torch.ops.aten.as_strided.default: - # size, stride, storage_offset are referring to tensor elements, not physical bytes - self, size, stride, storage_offset = args - size = down_size(size) - - new_stride = [] - for s in stride: - if s != 1: - # since four trinary values equals to 1 uint8 - new_stride.append(s // 4) - else: - new_stride.append(s) - stride = new_stride - - storage_offset //= 4 - return TrinaryTensor( - torch.ops.aten.as_strided.default( - self.elem, size, stride, storage_offset - ) - ) - - raise NotImplementedError(f"{func}") - - __torch_function__ = torch._C._disabled_torch_function_impl \ No newline at end of file From dcada3eb4cec32d8ba14f5eee4db54e2a9d74395 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 13:59:13 -0400 Subject: [PATCH 11/19] moved tests to tests/ --- test/prototype/mx_formats/test_bitpacking.py | 46 +++++++++++++++++ torchao/prototype/common/bitpacking.py | 52 +------------------- 2 files changed, 47 insertions(+), 51 deletions(-) create mode 100644 test/prototype/mx_formats/test_bitpacking.py diff --git a/test/prototype/mx_formats/test_bitpacking.py b/test/prototype/mx_formats/test_bitpacking.py new file mode 100644 index 000000000..2a7056b2d --- /dev/null +++ b/test/prototype/mx_formats/test_bitpacking.py @@ -0,0 +1,46 @@ +import torch +from torchao.prototype.common.bitpacking import pack, unpack +import pytest + + +def test_uint4_to_uint8(): + test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() + packed = pack(test_tensor, 8, 4) + unpacked = unpack(packed, 4) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + +def test_uint4_to_uint8_compile(): + torch._dynamo.config.specialize_int = True + pack = torch.compile(pack) + unpack = torch.compile(unpack) + test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() + packed = pack(test_tensor, 8, 4) + unpacked = unpack(packed, 4) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + +def test_uint3_to_int16(): + test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() + packed = pack(test_tensor,16, 3) + unpacked = unpack(packed, 3) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + + +def test_uint2_to_uint8_col_wise_compile(): + torch._dynamo.config.specialize_int = True + pack = torch.compile(pack) + unpack = torch.compile(unpack) + test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() + packed = pack(test_tensor, 8, 2, False) + unpacked = unpack(packed,2, False) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + +def test_uint3_to_int16_col_wise(): + test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() + packed = pack(test_tensor,16, 3, False) + unpacked = unpack(packed, 3, False) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) \ No newline at end of file diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 43e7f40f2..51580a499 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -98,54 +98,4 @@ def _pack_by_cols(data, container_size, data_size) -> torch.Tensor: data = data.contiguous().view(-1) #shift the data to the different indexes within the larger dtype and then union them together ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda() - return ret.view(*shape[:-1],shape[-1] // scale) - -if __name__ == '__main__': - #debug - # import lovely_tensors - # lovely_tensors.monkey_patch() - - torch._dynamo.config.specialize_int = True - pack = torch.compile(pack) - unpack = torch.compile(unpack) - - test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 4) - unpacked = unpack(packed, 4) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - - - test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() - packed = pack(test_tensor,16, 3) - unpacked = unpack(packed, 3) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - - test_tensor = torch.randint(0, 15, (3, 9), dtype=torch.int32).cuda() - packed = pack(test_tensor,32, 16) - unpacked = unpack(packed,16) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - - - test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 2, False) - unpacked = unpack(packed,2, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 4, False) - unpacked = unpack(packed, 4, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - - - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() - packed = pack(test_tensor,16, 3, False) - unpacked = unpack(packed, 3, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) - - + return ret.view(*shape[:-1],shape[-1] // scale) \ No newline at end of file From a5dd25d601a64b34d934cd6b267aefc0fadedc03 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 14:25:40 -0400 Subject: [PATCH 12/19] updated tests to skip if cuda DNE --- .../{mx_formats => }/test_bitpacking.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) rename test/prototype/{mx_formats => }/test_bitpacking.py (64%) diff --git a/test/prototype/mx_formats/test_bitpacking.py b/test/prototype/test_bitpacking.py similarity index 64% rename from test/prototype/mx_formats/test_bitpacking.py rename to test/prototype/test_bitpacking.py index 2a7056b2d..57fcd7aa5 100644 --- a/test/prototype/mx_formats/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -3,43 +3,49 @@ import pytest +if not TORCH_VERSION_AFTER_2_4: + pytest.skip("Unsupported PyTorch version") + def test_uint4_to_uint8(): - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) packed = pack(test_tensor, 8, 4) unpacked = unpack(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) - + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") def test_uint4_to_uint8_compile(): torch._dynamo.config.specialize_int = True - pack = torch.compile(pack) - unpack = torch.compile(unpack) + pack_compiled = torch.compile(pack) + unpack_compiled = torch.compile(unpack) test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 4) - unpacked = unpack(packed, 4) + packed = pack_compiled(test_tensor, 8, 4) + unpacked = unpack_compiled(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) def test_uint3_to_int16(): - test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() + test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16) packed = pack(test_tensor,16, 3) unpacked = unpack(packed, 3) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) - +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") def test_uint2_to_uint8_col_wise_compile(): torch._dynamo.config.specialize_int = True - pack = torch.compile(pack) - unpack = torch.compile(unpack) + pack_compiled = torch.compile(pack) + unpack_compiled = torch.compile(unpack) test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 2, False) - unpacked = unpack(packed,2, False) + packed = pack_compiled(test_tensor, 8, 2, False) + unpacked = unpack_compiled(packed,2, False) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) def test_uint3_to_int16_col_wise(): - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() + test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) packed = pack(test_tensor,16, 3, False) unpacked = unpack(packed, 3, False) unpadded = unpacked[:test_tensor.shape[0], ...] From 5ec3deb92d0517e81b4aede3345b509e8d935721 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 14:57:06 -0400 Subject: [PATCH 13/19] added full_graph=True to compile --- test/prototype/test_bitpacking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 57fcd7aa5..e3bbc9a72 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -2,7 +2,7 @@ from torchao.prototype.common.bitpacking import pack, unpack import pytest - +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version") @@ -18,7 +18,7 @@ def test_uint4_to_uint8(): def test_uint4_to_uint8_compile(): torch._dynamo.config.specialize_int = True pack_compiled = torch.compile(pack) - unpack_compiled = torch.compile(unpack) + unpack_compiled = torch.compile(unpack, full_graph=True) test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() packed = pack_compiled(test_tensor, 8, 4) unpacked = unpack_compiled(packed, 4) @@ -37,7 +37,7 @@ def test_uint3_to_int16(): def test_uint2_to_uint8_col_wise_compile(): torch._dynamo.config.specialize_int = True pack_compiled = torch.compile(pack) - unpack_compiled = torch.compile(unpack) + unpack_compiled = torch.compile(unpack, full_graph=True) test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() packed = pack_compiled(test_tensor, 8, 2, False) unpacked = unpack_compiled(packed,2, False) From 6e9a7381494d221804b9f75ede995120d5b172ee Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 29 May 2024 12:23:38 -0700 Subject: [PATCH 14/19] Apply suggestions from code review --- test/prototype/test_bitpacking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index e3bbc9a72..c04982b71 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -17,8 +17,8 @@ def test_uint4_to_uint8(): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") def test_uint4_to_uint8_compile(): torch._dynamo.config.specialize_int = True - pack_compiled = torch.compile(pack) - unpack_compiled = torch.compile(unpack, full_graph=True) + pack_compiled = torch.compile(pack, fullgraph=True) + unpack_compiled = torch.compile(unpack, fullgraph=True) test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() packed = pack_compiled(test_tensor, 8, 4) unpacked = unpack_compiled(packed, 4) @@ -36,8 +36,8 @@ def test_uint3_to_int16(): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") def test_uint2_to_uint8_col_wise_compile(): torch._dynamo.config.specialize_int = True - pack_compiled = torch.compile(pack) - unpack_compiled = torch.compile(unpack, full_graph=True) + pack_compiled = torch.compile(pack, fullgraph=True) + unpack_compiled = torch.compile(unpack, fullgraph=True) test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() packed = pack_compiled(test_tensor, 8, 2, False) unpacked = unpack_compiled(packed,2, False) From 189677d491719353cad60ab1e56c446c74d82b2b Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 15:41:47 -0400 Subject: [PATCH 15/19] fixed test skipping --- test/prototype/test_bitpacking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index e3bbc9a72..cb7f5224d 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -4,7 +4,7 @@ from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 if not TORCH_VERSION_AFTER_2_4: - pytest.skip("Unsupported PyTorch version") + pytest.skip("Unsupported PyTorch version", allow_module_level=True) def test_uint4_to_uint8(): test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) From c8055b2f6fde7a0b0ab94842dce4008392c1ff06 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 17:25:38 -0400 Subject: [PATCH 16/19] added import for has_triton --- test/prototype/test_bitpacking.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 8cfc3249f..c59360030 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -1,8 +1,9 @@ import torch from torchao.prototype.common.bitpacking import pack, unpack import pytest - +from torch.utils._triton import has_triton from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) From ae91c6c923d3eb00f1b38d03529da7e384dbb777 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 18:07:38 -0400 Subject: [PATCH 17/19] added support for any device type --- test/prototype/test_bitpacking.py | 25 +++++++++++++++++---- torchao/prototype/common/bitpacking.py | 30 +++++++++++++------------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index c59360030..11c9eb6d5 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -7,13 +7,28 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -def test_uint4_to_uint8(): +def test_uint4_to_uint8_CPU(): test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - packed = pack(test_tensor, 8, 4) - unpacked = unpack(packed, 4) + packed = pack(test_tensor, 8, 4, device='cpu') + unpacked = unpack(packed, 4, device='cpu') + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + +def test_uint3_to_int16_col_wise_cpu(): + test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) + packed = pack(test_tensor,16, 3, False, device='cpu') + unpacked = unpack(packed, 3, False, device='cpu') unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_uint4_to_uint8_CPU(): + test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) + packed = pack(test_tensor, 8, 4, device='cpu') + unpacked = unpack(packed, 4) + unpadded = unpacked[:test_tensor.shape[0], ...] + assert(unpadded.allclose(test_tensor)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") def test_uint4_to_uint8_compile(): @@ -25,7 +40,8 @@ def test_uint4_to_uint8_compile(): unpacked = unpack_compiled(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) - + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint3_to_int16(): test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16) packed = pack(test_tensor,16, 3) @@ -45,6 +61,7 @@ def test_uint2_to_uint8_col_wise_compile(): unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor)) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint3_to_int16_col_wise(): test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) packed = pack(test_tensor,16, 3, False) diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 51580a499..35e471c34 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -3,7 +3,7 @@ -def unpack(data, data_size, by_rows = True): +def unpack(data, data_size, by_rows = True, device="cuda"): """ Unpacks small dtype elements from a larger dtype. @@ -22,11 +22,11 @@ def unpack(data, data_size, by_rows = True): Returns: torch.Tensor - a tensor of the unpacked elements. """ if by_rows: - return _unpack_by_rows(data, data_size) + return _unpack_by_rows(data, data_size, device) else: return _unpack_by_cols(data, data_size) -def pack(data, container_size, data_size, by_rows = True): +def pack(data, container_size, data_size, by_rows = True, device="cuda"): """ Packs small dtype elements into a larger dtype. Pads rows to be divisible by the scale. @@ -47,15 +47,15 @@ def pack(data, container_size, data_size, by_rows = True): Returns: torch.Tensor - a tensor of packed elements. """ if by_rows: - return _pack_by_rows(data, container_size, data_size) + return _pack_by_rows(data, container_size, data_size, device) else: - return _pack_by_cols(data, container_size, data_size) + return _pack_by_cols(data, container_size, data_size, device) -def _unpack_by_rows(data, data_size) -> torch.Tensor: +def _unpack_by_rows(data, data_size, device) -> torch.Tensor: shape = data.shape scale = data.element_size() * 8 // data_size - unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).cuda() + unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).to(device) nbits = (1 << data_size) - 1 # mask for the last dtype_size bits for i in range(scale): shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint @@ -72,30 +72,30 @@ def _unpack_by_cols(data, data_size) -> torch.Tensor: unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape -def _pack_by_rows(data, container_size, data_size) -> torch.Tensor: +def _pack_by_rows(data, container_size, data_size, device) -> torch.Tensor: scale = container_size // data_size assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" # pad the data to be divisible by scale if data.shape[0] % scale != 0: - padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).cuda() + padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).to(device) data = torch.cat([data, padding], dim=0).cuda() shape = data.shape - ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]).cuda() - return ret.view(shape[0] // scale, *shape[1:]) + ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]) + return ret.view(shape[0] // scale, *shape[1:]).to(device) -def _pack_by_cols(data, container_size, data_size) -> torch.Tensor: +def _pack_by_cols(data, container_size, data_size, device) -> torch.Tensor: scale = container_size // data_size assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" # pad the data to be divisible by scale if data.shape[-1] % scale != 0: - padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda() + padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).to(device) data = torch.cat([data, padding], dim=-1).cuda() shape = data.shape data = data.contiguous().view(-1) #shift the data to the different indexes within the larger dtype and then union them together - ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda() - return ret.view(*shape[:-1],shape[-1] // scale) \ No newline at end of file + ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]) + return ret.view(*shape[:-1],shape[-1] // scale).to(device) \ No newline at end of file From 68179e8dc7689e70585800546533a27bde7a45a1 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 19:15:10 -0400 Subject: [PATCH 18/19] fix gpu tests --- test/prototype/test_bitpacking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 11c9eb6d5..c482888cc 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -23,7 +23,7 @@ def test_uint3_to_int16_col_wise_cpu(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint4_to_uint8_CPU(): - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) + test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() packed = pack(test_tensor, 8, 4, device='cpu') unpacked = unpack(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] @@ -43,7 +43,7 @@ def test_uint4_to_uint8_compile(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint3_to_int16(): - test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16) + test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() packed = pack(test_tensor,16, 3) unpacked = unpack(packed, 3) unpadded = unpacked[:test_tensor.shape[0], ...] @@ -63,7 +63,7 @@ def test_uint2_to_uint8_col_wise_compile(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_uint3_to_int16_col_wise(): - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) + test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() packed = pack(test_tensor,16, 3, False) unpacked = unpack(packed, 3, False) unpadded = unpacked[:test_tensor.shape[0], ...] From 4ce12955bb3b47e267e09766847d082249e932e1 Mon Sep 17 00:00:00 2001 From: vayuda <120random.things@gmail.com> Date: Wed, 29 May 2024 19:15:55 -0400 Subject: [PATCH 19/19] fix gpu tests --- test/prototype/test_bitpacking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index c482888cc..c1b60e07f 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -22,9 +22,9 @@ def test_uint3_to_int16_col_wise_cpu(): assert(unpadded.allclose(test_tensor)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint4_to_uint8_CPU(): +def test_uint4_to_uint8(): test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 4, device='cpu') + packed = pack(test_tensor, 8, 4) unpacked = unpack(packed, 4) unpadded = unpacked[:test_tensor.shape[0], ...] assert(unpadded.allclose(test_tensor))