Skip to content

Commit

Permalink
Feat (llm): export to MatMulNBits
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 8, 2024
1 parent 9048ecb commit 50ea659
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 50 deletions.
34 changes: 34 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,43 @@
import onnx
import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import _get_tensor_sizes

from brevitas.export.onnx import onnx_export_opset


class MatMulNBitsFn(Function):

@staticmethod
def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
ret = g.op(
'com.microsoft::MatMulNBits',
x,
int_weights,
scales,
zero_points,
K_i=K,
N_i=N,
bits_i=bits,
block_size_i=block_size)
output_size = _get_tensor_sizes(x)
output_size[-1] = N
ret.setType(x.type().with_sizes(output_size))
return ret

@staticmethod
def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
dtype = x.dtype
device = x.device
shape = x.shape
out_shape = list(shape)
out_shape[-1] = N
# Only tensor metadata (shape, dtype, device) are preserved in the forward pass during
# tracing, not the correct value
out = torch.empty(out_shape, dtype=dtype, device=device)
return out


AXIS_OPSET = 13

DATATYPE_DICT = {
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(x))
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/nn/quant_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def forward(self, input: Union[Tensor, QuantTensor],
if self.export_mode:
assert self.cache_quant_io_metadata_only, "Can't cache multiple inputs"
out = self.export_handler(inp=input.value, other=other.value)
self._set_global_is_quant_layer(False)
return out
quant_input = self.input_quant(input)
quant_other = self.input_quant(other)
Expand Down Expand Up @@ -70,7 +69,6 @@ def forward(self,
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler([qt.value for qt in quant_tensor_list])
self._set_global_is_quant_layer(False)
return out
quant_tensor_list = [self.input_quant(qt) for qt in quant_tensor_list]
# trigger an assert if scale factors and bit widths are None or different
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(quant_input)
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
out = self.pack_output(out)
Expand Down Expand Up @@ -139,7 +138,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

quant_input = self.input_quant(inp)
Expand Down
3 changes: 0 additions & 3 deletions src/brevitas/nn/quant_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
out = self.export_handler(x.value)
self._set_global_is_quant_layer(False)
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
if self.mode != 'nearest':
Expand Down Expand Up @@ -69,7 +68,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
out = self.export_handler(x.value)
self._set_global_is_quant_layer(False)
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
# round interpolated values to scale
Expand Down Expand Up @@ -97,7 +95,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
out = self.export_handler(x.value)
self._set_global_is_quant_layer(False)
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
y = x.set(value=y_value)
Expand Down
151 changes: 109 additions & 42 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@

import numpy as np
import torch
from torch.nn import Module
from torch.onnx import register_custom_op_symbolic

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.manager import _set_layer_export_handler
from brevitas.export.manager import _set_layer_export_mode
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import BaseManager
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.standard.function import MatMulNBitsFn
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.nn import QuantLinear
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector


Expand Down Expand Up @@ -52,27 +57,6 @@ def __init__(self):
self.bit_width = None
self.dtype = None

def scaling_impl(self, proxy_module):
return proxy_module.tensor_quant.scaling_impl

def zero_point_impl(self, proxy_module):
return proxy_module.tensor_quant.zero_point_impl

def bit_width_impl(self, proxy_module):
return proxy_module.tensor_quant.msb_clamp_bit_width_impl

def export_scale(self, proxy_module, bit_width):
scaling_impl = self.scaling_impl(proxy_module)
int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl
int_threshold = int_scaling_impl(bit_width)
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
return threshold / int_threshold

def export_zero_point(self, proxy_module, scale, bit_width):
zero_point_impl = self.zero_point_impl(proxy_module)
return zero_point_impl.unexpanded_zero_point(scale, bit_width)

@abstractmethod
def prepare_for_export(self, module):
pass
Expand All @@ -83,6 +67,7 @@ def forward(self, x):


class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def __init__(self):
super().__init__()
Expand All @@ -93,20 +78,18 @@ def __init__(self):

def prepare_for_export(self, module):
assert len(module.tracked_module_list) == 1, "Shared quantizers not supported."
self.bit_width = self.bit_width_impl(module)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_layer = module.tracked_module_list[0]
quant_weight = quant_layer.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
signed = module.is_signed
self.int_dtype = torch.int8 if signed else torch.uint8
self.dtype = quant_weight.value.dtype
self.scale = self.export_scale(module, self.bit_width).detach()
self.expanded_groupwise_shape = self.scaling_impl(module).expanded_groupwise_shape
self.reshaped_groupwise_shape = self.scaling_impl(module).reshaped_groupwise_shape
self.scale = quant_weight.scale_
self.expanded_scaling_shape = quant_weight.value_.shape
self.reshaped_scaling_shape = quant_weight.value.shape
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach()
self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape
self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape
self.zero_point = quant_weight.zero_point_
else:
self.zero_point = None

Expand All @@ -131,15 +114,9 @@ def forward(self, x):
x = (x.type(self.dtype) - zero_point) * scale

# Fix shape post quantization
scale = scale.expand(self.expanded_groupwise_shape).contiguous().view(
self.reshaped_groupwise_shape)
# If zero_point is not defined, propagate same shape as scale
if self.zero_point is None:
zero_point = torch.zeros_like(scale).type(self.int_dtype)
else:
zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view(
self.reshaped_zero_point_shape).type(self.int_dtype)
x = x.view(self.reshaped_groupwise_shape)

return x, scale, zero_point, bit_width

Expand Down Expand Up @@ -208,18 +185,17 @@ def lcm(x, y):
raise ValueError(f"Bit width {bit_width} not supported.")

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_weight = module.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
self.bias = module.bias
self.scale = self.export_scale(module.weight_quant, self.bit_width)
self.scale = quant_weight.scale_
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(
module.weight_quant, self.scale, self.bit_width)
self.zero_point = quant_weight.zero_point_
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.block_size
self.group_size = quant_weight.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach())

Expand All @@ -237,10 +213,12 @@ def set_export_handler(cls, module):
_set_proxy_export_handler(cls, module)


def block_quant_layer_level_manager(export_handlers):
def block_quant_layer_level_manager(export_handlers, target=None, custom_fns_to_register=None):

class BlockQuantLayerLevelManager(BaseManager):
handlers = export_handlers
target_name = '' if target is None else target
custom_fns = [] if custom_fns_to_register is None else custom_fns_to_register

@classmethod
def set_export_handler(cls, module):
Expand Down Expand Up @@ -281,3 +259,92 @@ def replace_call_fn_target(graph_model, src, target):
node.target = target
graph_model.graph.lint()
graph_model.recompile()


class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase):
handled_layer = QuantLinear

def __init__(self):
super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__()
self.group_size = None

def pack_int_weights(self, bit_width, int_weights, zero_point):
assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input."
assert bit_width == 4, "Only 4 bit quantization export is supported at the moment"

is_symmetric = torch.sum(zero_point) == 0
zero_point = zero_point.to(torch.uint8)
rows, cols = int_weights.shape
group_size = self.group_size
blob_size = group_size // 2
k_blocks = (rows + group_size - 1) // group_size
padded_rows = k_blocks * group_size
pad_len = padded_rows - rows

# ONNX operator assumes implicit zp of 8 (largest negative number in Po2)
# If we are in a "symmetric" quantized scenario, we need to add this implicit zero point
# Otherwise it has already been added during the convesion to integer.
# This allows to pack weights always in unsigned integer.
zp = 0 if not int_weights.dtype == torch.int8 else 8
int_weights += zp
if pad_len > 0:
int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len))
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
rows, cols = int_weights.shape
int_weights = int_weights.t()
for n in range(cols):
for k_id in range(0, rows, group_size):
blk_int0 = (int_weights[n, k_id:k_id + group_size:2].numpy()).astype("uint8")
blk_int1 = (int_weights[n, k_id + 1:k_id + group_size:2].numpy()).astype("uint8")
packed[n, k_id // group_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4))

zero_point = zero_point.to(torch.uint8).flatten()

# The constant value 136 is derived from the source code in ORT test suite.
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
base_zp = 136 if is_symmetric else 0
packed_zp = base_zp * torch.ones(
(zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8)

i = 0
for column in range(packed_zp.shape[0]):
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_zp[column] |= zero_point[j] << shift_factor
i += 8 // bit_width
return torch.tensor(packed), packed_zp

def prepare_for_export(self, module):
quant_weight = module.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
self.bias = module.bias
self.scale = quant_weight.scale_
if (quant_weight.zero_point != 0.).any():
self.zero_point = quant_weight.zero_point_
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point)
self.weight_shape = module.weight.shape

def symbolic_execution(self, x):
int_weights = self.int_weight
scale = self.scale
bit_width = self.bit_width
N, K = self.weight_shape
out = MatMulNBitsFn.apply(
x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size)
return out


def export_packed_onnx(model, input, export_path):
export_class = block_quant_layer_level_manager(
export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd],
target='',
custom_fns_to_register=MatMulNBitsFn)

with torch.inference_mode(), brevitas_layer_export_mode(model, export_class):
torch.onnx.export(model, input, export_path)

0 comments on commit 50ea659

Please sign in to comment.