diff --git a/ff/baby_bear.hpp b/ff/baby_bear.hpp index 3e866c0..654f5f7 100644 --- a/ff/baby_bear.hpp +++ b/ff/baby_bear.hpp @@ -485,20 +485,69 @@ class __align__(16) bb31_4_t { { c[0] -= b; return *this; } private: - // don't bother with breaking these down, 1/x dominates. - inline bb31_t recip_b0(bb31_t beta) const - { return c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]); } - inline bb31_t recip_b2(bb31_t beta) const - { return c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]); } - inline bb31_4_t recip_ret(bb31_t b0, bb31_t b2, bb31_t beta) const + inline bb31_t recip_b0() const + { + union { uint64_t wl; uint32_t w[2]; }; + + // c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]); + wl = u[1] * (uint64_t)(u[3]<<1); + wl += u[2] * (uint64_t)(MOD-u[2]); + wl += (w[0] * M) * (uint64_t)MOD; final_sub(w[1]); + wl = w[1] * (uint64_t)(MOD-BETA); + wl += u[0] * (uint64_t)u[0]; + wl += (w[0] * M) * (uint64_t)MOD; + + return bb31_t{final_sub(w[1])}; + } + + inline bb31_t recip_b2() const + { + union { uint64_t wl; uint32_t w[2]; }; + + // c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]); + wl = u[3] * (uint64_t)u[3]; + wl += (w[0] * M) * (uint64_t)MOD; final_sub(w[1]); + wl = w[1] * (uint64_t)(MOD-BETA); + wl += u[1] * (uint64_t)(MOD-u[1]); final_sub(w[1]); + wl += u[0] * (uint64_t)(u[2]<<1); + wl += (w[0] * M) * (uint64_t)MOD; + + return bb31_t{final_sub(w[1])}; + } + + inline bb31_4_t recip_ret(bb31_t b0, bb31_t b2) const { bb31_4_t ret; - bb31_t beta_b2 = beta*b2; + union { uint64_t wl; uint32_t w[2]; }; - ret[0] = c[0]*b0 - c[2]*beta_b2; - ret[1] = c[3]*beta_b2 - c[1]*b0; - ret[2] = c[2]*b0 - c[0]*b2; - ret[3] = c[1]*b2 - c[3]*b0; + wl = b2[0] * (uint64_t)BETA; + wl += (w[0] * M) * (uint64_t)MOD; //final_sub(w[1]); + + uint32_t beta_b2 = w[1]; + + // ret[0] = c[0]*b0 - c[2]*beta_b2; + wl = u[0] * (uint64_t)b0[0]; + wl += (MOD-u[2]) * (uint64_t)beta_b2; + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[0] = final_sub(w[1]); + + // ret[1] = c[3]*beta_b2 - c[1]b0; + wl = u[3] * (uint64_t)beta_b2; + wl += (MOD-u[1]) * (uint64_t)b0[0]; + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[1] = final_sub(w[1]); + + // ret[2] = c[2]*b0 - c[0]*b2; + wl = u[2] * (uint64_t)b0[0]; + wl += (MOD-u[0]) * (uint64_t)b2[0]; + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[2] = final_sub(w[1]); + + // ret[3] = c[1]*b2 - c[3]*b0; + wl = u[1] * (uint64_t)b2[0]; + wl += (MOD-u[3]) * (uint64_t)b0[0]; + wl += (w[0] * M) * (uint64_t)MOD; + ret.u[3] = final_sub(w[1]); return ret; } @@ -508,15 +557,20 @@ class __align__(16) bb31_4_t { { const bb31_t beta{BETA}; - bb31_t b0 = recip_b0(beta); - bb31_t b2 = recip_b2(beta); + bb31_t b0 = recip_b0(); + bb31_t b2 = recip_b2(); +# if 0 // inefficient code generated by at least 12.5? bb31_t inv = 1/(b0*b0 - beta*b2*b2); +# else + bb31_t inv = b0*b0 - beta*b2*b2; + inv = 1/inv; +# endif b0 *= inv; b2 *= inv; - return recip_ret(b0, b2, beta); + return recip_ret(b0, b2); } friend inline bb31_4_t operator/(int one, const bb31_4_t& a) { assert(one == 1); return a.reciprocal(); } @@ -540,8 +594,8 @@ class __align__(16) bb31_4_t { for (size_t i = 0; i < N; i++) { bb31_4_t tmp = inp[i]; - b0[i] = tmp.recip_b0(beta); - b2[i] = tmp.recip_b2(beta); + b0[i] = tmp.recip_b0(); + b2[i] = tmp.recip_b2(); bx[i] = b0[i]*b0[i] - beta*b2[i]*b2[i]; } @@ -553,7 +607,7 @@ class __align__(16) bb31_4_t { b0[i] *= inv[i]; b2[i] *= inv[i]; bb31_4_t tmp = inp[i]; - out[i] = tmp.recip_ret(b0[i], b2[i], beta); + out[i] = tmp.recip_ret(b0[i], b2[i]); } }