diff --git a/client/client.go b/client/client.go index dfc19a25c936..8028b75c86f8 100644 --- a/client/client.go +++ b/client/client.go @@ -1714,6 +1714,10 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err) return nil, fmt.Errorf("failed to derive vault tokens: %v", err) } + if resp.Error != nil { + c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", resp.Error) + return nil, resp.Error + } if resp.Tasks == nil { c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response") return nil, fmt.Errorf("failed to derive vault tokens: invalid response") diff --git a/client/driver/docker.go b/client/driver/docker.go index ffc05fe7e6bc..19902d556c7d 100644 --- a/client/driver/docker.go +++ b/client/driver/docker.go @@ -629,7 +629,7 @@ func (d *DockerDriver) recoverablePullError(err error, image string) error { if imageNotFoundMatcher.MatchString(err.Error()) { recoverable = false } - return dstructs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable) + return structs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable) } func (d *DockerDriver) Periodic() (bool, time.Duration) { diff --git a/client/driver/structs/structs.go b/client/driver/structs/structs.go index 7714d5ac597a..0fa67ff2c06d 100644 --- a/client/driver/structs/structs.go +++ b/client/driver/structs/structs.go @@ -37,26 +37,6 @@ func (r *WaitResult) String() string { r.ExitCode, r.Signal, r.Err) } -// RecoverableError wraps an error and marks whether it is recoverable and could -// be retried or it is fatal. -type RecoverableError struct { - Err error - Recoverable bool -} - -// NewRecoverableError is used to wrap an error and mark it as recoverable or -// not. -func NewRecoverableError(e error, recoverable bool) *RecoverableError { - return &RecoverableError{ - Err: e, - Recoverable: recoverable, - } -} - -func (r *RecoverableError) Error() string { - return r.Err.Error() -} - // CheckResult encapsulates the result of a check type CheckResult struct { diff --git a/client/restarts.go b/client/restarts.go index e80d801721e7..2c52cd1c9001 100644 --- a/client/restarts.go +++ b/client/restarts.go @@ -6,7 +6,7 @@ import ( "sync" "time" - cstructs "github.com/hashicorp/nomad/client/driver/structs" + dstructs "github.com/hashicorp/nomad/client/driver/structs" "github.com/hashicorp/nomad/nomad/structs" ) @@ -34,7 +34,7 @@ func newRestartTracker(policy *structs.RestartPolicy, jobType string) *RestartTr } type RestartTracker struct { - waitRes *cstructs.WaitResult + waitRes *dstructs.WaitResult startErr error restartTriggered bool // Whether the task has been signalled to be restarted count int // Current number of attempts. @@ -63,7 +63,7 @@ func (r *RestartTracker) SetStartError(err error) *RestartTracker { } // SetWaitResult is used to mark the most recent wait result. -func (r *RestartTracker) SetWaitResult(res *cstructs.WaitResult) *RestartTracker { +func (r *RestartTracker) SetWaitResult(res *dstructs.WaitResult) *RestartTracker { r.lock.Lock() defer r.lock.Unlock() r.waitRes = res @@ -149,7 +149,7 @@ func (r *RestartTracker) GetState() (string, time.Duration) { // infinitely try to start a task. func (r *RestartTracker) handleStartError() (string, time.Duration) { // If the error is not recoverable, do not restart. - if rerr, ok := r.startErr.(*cstructs.RecoverableError); !(ok && rerr.Recoverable) { + if rerr, ok := r.startErr.(*structs.RecoverableError); !(ok && rerr.Recoverable) { r.reason = ReasonUnrecoverableErrror return structs.TaskNotRestarting, 0 } diff --git a/client/restarts_test.go b/client/restarts_test.go index 86960b1f7e7e..851052576e6a 100644 --- a/client/restarts_test.go +++ b/client/restarts_test.go @@ -108,7 +108,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Fail(t *testing.T) { t.Parallel() p := testPolicy(true, structs.RestartPolicyModeFail) rt := newRestartTracker(p, structs.JobTypeSystem) - recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true) + recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true) for i := 0; i < p.Attempts; i++ { state, when := rt.SetStartError(recErr).GetState() if state != structs.TaskRestarting { @@ -129,7 +129,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Delay(t *testing.T) { t.Parallel() p := testPolicy(true, structs.RestartPolicyModeDelay) rt := newRestartTracker(p, structs.JobTypeSystem) - recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true) + recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true) for i := 0; i < p.Attempts; i++ { state, when := rt.SetStartError(recErr).GetState() if state != structs.TaskRestarting { diff --git a/client/task_runner.go b/client/task_runner.go index 4dd7abea2a5d..dfb84b94db08 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -509,10 +509,10 @@ OUTER: // restoring the TaskRunner if token == "" { // Get a token - var ok bool - token, ok = r.deriveVaultToken() - if !ok { - // We are shutting down + var exit bool + token, exit = r.deriveVaultToken() + if exit { + // Exit the manager return } @@ -589,12 +589,20 @@ OUTER: // deriveVaultToken derives the Vault token using exponential backoffs. It // returns the Vault token and whether the token is valid. If it is not valid we // are shutting down -func (r *TaskRunner) deriveVaultToken() (string, bool) { +func (r *TaskRunner) deriveVaultToken() (token string, exit bool) { attempts := 0 for { tokens, err := r.vaultClient.DeriveToken(r.alloc, []string{r.task.Name}) if err == nil { - return tokens[r.task.Name], true + return tokens[r.task.Name], false + } + + // Check if we can't recover from the error + if rerr, ok := err.(*structs.RecoverableError); !ok || !rerr.Recoverable { + r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v", + r.task.Name, r.alloc.ID, err) + r.Kill("vault", fmt.Sprintf("failed to derive token: %v", err)) + return "", true } // Handle the retry case @@ -602,14 +610,15 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) { if backoff > vaultBackoffLimit { backoff = vaultBackoffLimit } - r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", r.task.Name, r.alloc.ID, err, backoff) + r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", + r.task.Name, r.alloc.ID, err, backoff) attempts++ // Wait till retrying select { case <-r.waitCh: - return "", false + return "", true case <-time.After(backoff): } } @@ -706,7 +715,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) { if err := getter.GetArtifact(r.getTaskEnv(), artifact, r.taskDir); err != nil { r.setState(structs.TaskStatePending, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(err)) - r.restartTracker.SetStartError(dstructs.NewRecoverableError(err, true)) + r.restartTracker.SetStartError(structs.NewRecoverableError(err, true)) goto RESTART } } diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 819c35799055..dd85dee9d4aa 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -721,7 +721,7 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { } count++ - return nil, fmt.Errorf("Want a retry") + return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true) } tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler go tr.Run() @@ -770,6 +770,49 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { } } +func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + ChangeMode: structs.VaultChangeModeRestart, + } + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + + // Error the token derivation + vc := tr.vaultClient.(*vaultclient.MockVaultClient) + vc.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable")) + go tr.Run() + + // Wait for the task to start + testutil.WaitForResult(func() (bool, error) { + if l := len(upd.events); l != 2 { + return false, fmt.Errorf("Expect two events; got %v", l) + } + + if upd.events[0].Type != structs.TaskReceived { + return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskKilling { + return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskKilling) + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} + func TestTaskRunner_Template_Block(t *testing.T) { alloc := mock.Alloc() task := alloc.Job.TaskGroups[0].Tasks[0] diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index be4a2c81887e..fc2e59038ab7 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/watch" + "github.com/hashicorp/raft" vapi "github.com/hashicorp/vault/api" ) @@ -940,22 +941,26 @@ func (b *batchFuture) Respond(index uint64, err error) { func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, reply *structs.DeriveVaultTokenResponse) error { if done, err := n.srv.forward("Node.DeriveVaultToken", args, args, reply); done { - return err + reply.Error = structs.NewRecoverableError(err, err == structs.ErrNoLeader) + return nil } defer metrics.MeasureSince([]string{"nomad", "client", "derive_vault_token"}, time.Now()) // Verify the arguments if args.NodeID == "" { - return fmt.Errorf("missing node ID") + reply.Error = structs.NewRecoverableError(fmt.Errorf("missing node ID"), false) } if args.SecretID == "" { - return fmt.Errorf("missing node SecretID") + reply.Error = structs.NewRecoverableError(fmt.Errorf("missing node SecretID"), false) + return nil } if args.AllocID == "" { - return fmt.Errorf("missing allocation ID") + reply.Error = structs.NewRecoverableError(fmt.Errorf("missing allocation ID"), false) + return nil } if len(args.Tasks) == 0 { - return fmt.Errorf("no tasks specified") + reply.Error = structs.NewRecoverableError(fmt.Errorf("no tasks specified"), false) + return nil } // Verify the following: @@ -965,41 +970,51 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, // tokens snap, err := n.srv.fsm.State().Snapshot() if err != nil { - return err + reply.Error = structs.NewRecoverableError(err, false) + return nil } node, err := snap.NodeByID(args.NodeID) if err != nil { - return err + reply.Error = structs.NewRecoverableError(err, false) + return nil } if node == nil { - return fmt.Errorf("Node %q does not exist", args.NodeID) + reply.Error = structs.NewRecoverableError(fmt.Errorf("Node %q does not exist", args.NodeID), false) + return nil } if node.SecretID != args.SecretID { - return fmt.Errorf("SecretID mismatch") + reply.Error = structs.NewRecoverableError(fmt.Errorf("SecretID mismatch"), false) + return nil } alloc, err := snap.AllocByID(args.AllocID) if err != nil { - return err + reply.Error = structs.NewRecoverableError(err, false) + return nil } if alloc == nil { - return fmt.Errorf("Allocation %q does not exist", args.AllocID) + reply.Error = structs.NewRecoverableError(fmt.Errorf("Allocation %q does not exist", args.AllocID), false) + return nil } if alloc.NodeID != args.NodeID { - return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID) + reply.Error = structs.NewRecoverableError(fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID), false) + return nil } if alloc.TerminalStatus() { - return fmt.Errorf("Can't request Vault token for terminal allocation") + reply.Error = structs.NewRecoverableError(fmt.Errorf("Can't request Vault token for terminal allocation"), false) + return nil } // Check the policies policies := alloc.Job.VaultPolicies() if policies == nil { - return fmt.Errorf("Job doesn't require Vault policies") + reply.Error = structs.NewRecoverableError(fmt.Errorf("Job doesn't require Vault policies"), false) + return nil } tg, ok := policies[alloc.TaskGroup] if !ok { - return fmt.Errorf("Task group does not require Vault policies") + reply.Error = structs.NewRecoverableError(fmt.Errorf("Task group does not require Vault policies"), false) + return nil } var unneeded []string @@ -1011,8 +1026,10 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, } if len(unneeded) != 0 { - return fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s", + e := fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s", strings.Join(unneeded, ", ")) + reply.Error = structs.NewRecoverableError(e, false) + return nil } // At this point the request is valid and we should contact Vault for @@ -1043,7 +1060,13 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, secret, err := n.srv.vault.CreateToken(ctx, alloc, task) if err != nil { - return fmt.Errorf("failed to create token for task %q: %v", task, err) + wrapped := fmt.Errorf("failed to create token for task %q: %v", task, err) + if rerr, ok := err.(*structs.RecoverableError); ok && rerr.Recoverable { + // If the error is recoverable, propogate it + return structs.NewRecoverableError(wrapped, true) + } + + return wrapped } results[task] = secret @@ -1068,9 +1091,9 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, }() // Wait for everything to complete or for an error - err = g.Wait() + createErr := g.Wait() - // Commit to Raft before returning any of the tokens + // Retrieve the results accessors := make([]*structs.VaultAccessor, 0, len(results)) tokens := make(map[string]string, len(results)) for task, secret := range results { @@ -1092,20 +1115,36 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, } // If there was an error revoke the created tokens - if err != nil { - var mErr multierror.Error - mErr.Errors = append(mErr.Errors, err) - if err := n.srv.vault.RevokeTokens(context.Background(), accessors, false); err != nil { - mErr.Errors = append(mErr.Errors, err) + if createErr != nil { + if revokeErr := n.srv.vault.RevokeTokens(context.Background(), accessors, false); revokeErr != nil { + n.srv.logger.Printf("[ERR] nomad.node: Vault token revocation failed: %v", revokeErr) + } + + if rerr, ok := createErr.(*structs.RecoverableError); ok { + reply.Error = rerr + } else { + reply.Error = structs.NewRecoverableError(createErr, false) } - return mErr.ErrorOrNil() + + return nil } + // Commit to Raft before returning any of the tokens req := structs.VaultAccessorsRequest{Accessors: accessors} _, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err) - return err + + // Determine if we can recover from the error + retry := false + switch err { + case raft.ErrNotLeader, raft.ErrLeadershipLost, raft.ErrRaftShutdown, raft.ErrEnqueueTimeout: + retry = true + default: + } + + reply.Error = structs.NewRecoverableError(err, retry) + return nil } reply.Index = index diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index bf4c22489a28..39c87c032d4b 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -1822,18 +1822,23 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { } var resp structs.DeriveVaultTokenResponse - err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) - if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") { - t.Fatalf("Expected SecretID mismatch: %v", err) + if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + + if resp.Error == nil || !strings.Contains(resp.Error.Error(), "SecretID mismatch") { + t.Fatalf("Expected SecretID mismatch: %v", resp.Error) } // Put the correct SecretID req.SecretID = node.SecretID // Now we should get an error about the allocation not running on the node - err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) - if err == nil || !strings.Contains(err.Error(), "not running on Node") { - t.Fatalf("Expected not running on node error: %v", err) + if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + if resp.Error == nil || !strings.Contains(resp.Error.Error(), "not running on Node") { + t.Fatalf("Expected not running on node error: %v", resp.Error) } // Update to be running on the node @@ -1843,9 +1848,11 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { } // Now we should get an error about the job not needing any Vault secrets - err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) - if err == nil || !strings.Contains(err.Error(), "does not require") { - t.Fatalf("Expected no policies error: %v", err) + if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + if resp.Error == nil || !strings.Contains(resp.Error.Error(), "does not require") { + t.Fatalf("Expected no policies error: %v", resp.Error) } // Update to be terminal @@ -1855,9 +1862,11 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { } // Now we should get an error about the job not needing any Vault secrets - err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) - if err == nil || !strings.Contains(err.Error(), "terminal") { - t.Fatalf("Expected terminal allocation error: %v", err) + if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + if resp.Error == nil || !strings.Contains(resp.Error.Error(), "terminal") { + t.Fatalf("Expected terminal allocation error: %v", resp.Error) } } @@ -1920,6 +1929,9 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) { if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { t.Fatalf("bad: %v", err) } + if resp.Error != nil { + t.Fatalf("bad: %v", resp.Error) + } // Check the state store and ensure that we created a VaultAccessor va, err := state.VaultAccessor(accessor) @@ -1947,3 +1959,59 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) { t.Fatalf("Got %#v; want %#v", va, expected) } } + +func TestClientEndpoint_DeriveVaultToken_VaultError(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Enable vault and allow authenticated + tr := true + s1.config.VaultConfig.Enabled = &tr + s1.config.VaultConfig.AllowUnauthenticated = &tr + + // Replace the Vault Client on the server + tvc := &TestVaultClient{} + s1.vault = tvc + + // Create the node + node := mock.Node() + if err := state.UpsertNode(2, node); err != nil { + t.Fatalf("err: %v", err) + } + + // Create an alloc an allocation that has vault policies required + alloc := mock.Alloc() + alloc.NodeID = node.ID + task := alloc.Job.TaskGroups[0].Tasks[0] + tasks := []string{task.Name} + task.Vault = &structs.Vault{Policies: []string{"a", "b"}} + if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Return an error when creating the token + tvc.SetCreateTokenError(alloc.ID, task.Name, + structs.NewRecoverableError(fmt.Errorf("recover"), true)) + + req := &structs.DeriveVaultTokenRequest{ + NodeID: node.ID, + SecretID: node.SecretID, + AllocID: alloc.ID, + Tasks: tasks, + QueryOptions: structs.QueryOptions{ + Region: "global", + }, + } + + var resp structs.DeriveVaultTokenResponse + err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err != nil { + t.Fatalf("bad: %v", err) + } + if resp.Error == nil || !resp.Error.Recoverable { + t.Fatalf("bad: %+v", resp.Error) + } +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index ac2c48443289..89d48772a49d 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -389,6 +389,11 @@ type VaultAccessor struct { type DeriveVaultTokenResponse struct { // Tasks is a mapping between the task name and the wrapped token Tasks map[string]string + + // Error stores any error that occured. Errors are stored here so we can + // communicate whether it is retriable + Error *RecoverableError + QueryMeta } @@ -3688,3 +3693,27 @@ type KeyringResponse struct { type KeyringRequest struct { Key string } + +// RecoverableError wraps an error and marks whether it is recoverable and could +// be retried or it is fatal. +type RecoverableError struct { + Err string + Recoverable bool +} + +// NewRecoverableError is used to wrap an error and mark it as recoverable or +// not. +func NewRecoverableError(e error, recoverable bool) *RecoverableError { + if e == nil { + return nil + } + + return &RecoverableError{ + Err: e.Error(), + Recoverable: recoverable, + } +} + +func (r *RecoverableError) Error() string { + return r.Err +} diff --git a/nomad/vault.go b/nomad/vault.go index 026e8343d564..0a1327b9d542 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "math/rand" + "strings" "sync" "sync/atomic" "time" @@ -45,6 +46,14 @@ const ( // vaultRevocationIntv is the interval at which Vault tokens that failed // initial revocation are retried vaultRevocationIntv = 5 * time.Minute + + // Errors returned by Vault + + // vaultErrInvalidRequest is returned if the request is invalid + vaultErrInvalidRequest = "invalid request" + + // vaultErrPermissionDenied is returned if the client is not authorized + vaultErrPermissionDenied = "permission denied" ) // VaultClient is the Servers interface for interfacing with Vault @@ -104,8 +113,11 @@ type vaultClient struct { config *config.VaultConfig // connEstablished marks whether we have an established connection to Vault. - // It should be accessed using a helper and updated atomically - connEstablished int32 + connEstablished bool + + // connEstablishedErr marks an error that can occur when establishing a + // connection + connEstablishedErr error // token is the raw token used by the client token string @@ -202,7 +214,7 @@ func (v *vaultClient) flush() { v.client = nil v.auth = nil - v.connEstablished = 0 + v.connEstablished = false v.token = "" v.tokenData = nil v.revoking = make(map[*structs.VaultAccessor]time.Time) @@ -225,7 +237,7 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error { if v.config.IsEnabled() { // Stop accepting any new request - atomic.StoreInt32(&v.connEstablished, 0) + v.connEstablished = false // Kill any background routine and create a new tomb v.tomb.Kill(nil) @@ -310,8 +322,8 @@ OUTER: case <-retryTimer.C: // Ensure the API is reachable if _, err := v.client.Sys().InitStatus(); err != nil { - v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v", - v.config.ConnectionRetryIntv) + v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v: %v", + v.config.ConnectionRetryIntv, err) retryTimer.Reset(v.config.ConnectionRetryIntv) continue OUTER } @@ -323,6 +335,10 @@ OUTER: // Retrieve our token, validate it and parse the lease duration if err := v.parseSelfToken(); err != nil { v.logger.Printf("[ERR] vault: failed to lookup self token and not retrying: %v", err) + v.l.Lock() + v.connEstablished = false + v.connEstablishedErr = err + v.l.Unlock() return } @@ -339,7 +355,9 @@ OUTER: v.tomb.Go(wrapNilError(v.renewalLoop)) } - atomic.StoreInt32(&v.connEstablished, 1) + v.l.Lock() + v.connEstablished = true + v.l.Unlock() } // renewalLoop runs the renew loop. This should only be called if we are given a @@ -407,7 +425,10 @@ func (v *vaultClient) renewalLoop() { // We have failed to renew the token past its expiration. Stop // renewing with Vault. v.logger.Printf("[ERR] vault: failed to renew Vault token before lease expiration. Shutting down Vault client") - atomic.StoreInt32(&v.connEstablished, 0) + v.l.Lock() + v.connEstablished = false + v.connEstablishedErr = err + v.l.Unlock() return } else if backoff > maxBackoff.Seconds() { @@ -521,36 +542,42 @@ func (v *vaultClient) parseSelfToken() error { } // ConnectionEstablished returns whether a connection to Vault has been -// established. -func (v *vaultClient) ConnectionEstablished() bool { - return atomic.LoadInt32(&v.connEstablished) == 1 +// established and any error that potentially caused it to be false +func (v *vaultClient) ConnectionEstablished() (bool, error) { + v.l.Lock() + defer v.l.Unlock() + return v.connEstablished, v.connEstablishedErr } +// Enabled returns whether the client is active func (v *vaultClient) Enabled() bool { v.l.Lock() defer v.l.Unlock() return v.config.IsEnabled() } -// +// Active returns whether the client is active func (v *vaultClient) Active() bool { return atomic.LoadInt32(&v.active) == 1 } // CreateToken takes the allocation and task and returns an appropriate Vault -// token. The call is rate limited and may be canceled with the passed policy +// token. The call is rate limited and may be canceled with the passed policy. +// When the error is recoverable, it will be of type RecoverableError func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) { if !v.Enabled() { return nil, fmt.Errorf("Vault integration disabled") } if !v.Active() { - return nil, fmt.Errorf("Vault client not active") + return nil, structs.NewRecoverableError(fmt.Errorf("Vault client not active"), true) } // Check if we have established a connection with Vault - if !v.ConnectionEstablished() { - return nil, fmt.Errorf("Connection to Vault has not been established. Retry") + if established, err := v.ConnectionEstablished(); !established && err == nil { + return nil, structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true) + } else if !established { + return nil, fmt.Errorf("Connection to Vault failed: %v", err) } // Retrieve the Vault block for the task @@ -596,7 +623,19 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta secret, err = v.auth.CreateWithRole(req, v.tokenData.Role) } - return secret, err + // Determine whether it is unrecoverable + if err != nil { + eStr := err.Error() + if strings.Contains(eStr, vaultErrInvalidRequest) || + strings.Contains(eStr, vaultErrPermissionDenied) { + return secret, err + } + + // The error is recoverable + return nil, structs.NewRecoverableError(err, true) + } + + return secret, nil } // LookupToken takes a Vault token and does a lookup against Vault. The call is @@ -611,8 +650,10 @@ func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secr } // Check if we have established a connection with Vault - if !v.ConnectionEstablished() { - return nil, fmt.Errorf("Connection to Vault has not been established. Retry") + if established, err := v.ConnectionEstablished(); !established && err == nil { + return nil, structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true) + } else if !established { + return nil, fmt.Errorf("Connection to Vault failed: %v", err) } // Ensure we are under our rate limit @@ -652,7 +693,7 @@ func (v *vaultClient) RevokeTokens(ctx context.Context, accessors []*structs.Vau // Check if we have established a connection with Vault. If not just add it // to the queue - if !v.ConnectionEstablished() { + if established, err := v.ConnectionEstablished(); !established && err == nil { // Only bother tracking it for later revocation if the accessor was // committed if committed { @@ -709,8 +750,10 @@ func (v *vaultClient) parallelRevoke(ctx context.Context, accessors []*structs.V } // Check if we have established a connection with Vault - if !v.ConnectionEstablished() { - return fmt.Errorf("Connection to Vault has not been established. Retry") + if established, err := v.ConnectionEstablished(); !established && err == nil { + return structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true) + } else if !established { + return fmt.Errorf("Connection to Vault failed: %v", err) } g, pCtx := errgroup.WithContext(ctx) @@ -770,7 +813,7 @@ func (v *vaultClient) revokeDaemon() { case <-v.tomb.Dying(): return case now := <-ticker.C: - if !v.ConnectionEstablished() { + if established, _ := v.ConnectionEstablished(); !established { continue } diff --git a/nomad/vault_test.go b/nomad/vault_test.go index 41249e2ce574..2d50ee2d60b6 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -3,6 +3,7 @@ package nomad import ( "context" "encoding/json" + "fmt" "log" "os" "reflect" @@ -67,7 +68,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) { // Sleep a little while and check that no connection has been established. time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) - if client.ConnectionEstablished() { + if established, _ := client.ConnectionEstablished(); established { t.Fatalf("ConnectionEstablished() returned true before Vault server started") } @@ -417,7 +418,7 @@ func TestVaultClient_CreateToken_Role(t *testing.T) { // Set the configs token in a new test role v.Config.Token = testVaultRoleAndToken(v, t, 5) - //testVaultRoleAndToken(v, t, 5) + // Start the client logger := log.New(os.Stderr, "", log.LstdFlags) client, err := NewVaultClient(v.Config, logger, nil) @@ -458,6 +459,74 @@ func TestVaultClient_CreateToken_Role(t *testing.T) { } } +func TestVaultClient_CreateToken_Role_InvalidToken(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + // Set the configs token in a new test role + testVaultRoleAndToken(v, t, 5) + v.Config.Token = "foo-bar" + + // Start the client + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + client.SetActive(true) + defer client.Stop() + + testutil.WaitForResult(func() (bool, error) { + established, err := client.ConnectionEstablished() + if established { + return false, fmt.Errorf("Shouldn't establish") + } + + return err != nil, nil + }, func(err error) { + t.Fatalf("Connection not established") + }) + + // Create an allocation that requires a Vault policy + a := mock.Alloc() + task := a.Job.TaskGroups[0].Tasks[0] + task.Vault = &structs.Vault{Policies: []string{"default"}} + + _, err = client.CreateToken(context.Background(), a, task.Name) + if err == nil || !strings.Contains(err.Error(), "Connection to Vault failed") { + t.Fatalf("CreateToken should have failed: %v", err) + } +} + +func TestVaultClient_CreateToken_Prestart(t *testing.T) { + v := testutil.NewTestVault(t) + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + client.SetActive(true) + defer client.Stop() + + // Create an allocation that requires a Vault policy + a := mock.Alloc() + task := a.Job.TaskGroups[0].Tasks[0] + task.Vault = &structs.Vault{Policies: []string{"default"}} + + _, err = client.CreateToken(context.Background(), a, task.Name) + if err == nil { + t.Fatalf("CreateToken should have failed: %v", err) + } + + if rerr, ok := err.(*structs.RecoverableError); !ok { + t.Fatalf("Err should have been type recoverable error") + } else if ok && !rerr.Recoverable { + t.Fatalf("Err should have been recoverable") + } +} + func TestVaultClient_RevokeTokens_PreEstablishs(t *testing.T) { v := testutil.NewTestVault(t) logger := log.New(os.Stderr, "", log.LstdFlags) @@ -559,7 +628,7 @@ func TestVaultClient_RevokeTokens(t *testing.T) { func waitForConnection(v *vaultClient, t *testing.T) { testutil.WaitForResult(func() (bool, error) { - return v.ConnectionEstablished(), nil + return v.ConnectionEstablished() }, func(err error) { t.Fatalf("Connection not established") }) diff --git a/vendor/vendor.json b/vendor/vendor.json index c0624a8ce11b..c8afb90b98e9 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -270,14 +270,14 @@ { "checksumSHA1": "tdhmIGUaoOMEDymMC23qTS7bt0g=", "path": "github.com/docker/docker/pkg/ioutils", - "revision": "52debcd58ac91bf68503ce60561536911b74ff05", - "revisionTime": "2016-05-20T15:17:10Z" + "revision": "da39e9a4f920a15683dd0f23923c302d4db6eed5", + "revisionTime": "2016-05-28T08:11:04Z" }, { "checksumSHA1": "tdhmIGUaoOMEDymMC23qTS7bt0g=", "path": "github.com/docker/docker/pkg/ioutils", - "revision": "da39e9a4f920a15683dd0f23923c302d4db6eed5", - "revisionTime": "2016-05-28T08:11:04Z" + "revision": "52debcd58ac91bf68503ce60561536911b74ff05", + "revisionTime": "2016-05-20T15:17:10Z" }, { "checksumSHA1": "ndnAFCfsGC3upNQ6jAEwzxcurww=",