Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform quantization in Chunks #196

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parametrize("shape", [(16, 16), (32, 16)])
@parametrize("chunk_size", [8, 16, 32])
def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):
a = torch.randn(shape, device='cuda', dtype=dtype)
with unittest.mock.patch("torchao.dtypes.nf4tensor.CHUNK_SIZE", chunk_size):
nf4_patched = to_nf4(a, 16, 2)
# This will be essentially no chunking since the numel is alot smaller than default chunk_size
nf4_base = to_nf4(a, 16, 2)

torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data)


instantiate_parametrized_tests(TestNF4Linear)

if __name__ == "__main__":
Expand Down
58 changes: 20 additions & 38 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import functools
from dataclasses import dataclass
import math
from typing import Dict, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor


aten = torch.ops.aten
Expand All @@ -15,6 +15,14 @@

NF4_OPS_TABLE: Dict[Any, Any] = {}

# Note: Quantize in Chunks
drisspg marked this conversation as resolved.
Show resolved Hide resolved
# During quantization to NF4, one of the steps to convert from the original float number
# to the index of the nearest value in the NF4 format. This can cause a large memory spike
# Due to intermediates of the quantization process. Instead we process the original
# tensor in chunks. This is a tradeoff between memory and speed. This number seems to
# strike a good balance between memory and speed
CHUNK_SIZE = 1024**2


def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
Expand Down Expand Up @@ -375,7 +383,7 @@ def dequantize_scalers(

@staticmethod
def convert_to_norm_float_weight(
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.Tensor
) -> torch.Tensor:
"""Convert a tensor to the normalized float weight format"""
flattened_tensor = inpt_tensor.flatten()
Expand All @@ -393,9 +401,13 @@ def convert_to_norm_float_weight(
scaled_blocks = blocks / scales

# Returns a flattened tensor with each element quantized to nf4 index
quantized_blocks = NF4Tensor.quantize_tensor_nearest(
scaled_blocks.flatten(), nf4
)
# See Note: Quantize in Chunks
quantized_blocks = torch.empty(numel, dtype=torch.uint8, device=inpt_tensor.device)
flattened = scaled_blocks.flatten()
for chunk_num in range(math.ceil(numel / CHUNK_SIZE)):
start = chunk_num * CHUNK_SIZE
end = min(start + CHUNK_SIZE, numel)
quantized_blocks[start:end] = NF4Tensor.quantize_tensor_nearest(flattened[start:end], nf4).to(torch.uint8)

# Combine the quantized elements into uint8 values
# This lays out two consecutive elements in the same byte
Expand Down Expand Up @@ -435,7 +447,7 @@ def get_original_weight(self) -> torch.Tensor:

@staticmethod
def quantize_tensor_nearest(
value: torch.float16, nf4: torch.Tensor
value: torch.Tensor, nf4: torch.Tensor
) -> torch.Tensor:
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
value = value.unsqueeze(-1) # (numel, 1)
Expand All @@ -445,36 +457,15 @@ def quantize_tensor_nearest(
return closest_nf4

@staticmethod

# inconsistently.

# defined in `torch._C.TensorBase`.
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
"""Dequantize a nf4 value to bfloat16 format"""
# return nf4.index_select(0, value)
return nf4[value]

def unpack(
self,
) -> Tuple[
int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size
]:

# Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
return (
self.block_size,
self.n_blocks,
self.scaler_block_size,
self.quantized_scalers,
self.quantization_factor,
self.scaler_mean,
self.quantized_data,
)

def __repr__(self):
def __repr__(self) -> str:
return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n"

def __str__(self):
def __str__(self) -> str:
return f"NF4Tensor({self.shape}, {self.block_size})"

def __tensor_flatten__(self):
Expand All @@ -501,9 +492,6 @@ def __tensor_flatten__(self):
], ctx

@staticmethod

# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.

def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 5, "Expected 5 inner tensors"
return NF4Tensor(
Expand Down Expand Up @@ -567,18 +555,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

class LinearNF4(torch.autograd.Function):
@staticmethod

# inconsistently.

def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
return F.linear(input, weight.to(input.dtype))

@staticmethod

# inconsistently.

def backward(ctx, grad_output):
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)"""
weight: NF4Tensor = ctx.nf4_weight
Expand Down
Loading