Skip to content

Commit

Permalink
Add an auto-vectorization implementation for CPU TBE-NBit kernel (#2182)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2182

Trying to reland D50289383 with the following changes:
- Fix bazel build failure in OSS
- Only enable auto-vec implementation in linux and exclude all other OSes to avoid AR/VR as well as other xplat failures

Reviewed By: jasonjk-park

Differential Revision: D51692953

fbshipit-source-id: 07a77d47454dac4dfb3b38f2ac1cd836a69cba1a
  • Loading branch information
excelle08 authored and facebook-github-bot committed Jan 8, 2024
1 parent ca6ac2f commit 273d964
Show file tree
Hide file tree
Showing 13 changed files with 563 additions and 75 deletions.
48 changes: 47 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,23 @@

load("@bazel_skylib//lib:paths.bzl", "paths")
load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
load("defs.bzl", "get_fbgemm_avx2_srcs", "get_fbgemm_inline_avx2_srcs", "get_fbgemm_avx512_srcs", "get_fbgemm_inline_avx512_srcs", "get_fbgemm_base_srcs", "get_fbgemm_generic_srcs", "get_fbgemm_public_headers", "get_fbgemm_tests")
load("defs.bzl", "get_fbgemm_avx2_srcs", "get_fbgemm_inline_avx2_srcs", "get_fbgemm_avx512_srcs", "get_fbgemm_inline_avx512_srcs", "get_fbgemm_base_srcs", "get_fbgemm_generic_srcs", "get_fbgemm_public_headers", "get_fbgemm_tests", "get_fbgemm_autovec_srcs")

config_setting(
name = "linux-x86_64",
constraint_values = [
"@platforms//os:linux",
"@platforms//cpu:x86_64",
]
)

config_setting(
name = "linux-aarch64",
constraint_values = [
"@platforms//os:linux",
"@platforms//cpu:aarch64",
]
)

cc_library(
name = "fbgemm_base",
Expand All @@ -31,6 +47,7 @@ cc_library(
"src",
],
deps = [
":fbgemm_autovec",
":fbgemm_avx2",
":fbgemm_inline_avx2",
":fbgemm_avx512",
Expand Down Expand Up @@ -115,6 +132,35 @@ cc_library(
linkstatic = 1,
)

cc_library(
name = "fbgemm_autovec",
srcs = get_fbgemm_autovec_srcs(),
hdrs = glob(["src/*.h"]),
copts = select({
":linux-x86_64": [
"-fopenmp",
"-m64",
"-mf16c",
"-mavx2",
"-mavx512f",
"-mavx512bw",
"-mavx512dq",
"-mavx512vl",
"-masm=intel",
],
":linux-aarch64": [
"-fopenmp",
"-march=armv9-a+sve2+fp16",
],
"//conditions:default": [],
}),
deps = [
":fbgemm_base",
":fbgemm_headers",
],
linkstatic = 1,
)

cc_library(
name = "fbgemm_headers",
hdrs = get_fbgemm_public_headers(),
Expand Down
23 changes: 20 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,21 @@ get_filelist("get_fbgemm_inline_avx2_srcs(msvc=${MSVC_BOOL})"
get_filelist("get_fbgemm_avx512_srcs(msvc=${MSVC_BOOL})" FBGEMM_AVX512_SRCS)
get_filelist("get_fbgemm_inline_avx512_srcs(msvc=${MSVC_BOOL})"
FBGEMM_AVX512_INLINE_SRCS)
get_filelist("get_fbgemm_autovec_srcs()" FBGEMM_AUTOVEC_SRCS)
get_filelist("get_fbgemm_public_headers()" FBGEMM_PUBLIC_HEADERS)

add_library(fbgemm_generic OBJECT ${FBGEMM_GENERIC_SRCS})
add_library(fbgemm_avx2 OBJECT ${FBGEMM_AVX2_SRCS} ${FBGEMM_AVX2_INLINE_SRCS})
add_library(fbgemm_avx512 OBJECT
${FBGEMM_AVX512_SRCS} ${FBGEMM_AVX512_INLINE_SRCS})
add_library(fbgemm_autovec OBJECT ${FBGEMM_AUTOVEC_SRCS})

# Make libraries depend on defs.bzl
add_custom_target(defs.bzl DEPENDS defs.bzl)
add_dependencies(fbgemm_generic defs.bzl)
add_dependencies(fbgemm_avx2 defs.bzl)
add_dependencies(fbgemm_avx512 defs.bzl)
add_dependencies(fbgemm_autovec defs.bzl)

# On Windows:
# 1) Adding definition of ASMJIT_STATIC to avoid generating asmjit function
Expand All @@ -186,6 +189,9 @@ if(MSVC)
endif()
target_compile_options(fbgemm_avx2 PRIVATE "/arch:AVX2")
target_compile_options(fbgemm_avx512 PRIVATE "/arch:AVX512")
if(OpenMP_FOUND)
target_compile_options(fbgemm_autovec PRIVATE "/openmp:experimental")
endif()
else(MSVC)
string(APPEND CMAKE_CXX_FLAGS " -Wall")
string(APPEND CMAKE_CXX_FLAGS " -Wextra")
Expand Down Expand Up @@ -300,29 +306,40 @@ target_include_directories(fbgemm_avx512 BEFORE
PRIVATE "${ASMJIT_SRC_DIR}/src"
PRIVATE "${CPUINFO_SOURCE_DIR}/include")

target_include_directories(fbgemm_autovec BEFORE
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
PRIVATE "${ASMJIT_SRC_DIR}/src"
PRIVATE "${CPUINFO_SOURCE_DIR}/include")

if(FBGEMM_LIBRARY_TYPE STREQUAL "default")
add_library(fbgemm
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
$<TARGET_OBJECTS:fbgemm_avx512>
$<TARGET_OBJECTS:fbgemm_autovec>)
elseif(FBGEMM_LIBRARY_TYPE STREQUAL "shared")
add_library(fbgemm SHARED
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
$<TARGET_OBJECTS:fbgemm_avx512>
$<TARGET_OBJECTS:fbgemm_autovec>)
set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET fbgemm_autovec PROPERTY POSITION_INDEPENDENT_CODE ON)
elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static")
add_library(fbgemm STATIC
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
$<TARGET_OBJECTS:fbgemm_avx512>
$<TARGET_OBJECTS:fbgemm_autovec>)
#MSVC need to define FBGEMM_STATIC for fbgemm_generic also to
#avoid generating _dllimport functions.
target_compile_definitions(fbgemm_generic PRIVATE FBGEMM_STATIC)
target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC)
target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC)
target_compile_definitions(fbgemm_autovec PRIVATE FBGEMM_STATIC)
target_compile_definitions(fbgemm PRIVATE FBGEMM_STATIC)
else()
message(FATAL_ERROR "Unsupported library type ${FBGEMM_LIBRARY_TYPE}")
Expand Down
176 changes: 128 additions & 48 deletions bench/EmbeddingSpMDMNBitBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "./BenchUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/FbgemmConvert.h"
#include "src/EmbeddingSpMDMAutovec.h"
#include "src/RefImplementations.h"

using namespace std;
Expand Down Expand Up @@ -136,8 +137,8 @@ int run_benchmark(
vector<float> output_slws_ref(output_sls_ref.size()),
output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());

constexpr int NUM_WARMUP = 4;
constexpr int NUM_ITER = 10;
constexpr int NUM_WARMUP = 10;
constexpr int NUM_ITER = 100;
// Only counts the number of bytes for reading embedding table and ignore
// others. Should be good enough as long as embdding_dim is big enough.
double bytes = lengths_sum * fused_embedding_dim;
Expand All @@ -148,36 +149,9 @@ int run_benchmark(

for (bool has_weight : {false, true}) {
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
vector<float> output_autovec(output_sls_ref.size());

bool success = false, success_ref = false;

if (use_32_bit_indices) {
success_ref = EmbeddingSpMDMNBit_ref(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices_32.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
} else {
success = EmbeddingSpMDMNBit_ref(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
}
bool success = false, success_ref = false, success_autovec = false;

auto kernel_32 = GenerateEmbeddingSpMDMNBit<int32_t>(
bit_rate,
Expand All @@ -194,6 +168,95 @@ int run_benchmark(

vector<float>& output = has_weight ? output_slws : output_sls;
for (bool flush_cache : {false, true}) {
// Reference implementation
double t_ref = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
success_ref = EmbeddingSpMDMNBit_ref(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices_32.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
} else {
success_ref = EmbeddingSpMDMNBit_ref(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
}
},
NUM_WARMUP,
NUM_ITER,
[&]() {
if (flush_cache) {
cache_evict(fused_embedding_table);
cache_evict(indices);
cache_evict(indices_32);
cache_evict(offsets);
cache_evict(weights);
cache_evict(output);
}
});

// Auto-vectorization implementation
double t_autovec = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
success_autovec = EmbeddingSpMDMNBit_autovec(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices_32.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_autovec.data());
} else {
success_autovec = EmbeddingSpMDMNBit_autovec(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
num_rows,
fused_embedding_table.data(),
indices.data(),
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_autovec.data());
}
},
NUM_WARMUP,
NUM_ITER,
[&]() {
if (flush_cache) {
cache_evict(fused_embedding_table);
cache_evict(indices);
cache_evict(indices_32);
cache_evict(offsets);
cache_evict(weights);
cache_evict(output);
}
});

// Hand-written AVX2/AVX512 implementation
double t = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
Expand Down Expand Up @@ -251,37 +314,55 @@ int run_benchmark(
// has_weight ? output_slws_ref : output_sls_ref;
if (success != success_ref) {
assert(
false && "ERROR: refernce impl and JIT imp did not both succeed");
false &&
"ERROR: reference impl and JIT impl did not both succeed");
} else if (success != success_autovec) {
assert(
false &&
"ERROR: reference impl and auto-vec impl did not both succeed");
} else if (success) {
for (size_t i = 0; i < output.size(); ++i) {
assert(fabs(output[i] - output_ref[i]) < 1e-3);
assert(fabs(output_autovec[i] - output_ref[i]) < 1e-3);
if (fabs(output[i] - output_ref[i]) >= 1e-3) {
cout << i << " " << output[i] << " " << output_ref[i] << endl;
cout << "asmjit vs ref : " << i << " " << output[i] << " "
<< output_ref[i] << endl;
}
if (fabs(output_autovec[i] - output_ref[i]) >= 1e-3) {
cout << "autovec vec ref: " << i << " " << output_autovec[i]
<< " " << output_ref[i] << endl;
}
}
}
}

if (has_weight) {
cout << setw(16) << "SLW(WEIGHTED) ";
cout << "SLW(WEIGHTED), ";
} else {
cout << setw(16) << "SLS ";
cout << "SLS, ";
}
if (flush_cache) {
cout << setw(20) << "cache flushed";
cout << "cache flushed, ";
} else {
cout << setw(20) << "cache not flushed";
cout << "cache not flushed, ";
}
if (prefetch) {
cout << setw(16) << "prefetch on";
cout << "prefetch on, ";
} else {
cout << setw(16) << "prefetch off";
cout << "prefetch off, ";
}

cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s"
<< setw(20) << "effective b/w: " << setw(16)
<< bytes_padded / 1e9 / t << "GB/s" << setw(8) << " time "
<< setw(16) << t << endl;
cout << "b/w, " << bytes / 1e9 / t << ", GB/s, "
<< "effective b/w, " << bytes_padded / 1e9 / t << ", GB/s, "
<< "time, " << t << ", autovec b/w, " << bytes / 1e9 / t_autovec
<< ", GB/s, "
<< "autovec eff. b/w, " << bytes_padded / 1e9 / t_autovec
<< ", GB/s, "
<< "autovec time, " << t_autovec << ", ref b/w, "
<< bytes / 1e9 / t_ref << ", GB/s, "
<< "ref eff. b/w, " << bytes_padded / 1e9 / t_ref << ", GB/s, "
<< "ref time, " << t_ref << ", autovec speedup, "
<< t_ref / t_autovec << ", asmjit speedup, " << t_ref / t << endl;
} // flush_cache
} // has_weight
return 0;
Expand All @@ -295,18 +376,17 @@ int main() {

vector<vector<int>> inputs(GetInputs_());

for (int bit_rate : {2, 4}) {
for (int bit_rate : {4, 2}) {
for (auto& input : inputs) {
assert(input.size() > 3);
batch_size = input[0];
num_rows = input[1];
embedding_dim = input[2];
average_len = input[3];

cout << "bit_rate" << setw(6) << bit_rate << "batch size" << setw(6)
<< batch_size << setw(10) << "num rows" << setw(16) << num_rows
<< setw(10) << "emb dim" << setw(6) << embedding_dim << setw(16)
<< "avg length" << setw(6) << average_len << endl;
cout << "bit_rate, " << bit_rate << ", batch size, " << batch_size
<< ", num rows, " << num_rows << ", emb dim, " << embedding_dim
<< ", avg length, " << average_len << endl;
// args: batch sz, num rows, emb dim, avg len, normalize, use 32b,
// prefetch
cout << "64 bit indices, ";
Expand Down
Loading

0 comments on commit 273d964

Please sign in to comment.