Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] fix access request cache panic #45225

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions lib/services/access_request_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type AccessRequestCacheConfig struct {
Events types.Events
// Getter is an access request getter client.
Getter AccessRequestGetter
// MaxRetryPeriod is the maximum retry period on failed watches.
MaxRetryPeriod time.Duration
}

// CheckAndSetDefaults valides the config and provides reasonable defaults for optional fields.
Expand Down Expand Up @@ -87,8 +89,12 @@ type AccessRequestCache struct {
primaryCache *sortcache.SortCache[*types.AccessRequestV3]
ttlCache *utils.FnCache
initC chan struct{}
initOnce sync.Once
closeContext context.Context
cancel context.CancelFunc
// onInit is a callback used in tests to detect
// individual initializations.
onInit func()
}

// NewAccessRequestCache sets up a new [AccessRequestCache] instance based on the supplied
Expand Down Expand Up @@ -120,8 +126,9 @@ func NewAccessRequestCache(cfg AccessRequestCacheConfig) (*AccessRequestCache, e
}

if _, err := newResourceWatcher(ctx, c, ResourceWatcherConfig{
Component: "access-request-cache",
Client: cfg.Events,
Component: "access-request-cache",
Client: cfg.Events,
MaxRetryPeriod: cfg.MaxRetryPeriod,
}); err != nil {
cancel()
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -352,10 +359,22 @@ func (c *AccessRequestCache) getResourcesAndUpdateCurrent(ctx context.Context) e
c.rw.Lock()
defer c.rw.Unlock()
c.primaryCache = cache
close(c.initC)
c.initOnce.Do(func() {
close(c.initC)
})
if c.onInit != nil {
c.onInit()
}
return nil
}

// SetInitCallback is used in tests that care about cache inits.
func (c *AccessRequestCache) SetInitCallback(cb func()) {
c.rw.Lock()
defer c.rw.Unlock()
c.onInit = cb
}

// processEventsAndUpdateCurrent is part of the resourceCollector interface and is used to update the
// primary cache state when modification events occur.
func (c *AccessRequestCache) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) {
Expand Down Expand Up @@ -395,6 +414,7 @@ func (c *AccessRequestCache) notifyStale() {
}
c.primaryCache = nil
c.initC = make(chan struct{})
c.initOnce = sync.Once{}
}

// initializationChan is part of the resourceCollector interface and gets the channel
Expand Down
113 changes: 111 additions & 2 deletions lib/services/access_request_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
Expand All @@ -34,6 +37,8 @@ import (
type accessRequestServices struct {
types.Events
services.DynamicAccessExt

bk *memory.Memory
}

func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.AccessRequestCache) {
Expand All @@ -43,11 +48,13 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access
svcs := accessRequestServices{
Events: local.NewEventsService(bk),
DynamicAccessExt: local.NewDynamicAccessService(bk),
bk: bk,
}

cache, err := services.NewAccessRequestCache(services.AccessRequestCacheConfig{
Events: svcs,
Getter: svcs,
Events: svcs,
Getter: svcs,
MaxRetryPeriod: time.Millisecond * 100,
})
require.NoError(t, err)

Expand All @@ -60,6 +67,108 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access
return svcs, cache
}

func TestAccessRequestCacheResets(t *testing.T) {
const (
requestCount = 100
workers = 20
resets = 3
)

t.Parallel()

svcs, cache := newAccessRequestPack(t)
defer cache.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for i := 0; i < requestCount; i++ {
r, err := types.NewAccessRequest(uuid.New().String(), "alice@example.com", "some-role")
require.NoError(t, err)

_, err = svcs.CreateAccessRequestV2(ctx, r)
require.NoError(t, err)
}

timeout := time.After(time.Second * 30)

for {
rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Limit: requestCount,
})
require.NoError(t, err)
if len(rsp.AccessRequests) == requestCount {
break
}

select {
case <-timeout:
require.FailNow(t, "timeout waiting for access request cache to populate")
case <-time.After(time.Millisecond * 200):
}
}

doneC := make(chan struct{})
reads := make(chan struct{}, workers)
var eg errgroup.Group

for i := 0; i < workers; i++ {
eg.Go(func() error {
for {
select {
case <-doneC:
return nil
case <-time.After(time.Millisecond * 20):
}

rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Limit: int32(requestCount),
})
if err != nil {
return trace.Errorf("unexpected read failure: %v", err)
}

select {
case reads <- struct{}{}:
default:
}

if len(rsp.AccessRequests) != requestCount {
return trace.Errorf("unexpected number of access requests: %d (expected %d)", len(rsp.AccessRequests), requestCount)
}
}
})
}

inits := make(chan struct{}, resets+1)
cache.SetInitCallback(func() {
inits <- struct{}{}
})

timeout = time.After(time.Second * 30)
for i := 0; i < resets; i++ {
svcs.bk.CloseWatchers()
select {
case <-inits:
case <-timeout:
require.FailNowf(t, "timeout waiting for access request cache to reset", "reset=%d", i)
}

for j := 0; j < workers; j++ {
// ensure that we're not racing ahead of worker reads too
// much if inits are happening quickly.
select {
case <-reads:
case <-timeout:
require.FailNowf(t, "timeout waiting for worker reads to catch up", "reset=%d", i)
}
}
}

close(doneC)
require.NoError(t, eg.Wait())
}

// TestAccessRequestCacheBasics verifies the basic expected behaviors of the access request cache,
// including correct sorting and handling of put/delete events.
func TestAccessRequestCacheBasics(t *testing.T) {
Expand Down
Loading