From 09cc153dea939f23747bea622560c84b5a95183f Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 8 May 2024 02:10:49 -0700 Subject: [PATCH] Support NF4 on CPU backend --- bitsandbytes/autograd/_functions.py | 3 +- bitsandbytes/backends/cpu.py | 15 +- bitsandbytes/backends/cpu_xpu_common.py | 266 +++++++++++++++++++++++- bitsandbytes/nn/modules.py | 7 +- 4 files changed, 284 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7d570f28b..6dea211ff 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -572,7 +572,8 @@ def matmul_4bit( bias=None, ): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False: + if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False: + # CPU backend does not require A to be a vector if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index d6a9192e4..a5e123e62 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -9,6 +9,9 @@ double_quant_impl, igemmlt_impl, mm_dequant_impl, + quantize_4bit_impl, + dequantize_4bit_impl, + gemm_4bit_impl, ) Tensor = torch.Tensor @@ -132,7 +135,8 @@ def quantize_4bit( quant_type: Literal["fp4", "nf4"] = "fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, absmax, out]) + return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) def dequantize_4bit( self, @@ -143,7 +147,8 @@ def dequantize_4bit( blocksize: int = 64, quant_type: Literal["fp4", "nf4"] = "fp4", ) -> torch.Tensor: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, absmax, out]) + return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) def gemv_4bit( self, @@ -154,7 +159,11 @@ def gemv_4bit( transposed_B=False, state: QuantState = None, ) -> torch.Tensor: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, B, out]) + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + + return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) def dequantize_blockwise( self, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index f4e5ed3ec..078b81680 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,6 +1,12 @@ import warnings - import torch +from typing import Optional +from bitsandbytes.functional import ( + get_4bit_type, + quantize_blockwise, + dequantize_blockwise, + QuantState, +) try: # to support Intel CPU/GPU (XPU) backend @@ -228,3 +234,261 @@ def mm_dequant_impl( out = out + bias.to(compute_dtype) out = out.to(output_dtype) return out + + +NF4_QUANT_TABLE = [ + -1.0 - 1e-2, # 0b0000 + -0.8480964004993439, # 0b0001 + -0.6106329262256622, # 0b0010 + -0.4599952697753906, # 0b0011 + -0.33967943489551544, # 0b0100 + -0.23460740596055984, # 0b0101 + -0.13791173323988914, # 0b0110 + -0.045525018125772476, # 0b0111 + 0.03979014977812767, # 0b1000 + 0.1202552504837513, # 0b1001 + 0.2035212516784668, # 0b1010 + 0.2920137718319893, # 0b1011 + 0.3893125355243683, # 0b1100 + 0.5016634166240692, # 0b1101 + 0.6427869200706482, # 0b1110 + 0.8614784181118011, # 0b1111 +] + + +# It's faster not to use torch.compile +def quantize_4bit_impl( + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", +) -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if quant_type != "nf4": + raise NotImplementedError( + f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU." + ) + n = A.numel() + input_shape = A.shape + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + + if absmax is None: + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[:n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem:]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem:] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + # map [-1, 1] to nf4 + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8) + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0: + state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + out.reshape([input_shape[0], input_shape[1] // 2]), + ipex_cpu.quantization.WoqWeightDtype.NF4, + input_shape, # weight shape + absmax.view(input_shape[0], input_shape[1] // blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + blocksize, + int(ipex_cpu.quantization.WoqLowpMode.BF16), + -1, # act_quant_mode + ) + + return out, state + + +@_maybe_torch_compile +def dequantize_4bit_impl( + A: Tensor, + quant_state = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="nf4", +) -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) + + else: + absmax = quant_state.absmax + + if quant_state.quant_type != "nf4": + raise NotImplementedError( + f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." + ) + + if quant_state.nested: + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + + if out is None: + out = torch.empty( + quant_state.shape, dtype=quant_state.dtype, device=A.device + ) + + n = out.numel() + # Map nf4 to [-1, 1] + out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device) + out_uint8[::2] = A.bitwise_and(0xF) + out_uint8[1::2] = A.bitwise_right_shift(4) + out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype) + for i in range(len(quant_state.code)): + out_dq[out_uint8 == i] = quant_state.code[i] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + out_reshaped = out.reshape(-1) + out_reshaped[:n - rem] = (out_dq[:n - rem].view(-1, blocksize) * absmax[:blocks - has_rem].view(-1, 1)).reshape(-1) + if has_rem: + out_reshaped[n - rem:] = out_dq[n - rem:] * absmax[-1] + + # take transpose here because weight is transposed (again) for computation + return out.t() + + +# Do not need torch.compile here as we are calling torch/ipex kernel +def gemm_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, +) -> torch.Tensor: + """ + Matrix-matrix multiplication with 4-bit quantization. + + Parameters + ---------- + A : torch.Tensor + The first input tensor. Usually the activation tensor. + B : torch.Tensor + The second input tensor. Usually the weight tensor. + out : torch.Tensor + The output tensor. + transposed_A : bool + Whether A is transposed + transposed_B : bool + Whether B is transposed + state : QuantState + Contains quantization info, such as blocksize and dtype + + Returns + ------- + torch.Tensor: + GEMM output tensor. + """ + if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and hasattr(state, "op_context"): + assert state.op_context is not None + output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) + else: + dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) + output = torch.matmul(A, dqB) + if out is not None: + out.copy_(output) + else: + out = output + return out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 7e9ab8d05..d52cb4847 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -285,7 +285,7 @@ def from_prequantized( return self def _quantize(self, device): - w = self.data.contiguous().cuda(device) + w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( w, blocksize=self.blocksize, @@ -303,6 +303,9 @@ def _quantize(self, device): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + def cpu(self, non_blocking: bool = False): + return self.to(device="cpu", non_blocking=non_blocking) + @overload def to( self: T, @@ -320,7 +323,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: