Skip to content

Commit

Permalink
session: move more session vars to stmt context for retrying (#8034) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackysp authored Dec 12, 2018
1 parent 4e4b225 commit 9ab1a50
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 46 deletions.
2 changes: 1 addition & 1 deletion ddl/ddl_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func testCheckJobCancelled(c *C, d *ddl, job *model.Job, state *model.SchemaStat
t := meta.NewMeta(txn)
historyJob, err := t.GetHistoryDDLJob(job.ID)
c.Assert(err, IsNil)
c.Assert(historyJob.IsCancelled() || historyJob.IsRollbackDone(), IsTrue, Commentf("histroy job %s", historyJob))
c.Assert(historyJob.IsCancelled() || historyJob.IsRollbackDone(), IsTrue, Commentf("history job %s", historyJob))
if state != nil {
c.Assert(historyJob.SchemaState, Equals, *state)
}
Expand Down
15 changes: 10 additions & 5 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1288,11 +1288,17 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.Priority = priority
}
}
if vars.LastInsertID > 0 {
vars.PrevLastInsertID = vars.LastInsertID
vars.LastInsertID = 0
if vars.StmtCtx.LastInsertID > 0 {
sc.PrevLastInsertID = vars.StmtCtx.LastInsertID
} else {
sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID
}
sc.PrevAffectedRows = 0
if vars.StmtCtx.InUpdateOrDeleteStmt || vars.StmtCtx.InInsertStmt {
sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows())
} else if vars.StmtCtx.InSelectStmt {
sc.PrevAffectedRows = -1
}
vars.ResetPrevAffectedRows()
err = vars.SetSystemVar("warning_count", fmt.Sprintf("%d", vars.StmtCtx.NumWarnings(false)))
if err != nil {
return errors.Trace(err)
Expand All @@ -1301,7 +1307,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
if err != nil {
return errors.Trace(err)
}
vars.InsertID = 0
vars.StmtCtx = sc
return
}
2 changes: 1 addition & 1 deletion executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c
if err != nil {
return types.Datum{}, errors.Trace(err)
}
e.ctx.GetSessionVars().InsertID = uint64(recordID)
e.ctx.GetSessionVars().StmtCtx.InsertID = uint64(recordID)
retryInfo.AddAutoIncrementID(recordID)
d.SetAutoID(recordID, c.Flag)
return d, nil
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (b *builtinLastInsertIDSig) Clone() builtinFunc {
// evalInt evals LAST_INSERT_ID().
// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_last-insert-id.
func (b *builtinLastInsertIDSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) {
res = int64(b.ctx.GetSessionVars().PrevLastInsertID)
res = int64(b.ctx.GetSessionVars().StmtCtx.PrevLastInsertID)
return res, false, nil
}

Expand Down Expand Up @@ -439,6 +439,6 @@ func (b *builtinRowCountSig) Clone() builtinFunc {
// evalInt evals ROW_COUNT().
// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_row-count.
func (b *builtinRowCountSig) evalInt(_ chunk.Row) (res int64, isNull bool, err error) {
res = int64(b.ctx.GetSessionVars().PrevAffectedRows)
res = int64(b.ctx.GetSessionVars().StmtCtx.PrevAffectedRows)
return res, false, nil
}
4 changes: 2 additions & 2 deletions expression/builtin_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (s *testEvaluatorSuite) TestRowCount(c *C) {
defer testleak.AfterTest(c)()
ctx := mock.NewContext()
sessionVars := ctx.GetSessionVars()
sessionVars.PrevAffectedRows = 10
sessionVars.StmtCtx.PrevAffectedRows = 10

f, err := funcs[ast.RowCount].getFunction(ctx, nil)
c.Assert(err, IsNil)
Expand Down Expand Up @@ -203,7 +203,7 @@ func (s *testEvaluatorSuite) TestLastInsertID(c *C) {
err error
)
if t.insertID > 0 {
s.ctx.GetSessionVars().PrevLastInsertID = t.insertID
s.ctx.GetSessionVars().StmtCtx.PrevLastInsertID = t.insertID
}

if t.args != nil {
Expand Down
15 changes: 8 additions & 7 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ func (s *session) Status() uint16 {
}

func (s *session) LastInsertID() uint64 {
if s.sessionVars.LastInsertID > 0 {
return s.sessionVars.LastInsertID
if s.sessionVars.StmtCtx.LastInsertID > 0 {
return s.sessionVars.StmtCtx.LastInsertID
}
return s.sessionVars.InsertID
return s.sessionVars.StmtCtx.InsertID
}

func (s *session) AffectedRows() uint64 {
Expand Down Expand Up @@ -427,8 +427,8 @@ func (s *session) String() string {
if sessVars.SnapshotTS != 0 {
data["snapshotTS"] = sessVars.SnapshotTS
}
if sessVars.LastInsertID > 0 {
data["lastInsertID"] = sessVars.LastInsertID
if sessVars.StmtCtx.LastInsertID > 0 {
data["lastInsertID"] = sessVars.StmtCtx.LastInsertID
}
if len(sessVars.PreparedStmts) > 0 {
data["preparedStmtCount"] = len(sessVars.PreparedStmts)
Expand Down Expand Up @@ -486,6 +486,9 @@ func (s *session) retry(ctx context.Context, maxCnt uint) error {
if st.IsReadOnly() {
continue
}
s.sessionVars.StmtCtx = sr.stmtCtx
s.sessionVars.StmtCtx.ResetForRetry()
s.sessionVars.PreparedParams = s.sessionVars.PreparedParams[:0]
schemaVersion, err = st.RebuildPlan()
if err != nil {
return errors.Trace(err)
Expand All @@ -499,8 +502,6 @@ func (s *session) retry(ctx context.Context, maxCnt uint) error {
} else {
log.Warnf("con:%d schema_ver:%d retry_cnt:%d query_num:%d", connID, schemaVersion, retryCnt, i)
}
s.sessionVars.StmtCtx = sr.stmtCtx
s.sessionVars.StmtCtx.ResetForRetry()
_, err = st.Exec(ctx)
if err != nil {
s.StmtRollback()
Expand Down
42 changes: 36 additions & 6 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ func (s *testSessionSuite) TestLastInsertID(c *C) {
tk.MustExec("execute stmt1 using @v1")
tk.MustExec("execute stmt1 using @v2")
tk.MustExec("deallocate prepare stmt1")
currLastInsertID := tk.Se.GetSessionVars().PrevLastInsertID
currLastInsertID := tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID
tk.MustQuery("select c1 from t where c2 = 20").Check(testkit.Rows(fmt.Sprint(currLastInsertID)))
c.Assert(lastInsertID+2, Equals, currLastInsertID)
}
Expand Down Expand Up @@ -778,7 +778,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) {
tk.MustExec("commit")

tk.MustQuery("select c1 from t where c2 = 11").Check(testkit.Rows("6"))
currLastInsertID := tk.Se.GetSessionVars().PrevLastInsertID
currLastInsertID := tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID
c.Assert(lastInsertID+5, Equals, currLastInsertID)

// insert set
Expand All @@ -793,7 +793,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) {
tk.MustExec("commit")

tk.MustQuery("select c1 from t where c2 = 31").Check(testkit.Rows("9"))
currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID
currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID
c.Assert(lastInsertID+3, Equals, currLastInsertID)

// replace
Expand All @@ -808,7 +808,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) {
tk.MustExec("commit")

tk.MustQuery("select c1 from t where c2 = 21").Check(testkit.Rows("10"))
currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID
currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID
c.Assert(lastInsertID+1, Equals, currLastInsertID)

// update
Expand All @@ -824,7 +824,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) {
tk.MustExec("commit")

tk.MustQuery("select c1 from t where c2 = 41").Check(testkit.Rows("0"))
currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID
currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID
c.Assert(lastInsertID+3, Equals, currLastInsertID)

// prepare
Expand All @@ -846,7 +846,7 @@ func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) {
tk.MustExec("commit")

tk.MustQuery("select c1 from t where c2 = 12").Check(testkit.Rows("7"))
currLastInsertID = tk.Se.GetSessionVars().PrevLastInsertID
currLastInsertID = tk.Se.GetSessionVars().StmtCtx.PrevLastInsertID
c.Assert(lastInsertID+3, Equals, currLastInsertID)
}

Expand Down Expand Up @@ -1306,6 +1306,36 @@ func (s *testSessionSuite) TestDelete(c *C) {
tk1.MustQuery("select * from t;").Check(testkit.Rows("1"))
}

func (s *testSessionSuite) TestResetCtx(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk1 := testkit.NewTestKitWithInit(c, s.store)

tk.MustExec("create table t (i int auto_increment not null key);")
tk.MustExec("insert into t values (1);")
tk.MustExec("begin;")
tk.MustExec("insert into t values (10);")
tk.MustExec("update t set i = i + row_count();")
tk.MustQuery("select * from t;").Check(testkit.Rows("2", "11"))

tk1.MustExec("update t set i = 0 where i = 1;")
tk1.MustQuery("select * from t;").Check(testkit.Rows("0"))

tk.MustExec("commit;")
tk.MustQuery("select * from t;").Check(testkit.Rows("1", "11"))

tk.MustExec("delete from t where i = 11;")
tk.MustExec("begin;")
tk.MustExec("insert into t values ();")
tk.MustExec("update t set i = i + last_insert_id() + 1;")
tk.MustQuery("select * from t;").Check(testkit.Rows("14", "25"))

tk1.MustExec("update t set i = 0 where i = 1;")
tk1.MustQuery("select * from t;").Check(testkit.Rows("0"))

tk.MustExec("commit;")
tk.MustQuery("select * from t;").Check(testkit.Rows("13", "25"))
}

func (s *testSessionSuite) TestUnique(c *C) {
// test for https://github.com/pingcap/tidb/pull/461

Expand Down
10 changes: 10 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ type StatementContext struct {
histogramsNotLoad bool
execDetails execdetails.ExecDetails
}
// PrevAffectedRows is the affected-rows value(DDL is 0, DML is the number of affected rows).
PrevAffectedRows int64
// PrevLastInsertID is the last insert ID of previous statement.
PrevLastInsertID uint64
// LastInsertID is the auto-generated ID in the current statement.
LastInsertID uint64
// InsertID is the given insert ID of an auto_increment column.
InsertID uint64

// Copied from SessionVars.TimeZone.
TimeZone *time.Location
Expand Down Expand Up @@ -239,6 +247,8 @@ func (sc *StatementContext) ResetForRetry() {
sc.mu.foundRows = 0
sc.mu.warnings = nil
sc.mu.Unlock()
sc.TableIDs = sc.TableIDs[:0]
sc.IndexIDs = sc.IndexIDs[:0]
}

// MergeExecDetails merges a single region execution details into self, used to print
Expand Down
24 changes: 3 additions & 21 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,8 @@ type SessionVars struct {
Value string
}

// Following variables are special for current session.

Status uint16
PrevLastInsertID uint64 // PrevLastInsertID is the last insert ID of previous statement.
LastInsertID uint64 // LastInsertID is the auto-generated ID in the current statement.
InsertID uint64 // InsertID is the given insert ID of an auto_increment column.
// PrevAffectedRows is the affected-rows value(DDL is 0, DML is the number of affected rows).
PrevAffectedRows int64
// Status stands for the session status. e.g. in transaction or not, auto commit is on or off, and so on.
Status uint16

// ClientCapability is client's capability.
ClientCapability uint32
Expand Down Expand Up @@ -405,7 +399,7 @@ func (s *SessionVars) GetCharsetInfo() (charset, collation string) {
// SetLastInsertID saves the last insert id to the session context.
// TODO: we may store the result for last_insert_id sys var later.
func (s *SessionVars) SetLastInsertID(insertID uint64) {
s.LastInsertID = insertID
s.StmtCtx.LastInsertID = insertID
}

// SetStatusFlag sets the session server status variable.
Expand Down Expand Up @@ -449,18 +443,6 @@ func (s *SessionVars) Location() *time.Location {
return loc
}

// ResetPrevAffectedRows reset the prev-affected-rows variable.
func (s *SessionVars) ResetPrevAffectedRows() {
s.PrevAffectedRows = 0
if s.StmtCtx != nil {
if s.StmtCtx.InUpdateOrDeleteStmt || s.StmtCtx.InInsertStmt {
s.PrevAffectedRows = int64(s.StmtCtx.AffectedRows())
} else if s.StmtCtx.InSelectStmt {
s.PrevAffectedRows = -1
}
}
}

// GetExecuteArgumentsInfo gets the argument list as a string of execute statement.
func (s *SessionVars) GetExecuteArgumentsInfo() string {
if len(s.PreparedParams) == 0 {
Expand Down
2 changes: 1 addition & 1 deletion sessionctx/variable/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (*testSessionSuite) TestSession(c *C) {

// For last insert id
ctx.GetSessionVars().SetLastInsertID(1)
c.Assert(ctx.GetSessionVars().LastInsertID, Equals, uint64(1))
c.Assert(ctx.GetSessionVars().StmtCtx.LastInsertID, Equals, uint64(1))

ss.ResetForRetry()
c.Assert(ss.AffectedRows(), Equals, uint64(0))
Expand Down

0 comments on commit 9ab1a50

Please sign in to comment.