Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize BDN Signature/Key Aggregation #546

Merged
merged 9 commits into from
Sep 24, 2024
84 changes: 37 additions & 47 deletions sign/bdn/bdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package bdn
import (
"crypto/cipher"
"errors"
"fmt"
"math/big"

"go.dedis.ch/kyber/v4"
Expand All @@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI
// We also use the entire roster so that the coefficient will vary for the same
// public key used in different roster
func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) {
peers := make([][]byte, len(pubs))
for i, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}

peers[i] = peer
}

h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil)
if err != nil {
return nil, err
}

for _, peer := range peers {
_, err := h.Write(peer)
for _, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}
_, err = h.Write(peer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -128,62 +122,58 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error {

// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128}
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
if len(sigs) != mask.CountEnabled() {
return nil, errors.New("length of signatures and public keys must match")
}

coefs, err := hashPointToR(mask.Publics())
if err != nil {
return nil, err
}

func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *Mask) (kyber.Point, error) {
agg := scheme.sigGroup.Point()
for i, buf := range sigs {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
// this should never happen as we check the lenths at the beginning
// an error here is probably a bug in the mask
return nil, errors.New("couldn't find the index")
for i := range mask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

if len(sigs) == 0 {
return nil, errors.New("length of signatures and public keys must match")
}

buf := sigs[0]
sigs = sigs[1:]

sig := scheme.sigGroup.Point()
err = sig.UnmarshalBinary(buf)
err := sig.UnmarshalBinary(buf)
if err != nil {
return nil, err
}

sigC := sig.Clone().Mul(coefs[peerIndex], sig)
sigC := sig.Clone().Mul(mask.publicCoefs[i], sig)
// c+1 because R is in the range [1, 2^128] and not [0, 2^128-1]
sigC = sigC.Add(sigC, sig)
agg = agg.Add(agg, sigC)
}

if len(sigs) > 0 {
return nil, errors.New("length of signatures and public keys must match")
}

return agg, nil
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: keyGroup -> R with R = {1, ..., 2^128}.
func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) {
coefs, err := hashPointToR(mask.Publics())
if err != nil {
return nil, err
}

func (scheme *Scheme) AggregatePublicKeys(mask *Mask) (kyber.Point, error) {
agg := scheme.keyGroup.Point()
for i := 0; i < mask.CountEnabled(); i++ {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
for i := range mask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, errors.New("couldn't find the index")
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

pub := mask.Publics()[peerIndex]
pubC := pub.Clone().Mul(coefs[peerIndex], pub)
pubC = pubC.Add(pubC, pub)
agg = agg.Add(agg, pubC)
agg = agg.Add(agg, mask.publicTerms[i])
}

return agg, nil
Expand Down Expand Up @@ -217,14 +207,14 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error {
// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128}
// Deprecated: use the new scheme methods instead.
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask)
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: G2 -> R with R = {1, ..., 2^128}.
// Deprecated: use the new scheme methods instead.
func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) {
func AggregatePublicKeys(suite pairing.Suite, mask *Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregatePublicKeys(mask)
}
110 changes: 104 additions & 6 deletions sign/bdn/bdn_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package bdn

import (
"encoding"
"encoding/hex"
"fmt"
"testing"

"github.com/stretchr/testify/require"
"go.dedis.ch/kyber/v4"
"go.dedis.ch/kyber/v4/pairing/bls12381/kilic"
"go.dedis.ch/kyber/v4/pairing/bn256"
"go.dedis.ch/kyber/v4/sign"
"go.dedis.ch/kyber/v4/sign/bls"
"go.dedis.ch/kyber/v4/util/random"
)
Expand All @@ -30,7 +32,7 @@ func TestBDN_HashPointToR_BN256(t *testing.T) {
require.Equal(t, "933f6013eb3f654f9489d6d45ad04eaf", coefs[2].String())
require.Equal(t, 16, coefs[0].MarshalSize())

mask, _ := sign.NewMask([]kyber.Point{p1, p2, p3}, nil)
mask, _ := NewMask([]kyber.Point{p1, p2, p3}, nil)
mask.SetBit(0, true)
mask.SetBit(1, true)
mask.SetBit(2, true)
Expand All @@ -54,7 +56,7 @@ func TestBDN_AggregateSignatures(t *testing.T) {
sig2, err := Sign(suite, private2, msg)
require.NoError(t, err)

mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil)
mask, _ := NewMask([]kyber.Point{public1, public2}, nil)
mask.SetBit(0, true)
mask.SetBit(1, true)

Expand Down Expand Up @@ -92,7 +94,7 @@ func TestBDN_SubsetSignature(t *testing.T) {
sig2, err := Sign(suite, private2, msg)
require.NoError(t, err)

mask, _ := sign.NewMask([]kyber.Point{public1, public3, public2}, nil)
mask, _ := NewMask([]kyber.Point{public1, public3, public2}, nil)
mask.SetBit(0, true)
mask.SetBit(2, true)

Expand Down Expand Up @@ -131,7 +133,7 @@ func TestBDN_RogueAttack(t *testing.T) {
require.NoError(t, scheme.Verify(agg, msg, sig))

// New scheme that should detect
mask, _ := sign.NewMask(pubs, nil)
mask, _ := NewMask(pubs, nil)
mask.SetBit(0, true)
mask.SetBit(1, true)
agg, err = AggregatePublicKeys(suite, mask)
Expand All @@ -149,7 +151,7 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
sig2, err := Sign(suite, private2, msg)
require.Nil(b, err)

mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil)
mask, _ := NewMask([]kyber.Point{public1, public2}, nil)
mask.SetBit(0, true)
mask.SetBit(1, false)

Expand All @@ -158,3 +160,99 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
AggregateSignatures(suite, [][]byte{sig1, sig2}, mask)
}
}

func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) {
suite := kilic.NewBLS12381Suite()
schemeOnG2 := NewSchemeOnG2(suite)

rng := random.New()
pubKeys := make([]kyber.Point, 3000)
privKeys := make([]kyber.Scalar, 3000)
for i := range pubKeys {
privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng)
}

mask, err := NewMask(pubKeys, nil)
require.NoError(b, err)
for i := range pubKeys {
require.NoError(b, mask.SetBit(i, true))
}

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sigs := make([][]byte, len(privKeys))
for i, k := range privKeys {
s, err := schemeOnG2.Sign(k, msg)
require.NoError(b, err)
sigs[i] = s
}

sig, err := schemeOnG2.AggregateSignatures(sigs, mask)
require.NoError(b, err)
sigb, err := sig.MarshalBinary()
require.NoError(b, err)

b.ResetTimer()
for i := 0; i < b.N; i++ {
pk, err := schemeOnG2.AggregatePublicKeys(mask)
require.NoError(b, err)
require.NoError(b, schemeOnG2.Verify(pk, msg, sigb))
}
}

func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T {
t.Helper()
b, err := hex.DecodeString(s)
require.NoError(t, err)
require.NoError(t, into.UnmarshalBinary(b))
return into
}

// This tests exists to make sure we don't accidentally make breaking changes to signature
// aggregation by using checking against known aggregated signatures and keys.
func TestBDNFixtures(t *testing.T) {
suite := bn256.NewSuite()
schemeOnG1 := NewSchemeOnG1(suite)

public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493")
private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38")
public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f")
private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581")
public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec")
private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa")

sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8")
require.NoError(t, err)
sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484")
require.NoError(t, err)
sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195")
require.NoError(t, err)

aggSigExp := unmarshalHex(t, suite.G1().Point(), "43c1d2ad5a7d71a08f3cd7495db6b3c81a4547af1b76438b2f215e85ec178fea048f93f6ffed65a69ea757b47761e7178103bb347fd79689652e55b6e0054af2")
aggKeyExp := unmarshalHex(t, suite.G2().Point(), "43b5161ede207b9a69fc93114b0c5022b76cc22e813ba739c7e622d826b132333cd637505399963b94e393ec7f5d4875f82391620b34be1fde1f232204fa4f723935d4dbfb725f059456bcf2557f846c03190969f7b800e904d25b0b5bcbdd421c9877d443f0313c3425dfc1e7e646b665d27b9e649faadef1129f95670d70e1")

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sig1, err := schemeOnG1.Sign(private1, msg)
require.Nil(t, err)
require.Equal(t, sig1Exp, sig1)

sig2, err := schemeOnG1.Sign(private2, msg)
require.Nil(t, err)
require.Equal(t, sig2Exp, sig2)

sig3, err := schemeOnG1.Sign(private3, msg)
require.Nil(t, err)
require.Equal(t, sig3Exp, sig3)

mask, _ := NewMask([]kyber.Point{public1, public2, public3}, nil)
mask.SetBit(0, true)
mask.SetBit(1, false)
mask.SetBit(2, true)

aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig3}, mask)
require.NoError(t, err)
require.True(t, aggSigExp.Equal(aggSig))

aggKey, err := schemeOnG1.AggregatePublicKeys(mask)
require.NoError(t, err)
require.True(t, aggKeyExp.Equal(aggKey))
}
55 changes: 52 additions & 3 deletions sign/mask.go → sign/bdn/mask.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
// Package sign contains useful tools for the different signing algorithms.
package sign
package bdn

import (
"errors"
"fmt"
"slices"

"go.dedis.ch/kyber/v4"
)

// Mask is a bitmask of the participation to a collective signature.
type Mask struct {
mask []byte
// The bitmask indicating which public keys are enabled/disabled for aggregation. This is
// the only mutable field.
mask []byte

// The following fields are immutable and should not be changed after the mask is created.
// They may be shared between multiple masks.

// Public keys for aggregation & signature verification.
publics []kyber.Point
// Coefficients used when aggregating signatures.
publicCoefs []kyber.Scalar
// Terms used to aggregate public keys
publicTerms []kyber.Point
}

// NewMask creates a new mask from a list of public keys. If a key is provided, it
// will set the bit of the key to 1 or return an error if it is not found.
//
// The returned Mask will contain pre-computed terms and coefficients for all provided public
// keys, so it should be re-used for optimal performance (e.g., by creating a "base" mask and
// cloning it whenever aggregating signatures and/or public keys).
func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) {
m := &Mask{
publics: publics,
Expand All @@ -33,6 +48,18 @@ func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) {
return nil, errors.New("key not found")
}

var err error
m.publicCoefs, err = hashPointToR(publics)
if err != nil {
return nil, fmt.Errorf("failed to hash public keys: %w", err)
}

m.publicTerms = make([]kyber.Point, len(publics))
for i, pub := range publics {
pubC := pub.Clone().Mul(m.publicCoefs[i], pub)
m.publicTerms[i] = pubC.Add(pubC, pub)
}

return m, nil
}

Expand All @@ -58,6 +85,17 @@ func (m *Mask) SetMask(mask []byte) error {
return nil
}

// GetBit returns true if the given bit is set.
func (m *Mask) GetBit(i int) (bool, error) {
if i >= len(m.publics) || i < 0 {
return false, errors.New("index out of range")
}

byteIndex := i / 8
mask := byte(1) << uint(i&7)
return m.mask[byteIndex]&mask != 0, nil
}

// SetBit turns on or off the bit at the given index.
func (m *Mask) SetBit(i int, enable bool) error {
if i >= len(m.publics) || i < 0 {
Expand Down Expand Up @@ -170,3 +208,14 @@ func (m *Mask) Merge(mask []byte) error {

return nil
}

// Clone copies the mask while keeping the precomputed coefficients, etc. This method is thread safe
// and does not modify the original mask. Modifications to the new Mask will not affect the original.
func (m *Mask) Clone() *Mask {
return &Mask{
mask: slices.Clone(m.mask),
publics: m.publics,
publicCoefs: m.publicCoefs,
publicTerms: m.publicTerms,
}
Stebalien marked this conversation as resolved.
Show resolved Hide resolved
}
Loading
Loading