Skip to content

Commit

Permalink
[v15] fix access request race & panic (#45493)
Browse files Browse the repository at this point in the history
* Revert "Revert "fix access request cache test race (#44653)" (#44787)"

This reverts commit 5472633.

* fix access request cache panic (#45225)
  • Loading branch information
fspmarshall authored Aug 14, 2024
1 parent 1fdedad commit 4190e38
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 4 deletions.
31 changes: 29 additions & 2 deletions lib/services/access_request_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type AccessRequestCacheConfig struct {
Events types.Events
// Getter is an access request getter client.
Getter AccessRequestGetter
// MaxRetryPeriod is the maximum retry period on failed watches.
MaxRetryPeriod time.Duration
}

// CheckAndSetDefaults valides the config and provides reasonable defaults for optional fields.
Expand Down Expand Up @@ -87,8 +89,12 @@ type AccessRequestCache struct {
primaryCache *sortcache.SortCache[*types.AccessRequestV3]
ttlCache *utils.FnCache
initC chan struct{}
initOnce sync.Once
closeContext context.Context
cancel context.CancelFunc
// onInit is a callback used in tests to detect
// individual initializations.
onInit func()
}

// NewAccessRequestCache sets up a new [AccessRequestCache] instance based on the supplied
Expand Down Expand Up @@ -120,8 +126,9 @@ func NewAccessRequestCache(cfg AccessRequestCacheConfig) (*AccessRequestCache, e
}

if _, err := newResourceWatcher(ctx, c, ResourceWatcherConfig{
Component: "access-request-cache",
Client: cfg.Events,
Component: "access-request-cache",
Client: cfg.Events,
MaxRetryPeriod: cfg.MaxRetryPeriod,
}); err != nil {
cancel()
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -352,9 +359,22 @@ func (c *AccessRequestCache) getResourcesAndUpdateCurrent(ctx context.Context) e
c.rw.Lock()
defer c.rw.Unlock()
c.primaryCache = cache
c.initOnce.Do(func() {
close(c.initC)
})
if c.onInit != nil {
c.onInit()
}
return nil
}

// SetInitCallback is used in tests that care about cache inits.
func (c *AccessRequestCache) SetInitCallback(cb func()) {
c.rw.Lock()
defer c.rw.Unlock()
c.onInit = cb
}

// processEventsAndUpdateCurrent is part of the resourceCollector interface and is used to update the
// primary cache state when modification events occur.
func (c *AccessRequestCache) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) {
Expand Down Expand Up @@ -394,6 +414,7 @@ func (c *AccessRequestCache) notifyStale() {
}
c.primaryCache = nil
c.initC = make(chan struct{})
c.initOnce = sync.Once{}
}

// initializationChan is part of the resourceCollector interface and gets the channel
Expand All @@ -404,6 +425,12 @@ func (c *AccessRequestCache) initializationChan() <-chan struct{} {
return c.initC
}

// InitializationChan is part of the resourceCollector interface and gets the channel
// used to signal that the accessRequestCache has been initialized.
func (c *AccessRequestCache) InitializationChan() <-chan struct{} {
return c.initializationChan()
}

// Close terminates the background process that keeps the access request cache up to
// date, and terminates any inflight load operations.
func (c *AccessRequestCache) Close() error {
Expand Down
119 changes: 117 additions & 2 deletions lib/services/access_request_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
Expand All @@ -34,6 +37,8 @@ import (
type accessRequestServices struct {
types.Events
services.DynamicAccessExt

bk *memory.Memory
}

func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.AccessRequestCache) {
Expand All @@ -43,17 +48,127 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access
svcs := accessRequestServices{
Events: local.NewEventsService(bk),
DynamicAccessExt: local.NewDynamicAccessService(bk),
bk: bk,
}

cache, err := services.NewAccessRequestCache(services.AccessRequestCacheConfig{
Events: svcs,
Getter: svcs,
Events: svcs,
Getter: svcs,
MaxRetryPeriod: time.Millisecond * 100,
})
require.NoError(t, err)

select {
case <-cache.InitializationChan():
case <-time.After(time.Second * 30):
require.FailNow(t, "timeout waiting for access request cache to initialize")
}

return svcs, cache
}

func TestAccessRequestCacheResets(t *testing.T) {
const (
requestCount = 100
workers = 20
resets = 3
)

t.Parallel()

svcs, cache := newAccessRequestPack(t)
defer cache.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for i := 0; i < requestCount; i++ {
r, err := types.NewAccessRequest(uuid.New().String(), "alice@example.com", "some-role")
require.NoError(t, err)

_, err = svcs.CreateAccessRequestV2(ctx, r)
require.NoError(t, err)
}

timeout := time.After(time.Second * 30)

for {
rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Limit: requestCount,
})
require.NoError(t, err)
if len(rsp.AccessRequests) == requestCount {
break
}

select {
case <-timeout:
require.FailNow(t, "timeout waiting for access request cache to populate")
case <-time.After(time.Millisecond * 200):
}
}

doneC := make(chan struct{})
reads := make(chan struct{}, workers)
var eg errgroup.Group

for i := 0; i < workers; i++ {
eg.Go(func() error {
for {
select {
case <-doneC:
return nil
case <-time.After(time.Millisecond * 20):
}

rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Limit: int32(requestCount),
})
if err != nil {
return trace.Errorf("unexpected read failure: %v", err)
}

select {
case reads <- struct{}{}:
default:
}

if len(rsp.AccessRequests) != requestCount {
return trace.Errorf("unexpected number of access requests: %d (expected %d)", len(rsp.AccessRequests), requestCount)
}
}
})
}

inits := make(chan struct{}, resets+1)
cache.SetInitCallback(func() {
inits <- struct{}{}
})

timeout = time.After(time.Second * 30)
for i := 0; i < resets; i++ {
svcs.bk.CloseWatchers()
select {
case <-inits:
case <-timeout:
require.FailNowf(t, "timeout waiting for access request cache to reset", "reset=%d", i)
}

for j := 0; j < workers; j++ {
// ensure that we're not racing ahead of worker reads too
// much if inits are happening quickly.
select {
case <-reads:
case <-timeout:
require.FailNowf(t, "timeout waiting for worker reads to catch up", "reset=%d", i)
}
}
}

close(doneC)
require.NoError(t, eg.Wait())
}

// TestAccessRequestCacheBasics verifies the basic expected behaviors of the access request cache,
// including correct sorting and handling of put/delete events.
func TestAccessRequestCacheBasics(t *testing.T) {
Expand Down

0 comments on commit 4190e38

Please sign in to comment.