Skip to content

Commit

Permalink
a little optimized split function for BLS12_381
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Dec 9, 2024
1 parent 1d2e60a commit 92d184c
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 57 deletions.
32 changes: 24 additions & 8 deletions include/mcl/ec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,30 +243,45 @@ void normalizeVecT(Eout& Q, Ein& P, size_t n, size_t N = 256)
}

/*
split x to (a, b) such that x = a + b L where 0 <= a, b <= L, 0 <= x <= r-1 = L^2+L
if adj is true, then 0 <= a < L, 0 <= b <= L+1
split x in [0, r-1] to (a, b) such that x = a + b L, 0 <= a < L, 0 <= b <= L+1
a[] : 128 bit
b[] : 128 bit
x[] : 256 bit
*/
inline void optimizedSplitRawForBLS12_381(Unit *a, Unit *b, const Unit *x)
{
const bool adj = false;
/*
z = -0xd201000000010000
L = z^2-1 = 0xac45a4010001a40200000000ffffffff
r = L^2+L+1 = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
s=255
v = (1<<s)//L = 0xbe35f678f00fd56eb1fb72917b67f718
q = (1<<s)//L = 0xbe35f678f00fd56eb1fb72917b67f718
H = 1<<128
*/
static const Unit L[] = { MCL_U64_TO_UNIT(0x00000000ffffffff), MCL_U64_TO_UNIT(0xac45a4010001a402) };
static const Unit v[] = { MCL_U64_TO_UNIT(0xb1fb72917b67f718), MCL_U64_TO_UNIT(0xbe35f678f00fd56e) };
static const Unit q[] = { MCL_U64_TO_UNIT(0xb1fb72917b67f718), MCL_U64_TO_UNIT(0xbe35f678f00fd56e) };
static const Unit one[] = { MCL_U64_TO_UNIT(1), MCL_U64_TO_UNIT(0) };
static const size_t n = 128 / mcl::UnitBitSize;
#if 1
Unit xH[n+1]; // x = xH * (H/2) + xL
mcl::bint::shrT<n+1>(xH, x+n-1, mcl::UnitBitSize-1); // >>127
Unit t[n*2];
mcl::bint::mulT<n>(t, xH, q);
mcl::bint::copyT<n>(b, t+n); // (xH * q)/H
mcl::bint::mulT<n>(t, b, L); // bL
mcl::bint::subT<n*2>(t, x, t); // x - bL
Unit d = mcl::bint::subT<n>(a, t, L);
if (t[n] - d == 0) {
mcl::bint::addT<n>(b, b, one);
} else {
mcl::bint::copyT<n>(a, t);
}
#else
const bool adj = false;
Unit t[n*3];
// n = 128 bit
// t[n*3] = x[n*2] * v[n]
mcl::bint::mulNM(t, x, n*2, v, n);
// t[n*3] = x[n*2] * q[n]
mcl::bint::mulNM(t, x, n*2, q, n);
// b[n] = t[n*3]>>255
mcl::bint::shrT<n+1>(t, t+n*2-1, mcl::UnitBitSize-1); // >>255
mcl::bint::copyT<n>(b, t);
Expand All @@ -283,6 +298,7 @@ inline void optimizedSplitRawForBLS12_381(Unit *a, Unit *b, const Unit *x)
mcl::bint::clearT<n>(a);
}
}
#endif
}

} // mcl::ec::local
Expand Down Expand Up @@ -571,7 +587,7 @@ void addJacobi(E& R, const E& P, const E& Q)

/*
accept P == Q
https://github.com/apache/incubator-milagro-crypto-c/blob/fa0a45a3/src/ecp.c.in#L767-L976
https://eprint.iacr.org/2015/1060
(x, y, z) is zero <=> x = 0, y = 1, z = 0
*/

Expand Down
55 changes: 29 additions & 26 deletions misc/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,56 +27,59 @@ G1 mulEach
### Definition of parameters

```python
M = 1<<256
H = 1<<128
m = 128
H = 1<<m
z = -0xd201000000010000
L = z*z - 1
r = L*L + L + 1
s = r.bit_length()
S = 1<<s
v = S // L
S = 1<<s # H * (H/2)
q = S // L
r0 = S % L

adj = False
def split(x):
b = (x * v) >> s
a = x - b * L
if adj:
if a >= L:
xH = x >> (m-1) # x // (H/2)
b = (xH * q) >> m # (xH * q) // H
a = x - b * L
if a >= L:
a -= L
b += 1
return (a, b)
return (a, b)
```

variables|z|L|r|S|v
variables|z|L|r|S|q
-|-|-|-|-|-
bit_length|64|128|255|255|128

- x in [0, r-1]
- a + b L = x for (a, b) = split(x).

### Theorem
0 <= a < 1.11 L < H and 0 <= b < L+1 for x in [0, r-1].
0 <= a < L and 0 <= b <= L+1

### Proof

```
Let r0 := L S % r, then S=v L + r0 and r0 in [0, L-1]. In fact, r0 ~ 0.11 L.
Let r1 := x v % S, then x v = b S + r1 and r1 in [0, S-1].
```
S = q * L + r0 where 0 <= r0 < L, r0 ~ 0.11 L
H/2 ~ 0.74 L
x = xH * (H/2) + xL where 0 <= xL < H/2, xH <= (r-1)/(H/2)
```
b <= xv / S < (r-1) (S/L)/S = (r-1)/L = L+1.
```
b = (xH * q) // H <= xH * q / H = xH * H/2 * q / (H * H/2) = (x-xL) * q / S
<= x * (S//L) / S <= x /L <= (r-1) / L = L+1
=> 0 <= x - b L = a
xH * q = b * H + r1 where 0 <= r1 < H
a H = (x - b L) * H = x * H - b * H * L = (xH * (H/2) + xL) * H - (xH * q - r1) * L
= xH * S + xL * H - xH * q * L + r1 * L
= xH * S + xL * H - xH * (S - r0) + r1 * L
= xL * H + xH * r0 + r1 * L
a = xL + xH * r0 / H + r1 * L / H
<= H/2 + (r-1)/(H/2) * r0 / H + (H-1) * L / H
= H/2 + (r-1)/S*r0 + L
= 0.74 L + 0.1 L + L = 1.8 L
```
aS = (x - bL)S = xS - bSL = xS - (xv - r1)L = x(S - vL) + r1 L = r0 x + r1 L
<= r0 (r-1) + (S-1)L = S L + (r-1)r0 - L.
a <= L + ((r-1)r0 - L)/S
((r-1)r0 - L)/S ~ 0.10016 L < 0.11 L.
```
### Remark
If adj is true, then a is in [0, L-1].


## window size
Expand Down
94 changes: 71 additions & 23 deletions misc/mulvec_test.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#include <mcl/ec.hpp>
#include <math.h>

void put(int n, int x)
// dummy function
namespace mcl { namespace fp {
size_t &getRefArgminForce(size_t) {
static size_t v = 0;
return v;
}
} }

void put(size_t n, size_t x)
{
printf(" x=%d(%zd)", x, mcl::ec::costMulVec(n, x));
printf(" x=%zd(%zd)", x, mcl::ec::costMulVec(n, x));
}

int getmin(int n)
size_t getmin(size_t n)
{
int min = 100000000;
int a = 0;
for (int x = 1; x < 30; x++) {
int v = mcl::ec::costMulVec(n, x);
size_t min = 100000000;
size_t a = 0;
for (size_t x = 1; x < 30; x++) {
size_t v = mcl::ec::costMulVec(n, x);
if (v < min) {
a = x;
min = v;
Expand All @@ -20,36 +28,76 @@ int getmin(int n)
return a;
}

void disp(int n)
void disp(size_t n)
{
int x0 = getmin(n);
int x1 = mcl::ec::argminForMulVec(n);
printf("n=%d", n);
size_t x0 = getmin(n);
size_t x1 = mcl::ec::argminForMulVec(n);
printf("n=%zd", n);
put(n, x0);
put(n, x1);
printf(" diff=%d\n", x1-x0);
printf(" diff=%zd\n", x1-x0);
}

inline size_t ilog2(size_t n)
{
if (n == 0) return 0;
return cybozu::bsr(n) + 1;
}

inline size_t costMulVec(size_t n, size_t x)
{
return (n + (size_t(1)<<(x+1))-1)/x;
}
// calculate approximate value such that argmin { x : (n + 2^(x+1)-1)/x }
inline size_t argminForMulVec0(size_t n)
{
if (n <= 16) return 2;
size_t log2n = ilog2(n);
return log2n - ilog2(log2n);
}

/*
First, get approximate value x and compute costMulVec of x-1 and x+1,
and return the minimum value.
*/
inline size_t argminForMulVec(size_t n)
{
size_t x = argminForMulVec0(n);
#if 1
size_t vm1 = x > 1 ? costMulVec(n, x-1) : n;
size_t v0 = costMulVec(n, x);
size_t vp1 = costMulVec(n, x+1);
if (vm1 <= v0) return x-1;
if (vp1 < v0) return x+1;
#endif
return x;
}

int main()
{
for (int i = 1; i < 16; i++) {
for (size_t i = 1; i < 16; i++) {
disp(i);
}
for (int i = 4; i < 30; i++) {
int n = 1 << i;
for (size_t i = 4; i < 30; i++) {
size_t n = size_t(1) << i;
disp(n*0.9);
disp(n);
disp(n*1.1);
}
for (size_t i = 5; i <= 27; i++) {
size_t n = size_t(1) << i;
size_t glvN = n/8*2;
printf("n=2^%zd=%zd v=%zd\n", i, n, getmin(glvN));
}
puts("all search");
for (int i = 1; i < 100000000; i++) {
int x0 = getmin(i);
int x1 = mcl::ec::argminForMulVec(i);
// if (std::abs(x0-x1) > 1) printf("i=%d x0=%d x1=%d\n", i, x0, x1);
if (x0 != x1) printf("i=%d x0=%d x1=%d\n", i, x0, x1);
for (size_t i = 1; i < 200000000; i++) {
size_t x0 = getmin(i);
size_t x1 = argminForMulVec(i);
// if (std::abs(x0-x1) > 1) printf("i=%zd x0=%zd x1=%zd\n", i, x0, x1);
if (x0 != x1) printf("i=%zd x0=%zd x1=%zd\n", i, x0, x1);
}
for (int i = 1; i <= 100000000; i *= 10) {
int x = mcl::ec::argminForMulVec(i);
printf("i=%d x=%d\n", i, x);
for (size_t i = 1; i <= 100000000; i *= 10) {
size_t x = argminForMulVec(i);
printf("i=%zd x=%zd\n", i, x);
}
}
32 changes: 32 additions & 0 deletions test/bls12_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,38 @@ CYBOZU_TEST_AUTO(verifyG2)
CYBOZU_TEST_ASSERT(n == 0);
}

void splitTest(const mpz_class& mx, const mpz_class& L)
{
mcl::Unit x[4], a[2], b[2];
mcl::gmp::getArray(x, 4, mx);
mcl::ec::local::optimizedSplitRawForBLS12_381(a, b, x);
mpz_class ma, mb;
mcl::gmp::setArray(ma, a, 2);
mcl::gmp::setArray(mb, b, 2);
CYBOZU_TEST_EQUAL(mb, mx / L);
CYBOZU_TEST_EQUAL(ma, mx % L);
}

CYBOZU_TEST_AUTO(split)
{
const char *Ls = "ac45a4010001a40200000000ffffffff";
mpz_class L;
mcl::gmp::setStr(L, Ls, 16);
cybozu::XorShift rg;
Fr x;
for (int i = 0; i < 100; i++) {
x.setByCSPRNG(rg);
splitTest(x.getMpz(), L);
}
const mpz_class LL = L*L;
const mpz_class tbl[] = {
0, 1, 2, 3, L-1, L, L+1, L*2, L*2-1, L*2+1, LL-L, LL-1, LL+1, LL+L-2, LL+L-1, LL+L,
};
for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) {
splitTest(tbl[i], L);
}
}

typedef std::vector<Fp> FpVec;

void f(FpVec& zv, const FpVec& xv, const FpVec& yv)
Expand Down

0 comments on commit 92d184c

Please sign in to comment.