Skip to content

Commit

Permalink
Minor upgrades to bit pack (pytorch#347)
Browse files Browse the repository at this point in the history
* added dim=-1 and device is now based on input data

* removed device from param list

* fixed randint range
  • Loading branch information
vayuda authored Jun 13, 2024
1 parent 221514e commit 2d27ccf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 204 deletions.
250 changes: 58 additions & 192 deletions benchmarks/benchmark_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4


def benchmark(function, num_runs, setup =None):
args = setup()
def benchmark(function, args, num_runs):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -21,207 +20,74 @@ def benchmark(function, num_runs, setup =None):


def test_vs_existing():
def new_():
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
def new_(scale):
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
packed = pack(fake_tensor, 4, dim=1)
unpacked = unpack(packed, 4, dim=1)
def old_():
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
def old_(scale):
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
packed = pack_uint4(fake_tensor)
unpacked = unpack_uint4(packed)
new_ = torch.compile(new_, fullgraph=True)
old_ = torch.compile(old_, fullgraph=True)
new_()
old_()
print(f"new: {benchmark(new_, 1000)} ms ")
print(f"old: {benchmark(old_, 1000)} ms")



def test_iso_bitpack():
def load4x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda()
for scale in [256,512, 1024, 2048,4096, 8192]:
new_ = torch.compile(new_, fullgraph=True)
old_ = torch.compile(old_, fullgraph=True)
new_(scale)
old_(scale)
print("scale: ", scale)
print(f"new: {benchmark(new_,[scale], 10)} ms ")
print(f"old: {benchmark(old_,[scale], 10)} ms")


def compare_to_fp16():
class Linear16(torch.nn.Module):
def __init__(self, scale):
super().__init__()
scale += scale % 2
self.l1 = torch.nn.Linear(scale * 2, scale, bias=False,dtype=torch.float16).cuda()
self.l2 = torch.nn.Linear(scale, scale//2, bias=False,dtype=torch.float16).cuda()

def forward(self, x):
return self.l2(self.l1(x))

def load2x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda()

def loadx(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
class W4A16_symmetric_weight_only(torch.nn.Module):
def __init__(self, scale):
super().__init__()
assert scale % 4 == 0
self.l1 = torch.randint(2**8,(scale, scale), dtype=torch.uint8).cuda()
self.s1 = torch.tensor((scale),dtype=torch.float16).cuda()
self.l2 = torch.randint(2**8,(scale//2, scale//4), dtype=torch.uint8).cuda()
self.s2 = torch.tensor((scale//4),dtype=torch.float16).cuda()

def unpack8to2(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 2, dim=1)

def unpack8to4(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)

def t8to4wmm(scale=1024):
fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
def forward(self, x):
w = unpack(self.l1.detach(), 4, output_dtype=torch.float16)
x = x * self.s1
x = x @ w
w = unpack(self.l2.detach(), 4, output_dtype=torch.float16)
x = x * self.s2
x = x @ w

torch._dynamo.config.specialize_int = True
# _unpack_c = torch.compile(_unpack, fullgraph=True)
unpack_c = torch.compile(unpack, fullgraph=True)

scale = [16,64,256,1024,4096]
load4x_times = []
unpack8to2_times = []
load2x_times = []
unpack8to4_times = []
for s in scale:
res = benchmark(load4x, 50, scale=s)
load4x_times.append(res)
print(f"load(1, {4*s},{s}) time: {res} ms")

res=benchmark(unpack8to2, 50, scale=s)
unpack8to2_times.append(res)
print(f"load(1, {s},{s}) unpack uint2 time: {res} ms")
return x

torch._dynamo.config.specialize_int = True
for scale in [256,512, 1024, 2048,4096, 8192]:
a = Linear16(scale)
b = W4A16_symmetric_weight_only(scale)
# a = torch.compile(a, fullgraph=True)
b = torch.compile(b, fullgraph=True)

res = benchmark(load2x, 50, scale=s)
load2x_times.append(res)
print(f"load(1, {2*s},{s}) time: {res} ms")

res = benchmark(unpack8to4, 50, scale=s)
unpack8to4_times.append(res)
print(f"load(1, {s},{s}) unpack uint4 time: {res} ms")
print()

# import matplotlib.pyplot as plt
# plt.plot(scale, load4x_times, label="load(1, 4x, x)")
# plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2")
# plt.plot(scale, load2x_times, label="load(1, 2x, x)")
# plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4")
# plt.xlabel("scale")
# plt.ylabel("time (ms)")
# plt.yscale("log")
# plt.legend()
# plt.savefig("benchmark_bitpacking.png")


def test_vs_hqqpack():
#requires hqq to be installed
import hqq
import hqq.core.quantize as hqq_quantize
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}
test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
forward_args = [test_input]
b.forward(test_input)
print("scale: ", scale)
print("fp16 time: ", benchmark(a.forward, forward_args, 100))
print("uint4 time: ", benchmark(b.forward, forward_args, 100))

def mixed_mm(
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
W_q, meta = hqq_linear.W_q, hqq_linear.meta
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
W_dq = hqq_linear.dequantize()

scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
if pack_fn:
packed_w = pack(W_q.T,4,dim=0,order=False)
else:
packed_w = pack_2xint4(W_q.T)

if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=True,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)

else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=False,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)

shapes = [
[16, 128, 128],
[16, 4096, 4096],
]
group_sizes = [64, 128]
shape = [16, 128, 128]
group_size = 64
pack = torch.compile(pack, fullgraph=True)
for i in range(2):
shape = shapes[i]
group_size = group_sizes[i]
print("linear layer size: ", shape)
print("group size: ", group_size)
# run once to compile
test_mixed_mm(
shape,
group_size,
1,
torch.float16,
True,
"compute_bound",
torch.uint8,
)
# shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
print("pack time (ms): ", benchmark(test_mixed_mm, 100,
shape,
group_size,
1,
torch.float16,
True,
"compute_bound",
torch.uint8))

print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100,
shape,
group_size,
1,
torch.float16,
True,
"compute_bound", #max autotune doesnt work?
torch.uint8,
pack_fn=False))
print("")



if __name__ == "__main__":
compare_to_fp16()
test_vs_existing()

8 changes: 3 additions & 5 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4))
dimensions = (2, 1, 0)
dimensions = (2, 1, 0, -1)
orders = (True, False)


Expand Down Expand Up @@ -41,15 +41,13 @@ def test_CPU(dtype, dim, order):
element_type=element_type,
dim = dim,
order = order,
container_dtype = torch.uint8,
device='cpu')
container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order,
device='cpu')
order = order)
assert(unpacked.allclose(test_tensor))


Expand Down
18 changes: 11 additions & 7 deletions torchao/prototype/common/bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

def mod_shape(shape, mod, dim):
"""changes a select dimension of the input shape to mod"""
return (*shape[:dim], mod, *shape[dim+1:])
a = list(shape)
a[dim] = mod
return tuple(a)

def unpack(data: torch.Tensor,
element_bit_width: int,
element_type: Optional[str] = None,
dim: Optional[int] = 0,
order: Optional[bool] = True,
output_dtype: Optional[torch.dtype] = None,
device: Optional[str] ="cuda") -> torch.Tensor:
output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Unpacks small dtype elements from a larger dtype.
Expand All @@ -27,8 +28,10 @@ def unpack(data: torch.Tensor,
"""
container_size = torch.iinfo(data.dtype).bits
scale = container_size // element_bit_width

device = data.device

unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device)

if element_type == "trinary":
unpacked = unpacked.to(torch.int8) - 1
elif output_dtype is not None:
Expand Down Expand Up @@ -59,8 +62,7 @@ def pack(data: torch.Tensor,
dim: Optional[int] = 0,
container_dtype: Optional[torch.dtype] = None,
pad: Optional[bool] = False,
order: Optional[bool] = True,
device: Optional[str] = "cuda") -> torch.Tensor:
order: Optional[bool] = True) -> torch.Tensor:
"""
Packs small dtype elements into a container of a larger dtype.
Expand Down Expand Up @@ -93,6 +95,8 @@ def pack(data: torch.Tensor,
if container_dtype is not None:
data = data.to(container_dtype)

device = data.device

container_size = torch.iinfo(data.dtype).bits
scale = container_size // element_bit_width

Expand All @@ -117,4 +121,4 @@ def _pack(data, container_size, element_bit_width, scale, dim, order, device) ->
else:
packed |= data[slices] << element_bit_width*i
return packed


0 comments on commit 2d27ccf

Please sign in to comment.