Skip to content

Commit

Permalink
Modify approle tidy to validate dangling accessors (#4981)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai authored and briankassouf committed Jul 24, 2018
1 parent 8d2d9fd commit 77e6124
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 10 deletions.
3 changes: 3 additions & 0 deletions builtin/credential/approle/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package approle
import (
"context"
"sync"
"time"

"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/locksutil"
Expand Down Expand Up @@ -56,6 +57,8 @@ type backend struct {
// secretIDListingLock is a dedicated lock for listing SecretIDAccessors
// for all the SecretIDs issued against an approle
secretIDListingLock sync.RWMutex

testTidyDelay time.Duration
}

func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
Expand Down
83 changes: 74 additions & 9 deletions builtin/credential/approle/path_tidy_user_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,29 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
go func() {
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)

logger := b.Logger().Named("tidy")

checkCount := 0

defer func() {
if b.testTidyDelay > 0 {
logger.Trace("done checking entries", "num_entries", checkCount)
}
}()

// Don't cancel when the original client request goes away
ctx = context.Background()

logger := b.Logger().Named("tidy")

tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
logger.Trace("listing role HMACs", "prefix", secretIDPrefixToUse)

roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
if err != nil {
return err
}

logger.Trace("listing accessors", "prefix", accessorIDPrefixToUse)

// List all the accessors and add them all to a map
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
if err != nil {
Expand All @@ -59,7 +71,10 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
accessorMap[accessorHash] = true
}

time.Sleep(b.testTidyDelay)

secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
checkCount++
lock := b.secretIDLock(secretIDHMAC)
lock.Lock()
defer lock.Unlock()
Expand Down Expand Up @@ -91,6 +106,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
}
if accessorEntry == nil {
logger.Trace("found nil accessor")
if err := s.Delete(ctx, entryIndex); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
}
Expand All @@ -99,6 +115,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi

// ExpirationTime not being set indicates non-expiring SecretIDs
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
logger.Trace("found expired secret ID")
// Clean up the accessor of the secret ID first
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
if err != nil {
Expand Down Expand Up @@ -126,6 +143,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
}

for _, roleNameHMAC := range roleNameHMACs {
logger.Trace("listing secret ID HMACs", "role_hmac", roleNameHMAC)
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
if err != nil {
return err
Expand All @@ -140,13 +158,60 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi

// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
// to clean up the dangling accessor entries.
for accessorHash, _ := range accessorMap {
// Ideally, locking should be performed here. But for that, accessors
// are required in plaintext, which are not available. Hence performing
// a racy cleanup.
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
if err != nil {
return err
if len(accessorMap) > 0 {
for _, lock := range b.secretIDLocks {
lock.Lock()
defer lock.Unlock()
}
for accessorHash, _ := range accessorMap {
logger.Trace("found dangling accessor, verifying")
// Ideally, locking on accessors should be performed here too
// but for that, accessors are required in plaintext, which are
// not available. The code above helps but it may still be
// racy.
// ...
// Look up the secret again now that we have all the locks. The
// lock is held when writing accessor/secret so if we have the
// lock we know we're not in a
// wrote-accessor-but-not-yet-secret case, which can be racy.
var entry secretIDAccessorStorageEntry
entryIndex := accessorIDPrefixToUse + accessorHash
se, err := s.Get(ctx, entryIndex)
if err != nil {
return err
}
if se != nil {
err = se.DecodeJSON(&entry)
if err != nil {
return err
}

// The storage entry doesn't store the role ID, so we have
// to go about this the long way; fortunately we shouldn't
// actually hit this very often
var found bool
searchloop:
for _, roleNameHMAC := range roleNameHMACs {
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
if err != nil {
return err
}
for _, v := range secretIDHMACs {
if v == entry.SecretIDHMAC {
found = true
logger.Trace("accessor verified, not removing")
break searchloop
}
}
}
if !found {
logger.Trace("could not verify dangling accessor, removing")
err = s.Delete(ctx, entryIndex)
if err != nil {
return err
}
}
}
}
}

Expand Down
94 changes: 93 additions & 1 deletion builtin/credential/approle/path_tidy_user_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package approle

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/hashicorp/vault/logical"
)

func TestAppRole_TidyDanglingAccessors(t *testing.T) {
func TestAppRole_TidyDanglingAccessors_Normal(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
Expand Down Expand Up @@ -83,3 +85,93 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) {
t.Fatalf("bad: len(accessorHashes); expect 1, got %d", len(accessorHashes))
}
}

func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)

b.testTidyDelay = 300 * time.Millisecond

// Create a role
createRole(t, b, storage, "role1", "a,b,c")

// Create an initial entry
roleSecretIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/role1/secret-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
count := 1

wg := sync.WaitGroup{}
now := time.Now()
started := false
for {
if time.Now().Sub(now) > 700*time.Millisecond {
break
}
if time.Now().Sub(now) > 100*time.Millisecond && !started {
started = true
_, err = b.tidySecretID(context.Background(), &logical.Request{
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
}
go func() {
wg.Add(1)
defer wg.Done()
roleSecretIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/role1/secret-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
}()
count++
}

t.Logf("wrote %d entries", count)

wg.Wait()
// Let tidy finish
time.Sleep(1 * time.Second)

// Run tidy again
_, err = b.tidySecretID(context.Background(), &logical.Request{
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)

accessorHashes, err := storage.List(context.Background(), "accessor/")
if err != nil {
t.Fatal(err)
}
if len(accessorHashes) != count {
t.Fatalf("bad: len(accessorHashes); expect %d, got %d", count, len(accessorHashes))
}

roleHMACs, err := storage.List(context.Background(), secretIDPrefix)
if err != nil {
t.Fatal(err)
}
secretIDs, err := storage.List(context.Background(), fmt.Sprintf("%s%s", secretIDPrefix, roleHMACs[0]))
if err != nil {
t.Fatal(err)
}
if len(secretIDs) != count {
t.Fatalf("bad: len(secretIDs); expect %d, got %d", count, len(secretIDs))
}
}

0 comments on commit 77e6124

Please sign in to comment.