Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node: Processor performance improvements #3988

29 changes: 29 additions & 0 deletions node/pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,35 @@ func (d *Database) StoreSignedVAA(v *vaa.VAA) error {
return nil
}

// StoreSignedVAABatch writes multiple VAAs to the database using the BadgerDB batch API.
// Note that the API takes care of splitting up the slice into the maximum allowed count
// and size so we don't need to worry about that.
func (d *Database) StoreSignedVAABatch(vaaBatch []*vaa.VAA) error {
batchTx := d.db.NewWriteBatch()
defer batchTx.Cancel()

for _, v := range vaaBatch {
if len(v.Signatures) == 0 {
panic("StoreSignedVAABatch called for unsigned VAA")
}

b, err := v.Marshal()
if err != nil {
panic("StoreSignedVAABatch failed to marshal VAA")
}

err = batchTx.Set(VaaIDFromVAA(v).Bytes(), b)
if err != nil {
return err
}
}

// Wait for the batch to finish.
err := batchTx.Flush()
storedVaaTotal.Add(float64(len(vaaBatch)))
bruce-riley marked this conversation as resolved.
Show resolved Hide resolved
return err
}

func (d *Database) HasVAA(id VAAID) (bool, error) {
err := d.db.View(func(txn *badger.Txn) error {
_, err := txn.Get(id.Bytes())
Expand Down
69 changes: 68 additions & 1 deletion node/pkg/db/db_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"bytes"
"crypto/ecdsa"
"crypto/rand"
"fmt"
Expand All @@ -22,6 +23,10 @@ import (
)

func getVAA() vaa.VAA {
return getVAAWithSeqNum(1)
}

func getVAAWithSeqNum(seqNum uint64) vaa.VAA {
var payload = []byte{97, 97, 97, 97, 97, 97}
var governanceEmitter = vaa.Address{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4}

Expand All @@ -31,7 +36,7 @@ func getVAA() vaa.VAA {
Signatures: nil,
Timestamp: time.Unix(0, 0),
Nonce: uint32(1),
Sequence: uint64(1),
Sequence: seqNum,
ConsistencyLevel: uint8(32),
EmitterChain: vaa.ChainIDSolana,
EmitterAddress: governanceEmitter,
Expand Down Expand Up @@ -114,6 +119,68 @@ func TestStoreSignedVAASigned(t *testing.T) {
assert.NoError(t, err2)
}

func TestStoreSignedVAABatch(t *testing.T) {
dbPath := t.TempDir()
db, err := Open(dbPath)
if err != nil {
t.Error("failed to open database")
}
defer db.Close()
defer os.Remove(dbPath)

privKey, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
require.NoError(t, err)

require.Less(t, int64(0), db.db.MaxBatchCount()) // In testing this was 104857.
require.Less(t, int64(0), db.db.MaxBatchSize()) // In testing this was 10066329.

// Make sure we exceed the max batch size.
numVAAs := uint64(db.db.MaxBatchCount() + 1)

// Build the VAA batch.
vaaBatch := make([]*vaa.VAA, 0, numVAAs)
for seqNum := uint64(0); seqNum < numVAAs; seqNum++ {
v := getVAAWithSeqNum(seqNum)
v.AddSignature(privKey, 0)
vaaBatch = append(vaaBatch, &v)
}

// Store the batch in the database.
err = db.StoreSignedVAABatch(vaaBatch)
require.NoError(t, err)

// Verify all the VAAs are in the database.
for _, v := range vaaBatch {
storedBytes, err := db.GetSignedVAABytes(*VaaIDFromVAA(v))
require.NoError(t, err)

origBytes, err := v.Marshal()
require.NoError(t, err)

assert.True(t, bytes.Equal(origBytes, storedBytes))
}

// Verify that updates work as well by tweaking the VAAs and rewriting them.
for _, v := range vaaBatch {
v.Nonce += 1
}

// Store the updated batch in the database.
err = db.StoreSignedVAABatch(vaaBatch)
require.NoError(t, err)

// Verify all the updated VAAs are in the database.
for _, v := range vaaBatch {
storedBytes, err := db.GetSignedVAABytes(*VaaIDFromVAA(v))
require.NoError(t, err)

origBytes, err := v.Marshal()
require.NoError(t, err)

assert.True(t, bytes.Equal(origBytes, storedBytes))
}
}

func TestGetSignedVAABytes(t *testing.T) {
dbPath := t.TempDir()
db, err := Open(dbPath)
Expand Down
54 changes: 9 additions & 45 deletions node/pkg/processor/broadcast.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package processor

import (
"encoding/hex"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"

ethcommon "github.com/ethereum/go-ethereum/common"
ethCommon "github.com/ethereum/go-ethereum/common"
"google.golang.org/protobuf/proto"

node_common "github.com/certusone/wormhole/node/pkg/common"
gossipv1 "github.com/certusone/wormhole/node/pkg/proto/gossip/v1"
"github.com/wormhole-foundation/wormhole/sdk/vaa"
)
Expand All @@ -22,31 +18,26 @@ var (
Help: "Total number of signed observations queued for broadcast",
})

observationsPostedInternally = promauto.NewCounter(
prometheus.CounterOpts{
Name: "wormhole_observations_posted_internally",
Help: "Total number of our observations posted internally",
})

signedVAAsBroadcast = promauto.NewCounter(
prometheus.CounterOpts{
Name: "wormhole_signed_vaas_queued_for_broadcast",
Help: "Total number of signed vaas queued for broadcast",
})
)

// broadcastSignature broadcasts the observation for something we observed locally.
func (p *Processor) broadcastSignature(
o Observation,
signature []byte,
messageID string,
txhash []byte,
) {
digest := o.SigningDigest()
digest ethCommon.Hash,
signature []byte,
) (*gossipv1.SignedObservation, []byte) {
obsv := gossipv1.SignedObservation{
Addr: p.ourAddr.Bytes(),
Hash: digest.Bytes(),
Signature: signature,
TxHash: txhash,
MessageId: o.MessageID(),
MessageId: messageID,
}

w := gossipv1.GossipMessage{Message: &gossipv1.GossipMessage_SignedObservation{SignedObservation: &obsv}}
Expand All @@ -59,37 +50,10 @@ func (p *Processor) broadcastSignature(
// Broadcast the observation.
p.gossipSendC <- msg
observationsBroadcast.Inc()

hash := hex.EncodeToString(digest.Bytes())

if p.state.signatures[hash] == nil {
p.state.signatures[hash] = &state{
firstObserved: time.Now(),
nextRetry: time.Now().Add(nextRetryDuration(0)),
signatures: map[ethcommon.Address][]byte{},
source: "loopback",
}
}

p.state.signatures[hash].ourObservation = o
p.state.signatures[hash].ourMsg = msg
p.state.signatures[hash].txHash = txhash
p.state.signatures[hash].source = o.GetEmitterChain().String()
p.state.signatures[hash].gs = p.gs // guaranteed to match ourObservation - there's no concurrent access to p.gs

// Fast path for our own signature
// send to obsvC directly if there is capacity, otherwise do it in a go routine.
// We can't block here because the same process would be responsible for reading from obsvC.
om := node_common.CreateMsgWithTimestamp[gossipv1.SignedObservation](&obsv)
select {
case p.obsvC <- om:
default:
go func() { p.obsvC <- om }()
}

observationsPostedInternally.Inc()
return &obsv, msg
}

// broadcastSignedVAA broadcasts a VAA to the gossip network.
func (p *Processor) broadcastSignedVAA(v *vaa.VAA) {
b, err := v.Marshal()
if err != nil {
Expand Down
35 changes: 20 additions & 15 deletions node/pkg/processor/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,29 +271,34 @@ func (p *Processor) signedVaaAlreadyInDB(hash string, s *state) (bool, error) {
return false, nil
}

vaaID, err := db.VaaIDFromString(s.ourObservation.MessageID())
msgId := s.ourObservation.MessageID()
vaaID, err := db.VaaIDFromString(msgId)
if err != nil {
return false, fmt.Errorf(`failed to generate VAA ID from message id "%s": %w`, s.ourObservation.MessageID(), err)
}

vb, err := p.db.GetSignedVAABytes(*vaaID)
if err != nil {
if err == db.ErrVAANotFound {
if p.logger.Level().Enabled(zapcore.DebugLevel) {
p.logger.Debug("VAA not in DB",
zap.String("message_id", s.ourObservation.MessageID()),
zap.String("digest", hash),
)
// If the VAA is waiting to be written to the DB, use that version. Otherwise use the DB.
v := p.getVaaFromUpdateMap(msgId)
if v == nil {
vb, err := p.db.GetSignedVAABytes(*vaaID)
if err != nil {
if err == db.ErrVAANotFound {
if p.logger.Level().Enabled(zapcore.DebugLevel) {
p.logger.Debug("VAA not in DB",
zap.String("message_id", s.ourObservation.MessageID()),
zap.String("digest", hash),
)
}
return false, nil
}
return false, nil
} else {

return false, fmt.Errorf(`failed to look up message id "%s" in db: %w`, s.ourObservation.MessageID(), err)
}
}

v, err := vaa.Unmarshal(vb)
if err != nil {
return false, fmt.Errorf("failed to unmarshal VAA: %w", err)
v, err = vaa.Unmarshal(vb)
if err != nil {
return false, fmt.Errorf("failed to unmarshal VAA: %w", err)
}
}

oldHash := hex.EncodeToString(v.SigningDigest().Bytes())
Expand Down
62 changes: 37 additions & 25 deletions node/pkg/processor/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package processor

import (
"encoding/hex"
"time"

"github.com/mr-tron/base58"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"

ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
Expand All @@ -26,13 +28,6 @@ var (
Help: "Total number of messages observed",
},
[]string{"emitter_chain"})

messagesSignedTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "wormhole_message_observations_signed_total",
Help: "Total number of message observations that were successfully signed",
},
[]string{"emitter_chain"})
)

// handleMessage processes a message received from a chain and instantiates our deterministic copy of the VAA. An
Expand All @@ -48,18 +43,7 @@ func (p *Processor) handleMessage(k *common.MessagePublication) {
return
}

if p.logger.Core().Enabled(zapcore.DebugLevel) {
p.logger.Debug("message publication confirmed",
zap.String("message_id", k.MessageIDString()),
zap.Uint32("nonce", k.Nonce),
zap.Stringer("txhash", k.TxHash),
zap.Time("timestamp", k.Timestamp),
)
}

messagesObservedTotal.With(prometheus.Labels{
"emitter_chain": k.EmitterChain.String(),
}).Add(1)
messagesObservedTotal.WithLabelValues(k.EmitterChain.String()).Inc()

// All nodes will create the exact same VAA and sign its digest.
// Consensus is established on this digest.
Expand All @@ -83,9 +67,10 @@ func (p *Processor) handleMessage(k *common.MessagePublication) {

// Generate digest of the unsigned VAA.
digest := v.SigningDigest()
hash := hex.EncodeToString(digest.Bytes())

// Sign the digest using our node's guardian key.
s, err := crypto.Sign(digest.Bytes(), p.gk)
signature, err := crypto.Sign(digest.Bytes(), p.gk)
if err != nil {
panic(err)
}
Expand All @@ -95,16 +80,43 @@ func (p *Processor) handleMessage(k *common.MessagePublication) {
zap.String("message_id", k.MessageIDString()),
zap.Stringer("txhash", k.TxHash),
zap.String("txhash_b58", base58.Encode(k.TxHash.Bytes())),
zap.String("digest", hex.EncodeToString(digest.Bytes())),
zap.String("hash", hash),
zap.Uint32("nonce", k.Nonce),
zap.Time("timestamp", k.Timestamp),
zap.Uint8("consistency_level", k.ConsistencyLevel),
zap.String("signature", hex.EncodeToString(s)),
zap.String("signature", hex.EncodeToString(signature)),
zap.Bool("isReobservation", k.IsReobservation),
)
}

messagesSignedTotal.With(prometheus.Labels{
"emitter_chain": k.EmitterChain.String()}).Add(1)
// Broadcast the signature.
obsv, msg := p.broadcastSignature(v.MessageID(), k.TxHash.Bytes(), digest, signature)

p.broadcastSignature(v, s, k.TxHash.Bytes())
// Get / create our state entry.
s := p.state.signatures[hash]
if s == nil {
s = &state{
firstObserved: time.Now(),
nextRetry: time.Now().Add(nextRetryDuration(0)),
signatures: map[ethCommon.Address][]byte{},
source: "loopback",
}

p.state.signatures[hash] = s
}

// Update our state.
s.ourObservation = v
s.txHash = k.TxHash.Bytes()
s.source = v.GetEmitterChain().String()
s.gs = p.gs // guaranteed to match ourObservation - there's no concurrent access to p.gs
s.signatures[p.ourAddr] = signature
s.ourMsg = msg

// Fast path for our own signature.
if !s.submitted {
start := time.Now()
p.checkForQuorum(obsv, s, s.gs, hash)
timeToHandleObservation.Observe(float64(time.Since(start).Microseconds()))
}
}
Loading
Loading