Skip to content

Commit

Permalink
Support NF4 on CPU backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen committed May 8, 2024
1 parent 8561f09 commit 09cc153
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 7 deletions.
3 changes: 2 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
15 changes: 12 additions & 3 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
double_quant_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
dequantize_4bit_impl,
gemm_4bit_impl,
)

Tensor = torch.Tensor
Expand Down Expand Up @@ -132,7 +135,8 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, absmax, out])
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)

def dequantize_4bit(
self,
Expand All @@ -143,7 +147,8 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, absmax, out])
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

def gemv_4bit(
self,
Expand All @@ -154,7 +159,11 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")

return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)

def dequantize_blockwise(
self,
Expand Down
266 changes: 265 additions & 1 deletion bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import warnings

import torch
from typing import Optional
from bitsandbytes.functional import (
get_4bit_type,
quantize_blockwise,
dequantize_blockwise,
QuantState,
)

try:
# to support Intel CPU/GPU (XPU) backend
Expand Down Expand Up @@ -228,3 +234,261 @@ def mm_dequant_impl(
out = out + bias.to(compute_dtype)
out = out.to(output_dtype)
return out


NF4_QUANT_TABLE = [
-1.0 - 1e-2, # 0b0000
-0.8480964004993439, # 0b0001
-0.6106329262256622, # 0b0010
-0.4599952697753906, # 0b0011
-0.33967943489551544, # 0b0100
-0.23460740596055984, # 0b0101
-0.13791173323988914, # 0b0110
-0.045525018125772476, # 0b0111
0.03979014977812767, # 0b1000
0.1202552504837513, # 0b1001
0.2035212516784668, # 0b1010
0.2920137718319893, # 0b1011
0.3893125355243683, # 0b1100
0.5016634166240692, # 0b1101
0.6427869200706482, # 0b1110
0.8614784181118011, # 0b1111
]


# It's faster not to use torch.compile
def quantize_4bit_impl(
A: Tensor,
absmax: Tensor = None,
out: Tensor = None,
blocksize=64,
compress_statistics=False,
quant_type="nf4",
) -> Tensor:
"""
Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor (8-bit).
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
Returns
-------
torch.Tensor:
The 8-bit tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if quant_type != "nf4":
raise NotImplementedError(
f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU."
)
n = A.numel()
input_shape = A.shape
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0

if absmax is None:
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)

if out is None:
out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
rem = n % blocksize
has_rem = rem > 0

# Scale tensor to [-1, 1]
A_reshaped = A.reshape(n)
A_com = A_reshaped[:n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem:]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem:] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])

code = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
else:
state = QuantState(
absmax=absmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0:
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
out.reshape([input_shape[0], input_shape[1] // 2]),
ipex_cpu.quantization.WoqWeightDtype.NF4,
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
blocksize,
int(ipex_cpu.quantization.WoqLowpMode.BF16),
-1, # act_quant_mode
)

return out, state


@_maybe_torch_compile
def dequantize_4bit_impl(
A: Tensor,
quant_state = None,
absmax: Tensor = None,
out: Tensor = None,
blocksize: int = 64,
quant_type="nf4",
) -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
Parameters
----------
A : torch.Tensor
The input 8-bit tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
Returns
-------
torch.Tensor:
Dequantized tensor.
"""

if quant_state is None:
assert absmax is not None and out is not None

quant_state = QuantState(
absmax=absmax,
shape=out.shape,
dtype=out.dtype,
blocksize=blocksize,
quant_type=quant_type,
)

else:
absmax = quant_state.absmax

if quant_state.quant_type != "nf4":
raise NotImplementedError(
f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU."
)

if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if out is None:
out = torch.empty(
quant_state.shape, dtype=quant_state.dtype, device=A.device
)

n = out.numel()
# Map nf4 to [-1, 1]
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]

# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out_reshaped = out.reshape(-1)
out_reshaped[:n - rem] = (out_dq[:n - rem].view(-1, blocksize) * absmax[:blocks - has_rem].view(-1, 1)).reshape(-1)
if has_rem:
out_reshaped[n - rem:] = out_dq[n - rem:] * absmax[-1]

# take transpose here because weight is transposed (again) for computation
return out.t()


# Do not need torch.compile here as we are calling torch/ipex kernel
def gemm_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
"""
Matrix-matrix multiplication with 4-bit quantization.
Parameters
----------
A : torch.Tensor
The first input tensor. Usually the activation tensor.
B : torch.Tensor
The second input tensor. Usually the weight tensor.
out : torch.Tensor
The output tensor.
transposed_A : bool
Whether A is transposed
transposed_B : bool
Whether B is transposed
state : QuantState
Contains quantization info, such as blocksize and dtype
Returns
-------
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize)
output = torch.matmul(A, dqB)
if out is not None:
out.copy_(output)
else:
out = output
return out
7 changes: 5 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def from_prequantized(
return self

def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w = self.data.contiguous().to(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
Expand All @@ -303,6 +303,9 @@ def _quantize(self, device):
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def cpu(self, non_blocking: bool = False):
return self.to(device="cpu", non_blocking=non_blocking)

@overload
def to(
self: T,
Expand All @@ -320,7 +323,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type == "cuda" and not self.bnb_quantized:
if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
Expand Down

0 comments on commit 09cc153

Please sign in to comment.