Skip to content

Commit

Permalink
Add MatMul FP4 and NF4 Support (#18066)
Browse files Browse the repository at this point in the history
### Description
Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to
support quantization on weight.

This PR adds:
- schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating
point) and NF4 (4-bit NormalFloat) quantization on weight.
- a naive implementation for MatMulBnb4 on CPU and GPU, i.e.,
implemented like MatMul(A, Dequantize(B)).
- a special implementation for GemV for MatMulBnb4 and related benchmark
tool.
- tool to quantize model to FP4 or NF4.
  • Loading branch information
jambayk authored Oct 25, 2023
1 parent d88d52e commit d30d4d3
Show file tree
Hide file tree
Showing 23 changed files with 2,236 additions and 0 deletions.
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/dequantize_blockwise_bnb4.cuh"
"quantization/dequantize_blockwise_bnb4.cu"
"quantization/matmul_bnb4.cc"
"quantization/matmul_bnb4.cuh"
"quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
Expand Down
57 changes: 57 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Do not modify directly.*
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
* <a href="#com.microsoft.MatMulBnb4">com.microsoft.MatMulBnb4</a>
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
Expand Down Expand Up @@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MatMulBnb4"></a><a name="com.microsoft.matmulbnb4">**com.microsoft.MatMulBnb4**</a>

MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.

Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>quant_type</tt> : int (required)</dt>
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional quantized data for weight</dd>
<dt><tt>absmax</tt> : T1</dt>
<dd>quantization constants</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T1</dt>
<dd>tensor. The output tensor has the same rank as the input. </dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>


### <a name="com.microsoft.MatMulFpQ4"></a><a name="com.microsoft.matmulfpq4">**com.microsoft.MatMulFpQ4**</a>

Matrix product with right hand matrix being pre-packed and quantized int4 data blob.
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
Expand Down Expand Up @@ -852,6 +853,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
#endif
Expand Down
202 changes: 202 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>
#include <algorithm>
#include <cmath>

namespace onnxruntime {
namespace contrib {

#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif

typedef enum Bnb_DataType_t {
FP4 = 0,
NF4 = 1,
} Bnb_DataType_t;

FORCEINLINE uint8_t QuantizeOneFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12

// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assum input data is in [-1.0, 1.0]

// !be careful here, its easy to make a mistake
// that is difficult to noice if you add an extra
// zero somewhere!

uint8_t sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if (x > 0.29166667f) {
if (x > 0.583333f) {
if (x > 0.8333333f) {
return 0b0011 + sign;
} else {
return 0b0010 + sign;
}
} else if (x > 0.4166667f) {
return 0b101 + sign;
} else {
return 0b100 + sign;
}
} else if (x > 0.0859375f) {
if (x > 0.20833333f) {
return 0b0111 + sign;
} else {
return 0b0110 + sign;
}
} else if (x > 0.00260417f) {
return 0b0001 + sign;
} else {
return 0b0000 + sign;
}
}

FORCEINLINE uint8_t QuantizeOneNF4(float x) {
if (x > 0.03979014977812767f) {
if (x > 0.3893125355243683f) { // 1
if (x > 0.6427869200706482f) { // 11
if (x > 0.8614784181118011f) { // 111
return 0b1111;
} else {
return 0b1110;
}
} else if (x > 0.5016634166240692f) { // 110
return 0b1101;
} else {
return 0b1100;
}
} else if (x > 0.2035212516784668f) { // 10
if (x > 0.2920137718319893f) { // 101
return 0b1011;
} else {
return 0b1010;
}
} else if (x > 0.1202552504837513f) { // 100
return 0b1001;
} else {
return 0b1000;
}
} else if (x > -0.33967943489551544f) { // 0
if (x > -0.13791173323988914f) { // 01
if (x > -0.045525018125772476f) { // 011
return 0b0111;
} else {
return 0b0110;
}
} else if (x > -0.23460740596055984f) { // 010
return 0b0101;
} else {
return 0b0100;
}
} else if (x > -0.6106329262256622f) { // 00
if (x > -0.4599952697753906f) { // 001
return 0b0011;
} else {
return 0b0010;
}
} else if (x > -0.8480964004993439f) { // 000
return 0b0001;
} else {
return 0b0000;
}
}

template <int32_t DATA_TYPE>
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
if constexpr (DATA_TYPE == FP4)
return QuantizeOneFP4(x);
else
return QuantizeOneNF4(x);
}

template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
float local_absmax = 0.0f;

int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size;
int32_t dst_offset = block_idx * block_size / 2;

for (int32_t idx = 0; idx < block_len; idx++) {
const float v = static_cast<float>(src[src_offset + idx]);
local_absmax = fmaxf(local_absmax, fabsf(v));
}

absmax_block = static_cast<T>(local_absmax);
const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;

for (int32_t idx = 0; idx < block_len; idx += 2) {
const float v0 = static_cast<float>(src[src_offset + idx]) * reciprocal_absmax;
const uint8_t vi0 = QuantizeOneBnb4<DATA_TYPE>(v0);

const float v1 = (idx + 1 < block_len) ? static_cast<float>(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
const uint8_t vi1 = QuantizeOneBnb4<DATA_TYPE>(v1);

dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
}
}

static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};

static float nf4_qaunt_map[16] = {-1.0f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
0.0f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};

template <typename T, int32_t DATA_TYPE>
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
if constexpr (DATA_TYPE == FP4)
return static_cast<T>(fp4_qaunt_map[x]);
else
return static_cast<T>(nf4_qaunt_map[x]);
}

template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size / 2;
int32_t dst_offset = block_idx * block_size;

for (int32_t idx = 0; idx < block_len; idx += 2) {
const uint8_t val = src[src_offset + idx / 2];

dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
}
}

} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit d30d4d3

Please sign in to comment.