Skip to content

Commit

Permalink
executer: fix the last_insert_id in insert on duplicate key update (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackysp authored Aug 30, 2018
1 parent 4a31302 commit 9070bb6
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 69 deletions.
24 changes: 4 additions & 20 deletions executor/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ type InsertExec struct {
finished bool
}

func (e *InsertExec) insertOneRow(row []types.Datum) (int64, error) {
e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil)
h, err := e.Table.AddRecord(e.ctx, row, false)
e.ctx.Txn().DelOption(kv.PresumeKeyNotExists)
if err != nil {
return 0, errors.Trace(err)
}
return h, nil
}

func (e *InsertExec) exec(rows [][]types.Datum) error {
// If tidb_batch_insert is ON and not in a transaction, we could use BatchInsert mode.
sessVars := e.ctx.GetSessionVars()
Expand All @@ -67,20 +57,17 @@ func (e *InsertExec) exec(rows [][]types.Datum) error {
return errors.Trace(err)
}
} else if ignoreErr {
err := e.batchCheckAndInsert(rows, e.insertOneRow)
err := e.batchCheckAndInsert(rows, e.addRecord)
if err != nil {
return errors.Trace(err)
}
} else {
for _, row := range rows {
if _, err := e.insertOneRow(row); err != nil {
if _, err := e.addRecord(row); err != nil {
return errors.Trace(err)
}
}
}
if e.lastInsertID != 0 {
sessVars.SetLastInsertID(e.lastInsertID)
}
e.finished = true
return nil
}
Expand Down Expand Up @@ -131,7 +118,7 @@ func (e *InsertExec) batchUpdateDupRows(newRows [][]types.Datum) error {
// and key-values should be filled back to dupOldRowValues for the further row check,
// due to there may be duplicate keys inside the insert statement.
if newRows[i] != nil {
newHandle, err := e.insertOneRow(newRows[i])
newHandle, err := e.addRecord(newRows[i])
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -220,13 +207,10 @@ func (e *InsertExec) doDupRowUpdate(handle int64, oldRow []types.Datum, newRow [
}

newData := row4Update[:len(oldRow)]
_, handleChanged, newHandle, lastInsertID, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true)
_, handleChanged, newHandle, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true)
if err != nil {
return nil, false, 0, errors.Trace(err)
}
if lastInsertID != 0 {
e.lastInsertID = lastInsertID
}
return newData, handleChanged, newHandle, nil
}

Expand Down
21 changes: 17 additions & 4 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows [][]types

rows := make([][]types.Datum, len(e.Lists))
for i, list := range e.Lists {
e.rowCount = uint64(i)
e.rowCount++
rows[i], err = e.getRow(cols, list, i)
if err != nil {
return errors.Trace(err)
Expand Down Expand Up @@ -445,7 +445,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *tab
return errors.Trace(err)
}
// It's compatible with mysql. So it sets last insert id to the first row.
if e.rowCount == 0 {
if e.rowCount == 1 {
e.lastInsertID = uint64(recordID)
}
}
Expand Down Expand Up @@ -474,7 +474,7 @@ func (e *InsertValues) handleWarning(err error, logInfo string) {

// batchCheckAndInsert checks rows with duplicate errors.
// All duplicate rows will be ignored and appended as duplicate warnings.
func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow func(row []types.Datum) (int64, error)) error {
func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, addRecord func(row []types.Datum) (int64, error)) error {
// all the rows will be checked, so it is safe to set BatchCheck = true
e.ctx.GetSessionVars().StmtCtx.BatchCheck = true
err := e.batchGetInsertKeys(e.ctx, e.Table, rows)
Expand Down Expand Up @@ -502,7 +502,7 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow fu
// it should be add to values map for the further row check.
// There may be duplicate keys inside the insert statement.
if rows[i] != nil {
_, err = insertOneRow(rows[i])
_, err = addRecord(rows[i])
if err != nil {
return errors.Trace(err)
}
Expand All @@ -516,3 +516,16 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow fu
}
return nil
}

func (e *InsertValues) addRecord(row []types.Datum) (int64, error) {
e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil)
h, err := e.Table.AddRecord(e.ctx, row, false)
e.ctx.Txn().DelOption(kv.PresumeKeyNotExists)
if err != nil {
return 0, errors.Trace(err)
}
if e.lastInsertID != 0 {
e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID)
}
return h, nil
}
10 changes: 3 additions & 7 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,10 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error
break
}
}
err := e.batchCheckAndInsert(rows, e.insertData)
err := e.batchCheckAndInsert(rows, e.addRecordLD)
if err != nil {
return nil, reachLimit, errors.Trace(err)
}
if e.lastInsertID != 0 {
e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID)
}

return curData, reachLimit, nil
}

Expand All @@ -282,11 +278,11 @@ func (e *LoadDataInfo) colsToRow(cols []field) []types.Datum {
return row
}

func (e *LoadDataInfo) insertData(row []types.Datum) (int64, error) {
func (e *LoadDataInfo) addRecordLD(row []types.Datum) (int64, error) {
if row == nil {
return 0, nil
}
h, err := e.Table.AddRecord(e.ctx, row, false)
h, err := e.addRecord(row)
if err != nil {
e.handleWarning(err,
fmt.Sprintf("Load Data: insert data:%v failed:%v", e.row, errors.ErrorStack(err)))
Expand Down
18 changes: 1 addition & 17 deletions executor/replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package executor

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -83,18 +82,6 @@ func (e *ReplaceExec) removeRow(handle int64, r toBeCheckedRow) (bool, error) {
return false, nil
}

// addRow adds a row when all the duplicate key were checked.
func (e *ReplaceExec) addRow(row []types.Datum) (int64, error) {
// Set kv.PresumeKeyNotExists is safe here, because we've already removed all duplicated rows.
e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil)
h, err := e.Table.AddRecord(e.ctx, row, false)
e.ctx.Txn().DelOption(kv.PresumeKeyNotExists)
if err != nil {
return 0, errors.Trace(err)
}
return h, nil
}

// replaceRow removes all duplicate rows for one row, then inserts it.
func (e *ReplaceExec) replaceRow(r toBeCheckedRow) error {
if r.handleKey != nil {
Expand Down Expand Up @@ -129,7 +116,7 @@ func (e *ReplaceExec) replaceRow(r toBeCheckedRow) error {
}

// No duplicated rows now, insert the row.
newHandle, err := e.addRow(r.row)
newHandle, err := e.addRecord(r.row)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -190,9 +177,6 @@ func (e *ReplaceExec) exec(newRows [][]types.Datum) error {
return errors.Trace(err)
}
}
if e.lastInsertID != 0 {
e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID)
}
e.finished = true
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion executor/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (e *UpdateExec) exec(schema *expression.Schema) ([]types.Datum, error) {
}

// Update row
changed, _, _, _, err1 := updateRecord(e.ctx, handle, oldData, newTableData, flags, tbl, false)
changed, _, _, err1 := updateRecord(e.ctx, handle, oldData, newTableData, flags, tbl, false)
if err1 == nil {
if changed {
e.updatedRowKeys[id][handle] = struct{}{}
Expand Down
33 changes: 15 additions & 18 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,15 @@ var (
// 1. changed (bool) : does the update really change the row values. e.g. update set i = 1 where i = 1;
// 2. handleChanged (bool) : is the handle changed after the update.
// 3. newHandle (int64) : if handleChanged == true, the newHandle means the new handle after update.
// 4. lastInsertID (uint64) : the lastInsertID should be set by the newData.
// 5. err (error) : error in the update.
// 4. err (error) : error in the update.
func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, modified []bool, t table.Table,
onDup bool) (bool, bool, int64, uint64, error) {
onDup bool) (bool, bool, int64, error) {
var sc = ctx.GetSessionVars().StmtCtx
var changed, handleChanged = false, false
// onUpdateSpecified is for "UPDATE SET ts_field = old_value", the
// timestamp field is explicitly set, but not changed in fact.
var onUpdateSpecified = make(map[int]bool)
var newHandle int64
var lastInsertID uint64

// We can iterate on public columns not writable columns,
// because all of them are sorted by their `Offset`, which
Expand All @@ -61,7 +59,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
// Cast changed fields with respective columns.
v, err := table.CastValue(ctx, newData[i], col.ToInfo())
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
newData[i] = v
}
Expand All @@ -70,27 +68,26 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
var err error
newData[i], err = table.GetColDefaultValue(ctx, col.ToInfo())
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
}
// Rebase auto increment id if the field is changed.
if mysql.HasAutoIncrementFlag(col.Flag) {
if newData[i].IsNull() {
return false, handleChanged, newHandle, 0, table.ErrColumnCantNull.GenByArgs(col.Name)
return false, handleChanged, newHandle, table.ErrColumnCantNull.GenByArgs(col.Name)
}
val, errTI := newData[i].ToInt64(sc)
if errTI != nil {
return false, handleChanged, newHandle, 0, errors.Trace(errTI)
return false, handleChanged, newHandle, errors.Trace(errTI)
}
lastInsertID = uint64(val)
err := t.RebaseAutoID(ctx, val, true)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
}
cmp, err := newData[i].CompareDatum(sc, &oldData[i])
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
if cmp != 0 {
changed = true
Expand All @@ -111,23 +108,23 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
// Check the not-null constraints.
err := table.CheckNotNull(t.Cols(), newData)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}

if !changed {
// See https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS
if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 {
sc.AddAffectedRows(1)
}
return false, handleChanged, newHandle, lastInsertID, nil
return false, handleChanged, newHandle, nil
}

// Fill values into on-update-now fields, only if they are really changed.
for i, col := range t.Cols() {
if mysql.HasOnUpdateNowFlag(col.Flag) && !modified[i] && !onUpdateSpecified[i] {
v, errGT := expression.GetTimeValue(ctx, strings.ToUpper(ast.CurrentTimestamp), col.Tp, col.Decimal)
if errGT != nil {
return false, handleChanged, newHandle, 0, errors.Trace(errGT)
return false, handleChanged, newHandle, errors.Trace(errGT)
}
newData[i] = v
modified[i] = true
Expand All @@ -140,21 +137,21 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
// if the new handle exists. `UPDATE IGNORE` will avoid removing record, and do nothing.
err = tables.CheckHandleExists(ctx, t, newHandle, newData)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
skipHandleCheck = true
}
err = t.RemoveRecord(ctx, h, oldData)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
newHandle, err = t.AddRecord(ctx, newData, skipHandleCheck)
} else {
// Update record to new value and update index.
err = t.UpdateRecord(ctx, h, oldData, newData, modified)
}
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}

if onDup {
Expand All @@ -173,7 +170,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
}
}
ctx.GetSessionVars().TxnCtx.UpdateDeltaForTable(t.Meta().ID, 0, 1, colSize)
return true, handleChanged, newHandle, lastInsertID, nil
return true, handleChanged, newHandle, nil
}

// resetErrDataTooLong reset ErrDataTooLong error msg.
Expand Down
5 changes: 5 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@ commit;`
testSQL = `SELECT LAST_INSERT_ID();`
r = tk.MustQuery(testSQL)
r.Check(testkit.Rows("1"))
testSQL = `INSERT t1 (f2) VALUES ('test') ON DUPLICATE KEY UPDATE f1 = 2;`
tk.MustExec(testSQL)
testSQL = `SELECT LAST_INSERT_ID();`
r = tk.MustQuery(testSQL)
r.Check(testkit.Rows("1"))

testSQL = `DROP TABLE IF EXISTS t1;
CREATE TABLE t1 (f1 INT);
Expand Down
4 changes: 2 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func runTestLoadData(c *C, server *Server) {
dbt.Assert(err, IsNil)
lastID, err = rs.LastInsertId()
dbt.Assert(err, IsNil)
dbt.Assert(lastID, Equals, int64(6))
dbt.Assert(lastID, Equals, int64(7))
affectedRows, err = rs.RowsAffected()
dbt.Assert(err, IsNil)
dbt.Assert(affectedRows, Equals, int64(4))
Expand Down Expand Up @@ -464,7 +464,7 @@ func runTestLoadData(c *C, server *Server) {
dbt.Assert(err, IsNil)
lastID, err = rs.LastInsertId()
dbt.Assert(err, IsNil)
dbt.Assert(lastID, Equals, int64(10))
dbt.Assert(lastID, Equals, int64(11))
affectedRows, err = rs.RowsAffected()
dbt.Assert(err, IsNil)
dbt.Assert(affectedRows, Equals, int64(799))
Expand Down

0 comments on commit 9070bb6

Please sign in to comment.