diff --git a/pkg/ccl/multitenantccl/tenantcostclient/tenant_side_test.go b/pkg/ccl/multitenantccl/tenantcostclient/tenant_side_test.go index 94b5bf1f3f78..9460945cb07a 100644 --- a/pkg/ccl/multitenantccl/tenantcostclient/tenant_side_test.go +++ b/pkg/ccl/multitenantccl/tenantcostclient/tenant_side_test.go @@ -322,7 +322,7 @@ func (ts *testState) request( return "" } -func (ts *testState) externalIngress(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) externalIngress(t *testing.T, _ *datadriven.TestData, args cmdArgs) string { usage := multitenant.ExternalIOUsage{IngressBytes: args.bytes} if err := ts.controller.OnExternalIOWait(context.Background(), usage); err != nil { t.Errorf("OnExternalIOWait error: %s", err) @@ -341,12 +341,12 @@ func (ts *testState) externalEgress(t *testing.T, d *datadriven.TestData, args c return "" } -func (ts *testState) enableRUAccounting(t *testing.T, _ *datadriven.TestData, _ cmdArgs) string { +func (ts *testState) enableRUAccounting(_ *testing.T, _ *datadriven.TestData, _ cmdArgs) string { tenantcostclient.ExternalIORUAccountingMode.Override(context.Background(), &ts.settings.SV, "on") return "" } -func (ts *testState) disableRUAccounting(t *testing.T, _ *datadriven.TestData, _ cmdArgs) string { +func (ts *testState) disableRUAccounting(_ *testing.T, _ *datadriven.TestData, _ cmdArgs) string { tenantcostclient.ExternalIORUAccountingMode.Override(context.Background(), &ts.settings.SV, "off") return "" } @@ -424,7 +424,7 @@ func (ts *testState) advance(t *testing.T, d *datadriven.TestData, args cmdArgs) // waitForEvent waits until the tenant controller reports the given event // type(s), at the current time. -func (ts *testState) waitForEvent(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) waitForEvent(t *testing.T, d *datadriven.TestData, _ cmdArgs) string { typs := make(map[string]tenantcostclient.TestEventType) for ev, evStr := range eventTypeStr { typs[evStr] = ev @@ -444,7 +444,7 @@ func (ts *testState) waitForEvent(t *testing.T, d *datadriven.TestData, args cmd // unblockRequest resumes a token bucket request that was blocked by the // "blockRequest" configuration option. -func (ts *testState) unblockRequest(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) unblockRequest(t *testing.T, _ *datadriven.TestData, _ cmdArgs) string { ts.provider.unblockRequest(t) return "" } @@ -461,7 +461,7 @@ func (ts *testState) unblockRequest(t *testing.T, d *datadriven.TestData, args c // ---- // 00:00:01.000 // 00:00:02.000 -func (ts *testState) timers(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) timers(t *testing.T, d *datadriven.TestData, _ cmdArgs) string { // If we are rewriting the test, just sleep a bit before returning the // timers. if d.Rewrite { @@ -491,7 +491,7 @@ func timesToString(times []time.Time) string { } // configure the test provider. -func (ts *testState) configure(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) configure(t *testing.T, d *datadriven.TestData, _ cmdArgs) string { var cfg testProviderConfig if err := yaml.UnmarshalStrict([]byte(d.Input), &cfg); err != nil { d.Fatalf(t, "failed to parse request yaml: %v", err) @@ -501,13 +501,13 @@ func (ts *testState) configure(t *testing.T, d *datadriven.TestData, args cmdArg } // tokenBucket dumps the current state of the tenant's token bucket. -func (ts *testState) tokenBucket(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) tokenBucket(*testing.T, *datadriven.TestData, cmdArgs) string { return tenantcostclient.TestingTokenBucketString(ts.controller) } // cpu adds CPU usage which will be observed by the controller on the next main // loop tick. -func (ts *testState) cpu(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) cpu(t *testing.T, d *datadriven.TestData, _ cmdArgs) string { duration, err := time.ParseDuration(d.Input) if err != nil { d.Fatalf(t, "error parsing cpu duration: %v", err) @@ -518,7 +518,7 @@ func (ts *testState) cpu(t *testing.T, d *datadriven.TestData, args cmdArgs) str // pgwire adds PGWire egress usage which will be observed by the controller on the next // main loop tick. -func (ts *testState) pgwireEgress(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) pgwireEgress(t *testing.T, d *datadriven.TestData, _ cmdArgs) string { bytes, err := strconv.Atoi(d.Input) if err != nil { d.Fatalf(t, "error parsing pgwire bytes value: %v", err) @@ -529,7 +529,7 @@ func (ts *testState) pgwireEgress(t *testing.T, d *datadriven.TestData, args cmd // usage prints out the latest consumption. Callers are responsible for // triggering calls to the token bucket provider and waiting for responses. -func (ts *testState) usage(t *testing.T, d *datadriven.TestData, args cmdArgs) string { +func (ts *testState) usage(*testing.T, *datadriven.TestData, cmdArgs) string { c := ts.provider.consumption() return fmt.Sprintf(""+ "RU: %.2f\n"+ @@ -695,7 +695,7 @@ func (tp *testProvider) unblockRequest(t *testing.T) { // TokenBucket implements the kvtenant.TokenBucketProvider interface. func (tp *testProvider) TokenBucket( - ctx context.Context, in *roachpb.TokenBucketRequest, + _ context.Context, in *roachpb.TokenBucketRequest, ) (*roachpb.TokenBucketResponse, error) { tp.mu.Lock() defer tp.mu.Unlock() @@ -930,7 +930,7 @@ func TestSQLLivenessExemption(t *testing.T) { // Make the tenant heartbeat like crazy. ctx := context.Background() //slinstance.DefaultTTL.Override(ctx, &st.SV, 20*time.Millisecond) - slinstance.DefaultHeartBeat.Override(ctx, &st.SV, time.Millisecond) + slinstance.DefaultHeartBeat.Override(ctx, &st.SV, 50*time.Millisecond) _, tenantDB := serverutils.StartTenant(t, hostServer, base.TestTenantArgs{ TenantID: tenantID, @@ -960,7 +960,6 @@ func TestSQLLivenessExemption(t *testing.T) { // Verify that heartbeats can go through and update the expiration time. val := livenessValue() - time.Sleep(2 * time.Millisecond) testutils.SucceedsSoon( t, func() error { diff --git a/pkg/sql/sqlliveness/slinstance/BUILD.bazel b/pkg/sql/sqlliveness/slinstance/BUILD.bazel index 522c3e91a3ab..70bec476289e 100644 --- a/pkg/sql/sqlliveness/slinstance/BUILD.bazel +++ b/pkg/sql/sqlliveness/slinstance/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//pkg/settings", "//pkg/settings/cluster", "//pkg/sql/sqlliveness", + "//pkg/util/contextutil", "//pkg/util/grpcutil", "//pkg/util/hlc", "//pkg/util/log", @@ -18,6 +19,7 @@ go_library( "//pkg/util/syncutil", "//pkg/util/timeutil", "//pkg/util/uuid", + "@com_github_cockroachdb_errors//:errors", ], ) @@ -35,6 +37,7 @@ go_test( "//pkg/settings/cluster", "//pkg/sql/sqlliveness", "//pkg/sql/sqlliveness/slstorage", + "//pkg/testutils", "//pkg/util/hlc", "//pkg/util/leaktest", "//pkg/util/log", diff --git a/pkg/sql/sqlliveness/slinstance/slinstance.go b/pkg/sql/sqlliveness/slinstance/slinstance.go index d76e8ca5a1c1..105e3549278d 100644 --- a/pkg/sql/sqlliveness/slinstance/slinstance.go +++ b/pkg/sql/sqlliveness/slinstance/slinstance.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -29,6 +30,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" ) var ( @@ -151,6 +153,14 @@ func (l *Instance) setSession(s *session) { } func (l *Instance) clearSession(ctx context.Context) { + l.checkExpiry(ctx) + l.mu.Lock() + defer l.mu.Unlock() + l.mu.s = nil + l.mu.blockCh = make(chan struct{}) +} + +func (l *Instance) checkExpiry(ctx context.Context) { l.mu.Lock() defer l.mu.Unlock() if expiration := l.mu.s.Expiration(); expiration.Less(l.clock.Now()) { @@ -158,8 +168,6 @@ func (l *Instance) clearSession(ctx context.Context) { // associated with the session. l.mu.s.invokeSessionExpiryCallbacks(ctx) } - l.mu.s = nil - l.mu.blockCh = make(chan struct{}) } // createSession tries until it can create a new session and returns an error @@ -253,8 +261,13 @@ func (l *Instance) heartbeatLoop(ctx context.Context) { t.Read = true s, _ := l.getSessionOrBlockCh() if s == nil { - newSession, err := l.createSession(ctx) - if err != nil { + var newSession *session + if err := contextutil.RunWithTimeout(ctx, "sqlliveness create session", l.hb(), func(ctx context.Context) error { + var err error + newSession, err = l.createSession(ctx) + return err + }); err != nil { + log.Errorf(ctx, "sqlliveness failed to create new session: %v", err) func() { l.mu.Lock() defer l.mu.Unlock() @@ -270,21 +283,37 @@ func (l *Instance) heartbeatLoop(ctx context.Context) { t.Reset(l.hb()) continue } - found, err := l.extendSession(ctx, s) - if err != nil { + var found bool + err := contextutil.RunWithTimeout(ctx, "sqlliveness extend session", l.hb(), func(ctx context.Context) error { + var err error + found, err = l.extendSession(ctx, s) + return err + }) + switch { + case errors.HasType(err, (*contextutil.TimeoutError)(nil)): + // Retry without clearing the session because we don't know the current status. + l.checkExpiry(ctx) + t.Reset(0) + continue + case err != nil && ctx.Err() == nil: + log.Errorf(ctx, "sqlliveness failed to extend session: %v", err) + fallthrough + case err != nil: + // TODO(ajwerner): Decide whether we actually should exit the heartbeat loop here if the context is not + // canceled. Consider the case of an ambiguous result error: shouldn't we try again? l.clearSession(ctx) return - } - if !found { + case !found: + // No existing session found, immediately create one. l.clearSession(ctx) // Start next loop iteration immediately to insert a new session. t.Reset(0) - continue - } - if log.V(2) { - log.Infof(ctx, "extended SQL liveness session %s", s.ID()) + default: + if log.V(2) { + log.Infof(ctx, "extended SQL liveness session %s", s.ID()) + } + t.Reset(l.hb()) } - t.Reset(l.hb()) } } } diff --git a/pkg/sql/sqlliveness/slinstance/slinstance_test.go b/pkg/sql/sqlliveness/slinstance/slinstance_test.go index 75b069bab914..c0283c41ea38 100644 --- a/pkg/sql/sqlliveness/slinstance/slinstance_test.go +++ b/pkg/sql/sqlliveness/slinstance/slinstance_test.go @@ -12,6 +12,7 @@ package slinstance_test import ( "context" + "sync/atomic" "testing" "time" @@ -20,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slinstance" "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slstorage" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -40,8 +42,8 @@ func TestSQLInstance(t *testing.T) { clusterversion.TestingBinaryVersion, clusterversion.TestingBinaryMinSupportedVersion, true /* initializeVersion */) - slinstance.DefaultTTL.Override(ctx, &settings.SV, 2*time.Microsecond) - slinstance.DefaultHeartBeat.Override(ctx, &settings.SV, time.Microsecond) + slinstance.DefaultTTL.Override(ctx, &settings.SV, 20*time.Millisecond) + slinstance.DefaultHeartBeat.Override(ctx, &settings.SV, 10*time.Millisecond) fakeStorage := slstorage.NewFakeStorage() sqlInstance := slinstance.NewSQLInstance(stopper, clock, fakeStorage, settings, nil) @@ -91,3 +93,113 @@ func TestSQLInstance(t *testing.T) { _, err = sqlInstance.Session(ctx) require.Error(t, err) } + +// TestSQLInstanceDeadlines tests that we have proper deadlines set on the +// create and extend session operations. This is done by blocking the fake +// storage layer and ensuring that no sessions get created because the +// timeouts are constantly triggered. +func TestSQLInstanceDeadlines(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx, stopper := context.Background(), stop.NewStopper() + defer stopper.Stop(ctx) + + clock := hlc.NewClock(timeutil.NewManualTime(timeutil.Unix(0, 42)), time.Nanosecond /* maxOffset */) + settings := cluster.MakeTestingClusterSettingsWithVersions( + clusterversion.TestingBinaryVersion, + clusterversion.TestingBinaryMinSupportedVersion, + true /* initializeVersion */) + slinstance.DefaultTTL.Override(ctx, &settings.SV, 20*time.Millisecond) + slinstance.DefaultHeartBeat.Override(ctx, &settings.SV, 10*time.Millisecond) + + fakeStorage := slstorage.NewFakeStorage() + // block the fake storage + fakeStorage.SetBlockCh() + cleanUpFunc := func() { + fakeStorage.CloseBlockCh() + } + defer cleanUpFunc() + + sqlInstance := slinstance.NewSQLInstance(stopper, clock, fakeStorage, settings, nil) + sqlInstance.Start(ctx) + + // verify that we do not create a session + require.Never( + t, + func() bool { + _, err := sqlInstance.Session(ctx) + return err == nil + }, + 100*time.Millisecond, 10*time.Millisecond, + ) +} + +// TestSQLInstanceDeadlinesExtend tests that we have proper deadlines set on the +// create and extend session operations. This tests the case where the session is +// successfully created first and then blocks indefinitely. +func TestSQLInstanceDeadlinesExtend(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx, stopper := context.Background(), stop.NewStopper() + defer stopper.Stop(ctx) + + mt := timeutil.NewManualTime(timeutil.Unix(0, 42)) + clock := hlc.NewClock(mt, time.Nanosecond /* maxOffset */) + settings := cluster.MakeTestingClusterSettingsWithVersions( + clusterversion.TestingBinaryVersion, + clusterversion.TestingBinaryMinSupportedVersion, + true /* initializeVersion */) + slinstance.DefaultTTL.Override(ctx, &settings.SV, 20*time.Millisecond) + // Must be shorter than the storage sleep amount below + slinstance.DefaultHeartBeat.Override(ctx, &settings.SV, 10*time.Millisecond) + + fakeStorage := slstorage.NewFakeStorage() + sqlInstance := slinstance.NewSQLInstance(stopper, clock, fakeStorage, settings, nil) + sqlInstance.Start(ctx) + + // verify that eventually session is created successfully + testutils.SucceedsSoon( + t, + func() error { + _, err := sqlInstance.Session(ctx) + return err + }, + ) + + // verify that session is also extended successfully a few times + require.Never( + t, + func() bool { + _, err := sqlInstance.Session(ctx) + return err != nil + }, + 100*time.Millisecond, 10*time.Millisecond, + ) + + // register a callback for verification that this session expired + var sessionExpired atomic.Bool + s, _ := sqlInstance.Session(ctx) + s.RegisterCallbackForSessionExpiry(func(ctx context.Context) { + sessionExpired.Store(true) + }) + + // block the fake storage + fakeStorage.SetBlockCh() + cleanUpFunc := func() { + fakeStorage.CloseBlockCh() + } + defer cleanUpFunc() + // advance manual clock so that session expires + mt.Advance(20 * time.Millisecond) + + // expect session to expire + require.Eventually( + t, + func() bool { + return sessionExpired.Load() + }, + testutils.DefaultSucceedsSoonDuration, 10*time.Millisecond, + ) +} diff --git a/pkg/sql/sqlliveness/slstorage/test_helpers.go b/pkg/sql/sqlliveness/slstorage/test_helpers.go index 23eb1115f217..91952ff538d1 100644 --- a/pkg/sql/sqlliveness/slstorage/test_helpers.go +++ b/pkg/sql/sqlliveness/slstorage/test_helpers.go @@ -24,6 +24,7 @@ type FakeStorage struct { mu struct { syncutil.Mutex sessions map[sqlliveness.SessionID]hlc.Timestamp + blockCh chan struct{} } } @@ -46,8 +47,16 @@ func (s *FakeStorage) IsAlive( // Insert implements the sqlliveness.Storage interface. func (s *FakeStorage) Insert( - _ context.Context, sid sqlliveness.SessionID, expiration hlc.Timestamp, + ctx context.Context, sid sqlliveness.SessionID, expiration hlc.Timestamp, ) error { + if ch := s.getBlockCh(); ch != nil { + select { + case <-ch: + break + case <-ctx.Done(): + return ctx.Err() + } + } s.mu.Lock() defer s.mu.Unlock() if _, ok := s.mu.sessions[sid]; ok { @@ -59,8 +68,16 @@ func (s *FakeStorage) Insert( // Update implements the sqlliveness.Storage interface. func (s *FakeStorage) Update( - _ context.Context, sid sqlliveness.SessionID, expiration hlc.Timestamp, + ctx context.Context, sid sqlliveness.SessionID, expiration hlc.Timestamp, ) (bool, error) { + if ch := s.getBlockCh(); ch != nil { + select { + case <-ch: + break + case <-ctx.Done(): + return false, ctx.Err() + } + } s.mu.Lock() defer s.mu.Unlock() if _, ok := s.mu.sessions[sid]; !ok { @@ -77,3 +94,23 @@ func (s *FakeStorage) Delete(_ context.Context, sid sqlliveness.SessionID) error delete(s.mu.sessions, sid) return nil } + +// SetBlockCh is used to block the storage for testing purposes +func (s *FakeStorage) SetBlockCh() { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.blockCh = make(chan struct{}) +} + +// CloseBlockCh is used to unblock the storage for testing purposes +func (s *FakeStorage) CloseBlockCh() { + s.mu.Lock() + defer s.mu.Unlock() + close(s.mu.blockCh) +} + +func (s *FakeStorage) getBlockCh() chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.mu.blockCh +}