Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bb31_4_t batch inversion #46

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 126 additions & 31 deletions ff/baby_bear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ class __align__(16) bb31_4_t {
# else
# define asm asm volatile
# endif
// +20% in comparison to multiplication by itself even though
// the amount of instructions is the same...
// +25% in comparison to multiplication by itself
uint32_t u3x2 = u[3]<<1;
uint32_t u1x2 = u[1]<<1;

// ret[0] = a[0]*a[0] + BETA*(2*a[1]*a[3] + a[2]*a[2]);
asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t"
"mul.lo.u32 %lo, %4, %2; mul.hi.u32 %hi, %4, %2;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"mad.lo.cc.u32 %lo, %3, %3, %lo; madc.hi.u32 %hi, %3, %3, %hi;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"
Expand All @@ -126,7 +127,7 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[0])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u3x2),
"r"(MOD), "r"(M), "r"(BETA));

// ret[1] = 2*(a[0]*a[1] + BETA*(a[2]*a[3]));
Expand All @@ -140,7 +141,6 @@ class __align__(16) bb31_4_t {

"mul.lo.u32 %lo, %hi, %7; mul.hi.u32 %hi, %hi, %7;\n\t"
"mad.lo.cc.u32 %lo, %2, %1, %lo; madc.hi.u32 %hi, %2, %1, %hi;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"

Expand All @@ -149,7 +149,7 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[1])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
: "r"(u[0]), "r"(u1x2), "r"(u[2]), "r"(u3x2),
"r"(MOD), "r"(M), "r"(BETA));

// ret[2] = 2*a[0]*a[2] + a[1]*a[1] + BETA*(a[3]*a[3]);
Expand All @@ -162,7 +162,6 @@ class __align__(16) bb31_4_t {
//"@%p sub.u32 %m, %m, %5;\n\t"

"mul.lo.u32 %lo, %3, %1; mul.hi.u32 %hi, %3, %1;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"mad.lo.cc.u32 %lo, %2, %2, %lo; madc.hi.u32 %hi, %2, %2, %hi;\n\t"
"mad.lo.cc.u32 %lo, %m, %7, %lo; madc.hi.u32 %hi, %m, %7, %hi;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
Expand All @@ -173,14 +172,13 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[2])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
: "r"(u[0]<<1), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(MOD), "r"(M), "r"(BETA));

// ret[3] = 2*(a[0]*a[3] + a[1]*a[2]);
asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t"
"mul.lo.u32 %lo, %4, %1; mul.hi.u32 %hi, %4, %1;\n\t"
"mad.lo.cc.u32 %lo, %3, %2, %lo; madc.hi.u32 %hi, %3, %2, %hi;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"

Expand All @@ -189,15 +187,16 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[3])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
: "r"(u[0]), "r"(u1x2), "r"(u[2]), "r"(u3x2),
"r"(MOD), "r"(M), "r"(BETA));
# undef asm
# else
union { uint64_t wl; uint32_t w[2]; };
uint32_t u3x2 = u[3]<<1;
uint32_t u1x2 = u[1]<<1;

// ret[0] = a[0]*a[0] + BETA*(2*a[1]*a[3] + a[2]*a[2]);
wl = u[1] * (uint64_t)u[3];
wl <<= 1;
wl = u[1] * (uint64_t)u3x2;
wl += u[2] * (uint64_t)u[2]; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
Expand All @@ -206,29 +205,26 @@ class __align__(16) bb31_4_t {
ret.u[0] = final_sub(w[1]);

// ret[1] = 2*(a[0]*a[1] + BETA*(a[2]*a[3]));
wl = u[2] * (uint64_t)u[3];
wl = u[2] * (uint64_t)u3x2;
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
wl += u[0] * (uint64_t)u[1];
wl <<= 1; final_sub(w[1]);
wl += u[0] * (uint64_t)u1x2; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[1] = final_sub(w[1]);

// ret[2] = 2*a[0]*a[2] + a[1]*a[1] + BETA*(a[3]*a[3]);
wl = u[3] * (uint64_t)u[3];
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
auto hi = w[1];
wl = u[0] * (uint64_t)u[2];
wl <<= 1;
wl = u[2] * (uint64_t)(u[0]<<1);
wl += u[1] * (uint64_t)u[1];
wl += hi * (uint64_t)BETA; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[2] = final_sub(w[1]);

// ret[3] = 2*(a[0]*a[3] + a[1]*a[2]);
wl = u[0] * (uint64_t)u[3];
wl += u[1] * (uint64_t)u[2];
wl <<= 1; final_sub(w[1]);
wl = u[0] * (uint64_t)u3x2;
wl += u[2] * (uint64_t)u1x2; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[3] = final_sub(w[1]);
# endif
Expand Down Expand Up @@ -488,27 +484,93 @@ class __align__(16) bb31_4_t {
inline bb31_4_t& operator-=(bb31_t b)
{ c[0] -= b; return *this; }

private:
inline bb31_t recip_b0() const
{
union { uint64_t wl; uint32_t w[2]; };

// c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]);
wl = u[1] * (uint64_t)(u[3]<<1);
wl += u[2] * (uint64_t)(MOD-u[2]);
wl += (w[0] * M) * (uint64_t)MOD; final_sub(w[1]);
wl = w[1] * (uint64_t)(MOD-BETA);
wl += u[0] * (uint64_t)u[0];
wl += (w[0] * M) * (uint64_t)MOD;

return bb31_t{final_sub(w[1])};
}

inline bb31_t recip_b2() const
{
union { uint64_t wl; uint32_t w[2]; };

// c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]);
wl = u[3] * (uint64_t)u[3];
wl += (w[0] * M) * (uint64_t)MOD; final_sub(w[1]);
wl = w[1] * (uint64_t)(MOD-BETA);
wl += u[1] * (uint64_t)(MOD-u[1]); final_sub(w[1]);
wl += u[0] * (uint64_t)(u[2]<<1);
wl += (w[0] * M) * (uint64_t)MOD;

return bb31_t{final_sub(w[1])};
}

inline bb31_4_t recip_ret(bb31_t b0, bb31_t b2) const
{
bb31_4_t ret;
union { uint64_t wl; uint32_t w[2]; };

wl = b2[0] * (uint64_t)BETA;
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);

uint32_t beta_b2 = w[1];

// ret[0] = c[0]*b0 - c[2]*beta_b2;
wl = u[0] * (uint64_t)b0[0];
wl += (MOD-u[2]) * (uint64_t)beta_b2;
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[0] = final_sub(w[1]);

// ret[1] = c[3]*beta_b2 - c[1]b0;
wl = u[3] * (uint64_t)beta_b2;
wl += (MOD-u[1]) * (uint64_t)b0[0];
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[1] = final_sub(w[1]);

// ret[2] = c[2]*b0 - c[0]*b2;
wl = u[2] * (uint64_t)b0[0];
wl += (MOD-u[0]) * (uint64_t)b2[0];
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[2] = final_sub(w[1]);

// ret[3] = c[1]*b2 - c[3]*b0;
wl = u[1] * (uint64_t)b2[0];
wl += (MOD-u[3]) * (uint64_t)b0[0];
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[3] = final_sub(w[1]);

return ret;
}

public:
inline bb31_4_t reciprocal() const
{
const bb31_t beta{BETA};

// don't bother with breaking this down, 1/x dominates.
bb31_t b0 = c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]);
bb31_t b2 = c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]);
bb31_t b0 = recip_b0();
bb31_t b2 = recip_b2();

# if 0 // inefficient code generated by at least 12.5?
bb31_t inv = 1/(b0*b0 - beta*b2*b2);
# else
bb31_t inv = b0*b0 - beta*b2*b2;
inv = 1/inv;
# endif

b0 *= inv;
b2 *= inv;

bb31_4_t ret;
bb31_t beta_b2 = beta*b2;
ret[0] = c[0]*b0 - c[2]*beta_b2;
ret[1] = c[3]*beta_b2 - c[1]*b0;
ret[2] = c[2]*b0 - c[0]*b2;
ret[3] = c[1]*b2 - c[3]*b0;

return ret;
return recip_ret(b0, b2);
}
friend inline bb31_4_t operator/(int one, const bb31_4_t& a)
{ assert(one == 1); return a.reciprocal(); }
Expand All @@ -523,6 +585,39 @@ class __align__(16) bb31_4_t {
inline bb31_4_t& operator/=(bb31_t a)
{ return *this *= a.reciprocal(); }

# ifdef __SPPARK_FF_BATCH_INVERSION_HPP__
template<size_t N, typename S = bb31_4_t[N]>
friend inline void batch_inversion(bb31_4_t out[N], const S inp)
{
const bb31_t beta{BETA};
bb31_t b0[N], b2[N], bx[N];

for (size_t i = 0; i < N; i++) {
bb31_4_t tmp = inp[i];
b0[i] = tmp.recip_b0();
b2[i] = tmp.recip_b2();
bx[i] = b0[i]*b0[i] - beta*b2[i]*b2[i];
}

bb31_t inv[N];

batch_inversion<bb31_t, N>(inv, bx);

for (size_t i = N; i--;) {
b0[i] *= inv[i];
b2[i] *= inv[i];
bb31_4_t tmp = inp[i];
out[i] = tmp.recip_ret(b0[i], b2[i]);
}
}

// Unlike the generic batch_inversion<T, N> bb31_4_t procedure
// can perform the inversion in-place.
template<size_t N>
friend inline void batch_inversion(bb31_4_t inout[N])
{ batch_inversion<N>(inout, inout); }
# endif

inline bool is_one() const
{ return c[0].is_one() & u[1]==0 & u[2]==0 & u[3]==0; }
inline bool is_zero() const
Expand Down
50 changes: 50 additions & 0 deletions ff/batch_inversion.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#ifndef __SPPARK_FF_BATCH_INVERSION_HPP__
#define __SPPARK_FF_BATCH_INVERSION_HPP__

/*
* Since the batch inversion requires twice the storage, on GPU there
* is incentive to use the shared memory. If deemed beneficial, the
* suggestion is to have the caller wrap T[] in S with custom operator[]
* that would address the shared memory and offload the input.
*/
template<class T, size_t N, typename S = T[N]>
#ifdef __CUDACC__
__device__ __host__ __forceinline__
#endif
static void batch_inversion(T out[N], const S inp, bool preloaded = false)
{
static_assert(N <= 32, "too large N");

if (!preloaded)
out[0] = inp[0];

bool zero = out[0].is_zero();
out[0] = T::csel(T::one(), out[0], zero);
unsigned int map = zero;

for (size_t i = 1; i < N; i++) {
if (!preloaded)
out[i] = inp[i];
zero = out[i].is_zero();
out[i] *= out[i-1];
out[i] = T::csel(out[i-1], out[i], zero);
map = (map << 1) + zero;
}

T tmp, inv = 1/out[N-1];

for (size_t i = N; --i; map >>= 1) {
out[i] = inv*out[i-1];
tmp = inp[i];
tmp *= inv;
inv = T::csel(inv, tmp, map&1);
out[i] = czero(out[i], map&1);
}

out[0] = czero(inv, map);
}
#endif
Loading