Skip to content

Commit

Permalink
Renew past recorded expiry till unrecoverable error
Browse files Browse the repository at this point in the history
Keep attempting to renew Vault token past locally recorded expiry, just
in case the token was renewed out of band, e.g. on another Nomad server,
until Vault returns an unrecoverable error.
  • Loading branch information
Mahmood Ali committed Nov 20, 2018
1 parent 7f699e4 commit 7838eb0
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 13 deletions.
44 changes: 35 additions & 9 deletions nomad/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ type vaultClient struct {
// running indicates whether the vault client is started.
running bool

// renewLoopActive indicates whether the renewal goroutine is running
// It should be accessed and updated atomically
// used for testing purposes only
renewLoopActive int32

// childTTL is the TTL for child tokens.
childTTL string

Expand Down Expand Up @@ -458,9 +463,16 @@ OUTER:
v.l.Unlock()
}

func (v *vaultClient) isRenewLoopActive() bool {
return atomic.LoadInt32(&v.renewLoopActive) == 1
}

// renewalLoop runs the renew loop. This should only be called if we are given a
// non-root token.
func (v *vaultClient) renewalLoop() {
atomic.StoreInt32(&v.renewLoopActive, 1)
defer atomic.StoreInt32(&v.renewLoopActive, 0)

// Create the renewal timer and set initial duration to zero so it fires
// immediately
authRenewTimer := time.NewTimer(0)
Expand All @@ -475,7 +487,7 @@ func (v *vaultClient) renewalLoop() {
return
case <-authRenewTimer.C:
// Renew the token and determine the new expiration
err := v.renew()
recoverable, err := v.renew()
currentExpiration := v.currentExpiration

// Successfully renewed
Expand All @@ -492,7 +504,12 @@ func (v *vaultClient) renewalLoop() {
}

metrics.IncrCounter([]string{"nomad", "vault", "renew_failed"}, 1)
v.logger.Warn("got error or bad auth, so backing off", "error", err)
v.logger.Warn("got error or bad auth, so backing off", "error", err, "recoverable", recoverable)

if !recoverable {
return
}

backoff = nextBackoff(backoff, currentExpiration)
if backoff < 0 {
// We have failed to renew the token past its expiration. Stop
Expand All @@ -518,6 +535,8 @@ func (v *vaultClient) renewalLoop() {
//
// It should increase the amount of backoff each time, with the following rules:
//
// * If token expired already despite earlier renewal attempts,
// back off for 1 minute + jitter
// * If we have an existing authentication that is going to expire,
// never back off more than half of the amount of time remaining
// until expiration (with 5s floor)
Expand All @@ -527,8 +546,10 @@ func (v *vaultClient) renewalLoop() {
// at the same time
func nextBackoff(backoff float64, expiry time.Time) float64 {
maxBackoff := time.Until(expiry) / 2

if maxBackoff < 0 {
return -1
// expiry passed
return 60 * (1.0 + rand.Float64())
}

switch {
Expand All @@ -553,28 +574,33 @@ func nextBackoff(backoff float64, expiry time.Time) float64 {
}

// renew attempts to renew our Vault token. If the renewal fails, an error is
// returned. This method updates the currentExpiration time
func (v *vaultClient) renew() error {
// returned. The boolean indicates whether it's safe to attempt to renew again.
// This method updates the currentExpiration time
func (v *vaultClient) renew() (bool, error) {
// Track how long the request takes
defer metrics.MeasureSince([]string{"nomad", "vault", "renew"}, time.Now())

// Attempt to renew the token
secret, err := v.auth.RenewSelf(v.tokenData.CreationTTL)
if err != nil {
return fmt.Errorf("failed to renew the vault token: %v", err)

// Check if there is a permission denied
recoverable := !structs.VaultUnrecoverableError.MatchString(err.Error())
return recoverable, fmt.Errorf("failed to renew the vault token: %v", err)
}

// these treated as transient errors, where can keep renewing
auth := secret.Auth
if auth == nil {
return fmt.Errorf("renewal successful but not auth information returned")
return true, fmt.Errorf("renewal successful but not auth information returned")
} else if auth.LeaseDuration == 0 {
return fmt.Errorf("renewal successful but no lease duration returned")
return true, fmt.Errorf("renewal successful but no lease duration returned")
}

v.currentExpiration = time.Now().Add(time.Duration(auth.LeaseDuration) * time.Second)

v.logger.Debug("successfully renewed server token")
return nil
return true, nil
}

// getWrappingFn returns an appropriate wrapping function for Nomad Servers
Expand Down
80 changes: 76 additions & 4 deletions nomad/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package nomad
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"reflect"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"golang.org/x/time/rate"
Expand Down Expand Up @@ -562,19 +564,89 @@ func TestVaultClientRenewUpdatesExpiration(t *testing.T) {

time.Sleep(1 * time.Second)

err = client.renew()
_, err = client.renew()
require.NoError(t, err)
exp1 := client.currentExpiration
require.True(t, exp0.Before(exp1))

time.Sleep(1 * time.Second)

err = client.renew()
_, err = client.renew()
require.NoError(t, err)
exp2 := client.currentExpiration
require.True(t, exp1.Before(exp2))
}

func TestVaultClient_StopsAfterPermissionError(t *testing.T) {
t.Parallel()
v := testutil.NewTestVault(t)
defer v.Stop()

// Set the configs token in a new test role
v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 2)

// Start the client
logger := testlog.HCLogger(t)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()

time.Sleep(500 * time.Millisecond)

assert.True(t, client.isRenewLoopActive())

// Get the current TTL
a := v.Client.Auth().Token()
assert.NoError(t, a.RevokeSelf(""))

testutil.WaitForResult(func() (bool, error) {
if !client.isRenewLoopActive() {
return true, nil
} else {
return false, errors.New("renew loop should terminate after token is revoked")
}
}, func(err error) {
t.Fatalf("err: %v", err)
})
}
func TestVaultClient_LoopsUntilCannotRenew(t *testing.T) {
t.Parallel()
v := testutil.NewTestVault(t)
defer v.Stop()

// Set the configs token in a new test role
v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 5)

// Start the client
logger := testlog.HCLogger(t)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()

// Sleep 8 seconds and ensure we have a non-zero TTL
time.Sleep(8 * time.Second)

// Get the current TTL
a := v.Client.Auth().Token()
s2, err := a.Lookup(v.Config.Token)
if err != nil {
t.Fatalf("failed to lookup token: %v", err)
}

ttl := parseTTLFromLookup(s2, t)
if ttl == 0 {
t.Fatalf("token renewal failed; ttl %v", ttl)
}

if client.currentExpiration.Before(time.Now()) {
t.Fatalf("found current expiration to be in past %s", time.Until(client.currentExpiration))
}
}

func parseTTLFromLookup(s *vapi.Secret, t *testing.T) int64 {
if s == nil {
t.Fatalf("nil secret")
Expand Down Expand Up @@ -1337,8 +1409,8 @@ func TestVaultClient_nextBackoff(t *testing.T) {

t.Run("past expiry", func(t *testing.T) {
b := nextBackoff(20, time.Now().Add(-1100*time.Millisecond))
if b >= 0.0 {
t.Fatalf("Expected backoff to be negative but found %v", b)
if !(60 <= b && b <= 120) {
t.Fatalf("Expected backoff within [%v, %v] but found %v", 60, 120, b)
}
})
}

0 comments on commit 7838eb0

Please sign in to comment.