diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 6a6ca3b5a2..f9e793c68e 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -13,6 +13,7 @@ #include #include +#include #ifdef __SSE__ #include diff --git a/faiss/impl/platform_macros.h b/faiss/impl/platform_macros.h index e9910e6356..1f4795c65c 100644 --- a/faiss/impl/platform_macros.h +++ b/faiss/impl/platform_macros.h @@ -49,8 +49,27 @@ inline int __builtin_clzll(uint64_t x) { return (int)__lzcnt64(x); } +#define __builtin_popcount __popcnt #define __builtin_popcountl __popcnt64 +// MSVC does not define __SSEx__, and _M_IX86_FP is only defined on 32-bit processors +// cf. https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macros +#ifdef __AVX__ +#define __SSE__ 1 +#define __SSE2__ 1 +#define __SSE3__ 1 +#define __SSE4_1__ 1 +#define __SSE4_2__ 1 +#endif + +// MSVC sets FMA and F16C automatically when using AVX2 +// Ref. FMA (under /arch:AVX2): https://docs.microsoft.com/en-us/cpp/build/reference/arch-x64 +// Ref. F16C (2nd paragraph): https://walbourn.github.io/directxmath-avx2/ +#ifdef __AVX2__ +#define __FMA__ 1 +#define __F16C__ 1 +#endif + #else /******************************************************* * Linux and OSX diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 21bbb5c01d..06fc9203c7 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -16,6 +16,7 @@ #include #include +#include #ifdef __SSE3__ #include @@ -165,7 +166,7 @@ void fvec_inner_products_ny_ref (float * ip, static inline __m128 masked_read (int d, const float *x) { assert (0 <= d && d < 4); - __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0}; + ALIGNED(16) float buf[4] = {0, 0, 0, 0}; switch (d) { case 3: buf[2] = x[2]; @@ -987,7 +988,7 @@ void compute_PQ_dis_tables_dsub2( simd8float32 centroids[8]; for (int k = 0; k < 8; k++) { - float centroid[8] __attribute__((aligned(32))); + ALIGNED(32) float centroid[8]; size_t wp = 0; size_t rp = (m0 * ksub + k + k0) * 2; for (int m = m0; m < m1; m++) {