Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Panama float vector distance impls inlinable #14031

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ Optimizations
* GITHUB#14032: Speed up PostingsEnum when positions are requested.
(Adrien Grand)

* GITHUB#14031: Ensure Panama float vector distance impls inlinable.
(Robert Muir, Chris Hegarty)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
}

// cached vector sizes for smaller method bodies
private static final int FLOAT_SPECIES_LENGTH = FLOAT_SPECIES.length();

// the way FMA should work! if available use it, otherwise fall back to mul/add
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
if (Constants.HAS_FAST_VECTOR_FMA) {
Expand All @@ -99,7 +102,7 @@ public float dotProduct(float[] a, float[] b) {
float res = 0;

// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
if (a.length > 2 * FLOAT_SPECIES.length()) {
if (a.length > 2 * FLOAT_SPECIES_LENGTH) {
i += FLOAT_SPECIES.loopBound(a.length);
res += dotProductBody(a, b, i);
}
Expand All @@ -120,30 +123,33 @@ private float dotProductBody(float[] a, float[] b, int limit) {
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
final int unrolledLimit = limit - 3 * FLOAT_SPECIES_LENGTH;
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES_LENGTH) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
acc1 = fma(va, vb, acc1);

// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
final int i2 = i + FLOAT_SPECIES_LENGTH;
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i2);
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i2);
acc2 = fma(vc, vd, acc2);

// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
final int i3 = i2 + FLOAT_SPECIES_LENGTH;
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i3);
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i3);
acc3 = fma(ve, vf, acc3);

// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
final int i4 = i3 + FLOAT_SPECIES_LENGTH;
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i4);
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i4);
acc4 = fma(vg, vh, acc4);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
for (; i < limit; i += FLOAT_SPECIES_LENGTH) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
acc1 = fma(va, vb, acc1);
Expand All @@ -162,7 +168,7 @@ public float cosine(float[] a, float[] b) {
float norm2 = 0;

// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
if (a.length > 2 * FLOAT_SPECIES.length()) {
if (a.length > 2 * FLOAT_SPECIES_LENGTH) {
i += FLOAT_SPECIES.loopBound(a.length);
float[] ret = cosineBody(a, b, i);
sum += ret[0];
Expand Down Expand Up @@ -190,8 +196,8 @@ private float[] cosineBody(float[] a, float[] b, int limit) {
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
final int unrolledLimit = limit - FLOAT_SPECIES_LENGTH;
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES_LENGTH) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
Expand All @@ -200,14 +206,15 @@ private float[] cosineBody(float[] a, float[] b, int limit) {
norm2_1 = fma(vb, vb, norm2_1);

// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
final int i2 = i + FLOAT_SPECIES_LENGTH;
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i2);
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i2);
sum2 = fma(vc, vd, sum2);
norm1_2 = fma(vc, vc, norm1_2);
norm2_2 = fma(vd, vd, norm2_2);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
for (; i < limit; i += FLOAT_SPECIES_LENGTH) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = fma(va, vb, sum1);
Expand All @@ -227,7 +234,7 @@ public float squareDistance(float[] a, float[] b) {
float res = 0;

// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
if (a.length > 2 * FLOAT_SPECIES.length()) {
if (a.length > 2 * FLOAT_SPECIES_LENGTH) {
i += FLOAT_SPECIES.loopBound(a.length);
res += squareDistanceBody(a, b, i);
}
Expand All @@ -240,6 +247,12 @@ public float squareDistance(float[] a, float[] b) {
return res;
}

/** helper: returns fma(a.sub(b), a.sub(b), c) */
private static FloatVector square(FloatVector a, FloatVector b, FloatVector c) {
FloatVector diff = a.sub(b);
return fma(diff, diff, c);
}

/** vectorized square distance body */
private float squareDistanceBody(float[] a, float[] b, int limit) {
int i = 0;
Expand All @@ -249,38 +262,36 @@ private float squareDistanceBody(float[] a, float[] b, int limit) {
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
final int unrolledLimit = limit - 3 * FLOAT_SPECIES_LENGTH;
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES_LENGTH) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
FloatVector diff1 = va.sub(vb);
acc1 = fma(diff1, diff1, acc1);
acc1 = square(va, vb, acc1);

// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
FloatVector diff2 = vc.sub(vd);
acc2 = fma(diff2, diff2, acc2);
final int i2 = i + FLOAT_SPECIES_LENGTH;
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i2);
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i2);
acc2 = square(vc, vd, acc2);

// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
FloatVector diff3 = ve.sub(vf);
acc3 = fma(diff3, diff3, acc3);
final int i3 = i2 + FLOAT_SPECIES_LENGTH;
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i3);
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i3);
acc3 = square(ve, vf, acc3);

// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
FloatVector diff4 = vg.sub(vh);
acc4 = fma(diff4, diff4, acc4);
final int i4 = i3 + FLOAT_SPECIES_LENGTH;
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i4);
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i4);
acc4 = square(vg, vh, acc4);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
for (; i < limit; i += FLOAT_SPECIES_LENGTH) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
FloatVector diff = va.sub(vb);
acc1 = fma(diff, diff, acc1);
acc1 = square(va, vb, acc1);
}
// reduce
FloatVector res1 = acc1.add(acc2);
Expand Down
Loading