From 2606ecea5fa2e3690e521ca8a99694dd1160f442 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 26 Jun 2024 12:20:19 +0200 Subject: [PATCH] ff/baby_bear.hpp: add exponentiation operator^. --- ff/baby_bear.hpp | 280 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 241 insertions(+), 39 deletions(-) diff --git a/ff/baby_bear.hpp b/ff/baby_bear.hpp index 47daf49..702df83 100644 --- a/ff/baby_bear.hpp +++ b/ff/baby_bear.hpp @@ -81,15 +81,162 @@ class __align__(16) bb31_4_t { inline size_t len() const { return 4; } inline bb31_4_t() {} - inline bb31_4_t(bb31_t a) { c[0] = a; u[1] = u[2] = u[3] = 0; } + inline bb31_4_t(bb31_t a) { c[0] = a; u[1] = u[2] = u[3] = 0; } // this is used in constant declaration, e.g. as bb31_4_t{1, 2, 3, 4} __host__ __device__ __forceinline__ bb31_4_t(int a) { c[0] = bb31_t{a}; u[1] = u[2] = u[3] = 0; } __host__ __device__ __forceinline__ bb31_4_t(int d, int f, int g, int h) { c[0] = bb31_t{d}; c[1] = bb31_t{f}; c[2] = bb31_t{g}; c[3] = bb31_t{h}; } - // Polynomial multiplication modulo x^4 - BETA - friend __device__ __noinline__ bb31_4_t operator*(bb31_4_t a, bb31_4_t b) +private: + static inline uint32_t final_sub(uint32_t& u) + { if (u >= MOD) u -= MOD; return u; } + + // Polynomial multiplication/squaring modulo x^4 - BETA + inline bb31_4_t& sqr() + { + bb31_4_t ret; + +# ifdef __CUDA_ARCH__ +# ifdef __GNUC__ +# define asm __asm__ __volatile__ +# else +# define asm asm volatile +# endif + // +20% in comparison to multiplication by itself even though + // the amount of instructions is the same... + // ret[0] = a[0]*a[0] + BETA*(2*a[1]*a[3] + a[2]*a[2]); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %2; mul.hi.u32 %hi, %4, %2;\n\t" + "shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t" + "mad.lo.cc.u32 %lo, %3, %3, %lo; madc.hi.u32 %hi, %3, %3, %hi;\n\t" + "setp.ge.u32 %p, %hi, %5;\n\t" + "@%p sub.u32 %hi, %hi, %5;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %hi, %m, %5, %hi;\n\t" + //"setp.ge.u32 %p, %hi, %5;\n\t" + //"@%p sub.u32 %hi, %hi, %5;\n\t" + + "mul.lo.u32 %lo, %hi, %7; mul.hi.u32 %hi, %hi, %7;\n\t" + "mad.lo.cc.u32 %lo, %1, %1, %lo; madc.hi.u32 %hi, %1, %1, %hi;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t" + "setp.ge.u32 %p, %0, %5;\n\t" + "@%p sub.u32 %0, %0, %5;\n\t" + "}" : "=r"(ret.u[0]) + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), + "r"(MOD), "r"(M), "r"(BETA)); + + // ret[1] = 2*(a[0]*a[1] + BETA*(a[2]*a[3])); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %3; mul.hi.u32 %hi, %4, %3;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %hi, %m, %5, %hi;\n\t" + //"setp.ge.u32 %p, %hi, %5;\n\t" + //"@%p sub.u32 %hi, %hi, %5;\n\t" + + "mul.lo.u32 %lo, %hi, %7; mul.hi.u32 %hi, %hi, %7;\n\t" + "mad.lo.cc.u32 %lo, %2, %1, %lo; madc.hi.u32 %hi, %2, %1, %hi;\n\t" + "shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t" + "setp.ge.u32 %p, %hi, %5;\n\t" + "@%p sub.u32 %hi, %hi, %5;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t" + "setp.ge.u32 %p, %0, %5;\n\t" + "@%p sub.u32 %0, %0, %5;\n\t" + "}" : "=r"(ret.u[1]) + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), + "r"(MOD), "r"(M), "r"(BETA)); + + // ret[2] = 2*a[0]*a[2] + a[1]*a[1] + BETA*(a[3]*a[3]); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %4; mul.hi.u32 %hi, %4, %4;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %m, %m, %5, %hi;\n\t" + //"setp.ge.u32 %p, %m, %5;\n\t" + //"@%p sub.u32 %m, %m, %5;\n\t" + + "mul.lo.u32 %lo, %3, %1; mul.hi.u32 %hi, %3, %1;\n\t" + "shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t" + "mad.lo.cc.u32 %lo, %2, %2, %lo; madc.hi.u32 %hi, %2, %2, %hi;\n\t" + "mad.lo.cc.u32 %lo, %m, %7, %lo; madc.hi.u32 %hi, %m, %7, %hi;\n\t" + "setp.ge.u32 %p, %hi, %5;\n\t" + "@%p sub.u32 %hi, %hi, %5;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t" + "setp.ge.u32 %p, %0, %5;\n\t" + "@%p sub.u32 %0, %0, %5;\n\t" + "}" : "=r"(ret.u[2]) + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), + "r"(MOD), "r"(M), "r"(BETA)); + + // ret[3] = 2*(a[0]*a[3] + a[1]*a[2]); + asm("{ .reg.b32 %lo, %hi, %m; .reg.pred %p;\n\t" + "mul.lo.u32 %lo, %4, %1; mul.hi.u32 %hi, %4, %1;\n\t" + "mad.lo.cc.u32 %lo, %3, %2, %lo; madc.hi.u32 %hi, %3, %2, %hi;\n\t" + "shf.l.wrap.b32 %hi, %lo, %hi, 1; shl.b32 %lo, %lo, 1;\n\t" + "setp.ge.u32 %p, %hi, %5;\n\t" + "@%p sub.u32 %hi, %hi, %5;\n\t" + + "mul.lo.u32 %m, %lo, %6;\n\t" + "mad.lo.cc.u32 %lo, %m, %5, %lo; madc.hi.u32 %0, %m, %5, %hi;\n\t" + "setp.ge.u32 %p, %0, %5;\n\t" + "@%p sub.u32 %0, %0, %5;\n\t" + "}" : "=r"(ret.u[3]) + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), + "r"(MOD), "r"(M), "r"(BETA)); +# undef asm +# else + union { uint64_t wl; uint32_t w[2]; }; + + // ret[0] = a[0]*a[0] + BETA*(2*a[1]*a[3] + a[2]*a[2]); + wl = u[1] * (uint64_t)u[3]; + wl <<= 1; + wl += u[2] * (uint64_t)u[2]; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + wl = w[1] * (uint64_t)BETA; + wl += u[0] * (uint64_t)u[0]; + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[0] = final_sub(w[1]); + + // ret[1] = 2*(a[0]*a[1] + BETA*(a[2]*a[3])); + wl = u[2] * (uint64_t)u[3]; + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + wl = w[1] * (uint64_t)BETA; + wl += u[0] * (uint64_t)u[1]; + wl <<= 1; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[1] = final_sub(w[1]); + + // ret[2] = 2*a[0]*a[2] + a[1]*a[1] + BETA*(a[3]*a[3]); + wl = u[3] * (uint64_t)u[3]; + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + auto hi = w[1]; + wl = u[0] * (uint64_t)u[2]; + wl <<= 1; + wl += u[1] * (uint64_t)u[1]; + wl += hi * (uint64_t)BETA; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[2] = final_sub(w[1]); + + // ret[3] = 2*(a[0]*a[3] + a[1]*a[2]); + wl = u[0] * (uint64_t)u[3]; + wl += u[1] * (uint64_t)u[2]; + wl <<= 1; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[3] = final_sub(w[1]); +# endif + + return *this = ret; + } + + inline bb31_4_t& mul(const bb31_4_t& b) { bb31_4_t ret; @@ -120,7 +267,7 @@ class __align__(16) bb31_4_t { "setp.ge.u32 %p, %0, %9;\n\t" "@%p sub.u32 %0, %0, %9;\n\t" "}" : "=r"(ret.u[0]) - : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), "r"(MOD), "r"(M), "r"(BETA)); @@ -145,7 +292,7 @@ class __align__(16) bb31_4_t { "setp.ge.u32 %p, %0, %9;\n\t" "@%p sub.u32 %0, %0, %9;\n\t" "}" : "=r"(ret.u[1]) - : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), "r"(MOD), "r"(M), "r"(BETA)); @@ -170,7 +317,7 @@ class __align__(16) bb31_4_t { "setp.ge.u32 %p, %0, %9;\n\t" "@%p sub.u32 %0, %0, %9;\n\t" "}" : "=r"(ret.u[2]) - : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), "r"(MOD), "r"(M), "r"(BETA)); @@ -188,54 +335,58 @@ class __align__(16) bb31_4_t { "setp.ge.u32 %p, %0, %9;\n\t" "@%p sub.u32 %0, %0, %9;\n\t" "}" : "=r"(ret.u[3]) - : "r"(a.u[0]), "r"(a.u[1]), "r"(a.u[2]), "r"(a.u[3]), + : "r"(u[0]), "r"(u[1]), "r"(u[2]), "r"(u[3]), "r"(b.u[0]), "r"(b.u[1]), "r"(b.u[2]), "r"(b.u[3]), "r"(MOD), "r"(M), "r"(BETA)); # undef asm # else - union { uint64_t ul; uint32_t u[2]; }; + union { uint64_t wl; uint32_t w[2]; }; // ret[0] = a[0]*b[0] + BETA*(a[1]*b[3] + a[2]*b[2] + a[3]*b[1]); - ul = a.u[1] * (uint64_t)b.u[3]; - ul += a.u[2] * (uint64_t)b.u[2]; - ul += a.u[3] * (uint64_t)b.u[1]; if (u[1] >= MOD) u[1] -= MOD; - ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD; - ul = u[1] * (uint64_t)BETA; - ul += a.u[0] * (uint64_t)b.u[0]; - ul += (u[0] * M) * (uint64_t)MOD; - ret.u[0] = u[1] >= MOD ? u[1] - MOD : u[1]; + wl = u[1] * (uint64_t)b.u[3]; + wl += u[2] * (uint64_t)b.u[2]; + wl += u[3] * (uint64_t)b.u[1]; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + wl = w[1] * (uint64_t)BETA; + wl += u[0] * (uint64_t)b.u[0]; + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[0] = final_sub(w[1]); // ret[1] = a[0]*b[1] + a[1]*b[0] + BETA*(a[2]*b[3] + a[3]*b[2]); - ul = a.u[2] * (uint64_t)b.u[3]; - ul += a.u[3] * (uint64_t)b.u[2]; - ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD; - ul = u[1] * (uint64_t)BETA; - ul += a.u[0] * (uint64_t)b.u[1]; - ul += a.u[1] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD; - ul += (u[0] * M) * (uint64_t)MOD; - ret.u[1] = u[1] >= MOD ? u[1] - MOD : u[1]; + wl = u[2] * (uint64_t)b.u[3]; + wl += u[3] * (uint64_t)b.u[2]; + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + wl = w[1] * (uint64_t)BETA; + wl += u[0] * (uint64_t)b.u[1]; + wl += u[1] * (uint64_t)b.u[0]; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[1] = final_sub(w[1]); // ret[2] = a[0]*b[2] + a[1]*b[1] + a[2]*b[0] + BETA*(a[3]*b[3]); - ul = a.u[3] * (uint64_t)b.u[3]; - ul += (u[0] * M) * (uint64_t)MOD; // if (u[1] >= MOD) u[1] -= MOD; - ul = u[1] * (uint64_t)BETA; - ul += a.u[0] * (uint64_t)b.u[2]; - ul += a.u[1] * (uint64_t)b.u[1]; - ul += a.u[2] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD; - ul += (u[0] * M) * (uint64_t)MOD; - ret.u[2] = u[1] >= MOD ? u[1] - MOD : u[1]; + wl = u[3] * (uint64_t)b.u[3]; + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + wl = w[1] * (uint64_t)BETA; + wl += u[0] * (uint64_t)b.u[2]; + wl += u[1] * (uint64_t)b.u[1]; + wl += u[2] * (uint64_t)b.u[0]; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[2] = final_sub(w[1]); // ret[3] = a[0]*b[3] + a[1]*b[2] + a[2]*b[1] + a[3]*b[0]; - ul = a.u[0] * (uint64_t)b.u[3]; - ul += a.u[1] * (uint64_t)b.u[2]; - ul += a.u[2] * (uint64_t)b.u[1]; - ul += a.u[3] * (uint64_t)b.u[0]; if (u[1] >= MOD) u[1] -= MOD; - ul += (u[0] * M) * (uint64_t)MOD; - ret.u[3] = u[1] >= MOD ? u[1] - MOD : u[1]; + wl = u[0] * (uint64_t)b.u[3]; + wl += u[1] * (uint64_t)b.u[2]; + wl += u[2] * (uint64_t)b.u[1]; + wl += u[3] * (uint64_t)b.u[0]; final_sub(w[1]); + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[3] = final_sub(w[1]); # endif - return ret; + return *this = ret; } + +public: + friend __device__ __noinline__ bb31_4_t operator*(bb31_4_t a, bb31_4_t b) + { return a.mul(b); } inline bb31_4_t& operator*=(const bb31_4_t& b) { return *this = *this * b; } @@ -357,6 +508,57 @@ class __align__(16) bb31_4_t { { return c[0].is_one() & u[1]==0 & u[2]==0 & u[3]==0; } inline bool is_zero() const { return u[0]==0 & u[1]==0 & u[2]==0 & u[3]==0; } + + // raise to a variable power, variable in respect to threadIdx, + // but mind the ^ operator's precedence! + inline bb31_4_t& operator^=(uint32_t p) + { + bb31_4_t sqr = *this; + + if (!(p&1)) { + c[0] = bb31_t{1}; + c[1] = c[2] = c[3] = 0; + } + + #pragma unroll 1 + while (p >>= 1) { + sqr.sqr(); + if (p&1) + mul(sqr); + } + + return *this; + } + friend inline bb31_4_t operator^(bb31_4_t a, uint32_t p) + { return a ^= p; } + inline bb31_4_t operator()(uint32_t p) + { return *this^p; } + + // raise to a constant power, e.g. x^7, to be unrolled at compile time + inline bb31_4_t& operator^=(int p) + { + assert(p >= 2); + + bb31_4_t sqr = *this; + if ((p&1) == 0) { + do { + sqr.sqr(); + p >>= 1; + } while ((p&1) == 0); + *this = sqr; + } + for (p >>= 1; p; p >>= 1) { + sqr.sqr(); + if (p&1) + mul(sqr); + } + + return *this; + } + friend inline bb31_4_t operator^(bb31_4_t a, int p) + { return a ^= p; } + inline bb31_4_t operator()(int p) + { return *this^p; } # undef inline public: