diff --git a/common/hash.go b/common/hash.go index e26b2fc6..304af71b 100644 --- a/common/hash.go +++ b/common/hash.go @@ -93,6 +93,47 @@ func SHA512_256i(in ...*big.Int) *big.Int { return new(big.Int).SetBytes(state.Sum(nil)) } +// SHA512_256i_TAGGED tagged version of SHA512_256i +func SHA512_256i_TAGGED(tag []byte, in ...*big.Int) *big.Int { + tagBz := SHA512_256(tag) + var data []byte + state := crypto.SHA512_256.New() + state.Write(tagBz) + state.Write(tagBz) + inLen := len(in) + if inLen == 0 { + return nil + } + bzSize := 0 + // prevent hash collisions with this prefix containing the block count + inLenBz := make([]byte, 64/8) + // converting between int and uint64 doesn't change the sign bit, but it may be interpreted as a larger value. + // this prefix is never read/interpreted, so that doesn't matter. + binary.LittleEndian.PutUint64(inLenBz, uint64(inLen)) + ptrs := make([][]byte, inLen) + for i, n := range in { + if n == nil { + ptrs[i] = zero.Bytes() + } else { + ptrs[i] = n.Bytes() + } + bzSize += len(ptrs[i]) + } + data = make([]byte, 0, len(inLenBz)+bzSize+inLen) + data = append(data, inLenBz...) + for i := range in { + data = append(data, ptrs[i]...) + data = append(data, hashInputDelimiter) // safety delimiter + } + // n < len(data) or an error will never happen. + // see: https://golang.org/pkg/hash/#Hash and https://github.com/golang/go/wiki/Hashing#the-hashhash-interface + if _, err := state.Write(data); err != nil { + Logger.Error(err) + return nil + } + return new(big.Int).SetBytes(state.Sum(nil)) +} + func SHA512_256iOne(in *big.Int) *big.Int { var data []byte state := crypto.SHA512_256.New() diff --git a/common/int.go b/common/int.go index 379da101..54c798cd 100644 --- a/common/int.go +++ b/common/int.go @@ -62,3 +62,10 @@ func (mi *modInt) i() *big.Int { func IsInInterval(b *big.Int, bound *big.Int) bool { return b.Cmp(bound) == -1 && b.Cmp(zero) >= 0 } + +func AppendBigIntToBytesSlice(commonBytes []byte, appended *big.Int) []byte { + resultBytes := make([]byte, len(commonBytes), len(commonBytes)+len(appended.Bytes())) + copy(resultBytes, commonBytes) + resultBytes = append(resultBytes, appended.Bytes()...) + return resultBytes +} diff --git a/common/random.go b/common/random.go index 26074e66..333cf549 100644 --- a/common/random.go +++ b/common/random.go @@ -111,3 +111,19 @@ func GetRandomQuadraticNonResidue(n *big.Int) *big.Int { } } } + +// GetRandomBytes returns random bytes of length. +func GetRandomBytes(length int) ([]byte, error) { + // Per [BIP32], the seed must be in range [MinSeedBytes, MaxSeedBytes]. + if length <= 0 { + return nil, errors.New("invalid length") + } + + buf := make([]byte, length) + _, err := rand.Read(buf) + if err != nil { + return nil, err + } + + return buf, nil +} diff --git a/crypto/facproof/proof.go b/crypto/facproof/proof.go index 69720bed..6458ff93 100644 --- a/crypto/facproof/proof.go +++ b/crypto/facproof/proof.go @@ -10,8 +10,9 @@ import ( "crypto/elliptic" "errors" "fmt" - "github.com/bnb-chain/tss-lib/common" "math/big" + + "github.com/bnb-chain/tss-lib/common" ) const ( @@ -30,31 +31,31 @@ var ( one = big.NewInt(1) ) -// NewProof implements proofFac -func NewProof(ec elliptic.Curve, N0, NCap, s, t, N0p, N0q *big.Int) (*ProofFac, error) { +// NewProof implements prooffac +func NewProof(Session []byte, ec elliptic.Curve, N0, NCap, s, t, N0p, N0q *big.Int) (*ProofFac, error) { if ec == nil || N0 == nil || NCap == nil || s == nil || t == nil || N0p == nil || N0q == nil { return nil, errors.New("ProveFac constructor received nil value(s)") } q := ec.Params().N + q3 := new(big.Int).Mul(q, q) + q3 = new(big.Int).Mul(q, q3) + qNCap := new(big.Int).Mul(q, NCap) + qN0NCap := new(big.Int).Mul(qNCap, N0) + q3NCap := new(big.Int).Mul(q3, NCap) + q3N0NCap := new(big.Int).Mul(q3NCap, N0) sqrtN0 := new(big.Int).Sqrt(N0) - - leSqrtN0 := new(big.Int).Mul(rangeParameter, q) - leSqrtN0 = new(big.Int).Mul(leSqrtN0, sqrtN0) - lNCap := new(big.Int).Mul(rangeParameter, NCap) - lN0NCap := new(big.Int).Mul(lNCap, N0) - leN0NCap := new(big.Int).Mul(lN0NCap, q) - leNCap := new(big.Int).Mul(lNCap, q) + q3SqrtN0 := new(big.Int).Mul(q3, sqrtN0) // Fig 28.1 sample - alpha := common.GetRandomPositiveInt(leSqrtN0) - beta := common.GetRandomPositiveInt(leSqrtN0) - mu := common.GetRandomPositiveInt(lNCap) - nu := common.GetRandomPositiveInt(lNCap) - sigma := common.GetRandomPositiveInt(lN0NCap) - r := common.GetRandomPositiveRelativelyPrimeInt(leN0NCap) - x := common.GetRandomPositiveInt(leNCap) - y := common.GetRandomPositiveInt(leNCap) + alpha := common.GetRandomPositiveInt(q3SqrtN0) + beta := common.GetRandomPositiveInt(q3SqrtN0) + mu := common.GetRandomPositiveInt(qNCap) + nu := common.GetRandomPositiveInt(qNCap) + sigma := common.GetRandomPositiveInt(qN0NCap) + r := common.GetRandomPositiveRelativelyPrimeInt(q3N0NCap) + x := common.GetRandomPositiveInt(q3NCap) + y := common.GetRandomPositiveInt(q3NCap) // Fig 28.1 compute modNCap := common.ModInt(NCap) @@ -76,7 +77,7 @@ func NewProof(ec elliptic.Curve, N0, NCap, s, t, N0p, N0q *big.Int) (*ProofFac, // Fig 28.2 e var e *big.Int { - eHash := common.SHA512_256i(N0, NCap, s, t, P, Q, A, B, T, sigma) + eHash := common.SHA512_256i_TAGGED(Session, N0, NCap, s, t, P, Q, A, B, T, sigma) e = common.RejectionSample(q, eHash) } @@ -120,82 +121,32 @@ func NewProofFromBytes(bzs [][]byte) (*ProofFac, error) { }, nil } -func (pf *ProofFac) Verify(ec elliptic.Curve, N0, NCap, s, t *big.Int) bool { +func (pf *ProofFac) Verify(Session []byte, ec elliptic.Curve, N0, NCap, s, t *big.Int) bool { if pf == nil || !pf.ValidateBasic() || ec == nil || N0 == nil || NCap == nil || s == nil || t == nil { return false } if N0.Sign() != 1 { return false } - if NCap.Sign() != 1 { - return false - } q := ec.Params().N + q3 := new(big.Int).Mul(q, q) + q3 = new(big.Int).Mul(q, q3) sqrtN0 := new(big.Int).Sqrt(N0) - - leSqrtN0 := new(big.Int).Mul(rangeParameter, q) - leSqrtN0 = new(big.Int).Mul(leSqrtN0, sqrtN0) - lNCap := new(big.Int).Mul(rangeParameter, NCap) - lN0NCap := new(big.Int).Mul(lNCap, N0) - leN0NCap2 := new(big.Int).Lsh(new(big.Int).Mul(lN0NCap, q), 1) - leNCap2 := new(big.Int).Lsh(new(big.Int).Mul(lNCap, q), 1) - - if !common.IsInInterval(pf.P, NCap) { - return false - } - if !common.IsInInterval(pf.Q, NCap) { - return false - } - if !common.IsInInterval(pf.A, NCap) { - return false - } - if !common.IsInInterval(pf.B, NCap) { - return false - } - if !common.IsInInterval(pf.T, NCap) { - return false - } - if !common.IsInInterval(pf.Sigma, lN0NCap) { - return false - } - if new(big.Int).GCD(nil, nil, pf.P, NCap).Cmp(one) != 0 { - return false - } - if new(big.Int).GCD(nil, nil, pf.Q, NCap).Cmp(one) != 0 { - return false - } - if new(big.Int).GCD(nil, nil, pf.A, NCap).Cmp(one) != 0 { - return false - } - if new(big.Int).GCD(nil, nil, pf.B, NCap).Cmp(one) != 0 { - return false - } - if new(big.Int).GCD(nil, nil, pf.T, NCap).Cmp(one) != 0 { - return false - } - if !common.IsInInterval(pf.W1, leNCap2) { - return false - } - if !common.IsInInterval(pf.W2, leNCap2) { - return false - } - if !common.IsInInterval(pf.V, leN0NCap2) { - return false - } + q3SqrtN0 := new(big.Int).Mul(q3, sqrtN0) // Fig 28. Range Check - if !common.IsInInterval(pf.Z1, leSqrtN0) { + if !common.IsInInterval(pf.Z1, q3SqrtN0) { return false } - if !common.IsInInterval(pf.Z2, leSqrtN0) { + if !common.IsInInterval(pf.Z2, q3SqrtN0) { return false } var e *big.Int { - eHash := common.SHA512_256i(N0, NCap, s, t, pf.P, pf.Q, pf.A, pf.B, pf.T, pf.Sigma) + eHash := common.SHA512_256i_TAGGED(Session, N0, NCap, s, t, pf.P, pf.Q, pf.A, pf.B, pf.T, pf.Sigma) e = common.RejectionSample(q, eHash) } diff --git a/crypto/facproof/proof_test.go b/crypto/facproof/proof_test.go index 993abb03..ef73ea74 100644 --- a/crypto/facproof/proof_test.go +++ b/crypto/facproof/proof_test.go @@ -23,6 +23,10 @@ const ( testSafePrimeBits = 1024 ) +var ( + Session = []byte("session") +) + func TestFac(test *testing.T) { ec := tss.EC() @@ -33,31 +37,19 @@ func TestFac(test *testing.T) { primes := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)} NCap, s, t, err := crypto.GenerateNTildei(primes) assert.NoError(test, err) - proof, err := NewProof(ec, N0, NCap, s, t, N0p, N0q) + proof, err := NewProof(Session, ec, N0, NCap, s, t, N0p, N0q) assert.NoError(test, err) - ok := proof.Verify(ec, N0, NCap, s, t) + ok := proof.Verify(Session, ec, N0, NCap, s, t) assert.True(test, ok, "proof must verify") N0p = common.GetRandomPrimeInt(1024) N0q = common.GetRandomPrimeInt(1024) N0 = new(big.Int).Mul(N0p, N0q) - proof, err = NewProof(ec, N0, NCap, s, t, N0p, N0q) + proof, err = NewProof(Session, ec, N0, NCap, s, t, N0p, N0q) assert.NoError(test, err) - ok = proof.Verify(ec, N0, NCap, s, t) + ok = proof.Verify(Session, ec, N0, NCap, s, t) assert.True(test, ok, "proof must verify") - - // factor should have bits [1024-16, 1024+16] - smallFactor := 900 - N0p = common.GetRandomPrimeInt(smallFactor) - N0q = common.GetRandomPrimeInt(2048 - smallFactor) - N0 = new(big.Int).Mul(N0p, N0q) - - proof, err = NewProof(ec, N0, NCap, s, t, N0p, N0q) - assert.NoError(test, err) - - ok = proof.Verify(ec, N0, NCap, s, t) - assert.False(test, ok, "proof must not verify") } diff --git a/crypto/modproof/proof.go b/crypto/modproof/proof.go index 3788a2b9..bcbb9ee0 100644 --- a/crypto/modproof/proof.go +++ b/crypto/modproof/proof.go @@ -37,7 +37,7 @@ func isQuadraticResidue(X, N *big.Int) bool { return big.Jacobi(X, N) == 1 } -func NewProof(N, P, Q *big.Int) (*ProofMod, error) { +func NewProof(Session []byte, N, P, Q *big.Int) (*ProofMod, error) { Phi := new(big.Int).Mul(new(big.Int).Sub(P, one), new(big.Int).Sub(Q, one)) // Fig 16.1 W := common.GetRandomQuadraticNonResidue(N) @@ -45,7 +45,7 @@ func NewProof(N, P, Q *big.Int) (*ProofMod, error) { // Fig 16.2 Y := [Iterations]*big.Int{} for i := range Y { - ei := common.SHA512_256i(append([]*big.Int{W, N}, Y[:i]...)...) + ei := common.SHA512_256i_TAGGED(Session, append([]*big.Int{W, N}, Y[:i]...)...) Y[i] = common.RejectionSample(N, ei) } @@ -112,7 +112,7 @@ func NewProofFromBytes(bzs [][]byte) (*ProofMod, error) { }, nil } -func (pf *ProofMod) Verify(N *big.Int) bool { +func (pf *ProofMod) Verify(Session []byte, N *big.Int) bool { if pf == nil || !pf.ValidateBasic() { return false } @@ -143,7 +143,7 @@ func (pf *ProofMod) Verify(N *big.Int) bool { modN := common.ModInt(N) Y := [Iterations]*big.Int{} for i := range Y { - ei := common.SHA512_256i(append([]*big.Int{pf.W, N}, Y[:i]...)...) + ei := common.SHA512_256i_TAGGED(Session, append([]*big.Int{pf.W, N}, Y[:i]...)...) Y[i] = common.RejectionSample(N, ei) } diff --git a/crypto/modproof/proof_test.go b/crypto/modproof/proof_test.go index 0841e0ce..d5380563 100644 --- a/crypto/modproof/proof_test.go +++ b/crypto/modproof/proof_test.go @@ -15,19 +15,23 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + Session = []byte("session") +) + func TestMod(test *testing.T) { preParams, err := keygen.GeneratePreParams(time.Minute*10, 8) assert.NoError(test, err) P, Q, N := preParams.PaillierSK.P, preParams.PaillierSK.Q, preParams.PaillierSK.N - proof, err := NewProof(N, P, Q) + proof, err := NewProof(Session, N, P, Q) assert.NoError(test, err) proofBzs := proof.Bytes() proof, err = NewProofFromBytes(proofBzs[:]) assert.NoError(test, err) - ok := proof.Verify(N) + ok := proof.Verify(Session, N) assert.True(test, ok, "proof must verify") } diff --git a/crypto/mta/proofs.go b/crypto/mta/proofs.go index 2152371d..55142c7a 100644 --- a/crypto/mta/proofs.go +++ b/crypto/mta/proofs.go @@ -36,7 +36,7 @@ type ( // ProveBobWC implements Bob's proof both with or without check "ProveMtawc_Bob" and "ProveMta_Bob" used in the MtA protocol from GG18Spec (9) Figs. 10 & 11. // an absent `X` generates the proof without the X consistency check X = g^x -func ProveBobWC(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int, X *crypto.ECPoint) (*ProofBobWC, error) { +func ProveBobWC(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int, X *crypto.ECPoint) (*ProofBobWC, error) { if pk == nil || NTilde == nil || h1 == nil || h2 == nil || c1 == nil || c2 == nil || x == nil || y == nil || r == nil { return nil, errors.New("ProveBob() received a nil argument") } @@ -103,9 +103,9 @@ func ProveBobWC(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c var eHash *big.Int // X is nil if called by ProveBob (Bob's proof "without check") if X == nil { - eHash = common.SHA512_256i(append(pk.AsInts(), c1, c2, z, zPrm, t, v, w)...) + eHash = common.SHA512_256i_TAGGED(Session, append(pk.AsInts(), c1, c2, z, zPrm, t, v, w)...) } else { - eHash = common.SHA512_256i(append(pk.AsInts(), X.X(), X.Y(), c1, c2, u.X(), u.Y(), z, zPrm, t, v, w)...) + eHash = common.SHA512_256i_TAGGED(Session, append(pk.AsInts(), X.X(), X.Y(), c1, c2, u.X(), u.Y(), z, zPrm, t, v, w)...) } e = common.RejectionSample(q, eHash) } @@ -139,10 +139,10 @@ func ProveBobWC(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c } // ProveBob implements Bob's proof "ProveMta_Bob" used in the MtA protocol from GG18Spec (9) Fig. 11. -func ProveBob(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int) (*ProofBob, error) { +func ProveBob(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int) (*ProofBob, error) { // the Bob proof ("with check") contains the ProofBob "without check"; this method extracts and returns it // X is supplied as nil to exclude it from the proof hash - pf, err := ProveBobWC(ec, pk, NTilde, h1, h2, c1, c2, x, y, r, nil) + pf, err := ProveBobWC(Session, ec, pk, NTilde, h1, h2, c1, c2, x, y, r, nil) if err != nil { return nil, err } @@ -189,7 +189,7 @@ func ProofBobFromBytes(bzs [][]byte) (*ProofBob, error) { // ProveBobWC.Verify implements verification of Bob's proof with check "VerifyMtawc_Bob" used in the MtA protocol from GG18Spec (9) Fig. 10. // an absent `X` verifies a proof generated without the X consistency check X = g^x -func (pf *ProofBobWC) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2 *big.Int, X *crypto.ECPoint) bool { +func (pf *ProofBobWC) Verify(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2 *big.Int, X *crypto.ECPoint) bool { if pk == nil || NTilde == nil || h1 == nil || h2 == nil || c1 == nil || c2 == nil { return false } @@ -274,12 +274,12 @@ func (pf *ProofBobWC) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, var eHash *big.Int // X is nil if called on a ProveBob (Bob's proof "without check") if X == nil { - eHash = common.SHA512_256i(append(pk.AsInts(), c1, c2, pf.Z, pf.ZPrm, pf.T, pf.V, pf.W)...) + eHash = common.SHA512_256i_TAGGED(Session, append(pk.AsInts(), c1, c2, pf.Z, pf.ZPrm, pf.T, pf.V, pf.W)...) } else { if !tss.SameCurve(ec, X.Curve()) { return false } - eHash = common.SHA512_256i(append(pk.AsInts(), X.X(), X.Y(), c1, c2, pf.U.X(), pf.U.Y(), pf.Z, pf.ZPrm, pf.T, pf.V, pf.W)...) + eHash = common.SHA512_256i_TAGGED(Session, append(pk.AsInts(), X.X(), X.Y(), c1, c2, pf.U.X(), pf.U.Y(), pf.Z, pf.ZPrm, pf.T, pf.V, pf.W)...) } e = common.RejectionSample(q, eHash) } @@ -340,12 +340,12 @@ func (pf *ProofBobWC) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, } // ProveBob.Verify implements verification of Bob's proof without check "VerifyMta_Bob" used in the MtA protocol from GG18Spec (9) Fig. 11. -func (pf *ProofBob) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2 *big.Int) bool { +func (pf *ProofBob) Verify(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2 *big.Int) bool { if pf == nil { return false } pfWC := &ProofBobWC{ProofBob: pf, U: nil} - return pfWC.Verify(ec, pk, NTilde, h1, h2, c1, c2, nil) + return pfWC.Verify(Session, ec, pk, NTilde, h1, h2, c1, c2, nil) } func (pf *ProofBob) ValidateBasic() bool { diff --git a/crypto/mta/share_protocol.go b/crypto/mta/share_protocol.go index b8c8be71..40e1b924 100644 --- a/crypto/mta/share_protocol.go +++ b/crypto/mta/share_protocol.go @@ -30,6 +30,7 @@ func AliceInit( } func BobMid( + Session []byte, ec elliptic.Curve, pkA *paillier.PublicKey, pf *RangeProofAlice, @@ -57,11 +58,12 @@ func BobMid( return } beta = common.ModInt(q).Sub(zero, betaPrm) - piB, err = ProveBob(ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand) + piB, err = ProveBob(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand) return } func BobMidWC( + Session []byte, ec elliptic.Curve, pkA *paillier.PublicKey, pf *RangeProofAlice, @@ -90,18 +92,19 @@ func BobMidWC( return } beta = common.ModInt(q).Sub(zero, betaPrm) - piB, err = ProveBobWC(ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand, B) + piB, err = ProveBobWC(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand, B) return } func AliceEnd( + Session []byte, ec elliptic.Curve, pkA *paillier.PublicKey, pf *ProofBob, h1A, h2A, cA, cB, NTildeA *big.Int, sk *paillier.PrivateKey, ) (*big.Int, error) { - if !pf.Verify(ec, pkA, NTildeA, h1A, h2A, cA, cB) { + if !pf.Verify(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB) { return nil, errors.New("ProofBob.Verify() returned false") } alphaPrm, err := sk.Decrypt(cB) @@ -113,6 +116,7 @@ func AliceEnd( } func AliceEndWC( + Session []byte, ec elliptic.Curve, pkA *paillier.PublicKey, pf *ProofBobWC, @@ -120,7 +124,7 @@ func AliceEndWC( cA, cB, NTildeA, h1A, h2A *big.Int, sk *paillier.PrivateKey, ) (*big.Int, error) { - if !pf.Verify(ec, pkA, NTildeA, h1A, h2A, cA, cB, B) { + if !pf.Verify(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, B) { return nil, errors.New("ProofBobWC.Verify() returned false") } alphaPrm, err := sk.Decrypt(cB) diff --git a/crypto/mta/share_protocol_test.go b/crypto/mta/share_protocol_test.go index 313163e1..81ad3c93 100644 --- a/crypto/mta/share_protocol_test.go +++ b/crypto/mta/share_protocol_test.go @@ -26,6 +26,10 @@ const ( testPaillierKeyLength = 2048 ) +var ( + Session = []byte("session") +) + func TestShareProtocol(t *testing.T) { q := tss.EC().Params().N @@ -46,10 +50,10 @@ func TestShareProtocol(t *testing.T) { cA, pf, err := AliceInit(tss.EC(), pk, a, NTildej, h1j, h2j) assert.NoError(t, err) - _, cB, betaPrm, pfB, err := BobMid(tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j) + _, cB, betaPrm, pfB, err := BobMid(Session, tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j) assert.NoError(t, err) - alpha, err := AliceEnd(tss.EC(), pk, pfB, h1i, h2i, cA, cB, NTildei, sk) + alpha, err := AliceEnd(Session, tss.EC(), pk, pfB, h1i, h2i, cA, cB, NTildei, sk) assert.NoError(t, err) // expect: alpha = ab + betaPrm @@ -82,10 +86,10 @@ func TestShareProtocolWC(t *testing.T) { gBPoint, err := crypto.NewECPoint(tss.EC(), gBX, gBY) assert.NoError(t, err) - _, cB, betaPrm, pfB, err := BobMidWC(tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j, gBPoint) + _, cB, betaPrm, pfB, err := BobMidWC(Session, tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j, gBPoint) assert.NoError(t, err) - alpha, err := AliceEndWC(tss.EC(), pk, pfB, gBPoint, cA, cB, NTildei, h1i, h2i, sk) + alpha, err := AliceEndWC(Session, tss.EC(), pk, pfB, gBPoint, cA, cB, NTildei, h1i, h2i, sk) assert.NoError(t, err) // expect: alpha = ab + betaPrm diff --git a/crypto/schnorr/schnorr_proof.go b/crypto/schnorr/schnorr_proof.go index 61a95862..bf12d44e 100644 --- a/crypto/schnorr/schnorr_proof.go +++ b/crypto/schnorr/schnorr_proof.go @@ -27,7 +27,7 @@ type ( ) // NewZKProof constructs a new Schnorr ZK proof of knowledge of the discrete logarithm (GG18Spec Fig. 16) -func NewZKProof(x *big.Int, X *crypto.ECPoint) (*ZKProof, error) { +func NewZKProof(Session []byte, x *big.Int, X *crypto.ECPoint) (*ZKProof, error) { if x == nil || X == nil || !X.ValidateBasic() { return nil, errors.New("ZKProof constructor received nil or invalid value(s)") } @@ -41,7 +41,7 @@ func NewZKProof(x *big.Int, X *crypto.ECPoint) (*ZKProof, error) { var c *big.Int { - cHash := common.SHA512_256i(X.X(), X.Y(), g.X(), g.Y(), alpha.X(), alpha.Y()) + cHash := common.SHA512_256i_TAGGED(Session, X.X(), X.Y(), g.X(), g.Y(), alpha.X(), alpha.Y()) c = common.RejectionSample(q, cHash) } t := new(big.Int).Mul(c, x) @@ -51,7 +51,7 @@ func NewZKProof(x *big.Int, X *crypto.ECPoint) (*ZKProof, error) { } // NewZKProof verifies a new Schnorr ZK proof of knowledge of the discrete logarithm (GG18Spec Fig. 16) -func (pf *ZKProof) Verify(X *crypto.ECPoint) bool { +func (pf *ZKProof) Verify(Session []byte, X *crypto.ECPoint) bool { if pf == nil || !pf.ValidateBasic() { return false } @@ -62,7 +62,7 @@ func (pf *ZKProof) Verify(X *crypto.ECPoint) bool { var c *big.Int { - cHash := common.SHA512_256i(X.X(), X.Y(), g.X(), g.Y(), pf.Alpha.X(), pf.Alpha.Y()) + cHash := common.SHA512_256i_TAGGED(Session, X.X(), X.Y(), g.X(), g.Y(), pf.Alpha.X(), pf.Alpha.Y()) c = common.RejectionSample(q, cHash) } tG := crypto.ScalarBaseMult(ec, pf.T) @@ -79,7 +79,7 @@ func (pf *ZKProof) ValidateBasic() bool { } // NewZKProof constructs a new Schnorr ZK proof of knowledge s_i, l_i such that V_i = R^s_i, g^l_i (GG18Spec Fig. 17) -func NewZKVProof(V, R *crypto.ECPoint, s, l *big.Int) (*ZKVProof, error) { +func NewZKVProof(Session []byte, V, R *crypto.ECPoint, s, l *big.Int) (*ZKVProof, error) { if V == nil || R == nil || s == nil || l == nil || !V.ValidateBasic() || !R.ValidateBasic() { return nil, errors.New("ZKVProof constructor received nil value(s)") } @@ -95,7 +95,7 @@ func NewZKVProof(V, R *crypto.ECPoint, s, l *big.Int) (*ZKVProof, error) { var c *big.Int { - cHash := common.SHA512_256i(V.X(), V.Y(), R.X(), R.Y(), g.X(), g.Y(), alpha.X(), alpha.Y()) + cHash := common.SHA512_256i_TAGGED(Session, V.X(), V.Y(), R.X(), R.Y(), g.X(), g.Y(), alpha.X(), alpha.Y()) c = common.RejectionSample(q, cHash) } modQ := common.ModInt(q) @@ -105,7 +105,7 @@ func NewZKVProof(V, R *crypto.ECPoint, s, l *big.Int) (*ZKVProof, error) { return &ZKVProof{Alpha: alpha, T: t, U: u}, nil } -func (pf *ZKVProof) Verify(V, R *crypto.ECPoint) bool { +func (pf *ZKVProof) Verify(Session []byte, V, R *crypto.ECPoint) bool { if pf == nil || !pf.ValidateBasic() { return false } @@ -116,7 +116,7 @@ func (pf *ZKVProof) Verify(V, R *crypto.ECPoint) bool { var c *big.Int { - cHash := common.SHA512_256i(V.X(), V.Y(), R.X(), R.Y(), g.X(), g.Y(), pf.Alpha.X(), pf.Alpha.Y()) + cHash := common.SHA512_256i_TAGGED(Session, V.X(), V.Y(), R.X(), R.Y(), g.X(), g.Y(), pf.Alpha.X(), pf.Alpha.Y()) c = common.RejectionSample(q, cHash) } tR := R.ScalarMult(pf.T) diff --git a/crypto/schnorr/schnorr_proof_test.go b/crypto/schnorr/schnorr_proof_test.go index c81fed4d..91175024 100644 --- a/crypto/schnorr/schnorr_proof_test.go +++ b/crypto/schnorr/schnorr_proof_test.go @@ -17,11 +17,15 @@ import ( "github.com/bnb-chain/tss-lib/tss" ) +var ( + Session = []byte("session") +) + func TestSchnorrProof(t *testing.T) { q := tss.EC().Params().N u := common.GetRandomPositiveInt(q) uG := crypto.ScalarBaseMult(tss.EC(), u) - proof, _ := NewZKProof(u, uG) + proof, _ := NewZKProof(Session, u, uG) assert.True(t, proof.Alpha.IsOnCurve()) assert.NotZero(t, proof.Alpha.X()) @@ -34,8 +38,8 @@ func TestSchnorrProofVerify(t *testing.T) { u := common.GetRandomPositiveInt(q) X := crypto.ScalarBaseMult(tss.EC(), u) - proof, _ := NewZKProof(u, X) - res := proof.Verify(X) + proof, _ := NewZKProof(Session, u, X) + res := proof.Verify(Session, X) assert.True(t, res, "verify result must be true") } @@ -47,8 +51,8 @@ func TestSchnorrProofVerifyBadX(t *testing.T) { X := crypto.ScalarBaseMult(tss.EC(), u) X2 := crypto.ScalarBaseMult(tss.EC(), u2) - proof, _ := NewZKProof(u2, X2) - res := proof.Verify(X) + proof, _ := NewZKProof(Session, u2, X2) + res := proof.Verify(Session, X) assert.False(t, res, "verify result must be false") } @@ -63,8 +67,8 @@ func TestSchnorrVProofVerify(t *testing.T) { lG := crypto.ScalarBaseMult(tss.EC(), l) V, _ := Rs.Add(lG) - proof, _ := NewZKVProof(V, R, s, l) - res := proof.Verify(V, R) + proof, _ := NewZKVProof(Session, V, R, s, l) + res := proof.Verify(Session, V, R) assert.True(t, res, "verify result must be true") } @@ -78,8 +82,8 @@ func TestSchnorrVProofVerifyBadPartialV(t *testing.T) { Rs := R.ScalarMult(s) V := Rs - proof, _ := NewZKVProof(V, R, s, l) - res := proof.Verify(V, R) + proof, _ := NewZKVProof(Session, V, R, s, l) + res := proof.Verify(Session, V, R) assert.False(t, res, "verify result must be false") } @@ -95,8 +99,8 @@ func TestSchnorrVProofVerifyBadS(t *testing.T) { lG := crypto.ScalarBaseMult(tss.EC(), l) V, _ := Rs.Add(lG) - proof, _ := NewZKVProof(V, R, s2, l) - res := proof.Verify(V, R) + proof, _ := NewZKVProof(Session, V, R, s2, l) + res := proof.Verify(Session, V, R) assert.False(t, res, "verify result must be false") } diff --git a/ecdsa/keygen/local_party.go b/ecdsa/keygen/local_party.go index 9d0066e3..5513aba9 100644 --- a/ecdsa/keygen/local_party.go +++ b/ecdsa/keygen/local_party.go @@ -49,6 +49,8 @@ type ( ui *big.Int // used for tests KGCs []cmt.HashCommitment vs vss.Vs + ssid []byte + ssidNonce *big.Int shares vss.Shares deCommitPolyG cmt.HashDeCommitment } diff --git a/ecdsa/keygen/prepare.go b/ecdsa/keygen/prepare.go index 6076e58e..775ef53e 100644 --- a/ecdsa/keygen/prepare.go +++ b/ecdsa/keygen/prepare.go @@ -24,6 +24,8 @@ const ( safePrimeBitLen = 1024 // Ticker for printing log statements while generating primes/modulus logProgressTickInterval = 8 * time.Second + // Safe big len using random for ssid + SafeBitLen = 1024 ) // GeneratePreParams finds two safe primes and computes the Paillier secret required for the protocol. diff --git a/ecdsa/keygen/round_1.go b/ecdsa/keygen/round_1.go index 9cfda2df..96104475 100644 --- a/ecdsa/keygen/round_1.go +++ b/ecdsa/keygen/round_1.go @@ -100,8 +100,14 @@ func (round *round1) Start() *tss.Error { // and keep in temporary storage: // - VSS Vs // - our set of Shamir shares + round.temp.ssidNonce = new(big.Int).SetUint64(0) round.save.ShareID = ids[i] round.temp.vs = vs + ssid, err := round.getSSID() + if err != nil { + return round.WrapError(errors.New("failed to generate ssid")) + } + round.temp.ssid = ssid round.temp.shares = shares // for this P: SAVE de-commitments, paillier keys for round 2 diff --git a/ecdsa/keygen/round_2.go b/ecdsa/keygen/round_2.go index 4bf6f155..364235e9 100644 --- a/ecdsa/keygen/round_2.go +++ b/ecdsa/keygen/round_2.go @@ -9,11 +9,12 @@ package keygen import ( "encoding/hex" "errors" - "github.com/bnb-chain/tss-lib/crypto/facproof" - "github.com/bnb-chain/tss-lib/crypto/modproof" "math/big" "sync" + "github.com/bnb-chain/tss-lib/crypto/facproof" + "github.com/bnb-chain/tss-lib/crypto/modproof" + "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/tss" ) @@ -112,13 +113,14 @@ func (round *round2) Start() *tss.Error { // 5. p2p send share ij to Pj shares := round.temp.shares + ContextI := append(round.temp.ssid, big.NewInt(int64(i)).Bytes()...) for j, Pj := range round.Parties().IDs() { facProof := &facproof.ProofFac{P: zero, Q: zero, A: zero, B: zero, T: zero, Sigma: zero, Z1: zero, Z2: zero, W1: zero, W2: zero, V: zero} if !round.Params().NoProofFac() { var err error - facProof, err = facproof.NewProof(round.EC(), round.save.PaillierSK.N, round.save.NTildej[j], + facProof, err = facproof.NewProof(ContextI, round.EC(), round.save.PaillierSK.N, round.save.NTildej[j], round.save.H1j[j], round.save.H2j[j], round.save.PaillierSK.P, round.save.PaillierSK.Q) if err != nil { return round.WrapError(err, round.PartyID()) @@ -138,7 +140,7 @@ func (round *round2) Start() *tss.Error { modProof := &modproof.ProofMod{W: zero, X: *new([80]*big.Int), A: zero, B: zero, Z: *new([80]*big.Int)} if !round.Parameters.NoProofMod() { var err error - modProof, err = modproof.NewProof(round.save.PaillierSK.N, + modProof, err = modproof.NewProof(ContextI, round.save.PaillierSK.N, round.save.PaillierSK.P, round.save.PaillierSK.Q) if err != nil { return round.WrapError(err, round.PartyID()) diff --git a/ecdsa/keygen/round_3.go b/ecdsa/keygen/round_3.go index 24d0b59a..205569ed 100644 --- a/ecdsa/keygen/round_3.go +++ b/ecdsa/keygen/round_3.go @@ -65,6 +65,7 @@ func (round *round3) Start() *tss.Error { if j == PIdx { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) // 6-8. go func(j int, ch chan<- vssOut) { // 4-9. @@ -92,7 +93,7 @@ func (round *round3) Start() *tss.Error { ch <- vssOut{errors.New("modProof verify failed"), nil} return } - if ok = modProof.Verify(round.save.PaillierPKs[j].N); !ok { + if ok = modProof.Verify(ContextJ, round.save.PaillierPKs[j].N); !ok { ch <- vssOut{errors.New("modProof verify failed"), nil} return } @@ -117,7 +118,7 @@ func (round *round3) Start() *tss.Error { ch <- vssOut{errors.New("facProof verify failed"), nil} return } - if ok = facProof.Verify(round.EC(), round.save.PaillierPKs[j].N, round.save.NTildei, + if ok = facProof.Verify(ContextJ, round.EC(), round.save.PaillierPKs[j].N, round.save.NTildei, round.save.H1i, round.save.H2i); !ok { ch <- vssOut{errors.New("facProof verify failed"), nil} return diff --git a/ecdsa/keygen/rounds.go b/ecdsa/keygen/rounds.go index 313184ab..5bce6d10 100644 --- a/ecdsa/keygen/rounds.go +++ b/ecdsa/keygen/rounds.go @@ -7,6 +7,9 @@ package keygen import ( + "math/big" + + "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/tss" ) @@ -94,3 +97,14 @@ func (round *base) resetOK() { round.ok[j] = false } } + +// get ssid from local params +func (round *base) getSSID() ([]byte, error) { + ssidList := []*big.Int{round.EC().Params().P, round.EC().Params().N, round.EC().Params().Gx, round.EC().Params().Gy} // ec curve + ssidList = append(ssidList, round.Parties().IDs().Keys()...) + ssidList = append(ssidList, big.NewInt(int64(round.number))) // round number + ssidList = append(ssidList, round.temp.ssidNonce) + ssid := common.SHA512_256i(ssidList...).Bytes() + + return ssid, nil +} diff --git a/ecdsa/resharing/local_party.go b/ecdsa/resharing/local_party.go index 7e0b200a..a8efdbb0 100644 --- a/ecdsa/resharing/local_party.go +++ b/ecdsa/resharing/local_party.go @@ -58,6 +58,9 @@ type ( newXi *big.Int newKs []*big.Int newBigXjs []*crypto.ECPoint // Xj to save in round 5 + + ssid []byte + ssidNonce *big.Int } ) diff --git a/ecdsa/resharing/round_1_old_step_1.go b/ecdsa/resharing/round_1_old_step_1.go index 6358512d..954494b4 100644 --- a/ecdsa/resharing/round_1_old_step_1.go +++ b/ecdsa/resharing/round_1_old_step_1.go @@ -9,6 +9,7 @@ package resharing import ( "errors" "fmt" + "math/big" "github.com/bnb-chain/tss-lib/crypto" "github.com/bnb-chain/tss-lib/crypto/commitments" @@ -38,6 +39,12 @@ func (round *round1) Start() *tss.Error { } round.allOldOK() + round.temp.ssidNonce = new(big.Int).SetUint64(uint64(0)) + ssid, err := round.getSSID() + if err != nil { + return round.WrapError(err) + } + round.temp.ssid = ssid Pi := round.PartyID() i := Pi.Index diff --git a/ecdsa/resharing/round_2_new_step_1.go b/ecdsa/resharing/round_2_new_step_1.go index 482a64c9..5fec50a9 100644 --- a/ecdsa/resharing/round_2_new_step_1.go +++ b/ecdsa/resharing/round_2_new_step_1.go @@ -8,9 +8,10 @@ package resharing import ( "errors" - "github.com/bnb-chain/tss-lib/crypto/modproof" "math/big" + "github.com/bnb-chain/tss-lib/crypto/modproof" + "github.com/bnb-chain/tss-lib/crypto/dlnproof" "github.com/bnb-chain/tss-lib/ecdsa/keygen" "github.com/bnb-chain/tss-lib/tss" @@ -77,9 +78,10 @@ func (round *round2) Start() *tss.Error { dlnProof2 := dlnproof.NewDLNProof(h2i, h1i, beta, p, q, NTildei) modProof := &modproof.ProofMod{W: zero, X: *new([80]*big.Int), A: zero, B: zero, Z: *new([80]*big.Int)} + ContextI := append(round.temp.ssid, big.NewInt(int64(i)).Bytes()...) if !round.Parameters.NoProofMod() { var err error - modProof, err = modproof.NewProof(preParams.PaillierSK.N, preParams.PaillierSK.P, preParams.PaillierSK.Q) + modProof, err = modproof.NewProof(ContextI, preParams.PaillierSK.N, preParams.PaillierSK.P, preParams.PaillierSK.Q) if err != nil { return round.WrapError(err, Pi) } diff --git a/ecdsa/resharing/round_4_new_step_2.go b/ecdsa/resharing/round_4_new_step_2.go index 83131793..55eeb3a1 100644 --- a/ecdsa/resharing/round_4_new_step_2.go +++ b/ecdsa/resharing/round_4_new_step_2.go @@ -9,10 +9,11 @@ package resharing import ( "encoding/hex" "errors" - "github.com/bnb-chain/tss-lib/crypto/facproof" "math/big" "sync" + "github.com/bnb-chain/tss-lib/crypto/facproof" + errors2 "github.com/pkg/errors" "github.com/bnb-chain/tss-lib/common" @@ -84,7 +85,8 @@ func (round *round4) Start() *tss.Error { common.Logger.Warningf("modProof verify failed for party %s", msg.GetFrom(), err) return } - if ok := modProof.Verify(paiPK.N); !ok { + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) + if ok := modProof.Verify(ContextJ, paiPK.N); !ok { paiProofCulprits[j] = msg.GetFrom() common.Logger.Warningf("modProof verify failed for party %s", msg.GetFrom(), err) } @@ -215,10 +217,11 @@ func (round *round4) Start() *tss.Error { if j == i { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) facProof := &facproof.ProofFac{P: zero, Q: zero, A: zero, B: zero, T: zero, Sigma: zero, Z1: zero, Z2: zero, W1: zero, W2: zero, V: zero} if !round.Parameters.NoProofFac() { - facProof, err = facproof.NewProof(round.EC(), round.save.PaillierSK.N, round.save.NTildej[j], + facProof, err = facproof.NewProof(ContextJ, round.EC(), round.save.PaillierSK.N, round.save.NTildej[j], round.save.H1j[j], round.save.H2j[j], round.save.PaillierSK.P, round.save.PaillierSK.Q) if err != nil { return round.WrapError(err, Pi) diff --git a/ecdsa/resharing/round_5_new_step_3.go b/ecdsa/resharing/round_5_new_step_3.go index 64304c69..355ffaa2 100644 --- a/ecdsa/resharing/round_5_new_step_3.go +++ b/ecdsa/resharing/round_5_new_step_3.go @@ -8,6 +8,8 @@ package resharing import ( "errors" + "math/big" + "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/tss" ) @@ -45,6 +47,7 @@ func (round *round5) Start() *tss.Error { if j == i { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) r4msg1 := msg.Content().(*DGRound4Message1) proof, err := r4msg1.UnmarshalFacProof() if err != nil && round.Parameters.NoProofFac() { @@ -54,7 +57,7 @@ func (round *round5) Start() *tss.Error { common.Logger.Warningf("facProof verify failed for party %s", msg.GetFrom(), err) return round.WrapError(err, round.NewParties().IDs()[j]) } - if ok := proof.Verify(round.EC(), round.save.PaillierPKs[j].N, round.save.NTildei, + if ok := proof.Verify(ContextJ, round.EC(), round.save.PaillierPKs[j].N, round.save.NTildei, round.save.H1i, round.save.H2i); !ok { common.Logger.Warningf("facProof verify failed for party %s", msg.GetFrom(), err) return round.WrapError(err, round.NewParties().IDs()[j]) diff --git a/ecdsa/resharing/rounds.go b/ecdsa/resharing/rounds.go index bcebf453..8dc78a5b 100644 --- a/ecdsa/resharing/rounds.go +++ b/ecdsa/resharing/rounds.go @@ -7,6 +7,11 @@ package resharing import ( + "errors" + "math/big" + + "github.com/bnb-chain/tss-lib/common" + "github.com/bnb-chain/tss-lib/crypto" "github.com/bnb-chain/tss-lib/ecdsa/keygen" "github.com/bnb-chain/tss-lib/tss" ) @@ -133,3 +138,22 @@ func (round *base) allNewOK() { round.newOK[j] = true } } + +// get ssid from local params +func (round *base) getSSID() ([]byte, error) { + ssidList := []*big.Int{round.EC().Params().P, round.EC().Params().N, round.EC().Params().B, round.EC().Params().Gx, round.EC().Params().Gy} // ec curve + ssidList = append(ssidList, round.Parties().IDs().Keys()...) // parties + BigXjList, err := crypto.FlattenECPoints(round.input.BigXj) + if err != nil { + return nil, round.WrapError(errors.New("read BigXj failed"), round.PartyID()) + } + ssidList = append(ssidList, BigXjList...) // BigXj + ssidList = append(ssidList, round.input.NTildej...) // NTilde + ssidList = append(ssidList, round.input.H1j...) // h1 + ssidList = append(ssidList, round.input.H2j...) // h2 + ssidList = append(ssidList, big.NewInt(int64(round.number))) // round number + ssidList = append(ssidList, round.temp.ssidNonce) + ssid := common.SHA512_256i(ssidList...).Bytes() + + return ssid, nil +} diff --git a/ecdsa/signing/local_party.go b/ecdsa/signing/local_party.go index ae202590..c6575179 100644 --- a/ecdsa/signing/local_party.go +++ b/ecdsa/signing/local_party.go @@ -91,6 +91,9 @@ type ( Ui, Ti *crypto.ECPoint DTelda cmt.HashDeCommitment + + ssidNonce *big.Int + ssid []byte } ) diff --git a/ecdsa/signing/round_1.go b/ecdsa/signing/round_1.go index 920930db..e1b8d0e3 100644 --- a/ecdsa/signing/round_1.go +++ b/ecdsa/signing/round_1.go @@ -45,6 +45,12 @@ func (round *round1) Start() *tss.Error { round.number = 1 round.started = true round.resetOK() + round.temp.ssidNonce = new(big.Int).SetUint64(0) + ssid, err := round.getSSID() + if err != nil { + return round.WrapError(err) + } + round.temp.ssid = ssid k := common.GetRandomPositiveInt(round.Params().EC().Params().N) gamma := common.GetRandomPositiveInt(round.Params().EC().Params().N) diff --git a/ecdsa/signing/round_2.go b/ecdsa/signing/round_2.go index 79702e0d..489d22bd 100644 --- a/ecdsa/signing/round_2.go +++ b/ecdsa/signing/round_2.go @@ -8,6 +8,7 @@ package signing import ( "errors" + "math/big" "sync" errorspkg "github.com/pkg/errors" @@ -30,6 +31,7 @@ func (round *round2) Start() *tss.Error { errChs := make(chan *tss.Error, (len(round.Parties().IDs())-1)*2) wg := sync.WaitGroup{} wg.Add((len(round.Parties().IDs()) - 1) * 2) + ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) for j, Pj := range round.Parties().IDs() { if j == i { continue @@ -44,6 +46,7 @@ func (round *round2) Start() *tss.Error { return } beta, c1ji, _, pi1ji, err := mta.BobMid( + ContextI, round.Parameters.EC(), round.key.PaillierPKs[j], rangeProofAliceJ, @@ -73,6 +76,7 @@ func (round *round2) Start() *tss.Error { return } v, c2ji, _, pi2ji, err := mta.BobMidWC( + ContextI, round.Parameters.EC(), round.key.PaillierPKs[j], rangeProofAliceJ, diff --git a/ecdsa/signing/round_3.go b/ecdsa/signing/round_3.go index 87c4f8b4..05cc9588 100644 --- a/ecdsa/signing/round_3.go +++ b/ecdsa/signing/round_3.go @@ -38,6 +38,7 @@ func (round *round3) Start() *tss.Error { if j == i { continue } + ContextJ := append(round.temp.ssid, new(big.Int).SetUint64(uint64(j)).Bytes()...) // Alice_end go func(j int, Pj *tss.PartyID) { defer wg.Done() @@ -48,6 +49,7 @@ func (round *round3) Start() *tss.Error { return } alphaIj, err := mta.AliceEnd( + ContextJ, round.Params().EC(), round.key.PaillierPKs[i], proofBob, @@ -72,6 +74,7 @@ func (round *round3) Start() *tss.Error { return } uIj, err := mta.AliceEndWC( + ContextJ, round.Params().EC(), round.key.PaillierPKs[i], proofBobWC, diff --git a/ecdsa/signing/round_4.go b/ecdsa/signing/round_4.go index 9048ff00..afc4b59e 100644 --- a/ecdsa/signing/round_4.go +++ b/ecdsa/signing/round_4.go @@ -41,7 +41,9 @@ func (round *round4) Start() *tss.Error { // compute the multiplicative inverse thelta mod q thetaInverse = modN.ModInverse(thetaInverse) - piGamma, err := schnorr.NewZKProof(round.temp.gamma, round.temp.pointGamma) + i := round.PartyID().Index + ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) + piGamma, err := schnorr.NewZKProof(ContextI, round.temp.gamma, round.temp.pointGamma) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(gamma, bigGamma)")) } diff --git a/ecdsa/signing/round_5.go b/ecdsa/signing/round_5.go index bcaefa5a..0162a778 100644 --- a/ecdsa/signing/round_5.go +++ b/ecdsa/signing/round_5.go @@ -8,6 +8,7 @@ package signing import ( "errors" + "math/big" errors2 "github.com/pkg/errors" @@ -30,6 +31,7 @@ func (round *round5) Start() *tss.Error { if j == round.PartyID().Index { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) r1msg2 := round.temp.signRound1Message2s[j].Content().(*SignRound1Message2) r4msg := round.temp.signRound4Messages[j].Content().(*SignRound4Message) SCj, SDj := r1msg2.UnmarshalCommitment(), r4msg.UnmarshalDeCommitment() @@ -46,7 +48,7 @@ func (round *round5) Start() *tss.Error { if err != nil { return round.WrapError(errors.New("failed to unmarshal bigGamma proof"), Pj) } - ok = proof.Verify(bigGammaJPoint) + ok = proof.Verify(ContextJ, bigGammaJPoint) if !ok { return round.WrapError(errors.New("failed to prove bigGamma"), Pj) } diff --git a/ecdsa/signing/round_6.go b/ecdsa/signing/round_6.go index de930654..b104fdd5 100644 --- a/ecdsa/signing/round_6.go +++ b/ecdsa/signing/round_6.go @@ -8,6 +8,7 @@ package signing import ( "errors" + "math/big" errors2 "github.com/pkg/errors" @@ -23,11 +24,13 @@ func (round *round6) Start() *tss.Error { round.started = true round.resetOK() - piAi, err := schnorr.NewZKProof(round.temp.roi, round.temp.bigAi) + i := round.PartyID().Index + ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) + piAi, err := schnorr.NewZKProof(ContextI, round.temp.roi, round.temp.bigAi) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(roi, bigAi)")) } - piV, err := schnorr.NewZKVProof(round.temp.bigVi, round.temp.bigR, round.temp.si, round.temp.li) + piV, err := schnorr.NewZKVProof(ContextI, round.temp.bigVi, round.temp.bigR, round.temp.si, round.temp.li) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKVProof(bigVi, bigR, si, li)")) } diff --git a/ecdsa/signing/round_7.go b/ecdsa/signing/round_7.go index 3242b64b..d2314ca1 100644 --- a/ecdsa/signing/round_7.go +++ b/ecdsa/signing/round_7.go @@ -32,6 +32,7 @@ func (round *round7) Start() *tss.Error { if j == round.PartyID().Index { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) r5msg := round.temp.signRound5Messages[j].Content().(*SignRound5Message) r6msg := round.temp.signRound6Messages[j].Content().(*SignRound6Message) cj, dj := r5msg.UnmarshalCommitment(), r6msg.UnmarshalDeCommitment() @@ -52,11 +53,11 @@ func (round *round7) Start() *tss.Error { } bigAjs[j] = bigAj pijA, err := r6msg.UnmarshalZKProof(round.Params().EC()) - if err != nil || !pijA.Verify(bigAj) { + if err != nil || !pijA.Verify(ContextJ, bigAj) { return round.WrapError(errors.New("schnorr verify for Aj failed"), Pj) } pijV, err := r6msg.UnmarshalZKVProof(round.Params().EC()) - if err != nil || !pijV.Verify(bigVj, round.temp.bigR) { + if err != nil || !pijV.Verify(ContextJ, bigVj, round.temp.bigR) { return round.WrapError(errors.New("vverify for Vj failed"), Pj) } } diff --git a/ecdsa/signing/rounds.go b/ecdsa/signing/rounds.go index b546b656..5a281b04 100644 --- a/ecdsa/signing/rounds.go +++ b/ecdsa/signing/rounds.go @@ -7,7 +7,11 @@ package signing import ( + "errors" + "math/big" + "github.com/bnb-chain/tss-lib/common" + "github.com/bnb-chain/tss-lib/crypto" "github.com/bnb-chain/tss-lib/ecdsa/keygen" "github.com/bnb-chain/tss-lib/tss" ) @@ -121,3 +125,22 @@ func (round *base) resetOK() { round.ok[j] = false } } + +// get ssid from local params +func (round *base) getSSID() ([]byte, error) { + ssidList := []*big.Int{round.EC().Params().P, round.EC().Params().N, round.EC().Params().B, round.EC().Params().Gx, round.EC().Params().Gy} // ec curve + ssidList = append(ssidList, round.Parties().IDs().Keys()...) // parties + BigXjList, err := crypto.FlattenECPoints(round.key.BigXj) + if err != nil { + return nil, round.WrapError(errors.New("read BigXj failed"), round.PartyID()) + } + ssidList = append(ssidList, BigXjList...) // BigXj + ssidList = append(ssidList, round.key.NTildej...) // NTilde + ssidList = append(ssidList, round.key.H1j...) // h1 + ssidList = append(ssidList, round.key.H2j...) // h2 + ssidList = append(ssidList, big.NewInt(int64(round.number))) // round number + ssidList = append(ssidList, round.temp.ssidNonce) + ssid := common.SHA512_256i(ssidList...).Bytes() + + return ssid, nil +} diff --git a/eddsa/keygen/local_party.go b/eddsa/keygen/local_party.go index 8d94f5d4..39900f2e 100644 --- a/eddsa/keygen/local_party.go +++ b/eddsa/keygen/local_party.go @@ -51,6 +51,9 @@ type ( vs vss.Vs shares vss.Shares deCommitPolyG cmt.HashDeCommitment + + ssid []byte + ssidNonce *big.Int } ) diff --git a/eddsa/keygen/round_1.go b/eddsa/keygen/round_1.go index a799d27c..4608de03 100644 --- a/eddsa/keygen/round_1.go +++ b/eddsa/keygen/round_1.go @@ -38,6 +38,13 @@ func (round *round1) Start() *tss.Error { Pi := round.PartyID() i := Pi.Index + round.temp.ssidNonce = new(big.Int).SetUint64(0) + ssid, err := round.getSSID() + if err != nil { + return round.WrapError(err) + } + round.temp.ssid = ssid + // 1. calculate "partial" key share ui ui := common.GetRandomPositiveInt(round.Params().EC().Params().N) round.temp.ui = ui diff --git a/eddsa/keygen/round_2.go b/eddsa/keygen/round_2.go index 0db0d1b5..0148bcf6 100644 --- a/eddsa/keygen/round_2.go +++ b/eddsa/keygen/round_2.go @@ -8,6 +8,7 @@ package keygen import ( "errors" + "math/big" errors2 "github.com/pkg/errors" @@ -45,7 +46,8 @@ func (round *round2) Start() *tss.Error { } // 5. compute Schnorr prove - pii, err := schnorr.NewZKProof(round.temp.ui, round.temp.vs[0]) + ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) + pii, err := schnorr.NewZKProof(ContextI, round.temp.ui, round.temp.vs[0]) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(ui, vi0)")) } diff --git a/eddsa/keygen/round_3.go b/eddsa/keygen/round_3.go index 7a82f883..393af636 100644 --- a/eddsa/keygen/round_3.go +++ b/eddsa/keygen/round_3.go @@ -65,6 +65,8 @@ func (round *round3) Start() *tss.Error { if j == PIdx { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) + // 6-9. go func(j int, ch chan<- vssOut) { // 4-10. @@ -92,7 +94,7 @@ func (round *round3) Start() *tss.Error { ch <- vssOut{errors.New("failed to unmarshal schnorr proof"), nil} return } - ok = proof.Verify(PjVs[0]) + ok = proof.Verify(ContextJ, PjVs[0]) if !ok { ch <- vssOut{errors.New("failed to prove schnorr proof"), nil} return diff --git a/eddsa/keygen/rounds.go b/eddsa/keygen/rounds.go index f87f47a4..153a3b28 100644 --- a/eddsa/keygen/rounds.go +++ b/eddsa/keygen/rounds.go @@ -7,6 +7,9 @@ package keygen import ( + "math/big" + + "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/tss" ) @@ -82,3 +85,14 @@ func (round *base) resetOK() { round.ok[j] = false } } + +// get ssid from local params +func (round *base) getSSID() ([]byte, error) { + ssidList := []*big.Int{round.EC().Params().P, round.EC().Params().N, round.EC().Params().Gx, round.EC().Params().Gy} // ec curve + ssidList = append(ssidList, round.Parties().IDs().Keys()...) + ssidList = append(ssidList, big.NewInt(int64(round.number))) // round number + ssidList = append(ssidList, round.temp.ssidNonce) + ssid := common.SHA512_256i(ssidList...).Bytes() + + return ssid, nil +} diff --git a/eddsa/signing/local_party.go b/eddsa/signing/local_party.go index 56aa5f1c..115988a5 100644 --- a/eddsa/signing/local_party.go +++ b/eddsa/signing/local_party.go @@ -59,6 +59,9 @@ type ( // round 3 r *big.Int + + ssid []byte + ssidNonce *big.Int } ) diff --git a/eddsa/signing/round_1.go b/eddsa/signing/round_1.go index 7af1d807..0ece1fcb 100644 --- a/eddsa/signing/round_1.go +++ b/eddsa/signing/round_1.go @@ -9,6 +9,7 @@ package signing import ( "errors" "fmt" + "math/big" "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/crypto" @@ -32,6 +33,12 @@ func (round *round1) Start() *tss.Error { round.started = true round.resetOK() + round.temp.ssidNonce = new(big.Int).SetUint64(0) + var err error + round.temp.ssid, err = round.getSSID() + if err != nil { + return round.WrapError(err) + } // 1. select ri ri := common.GetRandomPositiveInt(round.Params().EC().Params().N) diff --git a/eddsa/signing/round_2.go b/eddsa/signing/round_2.go index 6aa89657..027eb519 100644 --- a/eddsa/signing/round_2.go +++ b/eddsa/signing/round_2.go @@ -8,6 +8,7 @@ package signing import ( "errors" + "math/big" errors2 "github.com/pkg/errors" @@ -32,7 +33,8 @@ func (round *round2) Start() *tss.Error { } // 2. compute Schnorr prove - pir, err := schnorr.NewZKProof(round.temp.ri, round.temp.pointRi) + ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) + pir, err := schnorr.NewZKProof(ContextI, round.temp.ri, round.temp.pointRi) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(ri, pointRi)")) } diff --git a/eddsa/signing/round_3.go b/eddsa/signing/round_3.go index cbcd103f..3f636f2c 100644 --- a/eddsa/signing/round_3.go +++ b/eddsa/signing/round_3.go @@ -8,8 +8,10 @@ package signing import ( "crypto/sha512" + "math/big" "github.com/agl/ed25519/edwards25519" + "github.com/bnb-chain/tss-lib/common" "github.com/pkg/errors" "github.com/bnb-chain/tss-lib/crypto" @@ -38,6 +40,7 @@ func (round *round3) Start() *tss.Error { continue } + ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) msg := round.temp.signRound2Messages[j] r2msg := msg.Content().(*SignRound2Message) cmtDeCmt := commitments.HashCommitDecommit{C: round.temp.cjs[j], D: r2msg.UnmarshalDeCommitment()} @@ -58,7 +61,7 @@ func (round *round3) Start() *tss.Error { if err != nil { return round.WrapError(errors.New("failed to unmarshal Rj proof"), Pj) } - ok = proof.Verify(Rj) + ok = proof.Verify(ContextJ, Rj) if !ok { return round.WrapError(errors.New("failed to prove Rj"), Pj) } diff --git a/eddsa/signing/rounds.go b/eddsa/signing/rounds.go index 57adcd8d..841ed087 100644 --- a/eddsa/signing/rounds.go +++ b/eddsa/signing/rounds.go @@ -7,7 +7,11 @@ package signing import ( + "errors" + "math/big" + "github.com/bnb-chain/tss-lib/common" + "github.com/bnb-chain/tss-lib/crypto" "github.com/bnb-chain/tss-lib/eddsa/keygen" "github.com/bnb-chain/tss-lib/tss" ) @@ -97,3 +101,19 @@ func (round *base) resetOK() { round.ok[j] = false } } + +// get ssid from local params +func (round *base) getSSID() ([]byte, error) { + ssidList := []*big.Int{round.EC().Params().P, round.EC().Params().N, round.EC().Params().Gx, round.EC().Params().Gy} // ec curve + ssidList = append(ssidList, round.Parties().IDs().Keys()...) // parties + BigXjList, err := crypto.FlattenECPoints(round.key.BigXj) + if err != nil { + return nil, round.WrapError(errors.New("read BigXj failed"), round.PartyID()) + } + ssidList = append(ssidList, BigXjList...) // BigXj + ssidList = append(ssidList, big.NewInt(int64(round.number))) // round number + ssidList = append(ssidList, round.temp.ssidNonce) + ssid := common.SHA512_256i(ssidList...).Bytes() + + return ssid, nil +}