Skip to content

Commit

Permalink
ff/baby_bear.hpp: add exponentiation operator^.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jun 26, 2024
1 parent be6cb02 commit 2606ece
Showing 1 changed file with 241 additions and 39 deletions.
280 changes: 241 additions & 39 deletions ff/baby_bear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,162 @@ class __align__(16) bb31_4_t {
inline size_t len() const { return 4; }

inline bb31_4_t() {}
inline bb31_4_t(bb31_t a) { c[0] = a; u[1] = u[2] = u[3] = 0; }
inline bb31_4_t(bb31_t a) { c[0] = a; u[1] = u[2] = u[3] = 0; }
// this is used in constant declaration, e.g. as bb31_4_t{1, 2, 3, 4}
__host__ __device__ __forceinline__ bb31_4_t(int a)
{ c[0] = bb31_t{a}; u[1] = u[2] = u[3] = 0; }
__host__ __device__ __forceinline__ bb31_4_t(int d, int f, int g, int h)
{ c[0] = bb31_t{d}; c[1] = bb31_t{f}; c[2] = bb31_t{g}; c[3] = bb31_t{h}; }

// Polynomial multiplication modulo x^4 - BETA
friend __device__ __noinline__ bb31_4_t operator*(bb31_4_t a, bb31_4_t b)
private:
static inline uint32_t final_sub(uint32_t& u)
{ if (u >= MOD) u -= MOD; return u; }

// Polynomial multiplication/squaring modulo x^4 - BETA
inline bb31_4_t& sqr()
{
bb31_4_t ret;

# ifdef __CUDA_ARCH__
# ifdef __GNUC__
# define asm __asm__ __volatile__
# else
# define asm asm volatile
# endif
// +20% in comparison to multiplication by itself even though
// the amount of instructions is the same...
// ret[0] = a[0]*a[0] + BETA*(2*a[1]*a[3] + a[2]*a[2]);
asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t"
"mul.lo.u32 %lo, %4, %2; mul.hi.u32 %hi, %4, %2;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"mad.lo.cc.u32 %lo, %3, %3, %lo; madc.hi.u32 %hi, %3, %3, %hi;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %hi, %m, %5, %hi;\n\t"
//"setp.ge.u32 %p, %hi, %5;\n\t"
//"@%p sub.u32 %hi, %hi, %5;\n\t"

"mul.lo.u32 %lo, %hi, %7; mul.hi.u32 %hi, %hi, %7;\n\t"
"mad.lo.cc.u32 %lo, %1, %1, %lo; madc.hi.u32 %hi, %1, %1, %hi;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t"
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[0])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(MOD), "r"(M), "r"(BETA));

// ret[1] = 2*(a[0]*a[1] + BETA*(a[2]*a[3]));
asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t"
"mul.lo.u32 %lo, %4, %3; mul.hi.u32 %hi, %4, %3;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %hi, %m, %5, %hi;\n\t"
//"setp.ge.u32 %p, %hi, %5;\n\t"
//"@%p sub.u32 %hi, %hi, %5;\n\t"

"mul.lo.u32 %lo, %hi, %7; mul.hi.u32 %hi, %hi, %7;\n\t"
"mad.lo.cc.u32 %lo, %2, %1, %lo; madc.hi.u32 %hi, %2, %1, %hi;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t"
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[1])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(MOD), "r"(M), "r"(BETA));

// ret[2] = 2*a[0]*a[2] + a[1]*a[1] + BETA*(a[3]*a[3]);
asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t"
"mul.lo.u32 %lo, %4, %4; mul.hi.u32 %hi, %4, %4;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %m, %m, %5, %hi;\n\t"
//"setp.ge.u32 %p, %m, %5;\n\t"
//"@%p sub.u32 %m, %m, %5;\n\t"

"mul.lo.u32 %lo, %3, %1; mul.hi.u32 %hi, %3, %1;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"mad.lo.cc.u32 %lo, %2, %2, %lo; madc.hi.u32 %hi, %2, %2, %hi;\n\t"
"mad.lo.cc.u32 %lo, %m, %7, %lo; madc.hi.u32 %hi, %m, %7, %hi;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t"
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[2])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(MOD), "r"(M), "r"(BETA));

// ret[3] = 2*(a[0]*a[3] + a[1]*a[2]);
asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t"
"mul.lo.u32 %lo, %4, %1; mul.hi.u32 %hi, %4, %1;\n\t"
"mad.lo.cc.u32 %lo, %3, %2, %lo; madc.hi.u32 %hi, %3, %2, %hi;\n\t"
"shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t"
"setp.ge.u32 %p, %hi, %5;\n\t"
"@%p sub.u32 %hi, %hi, %5;\n\t"

"mul.lo.u32 %m, %lo, %6;\n\t"
"mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t"
"setp.ge.u32 %p, %0, %5;\n\t"
"@%p sub.u32 %0, %0, %5;\n\t"
"}" : "=r"(ret.u[3])
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(MOD), "r"(M), "r"(BETA));
# undef asm
# else
union { uint64_t wl; uint32_t w[2]; };

// ret[0] = a[0]*a[0] + BETA*(2*a[1]*a[3] + a[2]*a[2]);
wl = u[1] * (uint64_t)u[3];
wl <<= 1;
wl += u[2] * (uint64_t)u[2]; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
wl += u[0] * (uint64_t)u[0];
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[0] = final_sub(w[1]);

// ret[1] = 2*(a[0]*a[1] + BETA*(a[2]*a[3]));
wl = u[2] * (uint64_t)u[3];
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
wl += u[0] * (uint64_t)u[1];
wl <<= 1; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[1] = final_sub(w[1]);

// ret[2] = 2*a[0]*a[2] + a[1]*a[1] + BETA*(a[3]*a[3]);
wl = u[3] * (uint64_t)u[3];
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
auto hi = w[1];
wl = u[0] * (uint64_t)u[2];
wl <<= 1;
wl += u[1] * (uint64_t)u[1];
wl += hi * (uint64_t)BETA; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[2] = final_sub(w[1]);

// ret[3] = 2*(a[0]*a[3] + a[1]*a[2]);
wl = u[0] * (uint64_t)u[3];
wl += u[1] * (uint64_t)u[2];
wl <<= 1; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[3] = final_sub(w[1]);
# endif

return *this = ret;
}

inline bb31_4_t& mul(const bb31_4_t& b)
{
bb31_4_t ret;

Expand Down Expand Up @@ -120,7 +267,7 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %9;\n\t"
"@%p sub.u32 %0, %0, %9;\n\t"
"}" : "=r"(ret.u[0])
: "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]),
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]),
"r"(MOD), "r"(M), "r"(BETA));

Expand All @@ -145,7 +292,7 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %9;\n\t"
"@%p sub.u32 %0, %0, %9;\n\t"
"}" : "=r"(ret.u[1])
: "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]),
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]),
"r"(MOD), "r"(M), "r"(BETA));

Expand All @@ -170,7 +317,7 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %9;\n\t"
"@%p sub.u32 %0, %0, %9;\n\t"
"}" : "=r"(ret.u[2])
: "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]),
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]),
"r"(MOD), "r"(M), "r"(BETA));

Expand All @@ -188,54 +335,58 @@ class __align__(16) bb31_4_t {
"setp.ge.u32 %p, %0, %9;\n\t"
"@%p sub.u32 %0, %0, %9;\n\t"
"}" : "=r"(ret.u[3])
: "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]),
: "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]),
"r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]),
"r"(MOD), "r"(M), "r"(BETA));
# undef asm
# else
union { uint64_t ul; uint32_t u[2]; };
union { uint64_t wl; uint32_t w[2]; };

// ret[0] = a[0]*b[0] + BETA*(a[1]*b[3] + a[2]*b[2] + a[3]*b[1]);
ul = a.u[1] * (uint64_t)b.u[3];
ul += a.u[2] * (uint64_t)b.u[2];
ul += a.u[3] * (uint64_t)b.u[1]; if (u[1] >= MOD) u[1] -= MOD;
ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD;
ul = u[1] * (uint64_t)BETA;
ul += a.u[0] * (uint64_t)b.u[0];
ul += (u[0] * M) * (uint64_t)MOD;
ret.u[0] = u[1] >= MOD ? u[1] - MOD : u[1];
wl = u[1] * (uint64_t)b.u[3];
wl += u[2] * (uint64_t)b.u[2];
wl += u[3] * (uint64_t)b.u[1]; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
wl += u[0] * (uint64_t)b.u[0];
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[0] = final_sub(w[1]);

// ret[1] = a[0]*b[1] + a[1]*b[0] + BETA*(a[2]*b[3] + a[3]*b[2]);
ul = a.u[2] * (uint64_t)b.u[3];
ul += a.u[3] * (uint64_t)b.u[2];
ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD;
ul = u[1] * (uint64_t)BETA;
ul += a.u[0] * (uint64_t)b.u[1];
ul += a.u[1] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD;
ul += (u[0] * M) * (uint64_t)MOD;
ret.u[1] = u[1] >= MOD ? u[1] - MOD : u[1];
wl = u[2] * (uint64_t)b.u[3];
wl += u[3] * (uint64_t)b.u[2];
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
wl += u[0] * (uint64_t)b.u[1];
wl += u[1] * (uint64_t)b.u[0]; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[1] = final_sub(w[1]);

// ret[2] = a[0]*b[2] + a[1]*b[1] + a[2]*b[0] + BETA*(a[3]*b[3]);
ul = a.u[3] * (uint64_t)b.u[3];
ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD;
ul = u[1] * (uint64_t)BETA;
ul += a.u[0] * (uint64_t)b.u[2];
ul += a.u[1] * (uint64_t)b.u[1];
ul += a.u[2] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD;
ul += (u[0] * M) * (uint64_t)MOD;
ret.u[2] = u[1] >= MOD ? u[1] - MOD : u[1];
wl = u[3] * (uint64_t)b.u[3];
wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]);
wl = w[1] * (uint64_t)BETA;
wl += u[0] * (uint64_t)b.u[2];
wl += u[1] * (uint64_t)b.u[1];
wl += u[2] * (uint64_t)b.u[0]; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[2] = final_sub(w[1]);

// ret[3] = a[0]*b[3] + a[1]*b[2] + a[2]*b[1] + a[3]*b[0];
ul = a.u[0] * (uint64_t)b.u[3];
ul += a.u[1] * (uint64_t)b.u[2];
ul += a.u[2] * (uint64_t)b.u[1];
ul += a.u[3] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD;
ul += (u[0] * M) * (uint64_t)MOD;
ret.u[3] = u[1] >= MOD ? u[1] - MOD : u[1];
wl = u[0] * (uint64_t)b.u[3];
wl += u[1] * (uint64_t)b.u[2];
wl += u[2] * (uint64_t)b.u[1];
wl += u[3] * (uint64_t)b.u[0]; final_sub(w[1]);
wl += (w[0] * M) * (uint64_t)MOD;
ret.u[3] = final_sub(w[1]);
# endif

return ret;
return *this = ret;
}

public:
friend __device__ __noinline__ bb31_4_t operator*(bb31_4_t a, bb31_4_t b)
{ return a.mul(b); }
inline bb31_4_t& operator*=(const bb31_4_t& b)
{ return *this = *this * b; }

Expand Down Expand Up @@ -357,6 +508,57 @@ class __align__(16) bb31_4_t {
{ return c[0].is_one() & u[1]==0 & u[2]==0 & u[3]==0; }
inline bool is_zero() const
{ return u[0]==0 & u[1]==0 & u[2]==0 & u[3]==0; }

// raise to a variable power, variable in respect to threadIdx,
// but mind the ^ operator's precedence!
inline bb31_4_t& operator^=(uint32_t p)
{
bb31_4_t sqr = *this;

if (!(p&1)) {
c[0] = bb31_t{1};
c[1] = c[2] = c[3] = 0;
}

#pragma unroll 1
while (p >>= 1) {
sqr.sqr();
if (p&1)
mul(sqr);
}

return *this;
}
friend inline bb31_4_t operator^(bb31_4_t a, uint32_t p)
{ return a ^= p; }
inline bb31_4_t operator()(uint32_t p)
{ return *this^p; }

// raise to a constant power, e.g. x^7, to be unrolled at compile time
inline bb31_4_t& operator^=(int p)
{
assert(p >= 2);

bb31_4_t sqr = *this;
if ((p&1) == 0) {
do {
sqr.sqr();
p >>= 1;
} while ((p&1) == 0);
*this = sqr;
}
for (p >>= 1; p; p >>= 1) {
sqr.sqr();
if (p&1)
mul(sqr);
}

return *this;
}
friend inline bb31_4_t operator^(bb31_4_t a, int p)
{ return a ^= p; }
inline bb31_4_t operator()(int p)
{ return *this^p; }
# undef inline

public:
Expand Down

0 comments on commit 2606ece

Please sign in to comment.