diff --git a/app/app_test.go b/app/app_test.go index 562e43e33e..fa7783d9b3 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -1,5 +1,3 @@ -//go:build all || race - package app import ( @@ -123,7 +121,7 @@ func newStartedApp( var err error if peers == nil { - peers, err = peer.NewPeers(context.Background(), c) + peers, err = peer.NewPeers(context.Background(), c, make(chan struct{})) assert.NoError(t, err) } diff --git a/cmd/refinery/main.go b/cmd/refinery/main.go index 0690d96509..a619468d13 100644 --- a/cmd/refinery/main.go +++ b/cmd/refinery/main.go @@ -106,7 +106,8 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), c.GetPeerTimeout()) defer cancel() - peers, err := peer.NewPeers(ctx, c) + done := make(chan struct{}) + peers, err := peer.NewPeers(ctx, c, done) if err != nil { fmt.Printf("unable to load peers: %+v\n", err) @@ -226,5 +227,8 @@ func main() { // block on our signal handler to exit sig := <-sigsToExit + // unregister ourselves before we go + close(done) + time.Sleep(100 * time.Millisecond) a.Logger.Error().Logf("Caught signal \"%s\"", sig) } diff --git a/config/config_test_reload_error_test.go b/config/config_test_reload_error_test.go index 307166b27a..ca7e00c64a 100644 --- a/config/config_test_reload_error_test.go +++ b/config/config_test_reload_error_test.go @@ -1,5 +1,3 @@ -//go:build all || !race - package config import ( diff --git a/internal/peer/peers.go b/internal/peer/peers.go index c17000ef1d..ed84ab7764 100644 --- a/internal/peer/peers.go +++ b/internal/peer/peers.go @@ -3,6 +3,7 @@ package peer import ( "context" "errors" + "github.com/honeycombio/refinery/config" ) @@ -13,7 +14,7 @@ type Peers interface { RegisterUpdatedPeersCallback(callback func()) } -func NewPeers(ctx context.Context, c config.Config) (Peers, error) { +func NewPeers(ctx context.Context, c config.Config, done chan struct{}) (Peers, error) { t, err := c.GetPeerManagementType() if err != nil { @@ -24,7 +25,7 @@ func NewPeers(ctx context.Context, c config.Config) (Peers, error) { case "file": return newFilePeers(c), nil case "redis": - return newRedisPeers(ctx, c) + return newRedisPeers(ctx, c, done) default: return nil, errors.New("invalid config option 'PeerManagement.Type'") } diff --git a/internal/peer/peers_test.go b/internal/peer/peers_test.go index 5ec7f8137a..accaad1097 100644 --- a/internal/peer/peers_test.go +++ b/internal/peer/peers_test.go @@ -2,6 +2,7 @@ package peer import ( "context" + "strings" "testing" "time" @@ -16,7 +17,9 @@ func TestNewPeers(t *testing.T) { PeerTimeout: 5 * time.Second, } - p, err := NewPeers(context.Background(), c) + done := make(chan struct{}) + defer close(done) + p, err := NewPeers(context.Background(), c, done) assert.NoError(t, err) require.NotNil(t, p) @@ -32,7 +35,7 @@ func TestNewPeers(t *testing.T) { PeerTimeout: 5 * time.Second, } - p, err = NewPeers(context.Background(), c) + p, err = NewPeers(context.Background(), c, done) assert.NoError(t, err) require.NotNil(t, p) @@ -42,3 +45,31 @@ func TestNewPeers(t *testing.T) { t.Errorf("received %T expected %T", i, &redisPeers{}) } } + +func TestPeerShutdown(t *testing.T) { + c := &config.MockConfig{ + GetPeerListenAddrVal: "0.0.0.0:8081", + PeerManagementType: "redis", + PeerTimeout: 5 * time.Second, + } + + done := make(chan struct{}) + p, err := NewPeers(context.Background(), c, done) + assert.NoError(t, err) + require.NotNil(t, p) + + peer, ok := p.(*redisPeers) + assert.True(t, ok) + + peers, err := peer.GetPeers() + assert.NoError(t, err) + assert.Equal(t, 1, len(peers)) + assert.True(t, strings.HasPrefix(peers[0], "http")) + assert.True(t, strings.HasSuffix(peers[0], "8081")) + + close(done) + time.Sleep(100 * time.Millisecond) + peers, err = peer.GetPeers() + assert.NoError(t, err) + assert.Equal(t, 0, len(peers)) +} diff --git a/internal/peer/redis.go b/internal/peer/redis.go index 4d7be37715..405dc7f95a 100644 --- a/internal/peer/redis.go +++ b/internal/peer/redis.go @@ -45,7 +45,7 @@ type redisPeers struct { } // NewRedisPeers returns a peers collection backed by redis -func newRedisPeers(ctx context.Context, c config.Config) (Peers, error) { +func newRedisPeers(ctx context.Context, c config.Config, done chan struct{}) (Peers, error) { redisHost, _ := c.GetRedisHost() if redisHost == "" { @@ -108,7 +108,7 @@ func newRedisPeers(ctx context.Context, c config.Config) (Peers, error) { } // go establish a regular registration heartbeat to ensure I stay alive in redis - go peers.registerSelf() + go peers.registerSelf(done) // get our peer list once to seed ourselves peers.updatePeerListOnce() @@ -116,7 +116,7 @@ func newRedisPeers(ctx context.Context, c config.Config) (Peers, error) { // go watch the list of peers and trigger callbacks whenever it changes. // populate my local list of peers so each request can hit memory and only hit // redis on a ticker - go peers.watchPeers() + go peers.watchPeers(done) return peers, nil } @@ -135,15 +135,24 @@ func (p *redisPeers) RegisterUpdatedPeersCallback(cb func()) { // registerSelf inserts self into the peer list and updates self's entry on a // regular basis so it doesn't time out and get removed from the list of peers. -// If this function stops, this host will get ejected from other's peer lists. -func (p *redisPeers) registerSelf() { +// When this function stops, it tries to remove the registered key. +func (p *redisPeers) registerSelf(done chan struct{}) { tk := time.NewTicker(refreshCacheInterval) - for range tk.C { - ctx, cancel := context.WithTimeout(context.Background(), p.c.GetPeerTimeout()) - // every 5 seconds, insert a 30sec timeout record. we ignore the error - // here since Register() logs the error for us. - p.store.Register(ctx, p.publicAddr, peerEntryTimeout) - cancel() + for { + select { + case <-tk.C: + ctx, cancel := context.WithTimeout(context.Background(), p.c.GetPeerTimeout()) + // every interval, insert a timeout record. we ignore the error + // here since Register() logs the error for us. + p.store.Register(ctx, p.publicAddr, peerEntryTimeout) + cancel() + case <-done: + // unregister ourselves + ctx, cancel := context.WithTimeout(context.Background(), p.c.GetPeerTimeout()) + p.store.Unregister(ctx, p.publicAddr) + cancel() + return + } } } @@ -168,38 +177,46 @@ func (p *redisPeers) updatePeerListOnce() { p.peerLock.Unlock() } -func (p *redisPeers) watchPeers() { +func (p *redisPeers) watchPeers(done chan struct{}) { oldPeerList := p.peers sort.Strings(oldPeerList) tk := time.NewTicker(refreshCacheInterval) - for range tk.C { - ctx, cancel := context.WithTimeout(context.Background(), p.c.GetPeerTimeout()) - currentPeers, err := p.store.GetMembers(ctx) - cancel() - - if err != nil { - logrus.WithError(err). - WithFields(logrus.Fields{ - "name": p.publicAddr, - "timeout": p.c.GetPeerTimeout().String(), - "oldPeers": oldPeerList, - }). - Error("get members failed during watch") - continue - } + for { + select { + case <-tk.C: + ctx, cancel := context.WithTimeout(context.Background(), p.c.GetPeerTimeout()) + currentPeers, err := p.store.GetMembers(ctx) + cancel() + + if err != nil { + logrus.WithError(err). + WithFields(logrus.Fields{ + "name": p.publicAddr, + "timeout": p.c.GetPeerTimeout().String(), + "oldPeers": oldPeerList, + }). + Error("get members failed during watch") + continue + } - sort.Strings(currentPeers) - if !equal(oldPeerList, currentPeers) { - // update peer list and trigger callbacks saying the peer list has changed + sort.Strings(currentPeers) + if !equal(oldPeerList, currentPeers) { + // update peer list and trigger callbacks saying the peer list has changed + p.peerLock.Lock() + p.peers = currentPeers + oldPeerList = currentPeers + p.peerLock.Unlock() + for _, callback := range p.callbacks { + // don't block on any of the callbacks. + go callback() + } + } + case <-done: p.peerLock.Lock() - p.peers = currentPeers - oldPeerList = currentPeers + p.peers = []string{} p.peerLock.Unlock() - for _, callback := range p.callbacks { - // don't block on any of the callbacks. - go callback() - } + return } } } diff --git a/internal/redimem/redimem.go b/internal/redimem/redimem.go index 4def176a5a..ded96ed16d 100644 --- a/internal/redimem/redimem.go +++ b/internal/redimem/redimem.go @@ -20,6 +20,10 @@ type Membership interface { // in order to remain a member of the group. Register(ctx context.Context, memberName string, timeout time.Duration) error + // Unregister removes a name from the list immediately. It's intended to be + // used during shutdown so that there's no delay in the case of deliberate downsizing. + Unregister(ctx context.Context, memberName string) error + // GetMembers retrieves the list of all currently registered members. Members // that have registered but timed out will not be returned. GetMembers(ctx context.Context) ([]string, error) @@ -87,6 +91,27 @@ func (rm *RedisMembership) Register(ctx context.Context, memberName string, time return nil } +func (rm *RedisMembership) Unregister(ctx context.Context, memberName string) error { + err := rm.validateDefaults() + if err != nil { + return err + } + key := fmt.Sprintf("%s•%s•%s", globalPrefix, rm.Prefix, memberName) + conn, err := rm.Pool.GetContext(ctx) + if err != nil { + return err + } + defer conn.Close() + _, err = conn.Do("DEL", key) + if err != nil { + logrus.WithField("name", memberName). + WithField("err", err). + Error("unregistration failed") + return err + } + return nil +} + // GetMembers reaches out to Redis to retrieve a list of all members in the // cluster. It does this multiple times (how many is configured on // initializition) and takes the union of the results returned. @@ -189,10 +214,8 @@ func (rm *RedisMembership) scan(conn redis.Conn, pattern, count string, timeout break } - if keys != nil { - for _, key := range keys { - keyChan <- key - } + for _, key := range keys { + keyChan <- key } // redis will return 0 when we have iterated over the entire set diff --git a/sample/rules_test.go b/sample/rules_test.go index 61b1aaec6c..ea1dc166fa 100644 --- a/sample/rules_test.go +++ b/sample/rules_test.go @@ -538,6 +538,59 @@ func TestRules(t *testing.T) { ExpectedKeep: true, ExpectedRate: 1, }, + { + Rules: &config.RulesBasedSamplerConfig{ + Rule: []*config.RulesBasedSamplerRule{ + { + Name: "Check root span for span count", + Drop: true, + SampleRate: 0, + Condition: []*config.RulesBasedSamplerCondition{ + { + Field: "meta.span_count", + Operator: ">=", + Value: int(2), + }, + }, + }, + }, + }, + Spans: []*types.Span{ + { + Event: types.Event{ + Data: map[string]interface{}{ + "trace.trace_id": "12345", + "trace.span_id": "54321", + "meta.span_count": int64(2), + "test": int64(2), + }, + }, + }, + { + Event: types.Event{ + Data: map[string]interface{}{ + "trace.trace_id": "12345", + "trace.span_id": "654321", + "trace.parent_id": "54321", + "test": int64(2), + }, + }, + }, + { + Event: types.Event{ + Data: map[string]interface{}{ + "trace.trace_id": "12345", + "trace.span_id": "754321", + "trace.parent_id": "54321", + "test": int64(3), + }, + }, + }, + }, + ExpectedName: "Check root span for span count", + ExpectedKeep: false, + ExpectedRate: 0, + }, } for _, d := range data { diff --git a/sharder/deterministic_test.go b/sharder/deterministic_test.go index 8d32287697..828252341b 100644 --- a/sharder/deterministic_test.go +++ b/sharder/deterministic_test.go @@ -26,7 +26,9 @@ func TestWhichShard(t *testing.T) { GetPeersVal: peers, PeerManagementType: "file", } - filePeers, err := peer.NewPeers(context.Background(), config) + done := make(chan struct{}) + defer close(done) + filePeers, err := peer.NewPeers(context.Background(), config, done) assert.Equal(t, nil, err) sharder := DeterministicSharder{ Config: config, @@ -67,7 +69,9 @@ func TestWhichShardAtEdge(t *testing.T) { GetPeersVal: peers, PeerManagementType: "file", } - filePeers, err := peer.NewPeers(context.Background(), config) + done := make(chan struct{}) + defer close(done) + filePeers, err := peer.NewPeers(context.Background(), config, done) assert.Equal(t, nil, err) sharder := DeterministicSharder{ Config: config,