diff --git a/nomad/vault.go b/nomad/vault.go index d618f5badd76..bb8a540b754a 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -207,6 +207,11 @@ type vaultClient struct { // running indicates whether the vault client is started. running bool + // renewLoopActive indicates whether the renewal goroutine is running + // It should be accessed and updated atomically + // used for testing purposes only + renewLoopActive int32 + // childTTL is the TTL for child tokens. childTTL string @@ -458,9 +463,16 @@ OUTER: v.l.Unlock() } +func (v *vaultClient) isRenewLoopActive() bool { + return atomic.LoadInt32(&v.renewLoopActive) == 1 +} + // renewalLoop runs the renew loop. This should only be called if we are given a // non-root token. func (v *vaultClient) renewalLoop() { + atomic.StoreInt32(&v.renewLoopActive, 1) + defer atomic.StoreInt32(&v.renewLoopActive, 0) + // Create the renewal timer and set initial duration to zero so it fires // immediately authRenewTimer := time.NewTimer(0) @@ -475,7 +487,7 @@ func (v *vaultClient) renewalLoop() { return case <-authRenewTimer.C: // Renew the token and determine the new expiration - err := v.renew() + recoverable, err := v.renew() currentExpiration := v.currentExpiration // Successfully renewed @@ -492,7 +504,12 @@ func (v *vaultClient) renewalLoop() { } metrics.IncrCounter([]string{"nomad", "vault", "renew_failed"}, 1) - v.logger.Warn("got error or bad auth, so backing off", "error", err) + v.logger.Warn("got error or bad auth, so backing off", "error", err, "recoverable", recoverable) + + if !recoverable { + return + } + backoff = nextBackoff(backoff, currentExpiration) if backoff < 0 { // We have failed to renew the token past its expiration. Stop @@ -518,6 +535,8 @@ func (v *vaultClient) renewalLoop() { // // It should increase the amount of backoff each time, with the following rules: // +// * If token expired already despite earlier renewal attempts, +// back off for 1 minute + jitter // * If we have an existing authentication that is going to expire, // never back off more than half of the amount of time remaining // until expiration (with 5s floor) @@ -527,8 +546,10 @@ func (v *vaultClient) renewalLoop() { // at the same time func nextBackoff(backoff float64, expiry time.Time) float64 { maxBackoff := time.Until(expiry) / 2 + if maxBackoff < 0 { - return -1 + // expiry passed + return 60 * (1.0 + rand.Float64()) } switch { @@ -553,28 +574,33 @@ func nextBackoff(backoff float64, expiry time.Time) float64 { } // renew attempts to renew our Vault token. If the renewal fails, an error is -// returned. This method updates the currentExpiration time -func (v *vaultClient) renew() error { +// returned. The boolean indicates whether it's safe to attempt to renew again. +// This method updates the currentExpiration time +func (v *vaultClient) renew() (bool, error) { // Track how long the request takes defer metrics.MeasureSince([]string{"nomad", "vault", "renew"}, time.Now()) // Attempt to renew the token secret, err := v.auth.RenewSelf(v.tokenData.CreationTTL) if err != nil { - return fmt.Errorf("failed to renew the vault token: %v", err) + + // Check if there is a permission denied + recoverable := !structs.VaultUnrecoverableError.MatchString(err.Error()) + return recoverable, fmt.Errorf("failed to renew the vault token: %v", err) } + // these treated as transient errors, where can keep renewing auth := secret.Auth if auth == nil { - return fmt.Errorf("renewal successful but not auth information returned") + return true, fmt.Errorf("renewal successful but not auth information returned") } else if auth.LeaseDuration == 0 { - return fmt.Errorf("renewal successful but no lease duration returned") + return true, fmt.Errorf("renewal successful but no lease duration returned") } v.currentExpiration = time.Now().Add(time.Duration(auth.LeaseDuration) * time.Second) v.logger.Debug("successfully renewed server token") - return nil + return true, nil } // getWrappingFn returns an appropriate wrapping function for Nomad Servers diff --git a/nomad/vault_test.go b/nomad/vault_test.go index 67dcc0b1d315..089efca3e8db 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -3,6 +3,7 @@ package nomad import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "reflect" @@ -10,6 +11,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/time/rate" @@ -562,19 +564,89 @@ func TestVaultClientRenewUpdatesExpiration(t *testing.T) { time.Sleep(1 * time.Second) - err = client.renew() + _, err = client.renew() require.NoError(t, err) exp1 := client.currentExpiration require.True(t, exp0.Before(exp1)) time.Sleep(1 * time.Second) - err = client.renew() + _, err = client.renew() require.NoError(t, err) exp2 := client.currentExpiration require.True(t, exp1.Before(exp2)) } +func TestVaultClient_StopsAfterPermissionError(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 2) + + // Start the client + logger := testlog.HCLogger(t) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + time.Sleep(500 * time.Millisecond) + + assert.True(t, client.isRenewLoopActive()) + + // Get the current TTL + a := v.Client.Auth().Token() + assert.NoError(t, a.RevokeSelf("")) + + testutil.WaitForResult(func() (bool, error) { + if !client.isRenewLoopActive() { + return true, nil + } else { + return false, errors.New("renew loop should terminate after token is revoked") + } + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} +func TestVaultClient_LoopsUntilCannotRenew(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 5) + + // Start the client + logger := testlog.HCLogger(t) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + // Sleep 8 seconds and ensure we have a non-zero TTL + time.Sleep(8 * time.Second) + + // Get the current TTL + a := v.Client.Auth().Token() + s2, err := a.Lookup(v.Config.Token) + if err != nil { + t.Fatalf("failed to lookup token: %v", err) + } + + ttl := parseTTLFromLookup(s2, t) + if ttl == 0 { + t.Fatalf("token renewal failed; ttl %v", ttl) + } + + if client.currentExpiration.Before(time.Now()) { + t.Fatalf("found current expiration to be in past %s", time.Until(client.currentExpiration)) + } +} + func parseTTLFromLookup(s *vapi.Secret, t *testing.T) int64 { if s == nil { t.Fatalf("nil secret") @@ -1337,8 +1409,8 @@ func TestVaultClient_nextBackoff(t *testing.T) { t.Run("past expiry", func(t *testing.T) { b := nextBackoff(20, time.Now().Add(-1100*time.Millisecond)) - if b >= 0.0 { - t.Fatalf("Expected backoff to be negative but found %v", b) + if !(60 <= b && b <= 120) { + t.Fatalf("Expected backoff within [%v, %v] but found %v", 60, 120, b) } }) }