From 9070bb6b86b27e59f35ebc2f984e369d3c5a115d Mon Sep 17 00:00:00 2001 From: Jack Yu Date: Thu, 30 Aug 2018 14:16:30 +0800 Subject: [PATCH] executer: fix the last_insert_id in insert on duplicate key update (#7534) --- executor/insert.go | 24 ++++-------------------- executor/insert_common.go | 21 +++++++++++++++++---- executor/load_data.go | 10 +++------- executor/replace.go | 18 +----------------- executor/update.go | 2 +- executor/write.go | 33 +++++++++++++++------------------ executor/write_test.go | 5 +++++ server/server_test.go | 4 ++-- 8 files changed, 48 insertions(+), 69 deletions(-) diff --git a/executor/insert.go b/executor/insert.go index 66a91ee65830c..6587b9650e020 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -34,16 +34,6 @@ type InsertExec struct { finished bool } -func (e *InsertExec) insertOneRow(row []types.Datum) (int64, error) { - e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil) - h, err := e.Table.AddRecord(e.ctx, row, false) - e.ctx.Txn().DelOption(kv.PresumeKeyNotExists) - if err != nil { - return 0, errors.Trace(err) - } - return h, nil -} - func (e *InsertExec) exec(rows [][]types.Datum) error { // If tidb_batch_insert is ON and not in a transaction, we could use BatchInsert mode. sessVars := e.ctx.GetSessionVars() @@ -67,20 +57,17 @@ func (e *InsertExec) exec(rows [][]types.Datum) error { return errors.Trace(err) } } else if ignoreErr { - err := e.batchCheckAndInsert(rows, e.insertOneRow) + err := e.batchCheckAndInsert(rows, e.addRecord) if err != nil { return errors.Trace(err) } } else { for _, row := range rows { - if _, err := e.insertOneRow(row); err != nil { + if _, err := e.addRecord(row); err != nil { return errors.Trace(err) } } } - if e.lastInsertID != 0 { - sessVars.SetLastInsertID(e.lastInsertID) - } e.finished = true return nil } @@ -131,7 +118,7 @@ func (e *InsertExec) batchUpdateDupRows(newRows [][]types.Datum) error { // and key-values should be filled back to dupOldRowValues for the further row check, // due to there may be duplicate keys inside the insert statement. if newRows[i] != nil { - newHandle, err := e.insertOneRow(newRows[i]) + newHandle, err := e.addRecord(newRows[i]) if err != nil { return errors.Trace(err) } @@ -220,13 +207,10 @@ func (e *InsertExec) doDupRowUpdate(handle int64, oldRow []types.Datum, newRow [ } newData := row4Update[:len(oldRow)] - _, handleChanged, newHandle, lastInsertID, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true) + _, handleChanged, newHandle, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true) if err != nil { return nil, false, 0, errors.Trace(err) } - if lastInsertID != 0 { - e.lastInsertID = lastInsertID - } return newData, handleChanged, newHandle, nil } diff --git a/executor/insert_common.go b/executor/insert_common.go index 7859cc68da855..c011c1aa6fe37 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -157,7 +157,7 @@ func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows [][]types rows := make([][]types.Datum, len(e.Lists)) for i, list := range e.Lists { - e.rowCount = uint64(i) + e.rowCount++ rows[i], err = e.getRow(cols, list, i) if err != nil { return errors.Trace(err) @@ -445,7 +445,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *tab return errors.Trace(err) } // It's compatible with mysql. So it sets last insert id to the first row. - if e.rowCount == 0 { + if e.rowCount == 1 { e.lastInsertID = uint64(recordID) } } @@ -474,7 +474,7 @@ func (e *InsertValues) handleWarning(err error, logInfo string) { // batchCheckAndInsert checks rows with duplicate errors. // All duplicate rows will be ignored and appended as duplicate warnings. -func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow func(row []types.Datum) (int64, error)) error { +func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, addRecord func(row []types.Datum) (int64, error)) error { // all the rows will be checked, so it is safe to set BatchCheck = true e.ctx.GetSessionVars().StmtCtx.BatchCheck = true err := e.batchGetInsertKeys(e.ctx, e.Table, rows) @@ -502,7 +502,7 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow fu // it should be add to values map for the further row check. // There may be duplicate keys inside the insert statement. if rows[i] != nil { - _, err = insertOneRow(rows[i]) + _, err = addRecord(rows[i]) if err != nil { return errors.Trace(err) } @@ -516,3 +516,16 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow fu } return nil } + +func (e *InsertValues) addRecord(row []types.Datum) (int64, error) { + e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil) + h, err := e.Table.AddRecord(e.ctx, row, false) + e.ctx.Txn().DelOption(kv.PresumeKeyNotExists) + if err != nil { + return 0, errors.Trace(err) + } + if e.lastInsertID != 0 { + e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID) + } + return h, nil +} diff --git a/executor/load_data.go b/executor/load_data.go index 10ee2a4260a85..afe7b61f5bce3 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -248,14 +248,10 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error break } } - err := e.batchCheckAndInsert(rows, e.insertData) + err := e.batchCheckAndInsert(rows, e.addRecordLD) if err != nil { return nil, reachLimit, errors.Trace(err) } - if e.lastInsertID != 0 { - e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID) - } - return curData, reachLimit, nil } @@ -282,11 +278,11 @@ func (e *LoadDataInfo) colsToRow(cols []field) []types.Datum { return row } -func (e *LoadDataInfo) insertData(row []types.Datum) (int64, error) { +func (e *LoadDataInfo) addRecordLD(row []types.Datum) (int64, error) { if row == nil { return 0, nil } - h, err := e.Table.AddRecord(e.ctx, row, false) + h, err := e.addRecord(row) if err != nil { e.handleWarning(err, fmt.Sprintf("Load Data: insert data:%v failed:%v", e.row, errors.ErrorStack(err))) diff --git a/executor/replace.go b/executor/replace.go index a37f22dcc7e79..1debefb3cf9e5 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -15,7 +15,6 @@ package executor import ( "github.com/juju/errors" - "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -83,18 +82,6 @@ func (e *ReplaceExec) removeRow(handle int64, r toBeCheckedRow) (bool, error) { return false, nil } -// addRow adds a row when all the duplicate key were checked. -func (e *ReplaceExec) addRow(row []types.Datum) (int64, error) { - // Set kv.PresumeKeyNotExists is safe here, because we've already removed all duplicated rows. - e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil) - h, err := e.Table.AddRecord(e.ctx, row, false) - e.ctx.Txn().DelOption(kv.PresumeKeyNotExists) - if err != nil { - return 0, errors.Trace(err) - } - return h, nil -} - // replaceRow removes all duplicate rows for one row, then inserts it. func (e *ReplaceExec) replaceRow(r toBeCheckedRow) error { if r.handleKey != nil { @@ -129,7 +116,7 @@ func (e *ReplaceExec) replaceRow(r toBeCheckedRow) error { } // No duplicated rows now, insert the row. - newHandle, err := e.addRow(r.row) + newHandle, err := e.addRecord(r.row) if err != nil { return errors.Trace(err) } @@ -190,9 +177,6 @@ func (e *ReplaceExec) exec(newRows [][]types.Datum) error { return errors.Trace(err) } } - if e.lastInsertID != 0 { - e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID) - } e.finished = true return nil } diff --git a/executor/update.go b/executor/update.go index ab697d7db3cb4..39be67de247fb 100644 --- a/executor/update.go +++ b/executor/update.go @@ -80,7 +80,7 @@ func (e *UpdateExec) exec(schema *expression.Schema) ([]types.Datum, error) { } // Update row - changed, _, _, _, err1 := updateRecord(e.ctx, handle, oldData, newTableData, flags, tbl, false) + changed, _, _, err1 := updateRecord(e.ctx, handle, oldData, newTableData, flags, tbl, false) if err1 == nil { if changed { e.updatedRowKeys[id][handle] = struct{}{} diff --git a/executor/write.go b/executor/write.go index f1bc187617d85..ae010cf3ae048 100644 --- a/executor/write.go +++ b/executor/write.go @@ -41,17 +41,15 @@ var ( // 1. changed (bool) : does the update really change the row values. e.g. update set i = 1 where i = 1; // 2. handleChanged (bool) : is the handle changed after the update. // 3. newHandle (int64) : if handleChanged == true, the newHandle means the new handle after update. -// 4. lastInsertID (uint64) : the lastInsertID should be set by the newData. -// 5. err (error) : error in the update. +// 4. err (error) : error in the update. func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, modified []bool, t table.Table, - onDup bool) (bool, bool, int64, uint64, error) { + onDup bool) (bool, bool, int64, error) { var sc = ctx.GetSessionVars().StmtCtx var changed, handleChanged = false, false // onUpdateSpecified is for "UPDATE SET ts_field = old_value", the // timestamp field is explicitly set, but not changed in fact. var onUpdateSpecified = make(map[int]bool) var newHandle int64 - var lastInsertID uint64 // We can iterate on public columns not writable columns, // because all of them are sorted by their `Offset`, which @@ -61,7 +59,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu // Cast changed fields with respective columns. v, err := table.CastValue(ctx, newData[i], col.ToInfo()) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } newData[i] = v } @@ -70,27 +68,26 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu var err error newData[i], err = table.GetColDefaultValue(ctx, col.ToInfo()) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } } // Rebase auto increment id if the field is changed. if mysql.HasAutoIncrementFlag(col.Flag) { if newData[i].IsNull() { - return false, handleChanged, newHandle, 0, table.ErrColumnCantNull.GenByArgs(col.Name) + return false, handleChanged, newHandle, table.ErrColumnCantNull.GenByArgs(col.Name) } val, errTI := newData[i].ToInt64(sc) if errTI != nil { - return false, handleChanged, newHandle, 0, errors.Trace(errTI) + return false, handleChanged, newHandle, errors.Trace(errTI) } - lastInsertID = uint64(val) err := t.RebaseAutoID(ctx, val, true) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } } cmp, err := newData[i].CompareDatum(sc, &oldData[i]) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } if cmp != 0 { changed = true @@ -111,7 +108,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu // Check the not-null constraints. err := table.CheckNotNull(t.Cols(), newData) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } if !changed { @@ -119,7 +116,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) } - return false, handleChanged, newHandle, lastInsertID, nil + return false, handleChanged, newHandle, nil } // Fill values into on-update-now fields, only if they are really changed. @@ -127,7 +124,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu if mysql.HasOnUpdateNowFlag(col.Flag) && !modified[i] && !onUpdateSpecified[i] { v, errGT := expression.GetTimeValue(ctx, strings.ToUpper(ast.CurrentTimestamp), col.Tp, col.Decimal) if errGT != nil { - return false, handleChanged, newHandle, 0, errors.Trace(errGT) + return false, handleChanged, newHandle, errors.Trace(errGT) } newData[i] = v modified[i] = true @@ -140,13 +137,13 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu // if the new handle exists. `UPDATE IGNORE` will avoid removing record, and do nothing. err = tables.CheckHandleExists(ctx, t, newHandle, newData) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } skipHandleCheck = true } err = t.RemoveRecord(ctx, h, oldData) if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } newHandle, err = t.AddRecord(ctx, newData, skipHandleCheck) } else { @@ -154,7 +151,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu err = t.UpdateRecord(ctx, h, oldData, newData, modified) } if err != nil { - return false, handleChanged, newHandle, 0, errors.Trace(err) + return false, handleChanged, newHandle, errors.Trace(err) } if onDup { @@ -173,7 +170,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu } } ctx.GetSessionVars().TxnCtx.UpdateDeltaForTable(t.Meta().ID, 0, 1, colSize) - return true, handleChanged, newHandle, lastInsertID, nil + return true, handleChanged, newHandle, nil } // resetErrDataTooLong reset ErrDataTooLong error msg. diff --git a/executor/write_test.go b/executor/write_test.go index 896032f2e8478..ff2ef8d552ca8 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -622,6 +622,11 @@ commit;` testSQL = `SELECT LAST_INSERT_ID();` r = tk.MustQuery(testSQL) r.Check(testkit.Rows("1")) + testSQL = `INSERT t1 (f2) VALUES ('test') ON DUPLICATE KEY UPDATE f1 = 2;` + tk.MustExec(testSQL) + testSQL = `SELECT LAST_INSERT_ID();` + r = tk.MustQuery(testSQL) + r.Check(testkit.Rows("1")) testSQL = `DROP TABLE IF EXISTS t1; CREATE TABLE t1 (f1 INT); diff --git a/server/server_test.go b/server/server_test.go index 1b15a99056375..2a358e80ba1ff 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -425,7 +425,7 @@ func runTestLoadData(c *C, server *Server) { dbt.Assert(err, IsNil) lastID, err = rs.LastInsertId() dbt.Assert(err, IsNil) - dbt.Assert(lastID, Equals, int64(6)) + dbt.Assert(lastID, Equals, int64(7)) affectedRows, err = rs.RowsAffected() dbt.Assert(err, IsNil) dbt.Assert(affectedRows, Equals, int64(4)) @@ -464,7 +464,7 @@ func runTestLoadData(c *C, server *Server) { dbt.Assert(err, IsNil) lastID, err = rs.LastInsertId() dbt.Assert(err, IsNil) - dbt.Assert(lastID, Equals, int64(10)) + dbt.Assert(lastID, Equals, int64(11)) affectedRows, err = rs.RowsAffected() dbt.Assert(err, IsNil) dbt.Assert(affectedRows, Equals, int64(799))