Skip to content

Commit

Permalink
wi: introduce workload identity handler (#18672)
Browse files Browse the repository at this point in the history
Any code that tracks workloads and their identities should not rely on string
comparisons, especially since we support 2 types of workload identities: those
that identify tasks and those that identify services. This means we cannot rely
on task.Name for workload-identity pairs.

The new type structs.WIHandle solves this problem by providing a uniform way of
identifying workloads and their identities.
  • Loading branch information
pkazmierczak committed Oct 6, 2023
1 parent 0ccf942 commit 597d835
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 98 deletions.
14 changes: 7 additions & 7 deletions client/allocrunner/consul_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ func (h *consulHook) prepareConsulTokensForTask(job *structs.Job, task *structs.
if i.Name != expectedIdentity {
continue
}
ti := widmgr.TaskIdentity{
TaskName: task.Name,
IdentityName: i.Name,
ti := structs.WIHandle{
WorkloadIdentifier: task.Name,
IdentityName: i.Name,
}

req, err := h.prepareConsulClientReq(ti, consulTasksAuthMethodName)
Expand Down Expand Up @@ -157,9 +157,9 @@ func (h *consulHook) prepareConsulTokensForServices(services []*structs.Service,
continue
}

ti := widmgr.TaskIdentity{
TaskName: service.TaskName,
IdentityName: service.Identity.Name,
ti := structs.WIHandle{
WorkloadIdentifier: service.TaskName,
IdentityName: service.Identity.Name,
}

req, err := h.prepareConsulClientReq(ti, consulServicesAuthMethodName)
Expand Down Expand Up @@ -207,7 +207,7 @@ func (h *consulHook) getConsulTokens(cluster, identityName string, tokens map[st
return nil
}

func (h *consulHook) prepareConsulClientReq(identity widmgr.TaskIdentity, authMethodName string) (map[string]consul.JWTLoginRequest, error) {
func (h *consulHook) prepareConsulClientReq(identity structs.WIHandle, authMethodName string) (map[string]consul.JWTLoginRequest, error) {
req := map[string]consul.JWTLoginRequest{}

jwt, err := h.widmgr.Get(identity)
Expand Down
14 changes: 8 additions & 6 deletions client/allocrunner/identity_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ func TestIdentityHook_Prerun(t *testing.T) {
// do the initial signing
_, err := mockSigner.SignIdentities(1, []*structs.WorkloadIdentityRequest{
{
AllocID: alloc.ID,
TaskName: task.Name,
IdentityName: task.Identities[0].Name,
AllocID: alloc.ID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: task.Name,
IdentityName: task.Identities[0].Name,
},
},
})
must.NoError(t, err)
Expand All @@ -67,9 +69,9 @@ func TestIdentityHook_Prerun(t *testing.T) {
must.NoError(t, hook.Prerun())

time.Sleep(time.Second) // give goroutines a moment to run
sid, err := hook.widmgr.Get(widmgr.TaskIdentity{
TaskName: task.Name,
IdentityName: task.Identities[0].Name},
sid, err := hook.widmgr.Get(structs.WIHandle{
WorkloadIdentifier: task.Name,
IdentityName: task.Identities[0].Name},
)
must.Nil(t, err)
must.Eq(t, sid.IdentityName, task.Identity.Name)
Expand Down
2 changes: 1 addition & 1 deletion client/allocrunner/taskrunner/identity_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (h *identityHook) Prestart(context.Context, *interfaces.TaskPrestartRequest
}

func (h *identityHook) watchIdentity(wid *structs.WorkloadIdentity) {
id := widmgr.TaskIdentity{TaskName: h.task.Name, IdentityName: wid.Name}
id := structs.WIHandle{WorkloadIdentifier: h.task.Name, IdentityName: wid.Name}
signedIdentitiesChan, stopWatching := h.widmgr.Watch(id)
defer stopWatching()

Expand Down
16 changes: 10 additions & 6 deletions client/allocrunner/taskrunner/identity_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ func TestIdentityHook_RenewAll(t *testing.T) {
for _, i := range task.Identities {
_, err := mockSigner.SignIdentities(1, []*structs.WorkloadIdentityRequest{
{
AllocID: alloc.ID,
TaskName: task.Name,
IdentityName: i.Name,
AllocID: alloc.ID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: task.Name,
IdentityName: i.Name,
},
},
})
must.NoError(t, err)
Expand Down Expand Up @@ -183,9 +185,11 @@ func TestIdentityHook_RenewOne(t *testing.T) {
for _, i := range task.Identities {
_, err := mockSigner.SignIdentities(1, []*structs.WorkloadIdentityRequest{
{
AllocID: alloc.ID,
TaskName: task.Name,
IdentityName: i.Name,
AllocID: alloc.ID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: task.Name,
IdentityName: i.Name,
},
},
})
must.NoError(t, err)
Expand Down
10 changes: 5 additions & 5 deletions client/widmgr/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (m *MockWIDSigner) SignIdentities(minIndex uint64, req []*structs.WorkloadI
Namespace: "default",
JobID: "test",
AllocationID: idReq.AllocID,
TaskName: idReq.TaskName,
TaskName: idReq.WorkloadIdentifier,
}
claims.ID = uuid.Generate()
// If test has set workload identities. Lookup claims or reject unknown
Expand Down Expand Up @@ -94,17 +94,17 @@ func (m *MockWIDSigner) SignIdentities(minIndex uint64, req []*structs.WorkloadI
// MockWIDMgr mocks IdentityManager interface allowing to only get identities
// signed by the mock signer.
type MockWIDMgr struct {
swids map[TaskIdentity]*structs.SignedWorkloadIdentity
swids map[structs.WIHandle]*structs.SignedWorkloadIdentity
}

func NewMockWIDMgr(swids map[TaskIdentity]*structs.SignedWorkloadIdentity) *MockWIDMgr {
func NewMockWIDMgr(swids map[structs.WIHandle]*structs.SignedWorkloadIdentity) *MockWIDMgr {
return &MockWIDMgr{swids: swids}
}

// Run does not run a renewal loop in this mock
func (m MockWIDMgr) Run() error { return nil }

func (m MockWIDMgr) Get(identity TaskIdentity) (*structs.SignedWorkloadIdentity, error) {
func (m MockWIDMgr) Get(identity structs.WIHandle) (*structs.SignedWorkloadIdentity, error) {
sid, ok := m.swids[identity]
if !ok {
return nil, fmt.Errorf("identity not found")
Expand All @@ -113,7 +113,7 @@ func (m MockWIDMgr) Get(identity TaskIdentity) (*structs.SignedWorkloadIdentity,
}

// Watch does not do anything, this mock doesn't support watching.
func (m MockWIDMgr) Watch(identity TaskIdentity) (<-chan *structs.SignedWorkloadIdentity, func()) {
func (m MockWIDMgr) Watch(identity structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func()) {
return nil, nil
}

Expand Down
83 changes: 36 additions & 47 deletions client/widmgr/widmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,12 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
)

// TaskIdentity maps the name of the task to the name of a workload identity. Any
// task can have multiple identities.
type TaskIdentity struct {
TaskName string
IdentityName string
}

// IdentityManager defines a manager responsible for signing and renewing
// signed identities. At runtime it is implemented by *widmgr.WIDMgr.
type IdentityManager interface {
Run() error
Get(TaskIdentity) (*structs.SignedWorkloadIdentity, error)
Watch(TaskIdentity) (<-chan *structs.SignedWorkloadIdentity, func())
Get(structs.WIHandle) (*structs.SignedWorkloadIdentity, error)
Watch(structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func())
Shutdown()
}

Expand All @@ -40,12 +33,12 @@ type WIDMgr struct {

// lastToken are the last retrieved signed workload identifiers keyed by
// TaskIdentity
lastToken map[TaskIdentity]*structs.SignedWorkloadIdentity
lastToken map[structs.WIHandle]*structs.SignedWorkloadIdentity
lastTokenLock sync.RWMutex

// watchers is a map of task identities to slices of channels (each identity
// can have multiple watchers)
watchers map[TaskIdentity][]chan *structs.SignedWorkloadIdentity
watchers map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity
watchersLock sync.Mutex

// minWait is the minimum amount of time to wait before renewing. Settable to
Expand Down Expand Up @@ -79,8 +72,8 @@ func NewWIDMgr(signer IdentitySigner, a *structs.Allocation, logger hclog.Logger
widSpecs: widspecs,
signer: signer,
minWait: 10 * time.Second,
lastToken: map[TaskIdentity]*structs.SignedWorkloadIdentity{},
watchers: map[TaskIdentity][]chan *structs.SignedWorkloadIdentity{},
lastToken: map[structs.WIHandle]*structs.SignedWorkloadIdentity{},
watchers: map[structs.WIHandle][]chan *structs.SignedWorkloadIdentity{},
stopCtx: stopCtx,
stop: stop,
logger: logger.Named("widmgr"),
Expand Down Expand Up @@ -120,18 +113,18 @@ func (m *WIDMgr) Run() error {
// For retrieving tokens which might be renewed callers should use Watch
// instead to avoid missing new tokens retrieved by Run between Get and Watch
// calls.
func (m *WIDMgr) Get(id TaskIdentity) (*structs.SignedWorkloadIdentity, error) {
func (m *WIDMgr) Get(id structs.WIHandle) (*structs.SignedWorkloadIdentity, error) {
token := m.get(id)
if token == nil {
// This is an error as every identity should have a token by the time Get
// is called.
return nil, fmt.Errorf("unable to find token for task %q and identity %q", id.TaskName, id.IdentityName)
return nil, fmt.Errorf("uble to find token for task %q and identity %q", id.WorkloadIdentifier, id.IdentityName)
}

return token, nil
}

func (m *WIDMgr) get(id TaskIdentity) *structs.SignedWorkloadIdentity {
func (m *WIDMgr) get(id structs.WIHandle) *structs.SignedWorkloadIdentity {
m.lastTokenLock.RLock()
defer m.lastTokenLock.RUnlock()

Expand All @@ -143,7 +136,7 @@ func (m *WIDMgr) get(id TaskIdentity) *structs.SignedWorkloadIdentity {
//
// The caller must call the returned func to stop watching and ensure the
// watched id actually exists, otherwise the channel never returns a result.
func (m *WIDMgr) Watch(id TaskIdentity) (<-chan *structs.SignedWorkloadIdentity, func()) {
func (m *WIDMgr) Watch(id structs.WIHandle) (<-chan *structs.SignedWorkloadIdentity, func()) {
// If Shutdown has been called return a closed chan
if m.stopCtx.Err() != nil {
c := make(chan *structs.SignedWorkloadIdentity)
Expand Down Expand Up @@ -196,16 +189,18 @@ func (m *WIDMgr) Shutdown() {
// getIdentities fetches all signed identities or returns an error.
func (m *WIDMgr) getIdentities() error {
// get the default identity signed by the plan applier
defaultTokens := map[TaskIdentity]*structs.SignedWorkloadIdentity{}
defaultTokens := map[structs.WIHandle]*structs.SignedWorkloadIdentity{}
for taskName, signature := range m.defaultSignedIdentities {
id := TaskIdentity{
TaskName: taskName,
IdentityName: "default",
id := structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: "default",
}
widReq := structs.WorkloadIdentityRequest{
AllocID: m.allocID,
TaskName: taskName,
IdentityName: "default",
AllocID: m.allocID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: "default",
},
}
defaultTokens[id] = &structs.SignedWorkloadIdentity{
WorkloadIdentityRequest: widReq,
Expand All @@ -225,9 +220,11 @@ func (m *WIDMgr) getIdentities() error {
for taskName, widspecs := range m.widSpecs {
for _, widspec := range widspecs {
reqs = append(reqs, &structs.WorkloadIdentityRequest{
AllocID: m.allocID,
TaskName: taskName,
IdentityName: widspec.Name,
AllocID: m.allocID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: widspec.Name,
},
})
}
}
Expand All @@ -249,12 +246,7 @@ func (m *WIDMgr) getIdentities() error {

// Index initial workload identities by name
for _, swid := range signedWIDs {
id := TaskIdentity{
TaskName: swid.TaskName,
IdentityName: swid.IdentityName,
}

m.lastToken[id] = swid
m.lastToken[swid.WIHandle] = swid
}

// TODO: Persist signed identity token to client state
Expand All @@ -275,9 +267,11 @@ func (m *WIDMgr) renew() {
continue
}
reqs = append(reqs, &structs.WorkloadIdentityRequest{
AllocID: m.allocID,
TaskName: taskName,
IdentityName: widspec.Name,
AllocID: m.allocID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: widspec.Name,
},
})
}
}
Expand All @@ -298,9 +292,9 @@ func (m *WIDMgr) renew() {
}

//FIXME make this less ugly
token := m.get(TaskIdentity{
TaskName: taskName,
IdentityName: wid.Name,
token := m.get(structs.WIHandle{
WorkloadIdentifier: taskName,
IdentityName: wid.Name,
})
if token == nil {
// Missing a signature, treat this case as already expired so
Expand Down Expand Up @@ -366,19 +360,14 @@ func (m *WIDMgr) renew() {
minExp = time.Time{}

for _, token := range tokens {
id := TaskIdentity{
TaskName: token.TaskName,
IdentityName: token.IdentityName,
}

// Set for getters
m.lastTokenLock.Lock()
m.lastToken[id] = token
m.lastToken[token.WIHandle] = token
m.lastTokenLock.Unlock()

// Send to watchers
m.watchersLock.Lock()
m.send(id, token)
m.send(token.WIHandle, token)
m.watchersLock.Unlock()

// Set next expiration time
Expand All @@ -394,7 +383,7 @@ func (m *WIDMgr) renew() {
}

// send must be called while holding the m.watchersLock
func (m *WIDMgr) send(id TaskIdentity, token *structs.SignedWorkloadIdentity) {
func (m *WIDMgr) send(id structs.WIHandle, token *structs.SignedWorkloadIdentity) {
w, ok := m.watchers[id]
if !ok {
// No watchers
Expand Down
24 changes: 15 additions & 9 deletions client/widmgr/widmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ func TestWIDMgr(t *testing.T) {

_, err = mgr.SignIdentities(1, []*structs.WorkloadIdentityRequest{
{
AllocID: uuid.Generate(),
TaskName: "web",
IdentityName: "foo",
AllocID: uuid.Generate(),
WIHandle: structs.WIHandle{
WorkloadIdentifier: "web",
IdentityName: "foo",
},
},
})
must.ErrorContains(t, err, "rejected")
Expand Down Expand Up @@ -92,14 +94,18 @@ func TestWIDMgr(t *testing.T) {
// Get signed identites for alloc
widreqs := []*structs.WorkloadIdentityRequest{
{
AllocID: allocs[0].ID,
TaskName: job.TaskGroups[0].Tasks[0].Name,
IdentityName: "consul",
AllocID: allocs[0].ID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: job.TaskGroups[0].Tasks[0].Name,
IdentityName: "consul",
},
},
{
AllocID: allocs[0].ID,
TaskName: job.TaskGroups[0].Tasks[0].Name,
IdentityName: "vault",
AllocID: allocs[0].ID,
WIHandle: structs.WIHandle{
WorkloadIdentifier: job.TaskGroups[0].Tasks[0].Name,
IdentityName: "vault",
},
},
}

Expand Down
4 changes: 2 additions & 2 deletions nomad/alloc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ func (a *Alloc) SignIdentities(args *structs.AllocIdentitiesRequest, reply *stru
continue
}

task := out.LookupTask(idReq.TaskName)
task := out.LookupTask(idReq.WorkloadIdentifier)
if task == nil {
// Job has likely been updated to remove this task
reply.Rejections = append(reply.Rejections, &structs.WorkloadIdentityRejection{
Expand All @@ -582,7 +582,7 @@ func (a *Alloc) SignIdentities(args *structs.AllocIdentitiesRequest, reply *stru
}

widFound = true
claims := structs.NewIdentityClaims(out.Job, out, idReq.TaskName, wid, now)
claims := structs.NewIdentityClaims(out.Job, out, idReq.WorkloadIdentifier, wid, now)
token, _, err := a.srv.encrypter.SignClaims(claims)
if err != nil {
return err
Expand Down
Loading

0 comments on commit 597d835

Please sign in to comment.