From 884eca99fc5bc507583b70a09546ed2c007bdc06 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 1 May 2024 10:13:10 -0700 Subject: [PATCH] perform in chunks --- test/dtypes/test_nf4.py | 14 +++++++++ torchao/dtypes/nf4tensor.py | 58 +++++++++++++------------------------ 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 55bbe0bcb..278931516 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -236,6 +236,20 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype): out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("shape", [(16, 16), (32, 16)]) + @parametrize("chunk_size", [8, 16, 32]) + def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size): + a = torch.randn(shape, device='cuda', dtype=dtype) + with unittest.mock.patch("torchao.dtypes.nf4tensor.CHUNK_SIZE", chunk_size): + nf4_patched = to_nf4(a, 16, 2) + # This will be essentially no chunking since the numel is alot smaller than default chunk_size + nf4_base = to_nf4(a, 16, 2) + + torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data) + + instantiate_parametrized_tests(TestNF4Linear) if __name__ == "__main__": diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index f09d53821..4628ce994 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1,10 +1,10 @@ import functools from dataclasses import dataclass +import math from typing import Dict, Tuple import torch import torch.nn.functional as F -from torch import Tensor aten = torch.ops.aten @@ -15,6 +15,14 @@ NF4_OPS_TABLE: Dict[Any, Any] = {} +# Note: Quantize in Chunks +# During quantization to NF4, one of the steps to convert from the original float number +# to the index of the nearest value in the NF4 format. This can cause a large memory spike +# Due to intermediates of the quantization process. Instead we process the original +# tensor in chunks. This is a tradeoff between memory and speed. This number seems to +# strike a good balance between memory and speed +CHUNK_SIZE = 1024**2 + def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor) @@ -375,7 +383,7 @@ def dequantize_scalers( @staticmethod def convert_to_norm_float_weight( - inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor + inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.Tensor ) -> torch.Tensor: """Convert a tensor to the normalized float weight format""" flattened_tensor = inpt_tensor.flatten() @@ -393,9 +401,13 @@ def convert_to_norm_float_weight( scaled_blocks = blocks / scales # Returns a flattened tensor with each element quantized to nf4 index - quantized_blocks = NF4Tensor.quantize_tensor_nearest( - scaled_blocks.flatten(), nf4 - ) + # See Note: Quantize in Chunks + quantized_blocks = torch.empty(numel, dtype=torch.uint8, device=inpt_tensor.device) + flattened = scaled_blocks.flatten() + for chunk_num in range(math.ceil(numel / CHUNK_SIZE)): + start = chunk_num * CHUNK_SIZE + end = min(start + CHUNK_SIZE, numel) + quantized_blocks[start:end] = NF4Tensor.quantize_tensor_nearest(flattened[start:end], nf4).to(torch.uint8) # Combine the quantized elements into uint8 values # This lays out two consecutive elements in the same byte @@ -435,7 +447,7 @@ def get_original_weight(self) -> torch.Tensor: @staticmethod def quantize_tensor_nearest( - value: torch.float16, nf4: torch.Tensor + value: torch.Tensor, nf4: torch.Tensor ) -> torch.Tensor: """Quantize a float16 tensor to nf4 format to nearest and not rounded up""" value = value.unsqueeze(-1) # (numel, 1) @@ -445,36 +457,15 @@ def quantize_tensor_nearest( return closest_nf4 @staticmethod - - # inconsistently. - - # defined in `torch._C.TensorBase`. def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor: """Dequantize a nf4 value to bfloat16 format""" # return nf4.index_select(0, value) return nf4[value] - def unpack( - self, - ) -> Tuple[ - int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size - ]: - - # Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`. - return ( - self.block_size, - self.n_blocks, - self.scaler_block_size, - self.quantized_scalers, - self.quantization_factor, - self.scaler_mean, - self.quantized_data, - ) - - def __repr__(self): + def __repr__(self) -> str: return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n" - def __str__(self): + def __str__(self) -> str: return f"NF4Tensor({self.shape}, {self.block_size})" def __tensor_flatten__(self): @@ -501,9 +492,6 @@ def __tensor_flatten__(self): ], ctx @staticmethod - - # `typing.Dict[, ]` to avoid runtime subscripting errors. - def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): assert len(inner_tensors) == 5, "Expected 5 inner tensors" return NF4Tensor( @@ -567,18 +555,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): class LinearNF4(torch.autograd.Function): @staticmethod - - # inconsistently. - def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" ctx.nf4_weight = weight return F.linear(input, weight.to(input.dtype)) @staticmethod - - # inconsistently. - def backward(ctx, grad_output): """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" weight: NF4Tensor = ctx.nf4_weight