Skip to content

Commit

Permalink
Windows: portable intrinsics (#1684)
Browse files Browse the repository at this point in the history
Summary:
Trying to compile windows for AVX2 in conda-forge/faiss-split-feedstock#27
(after #1600) surfaced a bunch of things (#1680, #1681, #1682), but the most voluminous problem
was MSVC being much worse at dealing with operator overloads and casts around `__m128` / `__m256`.

This lead to loads of errors that looked as follows:
```
[...]\faiss\utils\distances_simd.cpp(411): error C2676: binary '+=': '__m128' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(440): error C2676: binary '-': '__m256' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(441): error C2676: binary '*': 'const __m256' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(446): error C2676: binary '+=': '__m128' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(451): error C2676: binary '-': '__m128' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(452): error C2676: binary '*': 'const __m128' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(459): error C2676: binary '-': '__m128' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(460): error C2676: binary '*': '__m128' does not define this operator or a conversion to a type acceptable to the predefined operator
[...]\faiss\utils\distances_simd.cpp(471): error C2440: '<function-style-cast>': cannot convert from '__m256i' to '__m256'
```

I've followed https://software.intel.com/sites/landingpage/IntrinsicsGuide/ to try to replace everything correctly,
but this will surely require close review, because I'm not sure how well these code-paths are checked by the
test suite.

In any case, with the commits from #1600 #1666 #1680 #1681 #1682, I was able to build `libfaiss` & `faiss`
for AVX2 on windows (while remaining "green" on linux/osx, both with & without AVX2).

Sidenote: the issues in the last commit (26fc7cf)
were uncovered by adding the `__SSE3__` compat macros in #1681.

Pull Request resolved: #1684

Test Plan: buck test //faiss/tests/...

Reviewed By: beauby

Differential Revision: D26454443

Pulled By: mdouze

fbshipit-source-id: 70df0818e357f1ecea6a056d619618df0236e0eb
  • Loading branch information
h-vetinari authored and facebook-github-bot committed Feb 18, 2021
1 parent d0ad3d7 commit 1afaddb
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 55 deletions.
13 changes: 7 additions & 6 deletions faiss/Index2Layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cstdio>
#include <cassert>
#include <stdint.h>
#include <faiss/impl/platform_macros.h>

#ifdef __SSE3__
#include <immintrin.h>
Expand Down Expand Up @@ -285,9 +286,9 @@ struct DistanceXPQ4 : Distance2Level {

for (int m = 0; m < M; m++) {
__m128 qi = _mm_loadu_ps(qa);
__m128 recons = l1_t[m] + pq_l2_t[*code++];
__m128 diff = qi - recons;
accu += diff * diff;
__m128 recons = _mm_add_ps (l1_t[m], pq_l2_t[*code++]);
__m128 diff = _mm_sub_ps (qi, recons);
accu = _mm_add_ps (accu, _mm_mul_ps (diff, diff));
pq_l2_t += 256;
qa += 4;
}
Expand Down Expand Up @@ -338,9 +339,9 @@ struct Distance2xXPQ4 : Distance2Level {

for (int m = 0; m < M_2; m++) {
__m128 qi = _mm_loadu_ps(qa);
__m128 recons = pq_l1[m] + pq_l2_t[*code++];
__m128 diff = qi - recons;
accu += diff * diff;
__m128 recons = _mm_add_ps (pq_l1[m], pq_l2_t[*code++]);
__m128 diff = _mm_sub_ps (qi, recons);
accu = _mm_add_ps (accu, _mm_mul_ps (diff, diff));
pq_l2_t += 256;
qa += 4;
}
Expand Down
28 changes: 14 additions & 14 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ struct Codec8bit {
i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
__m256 f8 = _mm256_cvtepi32_ps (i8);
__m256 half = _mm256_set1_ps (0.5f);
f8 += half;
f8 = _mm256_add_ps(f8, half);
__m256 one_255 = _mm256_set1_ps (1.f / 255.f);
return f8 * one_255;
return _mm256_mul_ps(f8, one_255);
}
#endif
};
Expand Down Expand Up @@ -118,9 +118,9 @@ struct Codec4bit {
i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
__m256 f8 = _mm256_cvtepi32_ps (i8);
__m256 half = _mm256_set1_ps (0.5f);
f8 += half;
f8 = _mm256_add_ps(f8, half);
__m256 one_255 = _mm256_set1_ps (1.f / 15.f);
return f8 * one_255;
return _mm256_mul_ps(f8, one_255);
}
#endif
};
Expand Down Expand Up @@ -197,9 +197,9 @@ struct Codec6bit {
// this could also be done with bit manipulations but it is
// not obviously faster
__m256 half = _mm256_set1_ps (0.5f);
f8 += half;
f8 = _mm256_add_ps(f8, half);
__m256 one_63 = _mm256_set1_ps (1.f / 63.f);
return f8 * one_63;
return _mm256_mul_ps(f8, one_63);
}

#endif
Expand Down Expand Up @@ -389,7 +389,7 @@ struct QuantizerTemplate<Codec, true, 8>: QuantizerTemplate<Codec, true, 1> {
__m256 reconstruct_8_components (const uint8_t * code, int i) const
{
__m256 xi = Codec::decode_8_components (code, i);
return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff);
return _mm256_add_ps(_mm256_set1_ps(this->vmin), _mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
}

};
Expand Down Expand Up @@ -449,7 +449,7 @@ struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
__m256 reconstruct_8_components (const uint8_t * code, int i) const
{
__m256 xi = Codec::decode_8_components (code, i);
return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
return _mm256_add_ps(_mm256_loadu_ps(this->vmin + i), _mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
}


Expand Down Expand Up @@ -808,13 +808,13 @@ struct SimilarityL2<8> {
void add_8_components (__m256 x) {
__m256 yiv = _mm256_loadu_ps (yi);
yi += 8;
__m256 tmp = yiv - x;
accu8 += tmp * tmp;
__m256 tmp = _mm256_sub_ps(yiv, x);
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
}

void add_8_components_2 (__m256 x, __m256 y) {
__m256 tmp = y - x;
accu8 += tmp * tmp;
__m256 tmp = _mm256_sub_ps(y, x);
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
}

float result_8 () {
Expand Down Expand Up @@ -888,11 +888,11 @@ struct SimilarityIP<8> {
void add_8_components (__m256 x) {
__m256 yiv = _mm256_loadu_ps (yi);
yi += 8;
accu8 += yiv * x;
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
}

void add_8_components_2 (__m256 x1, __m256 x2) {
accu8 += x1 * x2;
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
}

float result_8 () {
Expand Down
70 changes: 35 additions & 35 deletions faiss/utils/distances_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ struct ElementOpL2 {
}

static __m128 op (__m128 x, __m128 y) {
__m128 tmp = x - y;
return tmp * tmp;
__m128 tmp = _mm_sub_ps(x, y);
return _mm_mul_ps(tmp, tmp);
}

};
Expand All @@ -227,7 +227,7 @@ struct ElementOpIP {
}

static __m128 op (__m128 x, __m128 y) {
return x * y;
return _mm_mul_ps(x, y);
}

};
Expand Down Expand Up @@ -299,7 +299,7 @@ void fvec_op_ny_D8 (float * dis, const float * x,

for (size_t i = 0; i < ny; i++) {
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
accu = _mm_add_ps (accu, ElementOp::op(x1, _mm_loadu_ps (y))); y += 4;
accu = _mm_hadd_ps (accu, accu);
accu = _mm_hadd_ps (accu, accu);
dis[i] = _mm_cvtss_f32 (accu);
Expand All @@ -316,8 +316,8 @@ void fvec_op_ny_D12 (float * dis, const float * x,

for (size_t i = 0; i < ny; i++) {
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
accu = _mm_add_ps (accu, ElementOp::op(x1, _mm_loadu_ps (y))); y += 4;
accu = _mm_add_ps (accu, ElementOp::op(x2, _mm_loadu_ps (y))); y += 4;
accu = _mm_hadd_ps (accu, accu);
accu = _mm_hadd_ps (accu, accu);
dis[i] = _mm_cvtss_f32 (accu);
Expand Down Expand Up @@ -409,7 +409,7 @@ float fvec_inner_product (const float * x,
}

__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 += _mm256_extractf128_ps(msum1, 0);
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));

if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
Expand Down Expand Up @@ -438,27 +438,27 @@ float fvec_L2sqr (const float * x,
while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
d -= 8;
}

__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 += _mm256_extractf128_ps(msum1, 0);
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));

if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
d -= 4;
}

if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b1 = mx - my;
msum2 += a_m_b1 * a_m_b1;
__m128 a_m_b1 = _mm_sub_ps(mx, my);
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
}

msum2 = _mm_hadd_ps (msum2, msum2);
Expand All @@ -469,33 +469,33 @@ float fvec_L2sqr (const float * x,
float fvec_L1 (const float * x, const float * y, size_t d)
{
__m256 msum1 = _mm256_setzero_ps();
__m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32 (0x7fffffffUL));

while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b = mx - my;
msum1 += _mm256_and_ps(signmask, a_m_b);
const __m256 a_m_b = _mm256_sub_ps(mx, my);
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
d -= 8;
}

__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 += _mm256_extractf128_ps(msum1, 0);
__m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32 (0x7fffffffUL));

if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b = mx - my;
msum2 += _mm_and_ps(signmask2, a_m_b);
const __m128 a_m_b = _mm_sub_ps(mx, my);
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
d -= 4;
}

if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b = mx - my;
msum2 += _mm_and_ps(signmask2, a_m_b);
__m128 a_m_b = _mm_sub_ps(mx, my);
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
}

msum2 = _mm_hadd_ps (msum2, msum2);
Expand All @@ -506,32 +506,32 @@ float fvec_L1 (const float * x, const float * y, size_t d)
float fvec_Linf (const float * x, const float * y, size_t d)
{
__m256 msum1 = _mm256_setzero_ps();
__m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32 (0x7fffffffUL));

while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b = mx - my;
const __m256 a_m_b = _mm256_sub_ps(mx, my);
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
d -= 8;
}

__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
__m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32 (0x7fffffffUL));

if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b = mx - my;
const __m128 a_m_b = _mm_sub_ps(mx, my);
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
d -= 4;
}

if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b = mx - my;
__m128 a_m_b = _mm_sub_ps(mx, my);
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
}

Expand Down Expand Up @@ -562,17 +562,17 @@ float fvec_L2sqr (const float * x,
while (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
const __m128 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
d -= 4;
}

if (d > 0) {
// add the last 1, 2 or 3 values
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
__m128 a_m_b1 = mx - my;
msum1 += a_m_b1 * a_m_b1;
__m128 a_m_b1 = _mm_sub_ps(mx, my);
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
}

msum1 = _mm_hadd_ps (msum1, msum1);
Expand Down Expand Up @@ -821,7 +821,7 @@ static inline int fvec_madd_and_argmin_sse (
while (n--) {
__m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
*c4 = vc4;
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps (vmin4, vc4));
// imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!

imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
Expand All @@ -837,7 +837,7 @@ static inline int fvec_madd_and_argmin_sse (
{
idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
__m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps (vmin4, vc4));
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
_mm_andnot_si128 (mask, imin4));
vmin4 = _mm_min_ps (vmin4, vc4);
Expand All @@ -846,7 +846,7 @@ static inline int fvec_madd_and_argmin_sse (
{
idx4 = _mm_shuffle_epi32 (imin4, 1);
__m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps (vmin4, vc4));
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
_mm_andnot_si128 (mask, imin4));
// vmin4 = _mm_min_ps (vmin4, vc4);
Expand Down

0 comments on commit 1afaddb

Please sign in to comment.