Skip to content

Commit

Permalink
Reduce unrolling in Panama dotProduct float variant
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisHegarty committed Dec 16, 2024
1 parent 084480d commit 97b6c7b
Showing 1 changed file with 11 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ public float dotProduct(float[] a, float[] b) {

/** vectorized float dot product body */
private float dotProductBody(float[] a, float[] b, int limit) {
int i = 0;
// vector loop is unrolled 4x (4 accumulators in parallel)
// we don't know how many the cpu can do at once, some can do 2, some 4
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
// vector loop is unrolled 2x (2 accumulators in parallel)
FloatVector acc1 =
FloatVector.fromArray(FLOAT_SPECIES, a, 0).mul(FloatVector.fromArray(FLOAT_SPECIES, b, 0));
FloatVector acc2 =
FloatVector.fromArray(FLOAT_SPECIES, a, FLOAT_SPECIES.length())
.mul(FloatVector.fromArray(FLOAT_SPECIES, b, FLOAT_SPECIES.length()));
final int unrolledLimit = limit - FLOAT_SPECIES.length();
int i = 2 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
Expand All @@ -131,27 +131,15 @@ private float dotProductBody(float[] a, float[] b, int limit) {
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
acc2 = fma(vc, vd, acc2);

// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
acc3 = fma(ve, vf, acc3);

// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
acc4 = fma(vg, vh, acc4);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
if (i < limit) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
acc1 = fma(va, vb, acc1);
}
// reduce
FloatVector res1 = acc1.add(acc2);
FloatVector res2 = acc3.add(acc4);
return res1.add(res2).reduceLanes(ADD);
return res1.reduceLanes(ADD);
}

@Override
Expand Down

0 comments on commit 97b6c7b

Please sign in to comment.