diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 60db2c939b4b..cf789ef26d9d 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -5,9 +5,11 @@ import ( "errors" "net/url" "strings" + "sync" "sync/atomic" "time" + "github.com/armon/go-metrics" "github.com/hashicorp/eventlogger" "github.com/hashicorp/eventlogger/formatter_filters/cloudevents" "github.com/hashicorp/go-hclog" @@ -17,9 +19,13 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -var ErrNotStarted = errors.New("event broker has not been started") +const defaultTimeout = 60 * time.Second -var cloudEventsFormatterFilter *cloudevents.FormatterFilter +var ( + ErrNotStarted = errors.New("event broker has not been started") + cloudEventsFormatterFilter *cloudevents.FormatterFilter + subscriptions atomic.Int64 // keeps track of event subscription count in all event buses +) // EventBus contains the main logic of running an event broker for Vault. // Start() must be called before the EventBus will accept events for sending. @@ -28,6 +34,7 @@ type EventBus struct { broker *eventlogger.Broker started atomic.Bool formatterNodeID eventlogger.NodeID + timeout time.Duration } type pluginEventBus struct { @@ -42,6 +49,13 @@ type asyncChanNode struct { ch chan *logical.EventReceived namespace *namespace.Namespace logger hclog.Logger + + // used to close the connection + closeOnce sync.Once + cancelFunc context.CancelFunc + pipelineID eventlogger.PipelineID + eventType eventlogger.EventType + broker *eventlogger.Broker } var ( @@ -79,6 +93,10 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace, Timestamp: timestamppb.New(time.Now()), } bus.logger.Info("Sending event", "event", eventReceived) + + // We can't easily know when the Send is complete, so we can't call the cancel function. + // But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long. + ctx, _ = context.WithTimeout(ctx, bus.timeout) _, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived) if err != nil { // if no listeners for this event type are registered, that's okay, the event @@ -142,6 +160,7 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { logger: logger, broker: broker, formatterNodeID: formatterNodeID, + timeout: defaultTimeout, }, nil } @@ -178,7 +197,18 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve defer cancel() return nil, nil, err } - return asyncNode.ch, cancel, nil + addSubscriptions(1) + // add info needed to cancel the subscription + asyncNode.pipelineID = eventlogger.PipelineID(pipelineID) + asyncNode.eventType = eventlogger.EventType(eventType) + asyncNode.cancelFunc = cancel + return asyncNode.ch, asyncNode.Close, nil +} + +// SetSendTimeout sets the timeout of sending events. If the events are not accepted by the +// underlying channel before this timeout, then the channel closed. +func (bus *EventBus) SetSendTimeout(timeout time.Duration) { + bus.timeout = timeout } func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode { @@ -190,8 +220,21 @@ func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hc } } +// Close tells the bus to stop sending us events. +func (node *asyncChanNode) Close() { + node.closeOnce.Do(func() { + defer node.cancelFunc() + if node.broker != nil { + err := node.broker.RemovePipeline(node.eventType, node.pipelineID) + if err != nil { + node.logger.Warn("Error removing pipeline for closing node", "error", err) + } + } + addSubscriptions(-1) + }) +} + func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) { - // TODO: add timeout on sending to node.ch // sends to the channel async in another goroutine go func() { eventRecv := e.Payload.(*logical.EventReceived) @@ -200,12 +243,17 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (* if eventRecv.Namespace != node.namespace.Path { return } + var timeout bool select { case node.ch <- eventRecv: case <-ctx.Done(): - return + timeout = errors.Is(ctx.Err(), context.DeadlineExceeded) case <-node.ctx.Done(): - return + timeout = errors.Is(node.ctx.Err(), context.DeadlineExceeded) + } + if timeout { + node.logger.Info("Subscriber took too long to process event, closing", "ID", eventRecv.Event.ID()) + node.Close() } }() return e, nil @@ -218,3 +266,7 @@ func (node *asyncChanNode) Reopen() error { func (node *asyncChanNode) Type() eventlogger.NodeType { return eventlogger.NodeTypeSink } + +func addSubscriptions(delta int64) { + metrics.SetGauge([]string{"events", "subscriptions"}, float32(subscriptions.Add(delta))) +} diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index 8275efcaf16b..94b3dd2c5ecb 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -2,6 +2,8 @@ package eventbus import ( "context" + "fmt" + "sync/atomic" "testing" "time" @@ -9,6 +11,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +// TestBusBasics tests that basic event sending and subscribing function. func TestBusBasics(t *testing.T) { bus, err := NewEventBus(nil) if err != nil { @@ -62,6 +65,7 @@ func TestBusBasics(t *testing.T) { } } +// TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus. func TestNamespaceFiltering(t *testing.T) { bus, err := NewEventBus(nil) if err != nil { @@ -121,6 +125,7 @@ func TestNamespaceFiltering(t *testing.T) { } } +// TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers. func TestBus2Subscriptions(t *testing.T) { bus, err := NewEventBus(nil) if err != nil { @@ -180,3 +185,115 @@ func TestBus2Subscriptions(t *testing.T) { t.Error("Timeout waiting for event2") } } + +// TestBusSubscriptionsCancel verifies that canceled subscriptions are cleaned up. +func TestBusSubscriptionsCancel(t *testing.T) { + testCases := []struct { + cancel bool + }{ + {cancel: true}, + {cancel: false}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) { + subscriptions.Store(0) + bus, err := NewEventBus(nil) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + if !tc.cancel { + // set the timeout very short to make the test faster if we aren't canceling explicitly + bus.SetSendTimeout(100 * time.Millisecond) + } + bus.Start() + + // create and stop a bunch of subscriptions + const create = 100 + const stop = 50 + + eventType := logical.EventType("someType") + + var channels []<-chan *logical.EventReceived + var cancels []context.CancelFunc + stopped := atomic.Int32{} + + received := atomic.Int32{} + + for i := 0; i < create; i++ { + ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancelFunc) + channels = append(channels, ch) + cancels = append(cancels, cancelFunc) + + go func(i int32) { + <-ch // always receive one message + received.Add(1) + // continue receiving messages as long as are not stopped + for i < int32(stop) { + <-ch + received.Add(1) + } + if tc.cancel { + cancelFunc() // stop explicitly to unsubscribe + } + stopped.Add(1) + }(int32(i)) + } + + // check that all channels receive a message + event, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event) + if err != nil { + t.Error(err) + } + waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create) }) + waitFor(t, 1*time.Second, func() bool { return stopped.Load() == int32(stop) }) + + // send another message, but half should stop receiving + event, err = logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event) + if err != nil { + t.Error(err) + } + waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create*2-stop) }) + // the sends should time out and the subscriptions should drop when cancelFunc is called or the context cancels + waitFor(t, 1*time.Second, func() bool { return subscriptions.Load() == int64(create-stop) }) + }) + } +} + +// waitFor waits for a condition to be true, up to the maximum timeout. +// It waits with a capped exponential backoff starting at 1ms. +// It is guaranteed to try f() at least once. +func waitFor(t *testing.T, maxWait time.Duration, f func() bool) { + t.Helper() + start := time.Now() + + if f() { + return + } + sleepAmount := 1 * time.Millisecond + for time.Now().Sub(start) <= maxWait { + left := time.Now().Sub(start) + sleepAmount = sleepAmount * 2 + if sleepAmount > left { + sleepAmount = left + } + time.Sleep(sleepAmount) + if f() { + return + } + } + t.Error("Timeout waiting for condition") +}