diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index a0da617c883..9c8ae647d39 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -29,6 +29,7 @@ import ( _flag "vitess.io/vitess/go/internal/flag" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/streamlog" + "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/sqlparser" @@ -4157,62 +4158,81 @@ func TestWarmingReads(t *testing.T) { executor.normalize = true session := NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) + // Since queries on the replica will run in a separate go-routine, we need sycnronization for the Queries field in the sandboxconn. + replica.RequireQueriesLocking() _, err := executor.Execute(ctx, nil, "TestWarmingReads", session, "select age, city from user", map[string]*querypb.BindVariable{}) - time.Sleep(10 * time.Millisecond) require.NoError(t, err) wantQueries := []*querypb.BoundQuery{ {Sql: "select age, city from `user`"}, } - utils.MustMatch(t, wantQueries, primary.Queries) - primary.Queries = nil + utils.MustMatch(t, wantQueries, primary.GetQueries()) + primary.ClearQueries() + waitUntilQueryCount(t, replica, 1) wantQueriesReplica := []*querypb.BoundQuery{ {Sql: "select age, city from `user`/* warming read */"}, } - utils.MustMatch(t, wantQueriesReplica, replica.Queries) - replica.Queries = nil + utils.MustMatch(t, wantQueriesReplica, replica.GetQueries()) + replica.ClearQueries() _, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "select age, city from user /* already has a comment */ ", map[string]*querypb.BindVariable{}) - time.Sleep(10 * time.Millisecond) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{ {Sql: "select age, city from `user` /* already has a comment */"}, } - utils.MustMatch(t, wantQueries, primary.Queries) - primary.Queries = nil + utils.MustMatch(t, wantQueries, primary.GetQueries()) + primary.ClearQueries() + waitUntilQueryCount(t, replica, 1) wantQueriesReplica = []*querypb.BoundQuery{ {Sql: "select age, city from `user` /* already has a comment *//* warming read */"}, } - utils.MustMatch(t, wantQueriesReplica, replica.Queries) - replica.Queries = nil + utils.MustMatch(t, wantQueriesReplica, replica.GetQueries()) + replica.ClearQueries() _, err = executor.Execute(ctx, nil, "TestSelect", session, "insert into user (age, city) values (5, 'Boston')", map[string]*querypb.BindVariable{}) - time.Sleep(10 * time.Millisecond) + waitUntilQueryCount(t, replica, 0) require.NoError(t, err) - require.Nil(t, replica.Queries) + require.Nil(t, replica.GetQueries()) _, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "update user set age=5 where city='Boston'", map[string]*querypb.BindVariable{}) - time.Sleep(10 * time.Millisecond) + waitUntilQueryCount(t, replica, 0) require.NoError(t, err) - require.Nil(t, replica.Queries) + require.Nil(t, replica.GetQueries()) _, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "delete from user where city='Boston'", map[string]*querypb.BindVariable{}) - time.Sleep(10 * time.Millisecond) + waitUntilQueryCount(t, replica, 0) require.NoError(t, err) - require.Nil(t, replica.Queries) - primary.Queries = nil + require.Nil(t, replica.GetQueries()) + primary.ClearQueries() executor, primary, replica = createExecutorEnvWithPrimaryReplicaConn(t, ctx, 0) + replica.RequireQueriesLocking() _, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "select age, city from user", map[string]*querypb.BindVariable{}) - time.Sleep(10 * time.Millisecond) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{ {Sql: "select age, city from `user`"}, } - utils.MustMatch(t, wantQueries, primary.Queries) - require.Nil(t, replica.Queries) + utils.MustMatch(t, wantQueries, primary.GetQueries()) + waitUntilQueryCount(t, replica, 0) + require.Nil(t, replica.GetQueries()) +} + +// waitUntilQueryCount waits until the number of queries run on the tablet reach the specified count. +func waitUntilQueryCount(t *testing.T, tab *sandboxconn.SandboxConn, count int) { + timeout := time.After(1 * time.Second) + for { + select { + case <-timeout: + t.Fatalf("Timed out waiting for tablet %v query count to reach %v", topoproto.TabletAliasString(tab.Tablet().Alias), count) + default: + time.Sleep(10 * time.Millisecond) + if len(tab.GetQueries()) == count { + return + } + } + } } func TestMain(m *testing.M) { diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index b58a793db43..eea6a17bcfc 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -82,6 +82,8 @@ type SandboxConn struct { ReleaseCount atomic.Int64 GetSchemaCount atomic.Int64 + queriesRequireLocking bool + queriesMu sync.Mutex // Queries stores the non-batch requests received. Queries []*querypb.BoundQuery @@ -140,6 +142,39 @@ func NewSandboxConn(t *topodatapb.Tablet) *SandboxConn { } } +// RequireQueriesLocking sets the sandboxconn to require locking the access of Queries field. +func (sbc *SandboxConn) RequireQueriesLocking() { + sbc.queriesRequireLocking = true + sbc.queriesMu = sync.Mutex{} +} + +// GetQueries gets the Queries from sandboxconn. +func (sbc *SandboxConn) GetQueries() []*querypb.BoundQuery { + if sbc.queriesRequireLocking { + sbc.queriesMu.Lock() + defer sbc.queriesMu.Unlock() + } + return sbc.Queries +} + +// ClearQueries clears the Queries in sandboxconn. +func (sbc *SandboxConn) ClearQueries() { + if sbc.queriesRequireLocking { + sbc.queriesMu.Lock() + defer sbc.queriesMu.Unlock() + } + sbc.Queries = nil +} + +// appendToQueries appends to the Queries in sandboxconn. +func (sbc *SandboxConn) appendToQueries(q *querypb.BoundQuery) { + if sbc.queriesRequireLocking { + sbc.queriesMu.Lock() + defer sbc.queriesMu.Unlock() + } + sbc.Queries = append(sbc.Queries, q) +} + func (sbc *SandboxConn) getError() error { for code, count := range sbc.MustFailCodes { if count == 0 { @@ -181,7 +216,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que for k, v := range bindVars { bv[k] = v } - sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{ + sbc.appendToQueries(&querypb.BoundQuery{ Sql: query, BindVariables: bv, }) @@ -206,7 +241,7 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe for k, v := range bindVars { bv[k] = v } - sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{ + sbc.appendToQueries(&querypb.BoundQuery{ Sql: query, BindVariables: bv, }) @@ -673,6 +708,10 @@ func (sbc *SandboxConn) getTxReservedID(txID int64) int64 { // StringQueries returns the queries executed as a slice of strings func (sbc *SandboxConn) StringQueries() []string { + if sbc.queriesRequireLocking { + sbc.queriesMu.Lock() + defer sbc.queriesMu.Unlock() + } result := make([]string, len(sbc.Queries)) for i, query := range sbc.Queries { result[i] = query.Sql