From 0432ecd2d9d85627859c516529798cc64f5a7a71 Mon Sep 17 00:00:00 2001 From: wsu Date: Thu, 22 Aug 2024 13:44:41 -0700 Subject: [PATCH 1/3] Add a CPU nbit to float dequantization op that supports torch.quintMxN type and QuantizedCPU backend Differential Revision: D61305979 --- .../include/fbgemm_gpu/quantize_ops_utils.h | 19 ++++++ .../include/fbgemm_gpu/utils/ops_utils.h | 5 ++ fbgemm_gpu/include/fbgemm_gpu/utils/types.h | 4 ++ .../src/quantize_ops/quantize_ops_cpu.cpp | 63 ++++++++++++++++++- include/fbgemm/QuantUtils.h | 3 +- src/QuantUtils.cc | 14 +++-- 6 files changed, 102 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h index 68e6bb5686..3d9e35edb8 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h @@ -111,4 +111,23 @@ hfp8_to_float(uint8_t hfp8_val, int ebits, int exponent_bias) { return val_out.F; } +// Get the number of bytes of a row in a tensor with quantized nbit integers +inline int32_t nbit_elems_to_bytes(const at::Tensor& input) { + const auto input_sizes = input.sizes(); + const int32_t ncols = input_sizes[1]; + // at::kQUInt4x2 is the dtype for quantized int4 tensors and at::kQUInt2x4 is + // for quantized int2 tensors. QUIntMxN (M*N=8) means quantized M-bit integer + // with each byte holding N such elements. + // input_sizes[1] is the number of elements in each row, so we need to divide + // it by 2 or 4 for quint4x2 or quint2x4 respectively to get the number of + // bytes in each row. + if (input.dtype() == at::kQUInt2x4) { + return fbgemm_gpu::div_up(ncols, 4); + } else if (input.dtype() == at::kQUInt4x2) { + return fbgemm_gpu::div_up(ncols, 2); + } else { + return ncols; + } +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/ops_utils.h index 58a17eb917..866393db34 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/ops_utils.h @@ -56,6 +56,11 @@ __builtin_ia32_serialize(void) { #define DISPATCH_TO_CPU(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function))) +#define DISPATCH_TO_QUANTIZED_CPU(name, function) \ + m.impl( \ + name, \ + torch::dispatch(c10::DispatchKey::QuantizedCPU, TORCH_FN(function))) + #define DISPATCH_TO_META(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(function))) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/types.h b/fbgemm_gpu/include/fbgemm_gpu/utils/types.h index 3d3fbad3ac..295a8ea70c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/types.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/types.h @@ -15,4 +15,8 @@ using fint32 = union fint32 { float F; }; +inline int64_t div_up(int64_t val, int64_t unit) { + return (val + unit - 1) / unit; +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index b09321129c..3512132448 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -123,7 +123,8 @@ Tensor _fusednbitrowwise_to_float_cpu( const auto input_sizes = input.sizes(); const int64_t nrows = input_sizes[0]; - const int32_t ncols = input_sizes[1]; + // Here we want the number of bytes in a row + const int32_t ncols = nbit_elems_to_bytes(input); const int32_t num_elem_per_byte = 8 / bit_rate; const int32_t output_columns = (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; @@ -149,6 +150,40 @@ Tensor _fusednbitrowwise_to_float_cpu( return output; } +Tensor _fusednbitrowwise_sbfront_to_float_cpu( + const Tensor& input, + const int64_t bit_rate) { + TENSOR_ON_CPU(input); + TENSOR_NDIM_EQUALS(input, 2); + + const auto input_sizes = input.sizes(); + const int64_t nrows = input_sizes[0]; + // Here we want the number of bytes in a row + const int32_t ncols = nbit_elems_to_bytes(input); + const int32_t num_elem_per_byte = 8 / bit_rate; + const int32_t output_columns = + (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; + + Tensor output; + output = at::empty( + {nrows, output_columns}, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); + + float* output_data = static_cast( + output.data_ptr()); // output.data_ptr(); -> Yields + // unresolved data_ptr symbol. + + fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( + bit_rate, + input.data_ptr(), + nrows, + ncols, + output_data, + /*scale_bias_last=*/false); + + return output; +} + /// @ingroup quantize-data-cpu /// Tensor& _fused8bitrowwise_to_float_cpu_out( @@ -274,6 +309,24 @@ Tensor fusednbitrowwise_to_float_cpu( return _fusednbitrowwise_to_float_cpu(input, bit_rate); } +/// @ingroup quantize-data-cpu +/// @brief Dequantize int4/int2 rows with scale and bias stored in the front +/// into float32. +/// @param input Tensor of int4/int2 rows with scale and bias stored in the +/// front. +/// @param bit_rate Bit rate of each element. Should be 4 or 2. +/// @return Tensor of float32, holding dequantized numbers. +/// +/// Dequantize int4/int2 rows with scale and bias stored in the front into +/// float32. The input tensor should have torch.quint4x2 or torch.quint2x4 dtype +/// and QuantizedCPU backend. This operator is only recommended for testing +/// purpose because its kernel is reference implementation and not optimized. +Tensor fusednbitrowwise_sbfront_to_float_cpu( + const Tensor& input, + const int64_t bit_rate) { + return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate); +} + /// @ingroup quantize-data-cpu /// Tensor fusednbitrowwise_to_half_cpu( @@ -466,6 +519,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor"); m.def( "FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor"); + m.def( + "FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor"); m.def( "FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor"); m.def( @@ -485,6 +540,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("dequantize_mx_cuda(Tensor input, int mx_group_size) -> Tensor"); } +TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) { + DISPATCH_TO_QUANTIZED_CPU( + "FusedNBitRowwiseQuantizedSBHalfFrontToFloat", + fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu); +} + TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU( "FloatToFused8BitRowwiseQuantized", diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 8b0adedef0..86d22595fe 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -366,7 +366,8 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( const uint8_t* input, size_t input_rows, int input_columns, - OutputType* output); + OutputType* output, + bool scale_bias_last = true); /** * Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized. diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 7e2fb37264..e6a53253f0 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -729,7 +729,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( const uint8_t* input, size_t input_rows, int input_columns, - OutputType* output) { + OutputType* output, + bool scale_bias_last) { static_assert( std::is_same() || std::is_same(), "Only float and float16 types are allowed."); @@ -742,13 +743,17 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( const std::uint8_t* input_row = input + row * input_columns; const float16* input_row_scale_bias = reinterpret_cast( input_row + - (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); + (scale_bias_last + ? (output_columns + num_elem_per_byte - 1) / num_elem_per_byte + : 0)); float scale = cpu_half2float(input_row_scale_bias[0]); float bias = cpu_half2float(input_row_scale_bias[1]); + const std::uint8_t* nums = + (scale_bias_last) ? input_row : input_row + 2 * sizeof(float16); OutputType* output_row = output + row * output_columns; for (int64_t col = 0; col < output_columns; ++col) { - std::uint8_t quantized = input_row[col / num_elem_per_byte]; + std::uint8_t quantized = nums[col / num_elem_per_byte]; quantized >>= (col % num_elem_per_byte) * bit_rate; quantized &= (1 << bit_rate) - 1; float output_value = scale * quantized + bias; @@ -857,7 +862,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const uint8_t* input, \ size_t input_rows, \ int input_columns, \ - type* output); \ + type* output, \ + bool scale_bias_last); \ template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( \ int bit_rate, \ const uint8_t* input, \ From ea074b481ec63685a968b9d4208350bc7e1d59aa Mon Sep 17 00:00:00 2001 From: wsu Date: Thu, 22 Aug 2024 13:44:41 -0700 Subject: [PATCH 2/3] Add int4 to int4 CPU Sequence TBE kernel Differential Revision: D61305980 --- include/fbgemm/FbgemmEmbedding.h | 6 +- include/fbgemm/Utils.h | 35 +++++++++ src/EmbeddingSpMDMAutovec.cc | 50 ++++++++++--- src/EmbeddingSpMDMAutovec.h | 6 +- src/EmbeddingSpMDMNBit.cc | 99 ++++++++++++++++++++++--- src/RefImplementations.cc | 122 +++++++++++++++++++------------ src/RefImplementations.h | 8 +- 7 files changed, 253 insertions(+), 73 deletions(-) diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index f787a637e3..15d287a54b 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -159,7 +159,7 @@ FBGEMM_API typename EmbeddingSpMDMKernelSignature< OffsetType, OutType>::Type GenerateEmbeddingSpMDMNBitWithStrides( - int bit_rate, + const int input_bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, @@ -169,7 +169,9 @@ GenerateEmbeddingSpMDMNBitWithStrides( std::int64_t output_stride = -1, std::int64_t input_stride = -1, bool scale_bias_last = true, - bool is_bf16_out = false); + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); /** * @param output_stride If -1, output_stride is same as block_size diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 69b2e6d94f..6ad32bf860 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -416,4 +417,38 @@ FBGEMM_API bool is_autovec_disabled(); FBGEMM_API bool is_autovec_forced(); FBGEMM_API bool is_asmjit_disabled(); +/** + * @brief A function to check if the input parameter in the nbit CPU TBE kernel + * is valid. + */ +template +void nbit_embedding_sanity_check( + // assertions are ignored in release mode, in which case these parameters + // will be unused + [[maybe_unused]] const int input_bit_rate, + [[maybe_unused]] const int output_bit_rate, + [[maybe_unused]] const bool no_bag) { + assert( + (input_bit_rate == 2 || input_bit_rate == 4) && + "input_bit_rate must be 2 or 4"); + if (std::is_same::value) { + assert( + (no_bag && input_bit_rate == 4 && output_bit_rate == 4) && + "we currently only support int4 to int4 for sequential TBE"); + } else { + assert( + (output_bit_rate == 8 * sizeof(OutType)) && + "output_bit_rate should be equal to 8 * sizeof(OutType)"); + } +} + +#define WARN_ONCE(...) \ + do { \ + static bool _warned = false; \ + if (!_warned) { \ + _warned = true; \ + fprintf(stderr, __VA_ARGS__); \ + } \ + } while (0) + } // namespace fbgemm diff --git a/src/EmbeddingSpMDMAutovec.cc b/src/EmbeddingSpMDMAutovec.cc index 3247c9026b..67398ab461 100644 --- a/src/EmbeddingSpMDMAutovec.cc +++ b/src/EmbeddingSpMDMAutovec.cc @@ -273,7 +273,7 @@ INSTANTIATE_SPMDM_INDEX_T() template bool EmbeddingSpMDMNBit_autovec( - const int bit_rate, + const int input_bit_rate, const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -289,9 +289,14 @@ bool EmbeddingSpMDMNBit_autovec( int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, const bool scale_bias_last /*=true*/, - const bool is_bf16_out /*=false*/) { - assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); - const int num_elem_per_byte = 8 / bit_rate; + const bool is_bf16_out /*=false*/, + const bool no_bag /*=false*/, + int output_bit_rate /*=-1*/) { + if (output_bit_rate == -1) { + output_bit_rate = 8 * sizeof(OutType); + } + nbit_embedding_sanity_check(input_bit_rate, output_bit_rate, no_bag); + const int num_elem_per_byte = 8 / input_bit_rate; if (output_stride == -1) { output_stride = block_size; @@ -335,6 +340,26 @@ bool EmbeddingSpMDMNBit_autovec( } } + if (no_bag) { + // We currently only support int4 to int4 for sequential TBE in this nbit + // kernel. Note that assert() will be ignored in release mode, so we check + // here to double check and also avoid "unused variable" warning + if (!(input_bit_rate == 4 && output_bit_rate == 4)) { + WARN_ONCE("no_bag is only supported for int4 to int4"); + return false; + } + for (int64_t i = 0; i < output_size; ++i) { + const auto idx = indices[i]; + if (idx < 0 || idx > data_size) { + return false; + } + const uint8_t* input_row = input + input_stride * idx; + memcpy(out, input_row, sizeof(uint8_t) * input_stride); + out += input_stride; + } + return true; + } + int64_t current = 0; const int64_t rounded_bs = round_up(block_size, num_elem_per_byte); vector buf(rounded_bs); @@ -387,7 +412,7 @@ bool EmbeddingSpMDMNBit_autovec( const int64_t offset = input_stride * idx + (scale_bias_last ? 0 : scale_bias_offset); const uint8_t* input_row = input + offset; - if (bit_rate == 4) { + if (input_bit_rate == 4) { const size_t halfbufsz = (block_size + 1) / 2; for (size_t j = 0; j < halfbufsz; ++j) { float quantized1 = float(input_row[j] & 0xf); @@ -395,7 +420,7 @@ bool EmbeddingSpMDMNBit_autovec( buf[j * 2] = std::fma(scale, quantized1, buf[j * 2] + bias); buf[j * 2 + 1] = std::fma(scale, quantized2, buf[j * 2 + 1] + bias); } - } else if (bit_rate == 2) { + } else if (input_bit_rate == 2) { size_t qbufsz = (block_size + 3) / 4; const uint8_t mask1 = 0x3; const uint8_t mask2 = 0xC; @@ -445,7 +470,7 @@ bool EmbeddingSpMDMNBit_autovec( #define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API bool EmbeddingSpMDMNBit_autovec( \ - const int bit_rate, \ + const int input_bit_rate, \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ @@ -461,11 +486,14 @@ bool EmbeddingSpMDMNBit_autovec( int64_t output_stride, \ int64_t input_stride, \ const bool scale_bias_last, \ - const bool is_bf16_out); + const bool is_bf16_out, \ + const bool no_bag, \ + int output_bit_rate); -#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) +#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ + INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ + INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t) #define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \ INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \ diff --git a/src/EmbeddingSpMDMAutovec.h b/src/EmbeddingSpMDMAutovec.h index 997c02cdc3..2e7632c785 100644 --- a/src/EmbeddingSpMDMAutovec.h +++ b/src/EmbeddingSpMDMAutovec.h @@ -51,7 +51,7 @@ template < typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API bool EmbeddingSpMDMNBit_autovec( - const int bit_rate, + const int input_bit_rate, const std::int64_t block_size, const std::int64_t output_size, const std::int64_t index_size, @@ -67,7 +67,9 @@ FBGEMM_API bool EmbeddingSpMDMNBit_autovec( std::int64_t output_stride = -1, std::int64_t input_stride = -1, const bool scale_bias_last = true, - const bool is_bf16_out = false); + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); } // namespace fbgemm diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index f916c568b7..c0e4429bb5 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -1022,7 +1022,7 @@ template < typename EmbeddingSpMDMKernelSignature:: Type GenerateEmbeddingSpMDMNBitWithStrides( - int bit_rate, + const int input_bit_rate, const int64_t block_size, bool has_weight, bool normalize_by_lengths, @@ -1032,8 +1032,20 @@ typename EmbeddingSpMDMKernelSignature:: int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, - bool is_bf16_out) { - assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); + const bool is_bf16_out /*=false*/, + const bool no_bag /*=false*/, + int output_bit_rate /*=-1*/) { + if (output_bit_rate == -1) { + output_bit_rate = input_bit_rate; + } + assert( + (input_bit_rate == 2 || input_bit_rate == 4) && + "input_bit_rate must be 2 or 4"); + if (std::is_same::value) { + assert( + (no_bag && input_bit_rate == 4 && output_bit_rate == 4) && + "we currently only support int4 to int4 when using sequential TBE"); + } if (!cpuinfo_initialize()) { throw runtime_error("Failed to initialize cpuinfo!"); @@ -1042,10 +1054,74 @@ typename EmbeddingSpMDMKernelSignature:: output_stride = block_size; } if (input_stride == -1) { - int64_t num_elem_per_byte = 8 / bit_rate; + int64_t num_elem_per_byte = 8 / input_bit_rate; input_stride = ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(uint16_t); } + if (no_bag) { + if (!is_autovec_disabled()) { + return [=](int64_t output_size, + int64_t index_size, + int64_t data_size, + const uint8_t* input, + const indxType* indices, + const offsetType* offsets_or_lengths, + const float* weights, + outType* out) { + return EmbeddingSpMDMNBit_autovec( + input_bit_rate, + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets_or_lengths, + weights, + normalize_by_lengths, + out, + is_weight_positional, + use_offsets, + output_stride, + input_stride, + scale_bias_last, + is_bf16_out, + no_bag, + output_bit_rate); + }; + } else { + return [=](int64_t output_size, + int64_t index_size, + int64_t data_size, + const uint8_t* input, + const indxType* indices, + const offsetType* offsets_or_lengths, + const float* weights, + outType* out) { + return EmbeddingSpMDMNBit_ref( + input_bit_rate, + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets_or_lengths, + weights, + normalize_by_lengths, + out, + is_weight_positional, + use_offsets, + output_stride, + input_stride, + scale_bias_last, + is_bf16_out, + no_bag, + output_bit_rate); + }; + } + } + if (fbgemmHasAvx512Support() && !is_asmjit_disabled()) { static GenEmbeddingSpMDMNBitLookup< indxType, @@ -1056,7 +1132,7 @@ typename EmbeddingSpMDMKernelSignature:: THREAD_LOCAL> kernel_generator; const auto original_func = kernel_generator.getOrCreate( - bit_rate, + input_bit_rate, block_size, has_weight, is_weight_positional, @@ -1096,7 +1172,7 @@ typename EmbeddingSpMDMKernelSignature:: THREAD_LOCAL> kernel_generator; const auto original_func = kernel_generator.getOrCreate( - bit_rate, + input_bit_rate, block_size, has_weight, is_weight_positional, @@ -1139,7 +1215,7 @@ typename EmbeddingSpMDMKernelSignature:: const float* weights, outType* out) { return EmbeddingSpMDMNBit_autovec( - bit_rate, + input_bit_rate, block_size, output_size, index_size, @@ -1171,7 +1247,7 @@ typename EmbeddingSpMDMKernelSignature:: const float* weights, outType* out) { return EmbeddingSpMDMNBit_ref( - bit_rate, + input_bit_rate, block_size, output_size, index_size, @@ -1364,7 +1440,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( OFFSET_TYPE, \ OUT_TYPE, \ THREAD_LOCAL>( \ - int bit_rate, \ + const int input_bit_rate, \ const int64_t block_size, \ bool has_weight, \ bool normalize_by_lengths, \ @@ -1374,7 +1450,9 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( int64_t output_stride, \ int64_t input_stride, \ bool scale_bias_last, \ - bool is_bf16_out); + const bool is_bf16_out, \ + const bool no_bag, \ + int output_bit_rate); #define INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \ @@ -1396,6 +1474,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint16_t) \ + INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint8_t) \ template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ uint8_t, \ INDEX_TYPE, \ diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 0399750b27..e11e345ebd 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -1413,7 +1413,7 @@ bool EmbeddingSpMDM_ref( template bool EmbeddingSpMDMNBit_ref( - int bit_rate, + int input_bit_rate, const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1428,10 +1428,15 @@ bool EmbeddingSpMDMNBit_ref( bool use_offsets /*=true*/, int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, - bool scale_bias_last /*=true*/, - bool is_bf16_out /*=false*/) { - assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); - int num_elem_per_byte = 8 / bit_rate; + const bool scale_bias_last /*=true*/, + const bool is_bf16_out /*=false*/, + const bool no_bag /*=false*/, + int output_bit_rate /*=-1*/) { + if (output_bit_rate == -1) { + output_bit_rate = 8 * sizeof(OutType); + } + nbit_embedding_sanity_check(input_bit_rate, output_bit_rate, no_bag); + int num_elem_per_byte = 8 / input_bit_rate; if (output_stride == -1) { output_stride = block_size; @@ -1444,6 +1449,27 @@ bool EmbeddingSpMDMNBit_ref( input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte + scale_bias_offset; } + + if (no_bag) { + // We currently only support int4 to int4 for sequential TBE in this nbit + // kernel. Note that assert() will be ignored in release mode, so we check + // here to double check and also avoid "unused variable" warning + if (!(input_bit_rate == 4 && output_bit_rate == 4)) { + WARN_ONCE("no_bag is only supported for int4 to int4"); + return false; + } + for (int64_t i = 0; i < output_size; ++i) { + const auto idx = indices[i]; + if (idx < 0 || idx > data_size) { + return false; + } + const uint8_t* input_row = input + input_stride * idx; + memcpy(out, input_row, sizeof(uint8_t) * input_stride); + out += input_stride; + } + return true; + } + int64_t current = 0; vector buf(block_size); for (int m = 0; m < output_size; ++m) { @@ -1481,8 +1507,8 @@ bool EmbeddingSpMDMNBit_ref( uint8_t quantized = input [input_stride * idx + j / num_elem_per_byte + (scale_bias_last ? 0 : scale_bias_offset)]; - quantized >>= (j % num_elem_per_byte) * bit_rate; - quantized &= (1 << bit_rate) - 1; + quantized >>= (j % num_elem_per_byte) * input_bit_rate; + quantized &= (1 << input_bit_rate) - 1; buf[j] = std::fma(scale, quantized, buf[j] + bias); } @@ -2105,47 +2131,53 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) #undef INSTANTIATE_SPMDM_OUT_T #undef INSTANTIATE_SPMDM_BASE -#define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ - template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \ - int bit_rate, \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const INDEX_TYPE* indices, \ - const OFFSET_TYPE* offsets_or_lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OUT_TYPE* out, \ - bool is_weight_positional, \ - bool use_offsets, \ - int64_t output_stride, \ - int64_t input_stride, \ - bool scale_bias_last, \ - bool is_bf16_out); \ - template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const INDEX_TYPE* indices, \ - const OFFSET_TYPE* offsets_or_lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OUT_TYPE* out, \ - bool is_weight_positional, \ - bool use_offsets, \ - int64_t output_stride, \ - int64_t input_stride, \ - int exponent_bits, \ - int exponent_bias, \ +#define INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ + template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \ + const int input_bit_rate, \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const INDEX_TYPE* indices, \ + const OFFSET_TYPE* offsets_or_lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OUT_TYPE* out, \ + bool is_weight_positional, \ + bool use_offsets, \ + int64_t output_stride, \ + int64_t input_stride, \ + const bool scale_bias_last, \ + const bool is_bf16_out, \ + const bool no_bag, \ + int output_bit_rate); +#define INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ + template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const INDEX_TYPE* indices, \ + const OFFSET_TYPE* offsets_or_lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OUT_TYPE* out, \ + bool is_weight_positional, \ + bool use_offsets, \ + int64_t output_stride, \ + int64_t input_stride, \ + int exponent_bits, \ + int exponent_bias, \ bool is_bf16_out); #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ + INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ + INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t) \ template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( \ int bit_rate, \ const int64_t block_size, \ diff --git a/src/RefImplementations.h b/src/RefImplementations.h index f01aa57d5a..076e4e7fc4 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -246,7 +246,7 @@ template < typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API bool EmbeddingSpMDMNBit_ref( - int bit_rate, + const int input_bit_rate, const std::int64_t block_size, const std::int64_t output_size, const std::int64_t index_size, @@ -261,8 +261,10 @@ FBGEMM_API bool EmbeddingSpMDMNBit_ref( bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, - bool scale_bias_last = true, - bool is_bf16_out = false); + const bool scale_bias_last = true, + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); template < typename IndexType = std::int64_t, From e971c913b3b66c027f8755cb5743c573dbef2d6f Mon Sep 17 00:00:00 2001 From: Wei Su Date: Thu, 22 Aug 2024 14:00:21 -0700 Subject: [PATCH 3/3] Enable int4 to int4 CPU STBE in fbgemm_gpu TBE API (#2994) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2994 X-link: https://github.com/facebookresearch/FBGEMM/pull/89 Enable int4 to int4 sequential CPU TBE in codegen template so that fbgemm_gpu's `IntNBitTableBatchedEmbeddingBagsCodegen` could support it Reviewed By: sryap Differential Revision: D61305978 --- ...bedding_forward_quantized_cpu_template.cpp | 28 +++++++++++++------ .../fbgemm_gpu/utils/dispatch_macros.h | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 111c680aae..2b126c96d7 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -167,9 +167,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Tensor output; SparseType o_dtype = static_cast(output_dtype); - TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16); + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT4); bool output_is_bf16 = o_dtype == SparseType::BF16; bool output_is_int8 = o_dtype == SparseType::INT8; + bool output_is_int4 = o_dtype == SparseType::INT4; {% if not nobag %} const int kINT8QparamsBytes = 8; int64_t total_adjusted_D = total_D; @@ -178,10 +179,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ } output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory)); {% else %} - const int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout + constexpr int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout + constexpr int kINT4QparamsElems = 8; // scale + bias takes 4 bytes which are 8 int4 elements int64_t adjusted_D = D; if (o_dtype == SparseType::INT8) { adjusted_D += kINT8QparamsBytes; + } else if (o_dtype == SparseType::INT4) { + adjusted_D += kINT4QparamsElems; } output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory)); @@ -212,7 +216,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ using other_fbgemm_out_t = typename std::conditional< std::is_same::value, float16, - std::conditional::value, bfloat16, float>::type >::type; + std::conditional::value, bfloat16, float>::type> ::type; AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] { const auto* indices_acc = indices.data_ptr(); const auto* offsets_acc = offsets.data_ptr(); @@ -230,7 +234,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int32_t D_end = D_offsets_acc[t + 1]; const int32_t D = D_end - D_start; {% else %} - const int32_t D_start = offsets_acc[t * B] * adjusted_D; + const int32_t elems_D = (o_dtype == SparseType::INT4) ? at::divup(adjusted_D, 2) : adjusted_D; + const int32_t D_start = offsets_acc[t * B] * elems_D; {% endif %} const auto placement = static_cast(weights_placements_ptr[t]); @@ -266,8 +271,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ {% endif %} const float* indice_weights_ptr = nullptr; - // int8 output only enabled for nobag case with ref impl - const bool nobag_op = {{ "false" if not nobag else "output_is_int8" }}; + // int8/int4 output only enabled for nobag case + const bool nobag_op = {{ "false" if not nobag else "output_is_int8 || output_is_int4" }}; {% if weighted %} indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr; {% endif %} @@ -278,7 +283,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides" if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides") %} - using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base else "other_fbgemm_out_t" }}; + using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base or use_nbit else "other_fbgemm_out_t" }}; + {% if use_nbit %} + const int output_bit_rate = output_is_int4 ? 4 : sizeof(fbgemm_out_t) * 8; + {% endif %} // TODO: merge nobag int8 path with normal asmjit dispatch {% if nobag %} const index_t* offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr; @@ -299,7 +307,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ {% endif %} >( {% if use_nbit %} - /*bit_rate=*/bit_rate, + /*input_bit_rate=*/bit_rate, {% endif %} D, {% if has_asmjit %} @@ -324,6 +332,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ /*no_bag=*/nobag_op, {% endif %} /*is_bf16_out=*/output_is_bf16 + {% if use_nbit %} + ,/*no_bag=*/nobag_op, + /*output_bit_rate=*/output_bit_rate + {% endif %} ); success = kernel( {{ "B" if not nobag else "index_size"}}, diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h b/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h index 6705b016e4..ce6a46a813 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h @@ -122,6 +122,8 @@ at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \ PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \ PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + PRIVATE_CASE_TYPE_OUTPUT2( \ + at::ScalarType::QUInt4x2, uint8_t, __VA_ARGS__) \ default: \ AT_ERROR( \ #NAME, \