From 7e401c095fd2ed7aa3c73882c7c10864e63292cc Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Fri, 4 Sep 2020 10:46:57 -0700 Subject: [PATCH] part2: Move embedding quantization kernels to fbgemm for better sharing between C2/PT (#425) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/425 8bit with float scale and bias. Test and benchmark added. ``` With scale and bias as float bit_rate, rows, cols, elems_per_usec, GB/Sec 8, 100, 16, 556.20, 2.22 8, 100, 64, 1022.51, 4.09 8, 100, 128, 1121.43, 4.49 8, 100, 256, 1292.61, 5.17 8, 100, 512, 1526.69, 6.11 8, 100, 1024, 1407.09, 5.63 8, 100, 2048, 1620.34, 6.48 8, 120, 16, 562.60, 2.25 8, 120, 64, 1058.52, 4.23 8, 120, 128, 1082.74, 4.33 8, 120, 256, 1382.87, 5.53 8, 120, 512, 1513.15, 6.05 8, 120, 1024, 1441.19, 5.76 8, 120, 2048, 1634.99, 6.54 8, 1000, 16, 598.05, 2.39 8, 1000, 64, 1151.16, 4.60 8, 1000, 128, 1071.58, 4.29 8, 1000, 256, 1278.66, 5.11 8, 1000, 512, 1441.13, 5.76 8, 1000, 1024, 1605.48, 6.42 8, 1000, 2048, 1764.24, 7.06 ``` Reviewed By: supriyar Differential Revision: D23455486 fbshipit-source-id: e0dea307c42d614747302544a7179fa40194dad6 --- bench/EmbeddingQuantizeBenchmark.cc | 46 ++++++++++--- include/fbgemm/QuantUtils.h | 27 ++++++++ include/fbgemm/QuantUtilsAvx2.h | 6 ++ src/QuantUtils.cc | 44 ++++++++++++ src/QuantUtilsAvx2.cc | 101 ++++++++++++++++++++++++++++ test/QuantUtilsTest.cc | 50 +++++++++++++- 6 files changed, 264 insertions(+), 10 deletions(-) diff --git a/bench/EmbeddingQuantizeBenchmark.cc b/bench/EmbeddingQuantizeBenchmark.cc index efbcdbcb8f..bc91eda3cf 100644 --- a/bench/EmbeddingQuantizeBenchmark.cc +++ b/bench/EmbeddingQuantizeBenchmark.cc @@ -8,6 +8,7 @@ #include #include #include +#include #ifdef _OPENMP #include @@ -20,24 +21,41 @@ using namespace std; using namespace fbgemm; +// T is the type of scale and bias +template void performance_test() { constexpr int NWARMUP = 4; constexpr int NITER = 256; + if (is_same::value) { + cout << "With scale and bias as float16" << endl; + } else { + cout << "With scale and bias as float" << endl; + } cout << setw(8) << "bit_rate" << ", " << setw(6) << "rows" << "," << setw(6) << "cols" << "," << setw(16) << "elems_per_usec" << "," << setw(10) << "GB/Sec" << endl; - for (int bit_rate : {2, 4, 8}) { + std::vector bit_rates; + if (is_same::value) { + bit_rates = {2, 4, 8}; + } else { + // float + bit_rates = {8}; + } + for (int bit_rate : bit_rates) { for (int rowSize : {100, 120, 1000}) { for (int colSize : {16, 64, 128, 256, 512, 1024, 2048}) { aligned_vector inpVec(rowSize * colSize); randFill(inpVec, -10.0f, 10.0f); - int elements_per_byte = 8 / bit_rate; - int out_emb_cols = - (colSize + elements_per_byte - 1) / elements_per_byte; + int out_emb_cols = colSize; + + if (is_same::value) { + int elements_per_byte = 8 / bit_rate; + out_emb_cols = (colSize + elements_per_byte - 1) / elements_per_byte; + } int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(float16)); aligned_vector outVec(outVecSize); @@ -45,8 +63,15 @@ void performance_test() { duration = measureWithWarmup( [&]() { - FloatToFusedNBitRowwiseQuantizedSBHalf( - bit_rate, inpVec.data(), rowSize, colSize, outVec.data()); + is_same::value + ? FloatToFusedNBitRowwiseQuantizedSBHalf( + bit_rate, + inpVec.data(), + rowSize, + colSize, + outVec.data()) + : FloatToFused8BitRowwiseQuantizedSBFloat( + inpVec.data(), rowSize, colSize, outVec.data()); }, NWARMUP, NITER, @@ -63,8 +88,10 @@ void performance_test() { cout << setw(8) << bit_rate << "," << setw(6) << rowSize << ", " << setw(6) << colSize << ","; - cout << setw(16) << std::fixed << std::setprecision(2) << elements_per_usec << ", "; - cout << setw(10) << std::fixed << std::setprecision(2) << gigabyes_per_sec << endl; + cout << setw(16) << std::fixed << std::setprecision(2) + << elements_per_usec << ", "; + cout << setw(10) << std::fixed << std::setprecision(2) + << gigabyes_per_sec << endl; } // for each cols } // for each rows } // for each bit_rate @@ -78,6 +105,7 @@ int main() { omp_set_num_threads(1); } #endif - performance_test(); + performance_test(); + performance_test(); return 0; } diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 4a8abb9a81..5414c33794 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -274,4 +274,31 @@ FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalfRef( int input_rows, int input_columns, std::uint8_t* output); + +/** + * Convert float inputs to rowwise quantized (8-bit) outputs. + * Scale and Bias are in float. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * This version intentionally supports only 8-bit because we want to discourage + * the usage of float scale and bias with 2 and 4 bit cases as that diminishes + * the overall memory savings. + * + */ +FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloat( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +/** + * Same as FloatToFused8BitRowwiseQuantizedSBFloat but unoptimized. + * This should not be called directly except in testing. + */ +FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloatRef( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 62320d743b..1080253b0c 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -126,4 +126,10 @@ void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2( int input_columns, std::uint8_t* output); +void FloatToFused8BitRowwiseQuantizedSBFloatAvx2( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 99f9d43e36..1edf3aaf54 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -555,4 +555,48 @@ void FloatToFusedNBitRowwiseQuantizedSBHalf( } } +void FloatToFused8BitRowwiseQuantizedSBFloatRef( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + constexpr float kEpsilon = 1e-8f; + + int output_columns = input_columns + 2 * sizeof(float); + for (std::size_t row = 0; row < input_rows; ++row) { + const float* input_row = input + row * input_columns; + std::uint8_t* output_row = output + row * output_columns; + float* output_row_scale_bias = + reinterpret_cast(output_row + input_columns); + + float minimum_element = + *std::min_element(input_row, input_row + input_columns); + float maximum_element = + *std::max_element(input_row, input_row + input_columns); + float range = maximum_element - minimum_element; + + output_row_scale_bias[0] = range / 255.0f; + output_row_scale_bias[1] = minimum_element; + const auto inverse_scale = 255.0f / (range + kEpsilon); + for (std::size_t col = 0; col < input_columns; ++col) { + output_row[col] = + std::lrintf((input_row[col] - minimum_element) * inverse_scale); + } + } +} + +void FloatToFused8BitRowwiseQuantizedSBFloat( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { + FloatToFused8BitRowwiseQuantizedSBFloatAvx2( + input, input_rows, input_columns, output); + } else { + FloatToFused8BitRowwiseQuantizedSBFloatRef( + input, input_rows, input_columns, output); + } +} + } // namespace fbgemm diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index eacb0a2717..ec962e8497 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -1622,4 +1622,105 @@ template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>( int input_columns, std::uint8_t* output); +void FloatToFused8BitRowwiseQuantizedSBFloatAvx2( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + constexpr int VLEN = 8; + constexpr float kEpsilon = 1e-8f; + + __m256i permute_mask1_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + // clang-format off + __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00); + // clang-format on + + __m256i permute_mask2_v = + _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + + int output_columns = input_columns + 2 * sizeof(float); + for (std::size_t row = 0; row < input_rows; ++row) { + const float* input_row = input + row * input_columns; + std::uint8_t* output_row = output + row * output_columns; + float* output_row_scale_bias = + reinterpret_cast(output_row + input_columns); + + float minimum_element = FLT_MAX; + float maximum_element = -FLT_MAX; + __m256 min_v = _mm256_set1_ps(minimum_element); + __m256 max_v = _mm256_set1_ps(maximum_element); + std::size_t col; + for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) { + __m256 in_v = _mm256_loadu_ps(input_row + col); + min_v = _mm256_min_ps(min_v, in_v); + max_v = _mm256_max_ps(max_v, in_v); + } + alignas(64) float min_buf[VLEN], max_buf[VLEN]; + _mm256_store_ps(min_buf, min_v); + _mm256_store_ps(max_buf, max_v); + for (int i = 0; i < VLEN; ++i) { + minimum_element = std::min(minimum_element, min_buf[i]); + maximum_element = std::max(maximum_element, max_buf[i]); + } + for (; col < input_columns; ++col) { + minimum_element = std::min(minimum_element, input_row[col]); + maximum_element = std::max(maximum_element, input_row[col]); + } + + float range = maximum_element - minimum_element; + + output_row_scale_bias[0] = range / 255.0f; + output_row_scale_bias[1] = minimum_element; + const auto inverse_scale = 255.0f / (range + kEpsilon); + min_v = _mm256_set1_ps(minimum_element); + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + + for (col = 0; col < input_columns / (4 * VLEN) * (4 * VLEN); + col += 4 * VLEN) { + __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v), + inverse_scale_v)); + __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + VLEN), min_v), + inverse_scale_v)); + __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + 2 * VLEN), min_v), + inverse_scale_v)); + __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + 3 * VLEN), min_v), + inverse_scale_v)); + + // An instruction sequence to save 32 32-bit integers as 8-bit integers + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + xyzw_packed_v = + _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(output_row + col), xyzw_packed_v); + } + for (; col < input_columns / VLEN * VLEN; col += VLEN) { + __m256i rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v), + inverse_scale_v)); + + // An instruction sequence to save 8 32-bit integers as 8-bit integers + rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); + rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask2_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(output_row + col), + _mm256_castsi256_si128(rounded_v)); + } + for (; col < input_columns; ++col) { + output_row[col] = + std::lrintf((input_row[col] - minimum_element) * inverse_scale); + } + } +} + } // namespace fbgemm diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index 23d57e22f5..09b69bcc51 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -34,6 +34,11 @@ class FusedQuantizeDequantizeTest : public testing::TestWithParam {}; class EmbeddingQuantizeTest : public testing::TestWithParam> {}; +// Parameter are input rows and input columns +// Scale and Bias are of type float (SBFloat) +class EmbeddingQuantizeSBFloatTest + : public testing::TestWithParam> {}; + INSTANTIATE_TEST_CASE_P( InstantiationName, QuantizeGroupwiseTest, @@ -57,6 +62,13 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn({1, 2, 3}), ::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65}))); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + EmbeddingQuantizeSBFloatTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2, 3}), + ::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65}))); + template void ref_impl( const vector& src, @@ -185,7 +197,14 @@ ::testing::AssertionResult isQEmbeddingClose( res_ref.data() + i * ld + out_emb_cols)[1]); } else { // float scale and bias - // TODO: + scaleTest = reinterpret_cast( + res.data() + i * ld + out_emb_cols)[0]; + biasTest = reinterpret_cast( + res.data() + i * ld + out_emb_cols)[1]; + scaleRef = reinterpret_cast( + res_ref.data() + i * ld + out_emb_cols)[0]; + biasRef = reinterpret_cast( + res_ref.data() + i * ld + out_emb_cols)[1]; } if (fabs(scaleTest - scaleRef) > std::numeric_limits::epsilon()) { ss << " scale mismatch for row:" << i; @@ -548,3 +567,32 @@ TEST_P(EmbeddingQuantizeTest, embeddingHalfTest) { EXPECT_TRUE(isQEmbeddingClose(outVecTest, outVecRef, rows, out_emb_cols)); } + +TEST_P(EmbeddingQuantizeSBFloatTest, embeddingFloatTest) { + int rows, cols; + tie(rows, cols) = GetParam(); + + random_device rd; + mt19937 gen(rd()); + + uniform_real_distribution disFP(-10.0f, 10.0f); + + vector inpVec(rows * cols); + + generate(inpVec.begin(), inpVec.end(), [&, disFP]() mutable { + return disFP(gen); + }); + + int outVecSize = rows * (cols + 2 * sizeof(float)); + + vector outVecRef(outVecSize); + vector outVecTest(outVecSize); + + FloatToFused8BitRowwiseQuantizedSBFloatRef( + inpVec.data(), rows, cols, outVecRef.data()); + FloatToFused8BitRowwiseQuantizedSBFloat( + inpVec.data(), rows, cols, outVecTest.data()); + + // The number of input columns is the same as the number of output columns + EXPECT_TRUE(isQEmbeddingClose(outVecTest, outVecRef, rows, cols)); +}