Skip to content

Commit

Permalink
Fix data race in TestWarmingReads (#14187)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
GuptaManan100 committed Oct 5, 2023
1 parent d84f6d7 commit 22f1f7d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 22 deletions.
60 changes: 40 additions & 20 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
_flag "vitess.io/vitess/go/internal/flag"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/streamlog"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vtgate/logstats"

"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -4157,62 +4158,81 @@ func TestWarmingReads(t *testing.T) {

executor.normalize = true
session := NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded})
// Since queries on the replica will run in a separate go-routine, we need sycnronization for the Queries field in the sandboxconn.
replica.RequireQueriesLocking()

_, err := executor.Execute(ctx, nil, "TestWarmingReads", session, "select age, city from user", map[string]*querypb.BindVariable{})
time.Sleep(10 * time.Millisecond)
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{
{Sql: "select age, city from `user`"},
}
utils.MustMatch(t, wantQueries, primary.Queries)
primary.Queries = nil
utils.MustMatch(t, wantQueries, primary.GetQueries())
primary.ClearQueries()

waitUntilQueryCount(t, replica, 1)
wantQueriesReplica := []*querypb.BoundQuery{
{Sql: "select age, city from `user`/* warming read */"},
}
utils.MustMatch(t, wantQueriesReplica, replica.Queries)
replica.Queries = nil
utils.MustMatch(t, wantQueriesReplica, replica.GetQueries())
replica.ClearQueries()

_, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "select age, city from user /* already has a comment */ ", map[string]*querypb.BindVariable{})
time.Sleep(10 * time.Millisecond)
require.NoError(t, err)
wantQueries = []*querypb.BoundQuery{
{Sql: "select age, city from `user` /* already has a comment */"},
}
utils.MustMatch(t, wantQueries, primary.Queries)
primary.Queries = nil
utils.MustMatch(t, wantQueries, primary.GetQueries())
primary.ClearQueries()

waitUntilQueryCount(t, replica, 1)
wantQueriesReplica = []*querypb.BoundQuery{
{Sql: "select age, city from `user` /* already has a comment *//* warming read */"},
}
utils.MustMatch(t, wantQueriesReplica, replica.Queries)
replica.Queries = nil
utils.MustMatch(t, wantQueriesReplica, replica.GetQueries())
replica.ClearQueries()

_, err = executor.Execute(ctx, nil, "TestSelect", session, "insert into user (age, city) values (5, 'Boston')", map[string]*querypb.BindVariable{})
time.Sleep(10 * time.Millisecond)
waitUntilQueryCount(t, replica, 0)
require.NoError(t, err)
require.Nil(t, replica.Queries)
require.Nil(t, replica.GetQueries())

_, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "update user set age=5 where city='Boston'", map[string]*querypb.BindVariable{})
time.Sleep(10 * time.Millisecond)
waitUntilQueryCount(t, replica, 0)
require.NoError(t, err)
require.Nil(t, replica.Queries)
require.Nil(t, replica.GetQueries())

_, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "delete from user where city='Boston'", map[string]*querypb.BindVariable{})
time.Sleep(10 * time.Millisecond)
waitUntilQueryCount(t, replica, 0)
require.NoError(t, err)
require.Nil(t, replica.Queries)
primary.Queries = nil
require.Nil(t, replica.GetQueries())
primary.ClearQueries()

executor, primary, replica = createExecutorEnvWithPrimaryReplicaConn(t, ctx, 0)
replica.RequireQueriesLocking()
_, err = executor.Execute(ctx, nil, "TestWarmingReads", session, "select age, city from user", map[string]*querypb.BindVariable{})
time.Sleep(10 * time.Millisecond)
require.NoError(t, err)
wantQueries = []*querypb.BoundQuery{
{Sql: "select age, city from `user`"},
}
utils.MustMatch(t, wantQueries, primary.Queries)
require.Nil(t, replica.Queries)
utils.MustMatch(t, wantQueries, primary.GetQueries())
waitUntilQueryCount(t, replica, 0)
require.Nil(t, replica.GetQueries())
}

// waitUntilQueryCount waits until the number of queries run on the tablet reach the specified count.
func waitUntilQueryCount(t *testing.T, tab *sandboxconn.SandboxConn, count int) {
timeout := time.After(1 * time.Second)
for {
select {
case <-timeout:
t.Fatalf("Timed out waiting for tablet %v query count to reach %v", topoproto.TabletAliasString(tab.Tablet().Alias), count)
default:
time.Sleep(10 * time.Millisecond)
if len(tab.GetQueries()) == count {
return
}
}
}
}

func TestMain(m *testing.M) {
Expand Down
43 changes: 41 additions & 2 deletions go/vt/vttablet/sandboxconn/sandboxconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ type SandboxConn struct {
ReleaseCount atomic.Int64
GetSchemaCount atomic.Int64

queriesRequireLocking bool
queriesMu sync.Mutex
// Queries stores the non-batch requests received.
Queries []*querypb.BoundQuery

Expand Down Expand Up @@ -140,6 +142,39 @@ func NewSandboxConn(t *topodatapb.Tablet) *SandboxConn {
}
}

// RequireQueriesLocking sets the sandboxconn to require locking the access of Queries field.
func (sbc *SandboxConn) RequireQueriesLocking() {
sbc.queriesRequireLocking = true
sbc.queriesMu = sync.Mutex{}
}

// GetQueries gets the Queries from sandboxconn.
func (sbc *SandboxConn) GetQueries() []*querypb.BoundQuery {
if sbc.queriesRequireLocking {
sbc.queriesMu.Lock()
defer sbc.queriesMu.Unlock()
}
return sbc.Queries
}

// ClearQueries clears the Queries in sandboxconn.
func (sbc *SandboxConn) ClearQueries() {
if sbc.queriesRequireLocking {
sbc.queriesMu.Lock()
defer sbc.queriesMu.Unlock()
}
sbc.Queries = nil
}

// appendToQueries appends to the Queries in sandboxconn.
func (sbc *SandboxConn) appendToQueries(q *querypb.BoundQuery) {
if sbc.queriesRequireLocking {
sbc.queriesMu.Lock()
defer sbc.queriesMu.Unlock()
}
sbc.Queries = append(sbc.Queries, q)
}

func (sbc *SandboxConn) getError() error {
for code, count := range sbc.MustFailCodes {
if count == 0 {
Expand Down Expand Up @@ -181,7 +216,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que
for k, v := range bindVars {
bv[k] = v
}
sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{
sbc.appendToQueries(&querypb.BoundQuery{
Sql: query,
BindVariables: bv,
})
Expand All @@ -206,7 +241,7 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe
for k, v := range bindVars {
bv[k] = v
}
sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{
sbc.appendToQueries(&querypb.BoundQuery{
Sql: query,
BindVariables: bv,
})
Expand Down Expand Up @@ -673,6 +708,10 @@ func (sbc *SandboxConn) getTxReservedID(txID int64) int64 {

// StringQueries returns the queries executed as a slice of strings
func (sbc *SandboxConn) StringQueries() []string {
if sbc.queriesRequireLocking {
sbc.queriesMu.Lock()
defer sbc.queriesMu.Unlock()
}
result := make([]string, len(sbc.Queries))
for i, query := range sbc.Queries {
result[i] = query.Sql
Expand Down

0 comments on commit 22f1f7d

Please sign in to comment.