Skip to content

Commit

Permalink
Merge pull request #1659 from hashicorp/f-revoke-accessors
Browse files Browse the repository at this point in the history
Token revocation and keeping only a single Vault client active among servers
  • Loading branch information
dadgar authored Aug 31, 2016
2 parents daf3ca2 + 743cbeb commit 67481cd
Show file tree
Hide file tree
Showing 18 changed files with 1,305 additions and 183 deletions.
20 changes: 19 additions & 1 deletion nomad/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} {
return n.applyReconcileSummaries(buf[1:], log.Index)
case structs.VaultAccessorRegisterRequestType:
return n.applyUpsertVaultAccessor(buf[1:], log.Index)
case structs.VaultAccessorDegisterRequestType:
return n.applyDeregisterVaultAccessor(buf[1:], log.Index)
default:
if ignoreUnknown {
n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType)
Expand Down Expand Up @@ -472,7 +474,7 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{}
// and task
func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} {
defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now())
var req structs.VaultAccessorRegisterRequest
var req structs.VaultAccessorsRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
Expand All @@ -485,6 +487,22 @@ func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{
return nil
}

// applyDeregisterVaultAccessor deregisters a set of Vault accessors
func (n *nomadFSM) applyDeregisterVaultAccessor(buf []byte, index uint64) interface{} {
defer metrics.MeasureSince([]string{"nomad", "fsm", "deregister_vault_accessor"}, time.Now())
var req structs.VaultAccessorsRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}

if err := n.state.DeleteVaultAccessors(index, req.Accessors); err != nil {
n.logger.Printf("[ERR] nomad.fsm: DeregisterVaultAccessor failed: %v", err)
return err
}

return nil
}

func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) {
// Create a new snapshot
snap, err := n.state.Snapshot()
Expand Down
43 changes: 42 additions & 1 deletion nomad/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) {

va := mock.VaultAccessor()
va2 := mock.VaultAccessor()
req := structs.VaultAccessorRegisterRequest{
req := structs.VaultAccessorsRequest{
Accessors: []*structs.VaultAccessor{va, va2},
}
buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req)
Expand Down Expand Up @@ -819,6 +819,47 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) {
}
}

func TestFSM_DeregisterVaultAccessor(t *testing.T) {
fsm := testFSM(t)
fsm.blockedEvals.SetEnabled(true)

va := mock.VaultAccessor()
va2 := mock.VaultAccessor()
accessors := []*structs.VaultAccessor{va, va2}

// Insert the accessors
if err := fsm.State().UpsertVaultAccessor(1000, accessors); err != nil {
t.Fatalf("bad: %v", err)
}

req := structs.VaultAccessorsRequest{
Accessors: accessors,
}
buf, err := structs.Encode(structs.VaultAccessorDegisterRequestType, req)
if err != nil {
t.Fatalf("err: %v", err)
}

resp := fsm.Apply(makeLog(buf))
if resp != nil {
t.Fatalf("resp: %v", resp)
}

out1, err := fsm.State().VaultAccessor(va.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out1 != nil {
t.Fatalf("not deleted!")
}

tt := fsm.TimeTable()
index := tt.NearestIndex(time.Now().UTC())
if index != 1 {
t.Fatalf("bad: %d", index)
}
}

func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM {
// Snapshot
snap, err := fsm.Snapshot()
Expand Down
61 changes: 61 additions & 0 deletions nomad/leader.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nomad

import (
"context"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -132,6 +133,12 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error {
return err
}

// Activate the vault client
s.vault.SetActive(true)
if err := s.restoreRevokingAccessors(); err != nil {
return err
}

// Enable the periodic dispatcher, since we are now the leader.
s.periodicDispatcher.SetEnabled(true)
s.periodicDispatcher.Start()
Expand Down Expand Up @@ -205,6 +212,57 @@ func (s *Server) restoreEvals() error {
return nil
}

// restoreRevokingAccessors is used to restore Vault accessors that should be
// revoked.
func (s *Server) restoreRevokingAccessors() error {
// An accessor should be revoked if its allocation or node is terminal
state := s.fsm.State()
iter, err := state.VaultAccessors()
if err != nil {
return fmt.Errorf("failed to get vault accessors: %v", err)
}

var revoke []*structs.VaultAccessor
for {
raw := iter.Next()
if raw == nil {
break
}

va := raw.(*structs.VaultAccessor)

// Check the allocation
alloc, err := state.AllocByID(va.AllocID)
if err != nil {
return fmt.Errorf("failed to lookup allocation: %v", va.AllocID, err)
}
if alloc == nil || alloc.Terminated() {
// No longer running and should be revoked
revoke = append(revoke, va)
continue
}

// Check the node
node, err := state.NodeByID(va.NodeID)
if err != nil {
return fmt.Errorf("failed to lookup node %q: %v", va.NodeID, err)
}
if node == nil || node.TerminalStatus() {
// Node is terminal so any accessor from it should be revoked
revoke = append(revoke, va)
continue
}
}

if len(revoke) != 0 {
if err := s.vault.RevokeTokens(context.Background(), revoke, true); err != nil {
return fmt.Errorf("failed to revoke tokens: %v", err)
}
}

return nil
}

// restorePeriodicDispatcher is used to restore all periodic jobs into the
// periodic dispatcher. It also determines if a periodic job should have been
// created during the leadership transition and force runs them. The periodic
Expand Down Expand Up @@ -409,6 +467,9 @@ func (s *Server) revokeLeadership() error {
// Disable the periodic dispatcher, since it is only useful as a leader
s.periodicDispatcher.SetEnabled(false)

// Disable the Vault client as it is only useful as a leader.
s.vault.SetActive(false)

// Clear the heartbeat timers on either shutdown or step down,
// since we are no longer responsible for TTL expirations.
if err := s.clearAllHeartbeatTimers(); err != nil {
Expand Down
28 changes: 28 additions & 0 deletions nomad/leader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,31 @@ func TestLeader_ReapDuplicateEval(t *testing.T) {
t.Fatalf("err: %v", err)
})
}

func TestLeader_RestoreVaultAccessors(t *testing.T) {
s1 := testServer(t, func(c *Config) {
c.NumSchedulers = 0
})
defer s1.Shutdown()
testutil.WaitForLeader(t, s1.RPC)

// Insert a vault accessor that should be revoked
state := s1.fsm.State()
va := mock.VaultAccessor()
if err := state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va}); err != nil {
t.Fatalf("bad: %v", err)
}

// Swap the Vault client
tvc := &TestVaultClient{}
s1.vault = tvc

// Do a restore
if err := s1.restoreRevokingAccessors(); err != nil {
t.Fatalf("Failed to restore: %v", err)
}

if len(tvc.RevokedTokens) != 1 && tvc.RevokedTokens[0].Accessor != va.Accessor {
t.Fatalf("Bad revoked accessors: %v", tvc.RevokedTokens)
}
}
80 changes: 72 additions & 8 deletions nomad/node_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/armon/go-metrics"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/watch"
Expand Down Expand Up @@ -215,7 +216,7 @@ func (n *Node) constructNodeServerInfoResponse(snap *state.StateSnapshot, reply
return nil
}

// Deregister is used to remove a client from the client. If a client should
// Deregister is used to remove a client from the cluster. If a client should
// just be made unavailable for scheduling, a status update is preferred.
func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.NodeUpdateResponse) error {
if done, err := n.srv.forward("Node.Deregister", args, args, reply); done {
Expand Down Expand Up @@ -245,6 +246,20 @@ func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.No
return err
}

// Determine if there are any Vault accessors on the node
accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err)
return err
}

if len(accessors) != 0 {
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, true); err != nil {
n.srv.logger.Printf("[ERR] nomad.client: revoking accessors for node %q failed: %v", args.NodeID, err)
return err
}
}

// Setup the reply
reply.EvalIDs = evalIDs
reply.EvalCreateIndex = evalIndex
Expand Down Expand Up @@ -311,7 +326,22 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct
}

// Check if we need to setup a heartbeat
if args.Status != structs.NodeStatusDown {
switch args.Status {
case structs.NodeStatusDown:
// Determine if there are any Vault accessors on the node
accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err)
return err
}

if len(accessors) != 0 {
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, true); err != nil {
n.srv.logger.Printf("[ERR] nomad.client: revoking accessors for node %q failed: %v", args.NodeID, err)
return err
}
}
default:
ttl, err := n.srv.resetHeartbeatTimer(args.NodeID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: heartbeat reset failed: %v", err)
Expand Down Expand Up @@ -686,13 +716,41 @@ func (n *Node) batchUpdate(future *batchFuture, updates []*structs.Allocation) {
}

// Commit this update via Raft
var mErr multierror.Error
_, index, err := n.srv.raftApply(structs.AllocClientUpdateRequestType, batch)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: alloc update failed: %v", err)
mErr.Errors = append(mErr.Errors, err)
}

// For each allocation we are updating check if we should revoke any
// Vault Accessors
var revoke []*structs.VaultAccessor
for _, alloc := range updates {
// Skip any allocation that isn't dead on the client
if !alloc.Terminated() {
continue
}

// Determine if there are any Vault accessors for the allocation
accessors, err := n.srv.State().VaultAccessorsByAlloc(alloc.ID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for alloc %q failed: %v", alloc.ID, err)
mErr.Errors = append(mErr.Errors, err)
}

revoke = append(revoke, accessors...)
}

if len(revoke) != 0 {
if err := n.srv.vault.RevokeTokens(context.Background(), revoke, true); err != nil {
n.srv.logger.Printf("[ERR] nomad.client: batched accessor revocation failed: %v", err)
mErr.Errors = append(mErr.Errors, err)
}
}

// Respond to the future
future.Respond(index, err)
future.Respond(index, mErr.ErrorOrNil())
}

// List is used to list the available nodes
Expand Down Expand Up @@ -1011,10 +1069,6 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,

// Wait for everything to complete or for an error
err = g.Wait()
if err != nil {
// TODO Revoke any created token
return err
}

// Commit to Raft before returning any of the tokens
accessors := make([]*structs.VaultAccessor, 0, len(results))
Expand All @@ -1037,7 +1091,17 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
accessors = append(accessors, accessor)
}

req := structs.VaultAccessorRegisterRequest{Accessors: accessors}
// If there was an error revoke the created tokens
if err != nil {
var mErr multierror.Error
mErr.Errors = append(mErr.Errors, err)
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, false); err != nil {
mErr.Errors = append(mErr.Errors, err)
}
return mErr.ErrorOrNil()
}

req := structs.VaultAccessorsRequest{Accessors: accessors}
_, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err)
Expand Down
Loading

0 comments on commit 67481cd

Please sign in to comment.