From b7e6d3d6131c3a9d01456ed83358293a0cbddf2b Mon Sep 17 00:00:00 2001 From: Daniel Joos Date: Wed, 21 Feb 2024 11:55:15 +0100 Subject: [PATCH] Add tests for `KillWithContext` logic Signed-off-by: Daniel Joos --- go/mysql/fakesqldb/server.go | 9 +- .../vttablet/tabletserver/connpool/dbconn.go | 24 +++-- .../tabletserver/connpool/dbconn_test.go | 98 +++++++++++++++++++ 3 files changed, 121 insertions(+), 10 deletions(-) diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index cb3d20ae04b..4dc71dd3662 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -376,11 +376,11 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R } key := strings.ToLower(query) db.mu.Lock() - defer db.mu.Unlock() db.queryCalled[key]++ db.querylog = append(db.querylog, key) // Check if we should close the connection and provoke errno 2013. if db.shouldClose.Load() { + defer db.mu.Unlock() c.Close() // log error @@ -394,6 +394,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // The driver may send this at connection time, and we don't want it to // interfere. if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") { + defer db.mu.Unlock() + // log error if err := callback(&sqltypes.Result{}); err != nil { log.Errorf("callback failed : %v", err) @@ -403,12 +405,14 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // check if we should reject it. if err, ok := db.rejectedData[key]; ok { + db.mu.Unlock() return err } // Check explicit queries from AddQuery(). result, ok := db.data[key] if ok { + db.mu.Unlock() if f := result.BeforeFunc; f != nil { f() } @@ -419,6 +423,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R for _, pat := range db.patternData { if pat.expr.MatchString(query) { userCallback, ok := db.queryPatternUserCallback[pat.expr] + db.mu.Unlock() if ok { userCallback(query) } @@ -429,6 +434,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R } } + defer db.mu.Unlock() + if db.neverFail.Load() { return callback(&sqltypes.Result{}) } diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index 5632c651e84..b39299e49e6 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -40,6 +40,8 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) +const defaultKillTimeout = 5 * time.Second + // Conn is a db connection for tabletserver. // It performs automatic reconnects as needed. // Its Execute function has a timeout that can kill @@ -57,6 +59,8 @@ type Conn struct { // err will be set if a query is killed through a Kill. errmu sync.Mutex err error + + killTimeout time.Duration } // NewConnection creates a new DBConn. It triggers a CheckMySQL if creation fails. @@ -71,10 +75,11 @@ func newPooledConn(ctx context.Context, pool *Pool, appParams dbconfigs.Connecto return nil, err } db := &Conn{ - conn: c, - env: pool.env, - stats: pool.env.Stats(), - dbaPool: pool.dbaPool, + conn: c, + env: pool.env, + stats: pool.env.Stats(), + dbaPool: pool.dbaPool, + killTimeout: defaultKillTimeout, } return db, nil } @@ -86,9 +91,10 @@ func NewConn(ctx context.Context, params dbconfigs.Connector, dbaPool *dbconnpoo return nil, err } dbconn := &Conn{ - conn: c, - dbaPool: dbaPool, - stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), + conn: c, + dbaPool: dbaPool, + stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), + killTimeout: defaultKillTimeout, } if setting == nil { return dbconn, nil @@ -175,7 +181,7 @@ func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfi select { case <-ctx.Done(): - killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout) defer cancel() _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) @@ -274,7 +280,7 @@ func (dbc *Conn) streamOnce(ctx context.Context, query string, callback func(*sq select { case <-ctx.Done(): - killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout) defer cancel() _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go index 9717c95d9f7..b0e10c4c0d8 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go @@ -33,6 +33,8 @@ import ( "vitess.io/vitess/go/mysql/fakesqldb" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" ) func compareTimingCounts(t *testing.T, op string, delta int64, before, after map[string]int64) { @@ -291,6 +293,57 @@ func TestDBConnKill(t *testing.T) { } } +func TestDBKillWithContext(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) + defer connPool.Close() + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + query := fmt.Sprintf("kill %d", dbConn.ID()) + db.AddQuery(query, &sqltypes.Result{}) + db.SetBeforeFunc(query, func() { + // should take longer than our context deadline below. + time.Sleep(200 * time.Millisecond) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // KillWithContext should return context.DeadlineExceeded + err = dbConn.KillWithContext(ctx, "test kill", 0) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestDBKillWithContextDoneContext(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) + defer connPool.Close() + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + query := fmt.Sprintf("kill %d", dbConn.ID()) + db.AddRejectedQuery(query, errors.New("rejected")) + + contextErr := errors.New("context error") + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(contextErr) // cancel the context immediately + + // KillWithContext should return the cancellation cause + err = dbConn.KillWithContext(ctx, "test kill", 0) + require.ErrorIs(t, err, contextErr) +} + // TestDBConnClose tests that an Exec returns immediately if a connection // is asynchronously killed (and closed) in the middle of an execution. func TestDBConnClose(t *testing.T) { @@ -519,3 +572,48 @@ func TestDBConnReApplySetting(t *testing.T) { db.VerifyAllExecutedOrFail() } + +func TestDBExecOnceKillTimeout(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) + defer connPool.Close() + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + // A very long running query that will be killed. + expectedQuery := "select 1" + var timestampQuery time.Time + db.AddQuery(expectedQuery, &sqltypes.Result{}) + db.SetBeforeFunc(expectedQuery, func() { + timestampQuery = time.Now() + // should take longer than our context deadline below. + time.Sleep(1000 * time.Millisecond) + }) + + // We expect a kill-query to be fired, too. + // It should also run into a timeout. + var timestampKill time.Time + dbConn.killTimeout = 100 * time.Millisecond + db.AddQueryPatternWithCallback(`kill \d+`, &sqltypes.Result{}, func(string) { + timestampKill = time.Now() + // should take longer than the configured kill timeout above. + time.Sleep(200 * time.Millisecond) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + result, err := dbConn.ExecOnce(ctx, "select 1", 1, false) + timestampDone := time.Now() + + require.Error(t, err) + require.Equal(t, vtrpcpb.Code_CANCELED, vterrors.Code(err)) + require.Nil(t, result) + require.WithinDuration(t, timestampQuery, timestampKill, 150*time.Millisecond) + require.WithinDuration(t, timestampKill, timestampDone, 150*time.Millisecond) +}