From 6bf0400c5e466bcda924a1da44a22f5a7eb53faa Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Tue, 27 Aug 2024 08:40:35 +0900 Subject: [PATCH] a little optimization of vadd/vsub --- src/gen_bint_x64.py | 42 ++++++++++++++++++++++++++++++++++-------- src/msm_avx.cpp | 8 ++++---- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/gen_bint_x64.py b/src/gen_bint_x64.py index 6504584b..e7361948 100644 --- a/src/gen_bint_x64.py +++ b/src/gen_bint_x64.py @@ -63,7 +63,7 @@ def gen_vsubPre(mont, vN=1): def gen_vadd(mont, vN=1): SUF = 'A' if vN == 2 else '' with FuncProc(MSM_PRE+'vadd'+SUF): - with StackFrame(3, 0, useRCX=True, vNum=mont.N*2+2, vType=T_ZMM) as sf: + with StackFrame(3, 0, useRCX=True, vNum=mont.N*2+3, vType=T_ZMM) as sf: regs = list(reversed(sf.v)) W = mont.W N = mont.N @@ -75,6 +75,7 @@ def gen_vadd(mont, vN=1): t = pops(regs, N) vmask = pops(regs, 1)[0] c = pops(regs, 1)[0] + zero = pops(regs, 1)[0] mov(rax, mont.mask) vpbroadcastq(vmask, rax) @@ -104,7 +105,6 @@ def gen_vadd(mont, vN=1): if i > 0: vpsubq(t[i], t[i], c) vpsrlq(c, t[i], S) - un(vpandq)(t, t, vmask) else: # a little faster # s = x+y @@ -124,14 +124,13 @@ def gen_vadd(mont, vN=1): if i > 0: vpsubq(t[i], t[i], c); vpsrlq(c, t[i], S) - vpandq(t[i], t[i], vmask) - vpxorq(vmask, vmask, vmask) - vpcmpgtq(k1, c, vmask) # k1 = t<0 + vpxorq(zero, zero, zero) + vpcmpeqq(k1, c, zero) # k1 = t>=0 # z = select(k1, s, t) for i in range(N): - vmovdqa64(t[i]|k1, s[i]) - un(vmovdqa64)(ptr(z), t) + vpandq(s[i]|k1, t[i], vmask) + un(vmovdqa64)(ptr(z), s) if vN == 2: add(x, 64) @@ -226,6 +225,32 @@ def vmulUnitAdd(z, px, y, N, H, t): vpxorq(z[N], z[N], z[N]) vmulH(z[N], t, y) +def gen_vmul(mont): + with FuncProc(MSM_PRE+'vmul'): + with StackFrame(3, 0, vNum=mont.N*2+4, vType=T_ZMM) as sf: + regs = list(reversed(sf.v)) + W = mont.W + N = mont.N + pz = sf.p[0] + px = sf.p[1] + py = sf.p[2] + + t = pops(regs, N*2) + vmask = pops(regs, 1)[0] + c = pops(regs, 1)[0] + y = pops(regs, 1)[0] + H = pops(regs, 1)[0] + + mov(rax, mont.mask) + vpbroadcastq(vmask, rax) + + un = genUnrollFunc() + + vmovdqa64(y, ptr(py)) + un(vmovdqa64)(t[0:N], ptr(pz)) + vmulUnitAdd(t, px, y, N, H, c) + un(vmovdqa64)(ptr(pz), t[0:N+1]) + def msm_data(mont): makeLabel(C_p) dq_(', '.join(map(hex, mont.toArray(mont.p)))) @@ -234,9 +259,10 @@ def msm_code(mont): for vN in [1, 2]: gen_vaddPre(mont, vN) gen_vsubPre(mont, vN) + gen_vadd(mont, vN) - gen_vadd(mont) gen_vsub(mont) + gen_vmul(mont) SUF='_fast' param=None diff --git a/src/msm_avx.cpp b/src/msm_avx.cpp index 3b043db9..b6c38139 100644 --- a/src/msm_avx.cpp +++ b/src/msm_avx.cpp @@ -31,8 +31,8 @@ void mcl_c5_vsubPreA(VecA *, const VecA *, const VecA *); void mcl_c5_vadd(Vec *, const Vec *, const Vec *); void mcl_c5_vsub(Vec *, const Vec *, const Vec *); -//void mcl_c5_vaddA(VecA *, const VecA *, const VecA *); - +void mcl_c5_vmul(Vec *, const Vec *, const Vec *); +void mcl_c5_vaddA(VecA *, const VecA *, const VecA *); } @@ -169,7 +169,7 @@ inline void vadd(Vec *z, const Vec *x, const Vec *y) { mcl_c5_vadd(z, x, y); } -#if 0 +#if 1 template<> inline void vadd(VecA *z, const VecA *x, const VecA *y) { @@ -1727,7 +1727,7 @@ CYBOZU_TEST_AUTO(vaddPre) CYBOZU_BENCH_C("asm vsubPreA", C, mcl_c5_vsubPreA, za.v, za.v, xa.v); CYBOZU_BENCH_C("asm vadd", C, mcl_c5_vadd, z[0].v, z[0].v, x[0].v); CYBOZU_BENCH_C("asm vsub", C, mcl_c5_vsub, z[0].v, z[0].v, x[0].v); -// CYBOZU_BENCH_C("asm vaddA", C, mcl_c5_vaddA, za.v, za.v, xa.v); + CYBOZU_BENCH_C("asm vaddA", C, mcl_c5_vaddA, za.v, za.v, xa.v); #endif CYBOZU_BENCH_C("vadd::Vec", C, vadd, z[0].v, z[0].v, x[0].v); CYBOZU_BENCH_C("vsub::Vec", C, vsub, z[0].v, z[0].v, x[0].v);