Skip to content

Commit

Permalink
Node: Minor tweaks and spy improvement (wormhole-foundation#3974)
Browse files Browse the repository at this point in the history
* Node: Minor tweaks and spy improvement

* Add tests
  • Loading branch information
bruce-riley committed Jun 10, 2024
1 parent fdd2382 commit 0e2ba62
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 107 deletions.
33 changes: 2 additions & 31 deletions node/cmd/spy/spy.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,6 @@ func runSpy(cmd *cobra.Command, args []string) {
// Outbound gossip message queue
sendC := make(chan []byte)

// Inbound observations
obsvC := make(chan *common.MsgWithTimeStamp[gossipv1.SignedObservation], 1024)

// Inbound observation requests
obsvReqC := make(chan *gossipv1.ObservationRequest, 1024)

// Inbound signed VAAs
signedInC := make(chan *gossipv1.SignedVAAWithQuorum, 1024)

Expand All @@ -370,29 +364,6 @@ func runSpy(cmd *cobra.Command, args []string) {
}
}

// Ignore observations
go func() {
for {
select {
case <-rootCtx.Done():
return
case <-obsvC:
}
}
}()

// Ignore observation requests
// Note: without this, the whole program hangs on observation requests
go func() {
for {
select {
case <-rootCtx.Done():
return
case <-obsvReqC:
}
}
}()

// Log signed VAAs
go func() {
for {
Expand Down Expand Up @@ -422,8 +393,8 @@ func runSpy(cmd *cobra.Command, args []string) {
components.Port = *p2pPort
if err := supervisor.Run(ctx,
"p2p",
p2p.Run(obsvC,
obsvReqC,
p2p.Run(nil, // Ignore incoming observations.
nil, // Ignore observation requests.
nil,
sendC,
signedInC,
Expand Down
17 changes: 1 addition & 16 deletions node/pkg/accountant/submit_obs.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (acct *Accountant) handleBatch(ctx context.Context, subChan chan *common.Me
ctx, cancel := context.WithTimeout(ctx, delayInMS)
defer cancel()

msgs, err := readFromChannel[*common.MessagePublication](ctx, subChan, batchSize)
msgs, err := common.ReadFromChannelWithTimeout[*common.MessagePublication](ctx, subChan, batchSize)
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("failed to read messages from channel for %s: %w", tag, err)
}
Expand Down Expand Up @@ -95,21 +95,6 @@ func (acct *Accountant) handleBatch(ctx context.Context, subChan chan *common.Me
return nil
}

// readFromChannel reads events from the channel until a timeout occurs or the batch is full, and returns them.
func readFromChannel[T any](ctx context.Context, ch <-chan T, count int) ([]T, error) {
out := make([]T, 0, count)
for len(out) < count {
select {
case <-ctx.Done():
return out, ctx.Err()
case msg := <-ch:
out = append(out, msg)
}
}

return out, nil
}

// removeCompleted drops any messages that are no longer in the pending transfer map. This is to handle the case where the contract reports
// that a transfer is committed while it is in the channel. There is no point in submitting the observation once the transfer is committed.
func (acct *Accountant) removeCompleted(msgs []*common.MessagePublication) []*common.MessagePublication {
Expand Down
20 changes: 20 additions & 0 deletions node/pkg/common/channel_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package common

import (
"context"
)

// ReadFromChannelWithTimeout reads events from the channel until a timeout occurs or the max maxCount is reached.
func ReadFromChannelWithTimeout[T any](ctx context.Context, ch <-chan T, maxCount int) ([]T, error) {
out := make([]T, 0, maxCount)
for len(out) < maxCount {
select {
case <-ctx.Done():
return out, ctx.Err()
case msg := <-ch:
out = append(out, msg)
}
}

return out, nil
}
80 changes: 80 additions & 0 deletions node/pkg/common/channel_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package common

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const myDelay = time.Millisecond * 100
const myMaxSize = 2
const myQueueSize = myMaxSize * 10

func TestReadFromChannelWithTimeout_NoData(t *testing.T) {
ctx := context.Background()
myChan := make(chan int, myQueueSize)

// No data should timeout.
timeout, cancel := context.WithTimeout(ctx, myDelay)
defer cancel()
observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
assert.Equal(t, err, context.DeadlineExceeded)
assert.Equal(t, 0, len(observations))
}

func TestReadFromChannelWithTimeout_SomeData(t *testing.T) {
ctx := context.Background()
myChan := make(chan int, myQueueSize)
myChan <- 1

// Some data but not enough to fill a message should timeout and return the data.
timeout, cancel := context.WithTimeout(ctx, myDelay)
defer cancel()
observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
assert.Equal(t, err, context.DeadlineExceeded)
require.Equal(t, 1, len(observations))
assert.Equal(t, 1, observations[0])
}

func TestReadFromChannelWithTimeout_JustEnoughData(t *testing.T) {
ctx := context.Background()
myChan := make(chan int, myQueueSize)
myChan <- 1
myChan <- 2

// Just enough data should return the data and no error.
timeout, cancel := context.WithTimeout(ctx, myDelay)
defer cancel()
observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
assert.NoError(t, err)
require.Equal(t, 2, len(observations))
assert.Equal(t, 1, observations[0])
assert.Equal(t, 2, observations[1])
}

func TestReadFromChannelWithTimeout_TooMuchData(t *testing.T) {
ctx := context.Background()
myChan := make(chan int, myQueueSize)
myChan <- 1
myChan <- 2
myChan <- 3

// If there is more data than will fit, it should immediately return a full message, then timeout and return the remainder.
timeout, cancel := context.WithTimeout(ctx, myDelay)
defer cancel()
observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
assert.NoError(t, err)
require.Equal(t, 2, len(observations))
assert.Equal(t, 1, observations[0])
assert.Equal(t, 2, observations[1])

timeout2, cancel2 := context.WithTimeout(ctx, myDelay)
defer cancel2()
observations, err = ReadFromChannelWithTimeout[int](timeout2, myChan, myMaxSize)
assert.Equal(t, err, context.DeadlineExceeded)
require.Equal(t, 1, len(observations))
assert.Equal(t, 3, observations[0])
}
11 changes: 8 additions & 3 deletions node/pkg/common/guardianset.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,20 @@ type GuardianSet struct {
// On-chain set index
Index uint32

// Quorum value for this set of keys
Quorum int
// quorum value for this set of keys
quorum int

// A map from address to index. Testing showed that, on average, a map is almost three times faster than a sequential search of the key slice.
// Testing also showed that the map was twice as fast as using a sorted slice and `slices.BinarySearchFunc`. That being said, on a 4GHz CPU,
// the sequential search takes an average of 800 nanos and the map look up takes about 260 nanos. Is this worth doing?
keyMap map[common.Address]int
}

// Quorum returns the current quorum value.
func (gs *GuardianSet) Quorum() int {
return gs.quorum
}

func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet {
keyMap := map[common.Address]int{}
for idx, key := range keys {
Expand All @@ -71,7 +76,7 @@ func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet {
return &GuardianSet{
Keys: keys,
Index: index,
Quorum: vaa.CalculateQuorum(len(keys)),
quorum: vaa.CalculateQuorum(len(keys)),
keyMap: keyMap,
}
}
Expand Down
2 changes: 1 addition & 1 deletion node/pkg/common/guardianset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestNewGuardianSet(t *testing.T) {
gs := NewGuardianSet(keys, 1)
assert.True(t, reflect.DeepEqual(keys, gs.Keys))
assert.Equal(t, uint32(1), gs.Index)
assert.Equal(t, vaa.CalculateQuorum(len(keys)), gs.Quorum)
assert.Equal(t, vaa.CalculateQuorum(len(keys)), gs.Quorum())
}

func TestKeyIndex(t *testing.T) {
Expand Down
98 changes: 53 additions & 45 deletions node/pkg/p2p/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,9 @@ func Run(
}

// Send to local observation request queue (the loopback message is ignored)
obsvReqC <- msg
if obsvReqC != nil {
obsvReqC <- msg
}

err = th.Publish(ctx, b)
p2pMessagesSent.Inc()
Expand Down Expand Up @@ -699,59 +701,65 @@ func Run(
}()
}
case *gossipv1.GossipMessage_SignedObservation:
if err := common.PostMsgWithTimestamp[gossipv1.SignedObservation](m.SignedObservation, obsvC); err == nil {
p2pMessagesReceived.WithLabelValues("observation").Inc()
} else {
if components.WarnChannelOverflow {
logger.Warn("Ignoring SignedObservation because obsvC full", zap.String("hash", hex.EncodeToString(m.SignedObservation.Hash)))
if obsvC != nil {
if err := common.PostMsgWithTimestamp[gossipv1.SignedObservation](m.SignedObservation, obsvC); err == nil {
p2pMessagesReceived.WithLabelValues("observation").Inc()
} else {
if components.WarnChannelOverflow {
logger.Warn("Ignoring SignedObservation because obsvC full", zap.String("hash", hex.EncodeToString(m.SignedObservation.Hash)))
}
p2pReceiveChannelOverflow.WithLabelValues("observation").Inc()
}
p2pReceiveChannelOverflow.WithLabelValues("observation").Inc()
}
case *gossipv1.GossipMessage_SignedVaaWithQuorum:
select {
case signedInC <- m.SignedVaaWithQuorum:
p2pMessagesReceived.WithLabelValues("signed_vaa_with_quorum").Inc()
default:
if components.WarnChannelOverflow {
// TODO do not log this in production
var hexStr string
if vaa, err := vaa.Unmarshal(m.SignedVaaWithQuorum.Vaa); err == nil {
hexStr = vaa.HexDigest()
if signedInC != nil {
select {
case signedInC <- m.SignedVaaWithQuorum:
p2pMessagesReceived.WithLabelValues("signed_vaa_with_quorum").Inc()
default:
if components.WarnChannelOverflow {
// TODO do not log this in production
var hexStr string
if vaa, err := vaa.Unmarshal(m.SignedVaaWithQuorum.Vaa); err == nil {
hexStr = vaa.HexDigest()
}
logger.Warn("Ignoring SignedVaaWithQuorum because signedInC full", zap.String("hash", hexStr))
}
logger.Warn("Ignoring SignedVaaWithQuorum because signedInC full", zap.String("hash", hexStr))
p2pReceiveChannelOverflow.WithLabelValues("signed_vaa_with_quorum").Inc()
}
p2pReceiveChannelOverflow.WithLabelValues("signed_vaa_with_quorum").Inc()
}
case *gossipv1.GossipMessage_SignedObservationRequest:
s := m.SignedObservationRequest
gs := gst.Get()
if gs == nil {
if logger.Level().Enabled(zapcore.DebugLevel) {
logger.Debug("dropping SignedObservationRequest - no guardian set", zap.Any("value", s), zap.String("from", envelope.GetFrom().String()))
}
break
}
r, err := processSignedObservationRequest(s, gs)
if err != nil {
p2pMessagesReceived.WithLabelValues("invalid_signed_observation_request").Inc()
if logger.Level().Enabled(zapcore.DebugLevel) {
logger.Debug("invalid signed observation request received",
zap.Error(err),
zap.Any("payload", msg.Message),
zap.Any("value", s),
zap.Binary("raw", envelope.Data),
zap.String("from", envelope.GetFrom().String()))
}
} else {
if logger.Level().Enabled(zapcore.DebugLevel) {
logger.Debug("valid signed observation request received", zap.Any("value", r), zap.String("from", envelope.GetFrom().String()))
if obsvReqC != nil {
s := m.SignedObservationRequest
gs := gst.Get()
if gs == nil {
if logger.Level().Enabled(zapcore.DebugLevel) {
logger.Debug("dropping SignedObservationRequest - no guardian set", zap.Any("value", s), zap.String("from", envelope.GetFrom().String()))
}
break
}
r, err := processSignedObservationRequest(s, gs)
if err != nil {
p2pMessagesReceived.WithLabelValues("invalid_signed_observation_request").Inc()
if logger.Level().Enabled(zapcore.DebugLevel) {
logger.Debug("invalid signed observation request received",
zap.Error(err),
zap.Any("payload", msg.Message),
zap.Any("value", s),
zap.Binary("raw", envelope.Data),
zap.String("from", envelope.GetFrom().String()))
}
} else {
if logger.Level().Enabled(zapcore.DebugLevel) {
logger.Debug("valid signed observation request received", zap.Any("value", r), zap.String("from", envelope.GetFrom().String()))
}

select {
case obsvReqC <- r:
p2pMessagesReceived.WithLabelValues("signed_observation_request").Inc()
default:
p2pReceiveChannelOverflow.WithLabelValues("signed_observation_request").Inc()
select {
case obsvReqC <- r:
p2pMessagesReceived.WithLabelValues("signed_observation_request").Inc()
default:
p2pReceiveChannelOverflow.WithLabelValues("signed_observation_request").Inc()
}
}
}
case *gossipv1.GossipMessage_SignedChainGovernorConfig:
Expand Down
3 changes: 1 addition & 2 deletions node/pkg/processor/broadcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"

ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"google.golang.org/protobuf/proto"

node_common "github.com/certusone/wormhole/node/pkg/common"
Expand Down Expand Up @@ -43,7 +42,7 @@ func (p *Processor) broadcastSignature(
) {
digest := o.SigningDigest()
obsv := gossipv1.SignedObservation{
Addr: crypto.PubkeyToAddress(p.gk.PublicKey).Bytes(),
Addr: p.ourAddr.Bytes(),
Hash: digest.Bytes(),
Signature: signature,
TxHash: txhash,
Expand Down
8 changes: 4 additions & 4 deletions node/pkg/processor/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (p *Processor) handleCleanup(ctx context.Context) {
}

hasSigs := len(s.signatures)
quorum := hasSigs >= gs.Quorum
quorum := hasSigs >= gs.Quorum()

var chain vaa.ChainID
if s.ourObservation != nil {
Expand All @@ -128,7 +128,7 @@ func (p *Processor) handleCleanup(ctx context.Context) {
zap.String("digest", hash),
zap.Duration("delta", delta),
zap.Int("have_sigs", hasSigs),
zap.Int("required_sigs", gs.Quorum),
zap.Int("required_sigs", gs.Quorum()),
zap.Bool("quorum", quorum),
zap.Stringer("emitter_chain", chain),
)
Expand Down Expand Up @@ -245,8 +245,8 @@ func (p *Processor) handleCleanup(ctx context.Context) {
zap.String("digest", hash),
zap.Duration("delta", delta),
zap.Int("have_sigs", hasSigs),
zap.Int("required_sigs", p.gs.Quorum),
zap.Bool("quorum", hasSigs >= p.gs.Quorum),
zap.Int("required_sigs", p.gs.Quorum()),
zap.Bool("quorum", hasSigs >= p.gs.Quorum()),
)
}
delete(p.state.signatures, hash)
Expand Down
Loading

0 comments on commit 0e2ba62

Please sign in to comment.