Skip to content

Commit

Permalink
EdDSA: Explicit guard against infinite looping
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdettman committed Mar 20, 2024
1 parent deb0954 commit ebe1c75
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 20 deletions.
14 changes: 12 additions & 2 deletions core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed25519.java
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,12 @@ private static boolean implVerify(byte[] sig, int sigOff, byte[] pk, int pkOff,

int[] v0 = new int[4];
int[] v1 = new int[4];
Scalar25519.reduceBasisVar(nA, v0, v1);

if (!Scalar25519.reduceBasisVar(nA, v0, v1))
{
throw new IllegalStateException();
}

Scalar25519.multiply128Var(nS, v1, nS);

PointAccum pZ = new PointAccum();
Expand Down Expand Up @@ -628,7 +633,12 @@ private static boolean implVerify(byte[] sig, int sigOff, PublicPoint publicPoin

int[] v0 = new int[4];
int[] v1 = new int[4];
Scalar25519.reduceBasisVar(nA, v0, v1);

if (!Scalar25519.reduceBasisVar(nA, v0, v1))
{
throw new IllegalStateException();
}

Scalar25519.multiply128Var(nS, v1, nS);

PointAccum pZ = new PointAccum();
Expand Down
14 changes: 12 additions & 2 deletions core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed448.java
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,12 @@ private static boolean implVerify(byte[] sig, int sigOff, byte[] pk, int pkOff,

int[] v0 = new int[8];
int[] v1 = new int[8];
Scalar448.reduceBasisVar(nA, v0, v1);

if (!Scalar448.reduceBasisVar(nA, v0, v1))
{
throw new IllegalStateException();
}

Scalar448.multiply225Var(nS, v1, nS);

PointProjective pZ = new PointProjective();
Expand Down Expand Up @@ -569,7 +574,12 @@ private static boolean implVerify(byte[] sig, int sigOff, PublicPoint publicPoin

int[] v0 = new int[8];
int[] v1 = new int[8];
Scalar448.reduceBasisVar(nA, v0, v1);

if (!Scalar448.reduceBasisVar(nA, v0, v1))
{
throw new IllegalStateException();
}

Scalar448.multiply225Var(nS, v1, nS);

PointProjective pZ = new PointProjective();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ static byte[] reduce512(byte[] n)
return r;
}

static void reduceBasisVar(int[] k, int[] z0, int[] z1)
static boolean reduceBasisVar(int[] k, int[] z0, int[] z1)
{
/*
* Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
Expand All @@ -312,11 +312,18 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1)
int[] v0 = new int[4]; System.arraycopy(k, 0, v0, 0, 4);
int[] v1 = new int[4]; v1[0] = 1;

// Conservative upper bound on the number of loop iterations needed
int iterations = TARGET_LENGTH * 4;
int last = 15;
int len_Nv = ScalarUtil.getBitLengthPositive(last, Nv);

while (len_Nv > TARGET_LENGTH)
{
if (--iterations < 0)
{
return false;
}

int len_p = ScalarUtil.getBitLength(last, p);
int s = len_p - len_Nv;
s &= ~(s >> 31);
Expand Down Expand Up @@ -346,6 +353,7 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1)
// v1 * k == v0 mod L
System.arraycopy(v0, 0, z0, 0, 4);
System.arraycopy(v1, 0, z1, 0, 4);
return true;
}

static void toSignedDigits(int bits, int[] z)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ static byte[] reduce912(byte[] n)
return r;
}

static void reduceBasisVar(int[] k, int[] z0, int[] z1)
static boolean reduceBasisVar(int[] k, int[] z0, int[] z1)
{
/*
* Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
Expand All @@ -577,11 +577,18 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1)
int[] v0 = new int[8]; System.arraycopy(k, 0, v0, 0, 8);
int[] v1 = new int[8]; v1[0] = 1;

// Conservative upper bound on the number of loop iterations needed
int iterations = TARGET_LENGTH * 4;
int last = 27;
int len_Nv = ScalarUtil.getBitLengthPositive(last, Nv);

while (len_Nv > TARGET_LENGTH)
{
if (--iterations < 0)
{
return false;
}

int len_p = ScalarUtil.getBitLength(last, p);
int s = len_p - len_Nv;
s &= ~(s >> 31);
Expand Down Expand Up @@ -614,6 +621,7 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1)
// v1 * k == v0 mod L
System.arraycopy(v0, 0, z0, 0, 8);
System.arraycopy(v1, 0, z1, 0, 8);
return true;
}

static void toSignedDigits(int bits, int[] x, int[] z)
Expand Down
28 changes: 14 additions & 14 deletions core/src/main/java/org/bouncycastle/math/ec/rfc8032/ScalarUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ static void addShifted_NP(int last, int s, int[] Nu, int[] Nv, int[] p, int[] t)

cc_p += p_i & M;
cc_p += Nv[i] & M;
p_i = (int)cc_p; cc_p >>= 32;
p[i] = p_i;
p_i = (int)cc_p; cc_p >>>= 32;
p[i] = p_i;

cc_Nu += p_i & M;
Nu[i] = (int)cc_Nu; cc_Nu >>= 32;
Nu[i] = (int)cc_Nu; cc_Nu >>>= 32;
}
}
else if (s < 32)
Expand All @@ -50,20 +50,20 @@ else if (s < 32)

cc_p += p_i & M;
cc_p += v_s & M;
p_i = (int)cc_p; cc_p >>= 32;
p_i = (int)cc_p; cc_p >>>= 32;
p[i] = p_i;

int q_s = (p_i << s) | (prev_q >>> -s);
prev_q =p_i;
prev_q = p_i;

cc_Nu += q_s & M;
Nu[i] = (int)cc_Nu; cc_Nu >>= 32;
Nu[i] = (int)cc_Nu; cc_Nu >>>= 32;
}
}
else
{
// Keep the original value of p in t.
System.arraycopy(p, 0, t, 0, p.length);
// Copy the low limbs of the original p
System.arraycopy(p, 0, t, 0, last);

int sWords = s >>> 5; int sBits = s & 31;
if (sBits == 0)
Expand All @@ -75,10 +75,10 @@ else if (s < 32)

cc_p += p[i] & M;
cc_p += Nv[i - sWords] & M;
p[i] = (int)cc_p; cc_p >>= 32;
p[i] = (int)cc_p; cc_p >>>= 32;

cc_Nu += p[i - sWords] & M;
Nu[i] = (int)cc_Nu; cc_Nu >>= 32;
Nu[i] = (int)cc_Nu; cc_Nu >>>= 32;
}
}
else
Expand All @@ -102,14 +102,14 @@ else if (s < 32)

cc_p += p[i] & M;
cc_p += v_s & M;
p[i] = (int)cc_p; cc_p >>= 32;
p[i] = (int)cc_p; cc_p >>>= 32;

int next_q = p[i - sWords];
int q_s = (next_q << sBits) | (prev_q >>> -sBits);
prev_q = next_q;

cc_Nu += q_s & M;
Nu[i] = (int)cc_Nu; cc_Nu >>= 32;
Nu[i] = (int)cc_Nu; cc_Nu >>>= 32;
}
}
}
Expand Down Expand Up @@ -251,8 +251,8 @@ else if (s < 32)
}
else
{
// Keep the original value of p in t.
System.arraycopy(p, 0, t, 0, p.length);
// Copy the low limbs of the original p
System.arraycopy(p, 0, t, 0, last);

int sWords = s >>> 5; int sBits = s & 31;
if (sBits == 0)
Expand Down

0 comments on commit ebe1c75

Please sign in to comment.