Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

≈65% speedup of the AVX-512 implementation of ggml_vec_dot_q4_0() #933

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer"
option(LLAMA_AVX "llama: enable AVX" ON)
option(LLAMA_AVX2 "llama: enable AVX2" ON)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_FMA "llama: enable FMA" ON)
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
Expand Down Expand Up @@ -220,6 +222,16 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
if (MSVC)
if (LLAMA_AVX512)
add_compile_options(/arch:AVX512)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions(__AVX512VBMI__)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions(__AVX512VNNI__)
endif()
elseif (LLAMA_AVX2)
add_compile_options(/arch:AVX2)
elseif (LLAMA_AVX)
Expand All @@ -240,9 +252,13 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
endif()
if (LLAMA_AVX512)
add_compile_options(-mavx512f)
# add_compile_options(-mavx512cd)
# add_compile_options(-mavx512dq)
# add_compile_options(-mavx512bw)
add_compile_options(-mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
add_compile_options(-mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_options(-mavx512vnni)
endif()
endif()
else()
Expand Down
235 changes: 203 additions & 32 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1977,33 +1977,187 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
}

#if __AVX512F__ && QK4_0 == 32
static inline __m512 dot_q4_0_oneblock_avx512(
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | :. =_ () [] <> () Zz Yy|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
//
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
#ifdef __AVX512VBMI__
const __m512i byte_perm = _mm512_set_epi8(
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
);
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
// After applying VPERMB, `permuted` looks like this:
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
#else
const __m512i word_perm = _mm512_set_epi16(
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
);
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
#endif

// Shift every odd-numbered 16-bit group to the right by 4 bits.
const __mmask32 shift_mask = 0xaaaaaaaa;
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+

// Now we just need to zero out the higher nibble in each byte, and we're done.
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
return _mm512_and_si512( low_nibble_mask, shifted );
// The final result looks like this:
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
}

static inline __m512 dot_q4_0_twoblocks_avx512(
__m512 acc,
const block_q4_0 * restrict x,
const block_q4_0 * restrict y,
int i
) {
// Compute combined scale for the block
__m512 d = _mm512_set1_ps( x[i].d * y[i].d );

__m256i bx = bytesFromNibbles( x[i].qs );
__m256i by = bytesFromNibbles( y[i].qs );

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
by = _mm256_sub_epi8( by, off );

// Sign-extend 16 signed bytes into int16_t
__m512i x32 = _mm512_cvtepi8_epi16( bx );
__m512i y32 = _mm512_cvtepi8_epi16( by );
// Compute products of int16_t integers, add pairwise
__m512i i64 = _mm512_madd_epi16( x32, y32 );
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
const __mmask8 load_mask = 0x1f;
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );

// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// blocks_0_float
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// blocks_1_float
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
// random data, which might very well underflow. At least on Intel, this leads
// to a huge penalty that can't be ignored (easily 100x or more) unless you
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
// (and ggml can't assume that you do)...
const __mmask16 scale_mul_mask = 0x21;
#ifdef __clang__
// ...however, clang decides to optimize the multiplication mask away:
// https://godbolt.org/z/P8PqdsfvW
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
__m512i scales;
__asm__(
"vmulps %1, %2, %0%{%3%}"
: "=v" ( scales )
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
);
#else
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
#endif
const __m512i scale_perm = _mm512_set_epi32(
5, 5, 5, 5, 5, 5, 5, 5,
0, 0, 0, 0, 0, 0, 0, 0
);
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+

const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );

// Now we want to compute dot products of 4-element byte vectors and store them in
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
// +----+----+----+----+
// ... | 03 | 02 | 01 | 00 |
// +----+----+----+----+
// bytes_0
// +----+----+----+----+
// ... | D | C | B | A |
// +----+----+----+----+
// bytes_1
// +----+----+----+----+
// ... | H | G | F | E |
// +----+----+----+----+
// final_res_int
// +----+----+----+----+
// ... | A*E+B*F+C*G+D*H |
// +----+----+----+----+
const __m512i plus_8 = _mm512_set1_epi8( 8 );
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );

#ifdef __AVX512VNNI__
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
// which means we only need 2 instructions.
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
const __m512i minus_8 = _mm512_set1_epi8( -8 );
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
#else
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
const __m512i one = _mm512_set1_epi16( 1 );
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
#endif

// Convert int32_t to float
__m512 p = _mm512_cvtepi32_ps( i64 );
// Apply the scale, and accumulate
return _mm512_fmadd_ps( d, p, acc );
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
}
#endif

Expand Down Expand Up @@ -2135,25 +2289,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
__m512 acc0 = _mm512_setzero_ps();
__m512 acc1 = _mm512_setzero_ps();

const int superblock_size = 8;
const int superblock_size = 16;

const int superblock_count = nb / superblock_size;

for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
int i = superblock_ix * superblock_size;

acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
}

// Remainders
for (int i = superblock_count * superblock_size; i < nb; ++i) {
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
for (int i = superblock_count * superblock_size; i < nb; i += 2) {
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
}

// Horizontal sum of all lanes of the accumulator
Expand Down Expand Up @@ -11303,6 +11458,22 @@ int ggml_cpu_has_avx512(void) {
#endif
}

int ggml_cpu_has_avx512_vbmi(void) {
#if defined(__AVX512VBMI__)
return 1;
#else
return 0;
#endif
}

int ggml_cpu_has_avx512_vnni(void) {
#if defined(__AVX512VNNI__)
return 1;
#else
return 0;
#endif
}

int ggml_cpu_has_fma(void) {
#if defined(__FMA__)
return 1;
Expand Down
2 changes: 2 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,8 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_avx512_vbmi(void);
int ggml_cpu_has_avx512_vnni(void);
int ggml_cpu_has_fma(void);
int ggml_cpu_has_neon(void);
int ggml_cpu_has_arm_fma(void);
Expand Down
26 changes: 14 additions & 12 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1915,18 +1915,20 @@ const char * llama_print_system_info(void) {
static std::string s;

s = "";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";

return s.c_str();
}
Expand Down