Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Extend Fast Decoding to UINT2 + QZeros #25

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions python/bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,47 @@
}
"""

decode_i2_to_f16_scale_zeros_quantized = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);

static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
T4 const zero_r = *zeros;
uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);

// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);

#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num));

asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_quantized<T1, T2, T3, T4, false>(_i2u, B_local_decode, N, scale, zeros);
}
"""

decode_i1_to_f16 = """
template <typename T1, typename T2>
__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
Expand Down Expand Up @@ -1359,6 +1400,21 @@ def fast_decode_impl(
),
)

LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = (
"lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_")
TensorIntrin.register(
LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN,
*get_fast_decode_intrin(
source_bit=2,
storage_dtype="int8",
target_dtype="float16",
loops_extent=8,
with_scale=True,
with_zeros=True,
zeros_mode="quantized",
),
)

LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = (
"lop3_fast_decode_u1_to_int8_to_f16_l8_scale_")
TensorIntrin.register(
Expand Down Expand Up @@ -1561,6 +1617,7 @@ def get_lop3_intrin_group(
"i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale,
"i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale,
"i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized,
"i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized,
"i1_to_i8": decode_i1s_to_i8s,
"i2_to_i8": decode_i2s_to_i8s,
"i4_to_i8": decode_i4s_to_i8s,
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ psutil
scipy
tornado
torch
thefuzz
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ psutil
scipy
tornado
torch
thefuzz
39 changes: 39 additions & 0 deletions testing/cpp/lop3_type_conversion/fast_decoding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,45 @@ __device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_deco
decode_i2b_to_f16<T1, T2, false, true, true, 1>(_i2u, B_local_decode, N, scale, zeros);
}

template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);

static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
T4 const zero_r = *zeros;
uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);

// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);

#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num));

asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_quantized<T1, T2, T3, T4, false>(_i2u, B_local_decode, N, scale, zeros);
}

/*
Kind 0: original
Kind 1: rescale
Expand Down
56 changes: 56 additions & 0 deletions testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_rescale, dec
REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_rescale, decode_i2u_to_f16_scale_zeros_rescale)
REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16_scale_zeros_rescale, decode_i1u_to_f16_scale_zeros_rescale)
REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_quantized, decode_i4u_to_f16_scale_zeros_quantized)
REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_quantized, decode_i2u_to_f16_scale_zeros_quantized)

TEST(DecodeTest, DecodeInt4ToFloat16)
{
Expand Down Expand Up @@ -1076,4 +1077,59 @@ TEST(DecodeTest, DecodeUInt4ToFloat16WithScalingWithZerosQuantized)
free(ins);
free(interleaved);
free(decoded);
}

TEST(DecodeTest, DecodeUInt2toFloat16WithScalingWithZerosQuantized)
{
constexpr int nbits = 2;
constexpr int N = 32 / nbits;
constexpr int QN = N / 8 * nbits;
constexpr bool isSigned = false;

// create four int8_t values
int8_t in_data[N] = {
0};
half scale[1] = {__float2half(1.2)};
uint qzeros[1] = {(1 << (nbits - 1)) - 1};
// breed seed
srand(0);

// random initializations with nbits range
for (int i = 0; i < N; i++)
{
in_data[i] = (rand() % (1 << nbits));
}

int8_t *ins = new int8_t[QN];
general_compress(in_data, ins, nbits, N, isSigned);

int8_t *interleaved = new int8_t[QN];
general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false);
half *decoded = new half[N];
int8_t *ins_gpu;
half *decoded_gpu, *scale_gpu;
uint *qzeros_gpu;

cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t)));
cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half)));
cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half)));
cudaCheckLastError(cudaMalloc((void **)&qzeros_gpu, 1 * sizeof(uint)));
cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice));
cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice));
cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice));
cudaCheckLastError(cudaMemcpy(qzeros_gpu, qzeros, 1 * sizeof(uint), cudaMemcpyHostToDevice));
cudaCheckLastError(cudaDeviceSynchronize());
kernelWrapper_i2u_to_f16_scale_zeros_quantized<<<dim3(1, 1, 1), dim3(1, 1, 1)>>>(ins_gpu, decoded_gpu, scale_gpu, qzeros_gpu);
kernelWrapper_i2u_to_f16_scale_zeros_quantized<<<dim3(1, 1, 1), dim3(1, 1, 1)>>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu, qzeros_gpu);
cudaCheckLastError(cudaDeviceSynchronize());
cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost));
cudaCheckLastError(cudaFree(ins_gpu));
cudaCheckLastError(cudaFree(decoded_gpu));
for (int i = 0; i < N; i++)
{
EXPECT_NEAR(((int)in_data[i] - (int)qzeros[0]) * float(scale[0]), float(decoded[i]), 1e-2);
}
free(ins);
free(interleaved);
free(decoded);
}
Loading