Skip to content

Commit

Permalink
backport of commit db822c0 (hashicorp#19204)
Browse files Browse the repository at this point in the history
Co-authored-by: Christopher Swenson <christopher.swenson@hashicorp.com>
  • Loading branch information
hc-github-team-secure-vault-core and Christopher Swenson committed Feb 15, 2023
1 parent a41a24b commit 48a7feb
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 6 deletions.
64 changes: 58 additions & 6 deletions vault/eventbus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"errors"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/armon/go-metrics"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/eventlogger/formatter_filters/cloudevents"
"github.com/hashicorp/go-hclog"
Expand All @@ -17,9 +19,13 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)

var ErrNotStarted = errors.New("event broker has not been started")
const defaultTimeout = 60 * time.Second

var cloudEventsFormatterFilter *cloudevents.FormatterFilter
var (
ErrNotStarted = errors.New("event broker has not been started")
cloudEventsFormatterFilter *cloudevents.FormatterFilter
subscriptions atomic.Int64 // keeps track of event subscription count in all event buses
)

// EventBus contains the main logic of running an event broker for Vault.
// Start() must be called before the EventBus will accept events for sending.
Expand All @@ -28,6 +34,7 @@ type EventBus struct {
broker *eventlogger.Broker
started atomic.Bool
formatterNodeID eventlogger.NodeID
timeout time.Duration
}

type pluginEventBus struct {
Expand All @@ -42,6 +49,13 @@ type asyncChanNode struct {
ch chan *logical.EventReceived
namespace *namespace.Namespace
logger hclog.Logger

// used to close the connection
closeOnce sync.Once
cancelFunc context.CancelFunc
pipelineID eventlogger.PipelineID
eventType eventlogger.EventType
broker *eventlogger.Broker
}

var (
Expand Down Expand Up @@ -79,6 +93,10 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace,
Timestamp: timestamppb.New(time.Now()),
}
bus.logger.Info("Sending event", "event", eventReceived)

// We can't easily know when the Send is complete, so we can't call the cancel function.
// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
ctx, _ = context.WithTimeout(ctx, bus.timeout)
_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived)
if err != nil {
// if no listeners for this event type are registered, that's okay, the event
Expand Down Expand Up @@ -142,6 +160,7 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
logger: logger,
broker: broker,
formatterNodeID: formatterNodeID,
timeout: defaultTimeout,
}, nil
}

Expand Down Expand Up @@ -178,7 +197,18 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve
defer cancel()
return nil, nil, err
}
return asyncNode.ch, cancel, nil
addSubscriptions(1)
// add info needed to cancel the subscription
asyncNode.pipelineID = eventlogger.PipelineID(pipelineID)
asyncNode.eventType = eventlogger.EventType(eventType)
asyncNode.cancelFunc = cancel
return asyncNode.ch, asyncNode.Close, nil
}

// SetSendTimeout sets the timeout of sending events. If the events are not accepted by the
// underlying channel before this timeout, then the channel closed.
func (bus *EventBus) SetSendTimeout(timeout time.Duration) {
bus.timeout = timeout
}

func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode {
Expand All @@ -190,8 +220,21 @@ func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hc
}
}

// Close tells the bus to stop sending us events.
func (node *asyncChanNode) Close() {
node.closeOnce.Do(func() {
defer node.cancelFunc()
if node.broker != nil {
err := node.broker.RemovePipeline(node.eventType, node.pipelineID)
if err != nil {
node.logger.Warn("Error removing pipeline for closing node", "error", err)
}
}
addSubscriptions(-1)
})
}

func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) {
// TODO: add timeout on sending to node.ch
// sends to the channel async in another goroutine
go func() {
eventRecv := e.Payload.(*logical.EventReceived)
Expand All @@ -200,12 +243,17 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*
if eventRecv.Namespace != node.namespace.Path {
return
}
var timeout bool
select {
case node.ch <- eventRecv:
case <-ctx.Done():
return
timeout = errors.Is(ctx.Err(), context.DeadlineExceeded)
case <-node.ctx.Done():
return
timeout = errors.Is(node.ctx.Err(), context.DeadlineExceeded)
}
if timeout {
node.logger.Info("Subscriber took too long to process event, closing", "ID", eventRecv.Event.ID())
node.Close()
}
}()
return e, nil
Expand All @@ -218,3 +266,7 @@ func (node *asyncChanNode) Reopen() error {
func (node *asyncChanNode) Type() eventlogger.NodeType {
return eventlogger.NodeTypeSink
}

func addSubscriptions(delta int64) {
metrics.SetGauge([]string{"events", "subscriptions"}, float32(subscriptions.Add(delta)))
}
117 changes: 117 additions & 0 deletions vault/eventbus/bus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package eventbus

import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)

// TestBusBasics tests that basic event sending and subscribing function.
func TestBusBasics(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
Expand Down Expand Up @@ -62,6 +65,7 @@ func TestBusBasics(t *testing.T) {
}
}

// TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus.
func TestNamespaceFiltering(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
Expand Down Expand Up @@ -121,6 +125,7 @@ func TestNamespaceFiltering(t *testing.T) {
}
}

// TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers.
func TestBus2Subscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
Expand Down Expand Up @@ -180,3 +185,115 @@ func TestBus2Subscriptions(t *testing.T) {
t.Error("Timeout waiting for event2")
}
}

// TestBusSubscriptionsCancel verifies that canceled subscriptions are cleaned up.
func TestBusSubscriptionsCancel(t *testing.T) {
testCases := []struct {
cancel bool
}{
{cancel: true},
{cancel: false},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) {
subscriptions.Store(0)
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if !tc.cancel {
// set the timeout very short to make the test faster if we aren't canceling explicitly
bus.SetSendTimeout(100 * time.Millisecond)
}
bus.Start()

// create and stop a bunch of subscriptions
const create = 100
const stop = 50

eventType := logical.EventType("someType")

var channels []<-chan *logical.EventReceived
var cancels []context.CancelFunc
stopped := atomic.Int32{}

received := atomic.Int32{}

for i := 0; i < create; i++ {
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
if err != nil {
t.Fatal(err)
}
t.Cleanup(cancelFunc)
channels = append(channels, ch)
cancels = append(cancels, cancelFunc)

go func(i int32) {
<-ch // always receive one message
received.Add(1)
// continue receiving messages as long as are not stopped
for i < int32(stop) {
<-ch
received.Add(1)
}
if tc.cancel {
cancelFunc() // stop explicitly to unsubscribe
}
stopped.Add(1)
}(int32(i))
}

// check that all channels receive a message
event, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create) })
waitFor(t, 1*time.Second, func() bool { return stopped.Load() == int32(stop) })

// send another message, but half should stop receiving
event, err = logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create*2-stop) })
// the sends should time out and the subscriptions should drop when cancelFunc is called or the context cancels
waitFor(t, 1*time.Second, func() bool { return subscriptions.Load() == int64(create-stop) })
})
}
}

// waitFor waits for a condition to be true, up to the maximum timeout.
// It waits with a capped exponential backoff starting at 1ms.
// It is guaranteed to try f() at least once.
func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
t.Helper()
start := time.Now()

if f() {
return
}
sleepAmount := 1 * time.Millisecond
for time.Now().Sub(start) <= maxWait {
left := time.Now().Sub(start)
sleepAmount = sleepAmount * 2
if sleepAmount > left {
sleepAmount = left
}
time.Sleep(sleepAmount)
if f() {
return
}
}
t.Error("Timeout waiting for condition")
}

0 comments on commit 48a7feb

Please sign in to comment.