diff --git a/dm/pkg/log/log.go b/dm/pkg/log/log.go index 98920087430..62c63ca42d5 100644 --- a/dm/pkg/log/log.go +++ b/dm/pkg/log/log.go @@ -104,6 +104,7 @@ var ( // InitLogger initializes DM's and also the TiDB library's loggers. func InitLogger(cfg *Config) error { + inDev := strings.ToLower(cfg.Level) == "debug" // init DM logger logger, props, err := pclog.InitLogger(&pclog.Config{ Level: cfg.Level, @@ -114,6 +115,7 @@ func InitLogger(cfg *Config) error { MaxDays: cfg.FileMaxDays, MaxBackups: cfg.FileMaxBackups, }, + Development: inDev, }) if err != nil { return terror.ErrInitLoggerFail.Delegate(err) @@ -125,7 +127,7 @@ func InitLogger(cfg *Config) error { appLevel = props.Level appProps = props // init and set tidb slow query logger to stdout if log level is debug - if cfg.Level == "debug" { + if inDev { slowQueryLogger := zap.NewExample() slowQueryLogger = slowQueryLogger.With(zap.String("component", "slow query logger")) logutil.SlowQueryLogger = slowQueryLogger diff --git a/dm/pkg/schema/tracker.go b/dm/pkg/schema/tracker.go index 1d0640cc6f9..1ee9d1015a4 100644 --- a/dm/pkg/schema/tracker.go +++ b/dm/pkg/schema/tracker.go @@ -403,7 +403,7 @@ func (tr *Tracker) GetDownStreamTableInfo(tctx *tcontext.Context, tableID string return nil, err } - dti = GetDownStreamTi(ti, originTi) + dti = GetDownStreamTI(ti, originTi) tr.dsTracker.tableInfos[tableID] = dti } return dti, nil @@ -412,9 +412,14 @@ func (tr *Tracker) GetDownStreamTableInfo(tctx *tcontext.Context, tableID string // GetAvailableDownStreamUKIndexInfo gets available downstream UK whose data is not null. // note. this function will not init downstreamTrack. func (tr *Tracker) GetAvailableDownStreamUKIndexInfo(tableID string, data []interface{}) *model.IndexInfo { - dti, ok := tr.dsTracker.tableInfos[tableID] + dti := tr.dsTracker.tableInfos[tableID] + + return GetIdentityUKByData(dti, data) +} - if !ok || len(dti.AvailableUKIndexList) == 0 { +// GetIdentityUKByData gets available downstream UK whose data is not null. +func GetIdentityUKByData(downstreamTI *DownstreamTableInfo, data []interface{}) *model.IndexInfo { + if downstreamTI == nil || len(downstreamTI.AvailableUKIndexList) == 0 { return nil } // func for check data is not null @@ -422,7 +427,7 @@ func (tr *Tracker) GetAvailableDownStreamUKIndexInfo(tableID string, data []inte return data[i] != nil } - for _, uk := range dti.AvailableUKIndexList { + for _, uk := range downstreamTI.AvailableUKIndexList { // check uk's column data is not null if isSpecifiedIndexColumn(uk, fn) { return uk @@ -499,8 +504,8 @@ func (tr *Tracker) initDownStreamSQLModeAndParser(tctx *tcontext.Context) error return nil } -// GetDownStreamTi constructs downstreamTable index cache by tableinfo. -func GetDownStreamTi(ti *model.TableInfo, originTi *model.TableInfo) *DownstreamTableInfo { +// GetDownStreamTI constructs downstreamTable index cache by tableinfo. +func GetDownStreamTI(downstreamTI *model.TableInfo, originTi *model.TableInfo) *DownstreamTableInfo { var ( absoluteUKIndexInfo *model.IndexInfo availableUKIndexList = []*model.IndexInfo{} @@ -510,10 +515,10 @@ func GetDownStreamTi(ti *model.TableInfo, originTi *model.TableInfo) *Downstream // func for check not null constraint fn := func(i int) bool { - return mysql.HasNotNullFlag(ti.Columns[i].Flag) + return mysql.HasNotNullFlag(downstreamTI.Columns[i].Flag) } - for i, idx := range ti.Indices { + for i, idx := range downstreamTI.Indices { if !idx.Primary && !idx.Unique { continue } @@ -536,7 +541,7 @@ func GetDownStreamTi(ti *model.TableInfo, originTi *model.TableInfo) *Downstream // handle pk exceptional case. // e.g. "create table t(a int primary key, b int)". if !hasPk { - exPk := redirectIndexKeys(handlePkExCase(ti), originTi) + exPk := redirectIndexKeys(handlePkExCase(downstreamTI), originTi) if exPk != nil { absoluteUKIndexInfo = exPk absoluteUKPosition = len(availableUKIndexList) @@ -550,7 +555,7 @@ func GetDownStreamTi(ti *model.TableInfo, originTi *model.TableInfo) *Downstream } return &DownstreamTableInfo{ - TableInfo: ti, + TableInfo: downstreamTI, AbsoluteUKIndexInfo: absoluteUKIndexInfo, AvailableUKIndexList: availableUKIndexList, } diff --git a/dm/pkg/utils/common.go b/dm/pkg/utils/common.go index 1718f2d3fb6..25c4ecb348a 100644 --- a/dm/pkg/utils/common.go +++ b/dm/pkg/utils/common.go @@ -37,6 +37,10 @@ import ( "github.com/pingcap/tiflow/dm/pkg/terror" ) +func init() { + ZeroSessionCtx = NewSessionCtx(nil) +} + // TrimCtrlChars returns a slice of the string s with all leading // and trailing control characters removed. func TrimCtrlChars(s string) string { @@ -322,6 +326,9 @@ func (se *session) GetBuiltinFunctionUsage() map[string]uint32 { return se.builtinFunctionUsage } +// ZeroSessionCtx is used when the session variables is not important. +var ZeroSessionCtx sessionctx.Context + // NewSessionCtx return a session context with specified session variables. func NewSessionCtx(vars map[string]string) sessionctx.Context { variables := variable.NewSessionVars() diff --git a/dm/syncer/causality_test.go b/dm/syncer/causality_test.go index 1354ca7a979..1109b8c1b80 100644 --- a/dm/syncer/causality_test.go +++ b/dm/syncer/causality_test.go @@ -83,7 +83,7 @@ func (s *testSyncerSuite) TestCasuality(c *C) { Length: types.UnspecifiedLength, }}, } - downTi := schema.GetDownStreamTi(ti, ti) + downTi := schema.GetDownStreamTI(ti, ti) c.Assert(downTi, NotNil) jobCh := make(chan *job, 10) @@ -152,7 +152,7 @@ func (s *testSyncerSuite) TestCasualityWithPrefixIndex(c *C) { schemaStr := "create table t (c1 text, c2 int unique, unique key c1(c1(3)));" ti, err := createTableInfo(p, se, int64(0), schemaStr) c.Assert(err, IsNil) - downTi := schema.GetDownStreamTi(ti, ti) + downTi := schema.GetDownStreamTI(ti, ti) c.Assert(downTi, NotNil) c.Assert(len(downTi.AvailableUKIndexList) == 2, IsTrue) tiIndex := downTi.AvailableUKIndexList[0] diff --git a/dm/syncer/compactor_test.go b/dm/syncer/compactor_test.go index 06ad3993791..506f9581a7f 100644 --- a/dm/syncer/compactor_test.go +++ b/dm/syncer/compactor_test.go @@ -91,7 +91,7 @@ func (s *testSyncerSuite) TestCompactJob(c *C) { Length: types.UnspecifiedLength, }}, } - downTi := schema.GetDownStreamTi(ti, ti) + downTi := schema.GetDownStreamTI(ti, ti) c.Assert(downTi, NotNil) var dml *DML @@ -208,7 +208,7 @@ func (s *testSyncerSuite) TestCompactorSafeMode(c *C) { Length: types.UnspecifiedLength, }}, } - downTi := schema.GetDownStreamTi(ti, ti) + downTi := schema.GetDownStreamTI(ti, ti) c.Assert(downTi, NotNil) testCases := []struct { diff --git a/dm/syncer/dml_test.go b/dm/syncer/dml_test.go index a7748abb4bf..148c2538eff 100644 --- a/dm/syncer/dml_test.go +++ b/dm/syncer/dml_test.go @@ -224,7 +224,7 @@ func (s *testSyncerSuite) TestGenMultipleKeys(c *C) { ti, err := createTableInfo(p, se, int64(i+1), tc.schema) assert(err, IsNil) - dti := schema.GetDownStreamTi(ti, ti) + dti := schema.GetDownStreamTI(ti, ti) assert(dti, NotNil) keys := genMultipleKeys(sessCtx, dti, ti, tc.values, "table") assert(keys, DeepEquals, tc.keys) @@ -619,7 +619,7 @@ func (s *testSyncerSuite) TestTruncateIndexValues(c *C) { } ti, err := createTableInfo(p, se, int64(i+1), tc.schema) assert(err, IsNil) - dti := schema.GetDownStreamTi(ti, ti) + dti := schema.GetDownStreamTI(ti, ti) assert(dti, NotNil) assert(dti.AvailableUKIndexList, NotNil) cols := make([]*model.ColumnInfo, 0, len(dti.AvailableUKIndexList[0].Columns)) diff --git a/dm/syncer/syncer.go b/dm/syncer/syncer.go index 212a3fa9bb1..0c0ca3d9392 100644 --- a/dm/syncer/syncer.go +++ b/dm/syncer/syncer.go @@ -871,7 +871,9 @@ func (s *Syncer) updateReplicationLagMetric() { func (s *Syncer) saveTablePoint(table *filter.Table, location binlog.Location) { ti, err := s.schemaTracker.GetTableInfo(table) if err != nil && table.Name != "" { - s.tctx.L().DPanic("table info missing from schema tracker", + // TODO: if we RENAME tb1 TO tb2, the tracker will remove TableInfo of tb1 but we still save the table + // checkpoint for tb1. We can delete the table checkpoint in future. + s.tctx.L().Warn("table info missing from schema tracker", zap.Stringer("table", table), zap.Stringer("location", location), zap.Error(err)) diff --git a/pkg/sqlmodel/causality.go b/pkg/sqlmodel/causality.go new file mode 100644 index 00000000000..cc8d8f4572b --- /dev/null +++ b/pkg/sqlmodel/causality.go @@ -0,0 +1,166 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "fmt" + "strconv" + "strings" + + timodel "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/tablecodec" + "go.uber.org/zap" + + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/dm/pkg/utils" +) + +// CausalityKeys returns all string representation of causality keys. If two row +// changes has the same causality keys, they must be replicated sequentially. +func (r *RowChange) CausalityKeys() []string { + r.lazyInitIdentityInfo() + + ret := make([]string, 0, 1) + if r.preValues != nil { + ret = append(ret, r.getCausalityString(r.preValues)...) + } + if r.postValues != nil { + ret = append(ret, r.getCausalityString(r.postValues)...) + } + return ret +} + +func columnValue2String(value interface{}) string { + var data string + switch v := value.(type) { + case nil: + data = "null" + case bool: + if v { + data = "1" + } else { + data = "0" + } + case int: + data = strconv.FormatInt(int64(v), 10) + case int8: + data = strconv.FormatInt(int64(v), 10) + case int16: + data = strconv.FormatInt(int64(v), 10) + case int32: + data = strconv.FormatInt(int64(v), 10) + case int64: + data = strconv.FormatInt(v, 10) + case uint8: + data = strconv.FormatUint(uint64(v), 10) + case uint16: + data = strconv.FormatUint(uint64(v), 10) + case uint32: + data = strconv.FormatUint(uint64(v), 10) + case uint64: + data = strconv.FormatUint(v, 10) + case float32: + data = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + data = strconv.FormatFloat(v, 'f', -1, 64) + case string: + data = v + case []byte: + data = string(v) + default: + data = fmt.Sprintf("%v", v) + } + + return data +} + +func genKeyString( + table string, + columns []*timodel.ColumnInfo, + values []interface{}, +) string { + var buf strings.Builder + for i, data := range values { + if data == nil { + log.L().Debug("ignore null value", + zap.String("column", columns[i].Name.O), + zap.String("table", table)) + continue // ignore `null` value. + } + // one column key looks like:`column_val.column_name.` + buf.WriteString(columnValue2String(data)) + buf.WriteString(".") + buf.WriteString(columns[i].Name.L) + buf.WriteString(".") + } + if buf.Len() == 0 { + log.L().Debug("all value are nil, no key generated", + zap.String("table", table)) + return "" // all values are `null`. + } + buf.WriteString(table) + return buf.String() +} + +// truncateIndexValues truncate prefix index from data. +func truncateIndexValues( + ctx sessionctx.Context, + ti *timodel.TableInfo, + indexColumns *timodel.IndexInfo, + tiColumns []*timodel.ColumnInfo, + data []interface{}, +) []interface{} { + values := make([]interface{}, 0, len(indexColumns.Columns)) + datums, err := utils.AdjustBinaryProtocolForDatum(ctx, data, tiColumns) + if err != nil { + log.L().Warn("adjust binary protocol for datum error", zap.Error(err)) + return data + } + tablecodec.TruncateIndexValues(ti, indexColumns, datums) + for _, datum := range datums { + values = append(values, datum.GetValue()) + } + return values +} + +func (r *RowChange) getCausalityString(values []interface{}) []string { + pkAndUks := r.identityInfo.AvailableUKIndexList + if len(pkAndUks) == 0 { + // the table has no PK/UK, all values of the row consists the causality key + return []string{genKeyString(r.sourceTable.String(), r.sourceTableInfo.Columns, values)} + } + + ret := make([]string, 0, len(pkAndUks)) + + for _, indexCols := range pkAndUks { + cols, vals := getColsAndValuesOfIdx(r.sourceTableInfo.Columns, indexCols, values) + // handle prefix index + truncVals := truncateIndexValues(r.tiSessionCtx, r.sourceTableInfo, indexCols, cols, vals) + key := genKeyString(r.sourceTable.String(), cols, truncVals) + if len(key) > 0 { // ignore `null` value. + ret = append(ret, key) + } else { + log.L().Debug("ignore empty key", zap.String("table", r.sourceTable.String())) + } + } + + if len(ret) == 0 { + // the table has no PK/UK, or all UK are NULL. all values of the row + // consists the causality key + return []string{genKeyString(r.sourceTable.String(), r.sourceTableInfo.Columns, values)} + } + + return ret +} diff --git a/pkg/sqlmodel/causality_test.go b/pkg/sqlmodel/causality_test.go new file mode 100644 index 00000000000..4a49e95640e --- /dev/null +++ b/pkg/sqlmodel/causality_test.go @@ -0,0 +1,73 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" +) + +func TestCausalityKeys(t *testing.T) { + t.Parallel() + + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + + cases := []struct { + createSQL string + preValue []interface{} + postValue []interface{} + + causalityKeys []string + }{ + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT, c3 VARCHAR(10) UNIQUE)", + []interface{}{1, 2, "abc"}, + []interface{}{3, 4, "abc"}, + []string{"1.c.db.tb1", "abc.c3.db.tb1", "3.c.db.tb1", "abc.c3.db.tb1"}, + }, + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT, c3 VARCHAR(10), UNIQUE INDEX(c3(1)))", + []interface{}{1, 2, "abc"}, + []interface{}{3, 4, "adef"}, + []string{"1.c.db.tb1", "a.c3.db.tb1", "3.c.db.tb1", "a.c3.db.tb1"}, + }, + } + + for _, ca := range cases { + ti := mockTableInfo(t, ca.createSQL) + change := NewRowChange(source, nil, ca.preValue, ca.postValue, ti, nil, nil) + require.Equal(t, ca.causalityKeys, change.CausalityKeys()) + } +} + +func TestCausalityKeysNoRace(t *testing.T) { + t.Parallel() + + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + ti := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT, c3 VARCHAR(10) UNIQUE)") + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + change := NewRowChange(source, nil, []interface{}{1, 2, "abc"}, []interface{}{3, 4, "abc"}, ti, nil, nil) + change.CausalityKeys() + wg.Done() + }() + } + wg.Wait() +} diff --git a/pkg/sqlmodel/multivalue.go b/pkg/sqlmodel/multivalue.go new file mode 100644 index 00000000000..8c77387b276 --- /dev/null +++ b/pkg/sqlmodel/multivalue.go @@ -0,0 +1,202 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "strings" + + "go.uber.org/zap" + + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/pkg/quotes" +) + +// SameTypeTargetAndColumns check whether two row changes have same type, target +// and columns, so they can be merged to a multi-value DML. +func SameTypeTargetAndColumns(lhs *RowChange, rhs *RowChange) bool { + if lhs.tp != rhs.tp { + return false + } + if lhs.sourceTable.Schema == rhs.sourceTable.Schema && + lhs.sourceTable.Table == rhs.sourceTable.Table { + return true + } + if lhs.targetTable.Schema != rhs.targetTable.Schema || + lhs.targetTable.Table != rhs.targetTable.Table { + return false + } + + // when the targets are the same and the sources are not the same (same + // group of shard tables), this piece of code is run. + var lhsCols, rhsCols []string + switch lhs.tp { + case RowChangeDelete: + lhsCols, _ = lhs.whereColumnsAndValues() + rhsCols, _ = rhs.whereColumnsAndValues() + case RowChangeUpdate: + // not supported yet + return false + case RowChangeInsert: + for _, col := range lhs.sourceTableInfo.Columns { + lhsCols = append(lhsCols, col.Name.L) + } + for _, col := range rhs.sourceTableInfo.Columns { + rhsCols = append(rhsCols, col.Name.L) + } + } + + if len(lhsCols) != len(rhsCols) { + return false + } + for i := 0; i < len(lhsCols); i++ { + if lhsCols[i] != rhsCols[i] { + return false + } + } + return true +} + +// GenDeleteSQL generates the DELETE SQL and its arguments. +// Input `changes` should have same target table and same columns for WHERE +// (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. +func GenDeleteSQL(changes ...*RowChange) (string, []interface{}) { + if len(changes) == 0 { + log.L().DPanic("row changes is empty") + return "", nil + } + + first := changes[0] + + var buf strings.Builder + buf.Grow(1024) + buf.WriteString("DELETE FROM ") + buf.WriteString(first.targetTable.QuoteString()) + buf.WriteString(" WHERE (") + + whereColumns, _ := first.whereColumnsAndValues() + for i, column := range whereColumns { + if i != len(whereColumns)-1 { + buf.WriteString(quotes.QuoteName(column) + ",") + } else { + buf.WriteString(quotes.QuoteName(column) + ")") + } + } + buf.WriteString(" IN (") + // TODO: can't handle NULL by IS NULL, should use WHERE OR + args := make([]interface{}, 0, len(changes)*len(whereColumns)) + holder := valuesHolder(len(whereColumns)) + for i, change := range changes { + if i > 0 { + buf.WriteString(",") + } + buf.WriteString(holder) + _, whereValues := change.whereColumnsAndValues() + // a simple check about different number of WHERE values, not trying to + // cover all cases + if len(whereValues) != len(whereColumns) { + log.L().DPanic("len(whereValues) != len(whereColumns)", + zap.Int("len(whereValues)", len(whereValues)), + zap.Int("len(whereColumns)", len(whereColumns)), + zap.Any("whereValues", whereValues), + zap.Stringer("sourceTable", change.sourceTable)) + return "", nil + } + args = append(args, whereValues...) + } + buf.WriteString(")") + return buf.String(), args +} + +// TODO: support GenUpdateSQL(changes ...*RowChange) using UPDATE SET CASE WHEN + +// GenInsertSQL generates the INSERT SQL and its arguments. +// Input `changes` should have same target table and same modifiable columns, +// otherwise the behaviour is undefined. +func GenInsertSQL(tp DMLType, changes ...*RowChange) (string, []interface{}) { + if len(changes) == 0 { + log.L().DPanic("row changes is empty") + return "", nil + } + + first := changes[0] + + var buf strings.Builder + buf.Grow(1024) + if tp == DMLReplace { + buf.WriteString("REPLACE INTO ") + } else { + buf.WriteString("INSERT INTO ") + } + buf.WriteString(first.targetTable.QuoteString()) + buf.WriteString(" (") + columnNum := 0 + var skipColIdx []int + for i, col := range first.sourceTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, col.Name) { + skipColIdx = append(skipColIdx, i) + continue + } + + if columnNum != 0 { + buf.WriteByte(',') + } + columnNum++ + buf.WriteString(quotes.QuoteName(col.Name.O)) + } + buf.WriteString(") VALUES ") + holder := valuesHolder(columnNum) + for i := range changes { + if i > 0 { + buf.WriteString(",") + } + buf.WriteString(holder) + } + if tp == DMLInsertOnDuplicateUpdate { + buf.WriteString(" ON DUPLICATE KEY UPDATE ") + i := 0 // used as index of skipColIdx + writtenFirstCol := false + + for j, col := range first.sourceTableInfo.Columns { + if i < len(skipColIdx) && skipColIdx[i] == j { + i++ + continue + } + + if writtenFirstCol { + buf.WriteByte(',') + } + writtenFirstCol = true + + colName := quotes.QuoteName(col.Name.O) + buf.WriteString(colName + "=VALUES(" + colName + ")") + } + } + + args := make([]interface{}, 0, len(changes)*(len(first.sourceTableInfo.Columns)-len(skipColIdx))) + for _, change := range changes { + i := 0 // used as index of skipColIdx + for j, val := range change.postValues { + if i >= len(skipColIdx) { + args = append(args, change.postValues[j:]...) + break + } + if skipColIdx[i] == j { + i++ + continue + } + args = append(args, val) + } + } + return buf.String(), args +} diff --git a/pkg/sqlmodel/multivalue_test.go b/pkg/sqlmodel/multivalue_test.go new file mode 100644 index 00000000000..a06326d5ee9 --- /dev/null +++ b/pkg/sqlmodel/multivalue_test.go @@ -0,0 +1,68 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "testing" + + "github.com/stretchr/testify/require" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" +) + +func TestGenDeleteMultiValue(t *testing.T) { + t.Parallel() + + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT PRIMARY KEY, c2 INT)") + + change1 := NewRowChange(source1, target, []interface{}{1, 2}, nil, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{3, 4}, nil, sourceTI2, targetTI, nil) + sql, args := GenDeleteSQL(change1, change2) + + require.Equal(t, "DELETE FROM `db`.`tb` WHERE (`c`) IN ((?),(?))", sql) + require.Equal(t, []interface{}{1, 3}, args) +} + +func TestGenInsertMultiValue(t *testing.T) { + t.Parallel() + + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tb (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + + change1 := NewRowChange(source1, target, nil, []interface{}{2, 1, 2}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, nil, []interface{}{4, 3, 4}, sourceTI2, targetTI, nil) + + sql, args := GenInsertSQL(DMLInsert, change1, change2) + require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) + + sql, args = GenInsertSQL(DMLReplace, change1, change2) + require.Equal(t, "REPLACE INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) + + sql, args = GenInsertSQL(DMLInsertOnDuplicateUpdate, change1, change2) + require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) +} diff --git a/pkg/sqlmodel/reduce.go b/pkg/sqlmodel/reduce.go new file mode 100644 index 00000000000..6e146e2d54c --- /dev/null +++ b/pkg/sqlmodel/reduce.go @@ -0,0 +1,154 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "fmt" + "strings" + + "go.uber.org/zap" + + "github.com/pingcap/tiflow/dm/pkg/log" +) + +// HasNotNullUniqueIdx returns true when the target table structure has PK or UK +// whose columns are all NOT NULL. +func (r *RowChange) HasNotNullUniqueIdx() bool { + r.lazyInitIdentityInfo() + + return r.identityInfo.AbsoluteUKIndexInfo != nil +} + +// IdentityValues returns the two group of values that can be used to identify +// the row. That is to say, if two row changes has same IdentityValues, they are +// changes of the same row. We can use this property to only replicate latest +// changes of one row. +// We always use same index for same table structure to get IdentityValues. +// two groups returned are from preValues and postValues. +func (r *RowChange) IdentityValues() ([]interface{}, []interface{}) { + r.lazyInitIdentityInfo() + + indexInfo := r.identityInfo.AbsoluteUKIndexInfo + if indexInfo == nil { + return r.preValues, r.postValues + } + + pre := make([]interface{}, 0, len(indexInfo.Columns)) + post := make([]interface{}, 0, len(indexInfo.Columns)) + + for _, column := range indexInfo.Columns { + if r.preValues != nil { + pre = append(pre, r.preValues[column.Offset]) + } + if r.postValues != nil { + post = append(post, r.postValues[column.Offset]) + } + } + return pre, post +} + +func (r *RowChange) IsIdentityUpdated() bool { + if r.tp != RowChangeUpdate { + return false + } + + r.lazyInitIdentityInfo() + pre, post := r.IdentityValues() + if len(pre) != len(post) { + // should not happen + return true + } + for i := range pre { + if pre[i] != post[i] { + return true + } + } + return false +} + +// genKey gens key by values e.g. "a.1.b". +func genKey(values []interface{}) string { + builder := new(strings.Builder) + for i, v := range values { + if i != 0 { + builder.WriteString(".") + } + fmt.Fprintf(builder, "%v", v) + } + + return builder.String() +} + +// IdentityKey returns a string generated by IdentityValues. +// If RowChange.IsIdentityUpdated, the behaviour is undefined. +func (r *RowChange) IdentityKey() string { + pre, post := r.IdentityValues() + if len(pre) != 0 { + return genKey(pre) + } + return genKey(post) +} + +// Reduce will merge two row changes of same row into one row changes, +// e.g., INSERT{1} + UPDATE{1 -> 2} -> INSERT{2}. Receiver will be changed +// in-place. +func (r *RowChange) Reduce(preRowChange *RowChange) { + if r.IdentityKey() != preRowChange.IdentityKey() { + log.L().DPanic("reduce row change failed, identity key not match", + zap.String("preID", preRowChange.IdentityKey()), + zap.String("curID", r.IdentityKey())) + return + } + + // special handle INSERT + DELETE -> DELETE + if r.tp == RowChangeDelete && preRowChange.tp == RowChangeInsert { + return + } + + r.preValues = preRowChange.preValues + r.calculateType() +} + +// Split will split current RowChangeUpdate into two RowChangeDelete and +// RowChangeInsert one. The behaviour is undefined for other types of RowChange. +func (r *RowChange) Split() (*RowChange, *RowChange) { + if r.tp != RowChangeUpdate { + log.L().DPanic("Split should only be called on RowChangeUpdate", + zap.Stringer("rowChange", r)) + return nil, nil + } + + pre := &RowChange{ + sourceTable: r.sourceTable, + targetTable: r.targetTable, + preValues: r.preValues, + sourceTableInfo: r.sourceTableInfo, + targetTableInfo: r.targetTableInfo, + tiSessionCtx: r.tiSessionCtx, + tp: RowChangeDelete, + identityInfo: r.identityInfo, + } + post := &RowChange{ + sourceTable: r.sourceTable, + targetTable: r.targetTable, + postValues: r.postValues, + sourceTableInfo: r.sourceTableInfo, + targetTableInfo: r.targetTableInfo, + tiSessionCtx: r.tiSessionCtx, + tp: RowChangeInsert, + identityInfo: r.identityInfo, + } + + return pre, post +} diff --git a/pkg/sqlmodel/reduce_test.go b/pkg/sqlmodel/reduce_test.go new file mode 100644 index 00000000000..876e2089252 --- /dev/null +++ b/pkg/sqlmodel/reduce_test.go @@ -0,0 +1,130 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "testing" + + "github.com/stretchr/testify/require" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" +) + +func TestIdentity(t *testing.T) { + t.Parallel() + + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") + + change := NewRowChange(source, nil, []interface{}{1, 2}, nil, sourceTI1, nil, nil) + pre, post := change.IdentityValues() + require.Equal(t, []interface{}{1}, pre) + require.Len(t, post, 0) + + change = NewRowChange(source, nil, []interface{}{1, 2}, []interface{}{1, 4}, sourceTI1, nil, nil) + pre, post = change.IdentityValues() + require.Equal(t, []interface{}{1}, pre) + require.Equal(t, []interface{}{1}, post) + require.False(t, change.IsIdentityUpdated()) + + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c2 INT)") + change = NewRowChange(source, nil, nil, []interface{}{5, 6}, sourceTI2, nil, nil) + pre, post = change.IdentityValues() + require.Len(t, pre, 0) + require.Equal(t, []interface{}{5, 6}, post) +} + +func TestSplit(t *testing.T) { + t.Parallel() + + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") + + change := NewRowChange(source, nil, []interface{}{1, 2}, []interface{}{3, 4}, sourceTI1, nil, nil) + require.True(t, change.IsIdentityUpdated()) + del, ins := change.Split() + delIDKey := del.IdentityKey() + require.NotZero(t, delIDKey) + insIDKey := ins.IdentityKey() + require.NotZero(t, insIDKey) + require.NotEqual(t, delIDKey, insIDKey) +} + +func (s *dpanicSuite) TestReduce() { + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + sourceTI := mockTableInfo(s.T(), "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") + + cases := []struct { + pre1 []interface{} + post1 []interface{} + pre2 []interface{} + post2 []interface{} + preAfter []interface{} + postAfter []interface{} + }{ + // INSERT + UPDATE + { + nil, + []interface{}{1, 2}, + []interface{}{1, 2}, + []interface{}{3, 4}, + nil, + []interface{}{3, 4}, + }, + // INSERT + DELETE + { + nil, + []interface{}{1, 2}, + []interface{}{1, 2}, + nil, + []interface{}{1, 2}, + nil, + }, + // UPDATE + UPDATE + { + []interface{}{1, 2}, + []interface{}{1, 3}, + []interface{}{1, 3}, + []interface{}{1, 4}, + []interface{}{1, 2}, + []interface{}{1, 4}, + }, + // UPDATE + DELETE + { + []interface{}{1, 2}, + []interface{}{1, 3}, + []interface{}{1, 3}, + nil, + []interface{}{1, 2}, + nil, + }, + } + + for _, c := range cases { + change1 := NewRowChange(source, nil, c.pre1, c.post1, sourceTI, nil, nil) + change2 := NewRowChange(source, nil, c.pre2, c.post2, sourceTI, nil, nil) + changeAfter := NewRowChange(source, nil, c.preAfter, c.postAfter, sourceTI, nil, nil) + changeAfter.lazyInitIdentityInfo() + + change2.Reduce(change1) + s.Equal(changeAfter, change2) + } + + // test reduce on IdentityUpdated will DPanic + change1 := NewRowChange(source, nil, []interface{}{1, 2}, []interface{}{3, 4}, sourceTI, nil, nil) + change2 := NewRowChange(source, nil, []interface{}{3, 4}, []interface{}{5, 6}, sourceTI, nil, nil) + s.Panics(func() { + change2.Reduce(change1) + }) +} diff --git a/pkg/sqlmodel/row_change.go b/pkg/sqlmodel/row_change.go new file mode 100644 index 00000000000..a4a3cbfd8ca --- /dev/null +++ b/pkg/sqlmodel/row_change.go @@ -0,0 +1,377 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "fmt" + "strings" + + "github.com/pingcap/failpoint" + timodel "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "go.uber.org/zap" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/dm/pkg/schema" + "github.com/pingcap/tiflow/dm/pkg/utils" + "github.com/pingcap/tiflow/pkg/quotes" +) + +type RowChangeType int + +// these constants represent types of row change. +const ( + RowChangeNull RowChangeType = iota + RowChangeInsert + RowChangeUpdate + RowChangeDelete +) + +// String implements fmt.Stringer interface. +func (t RowChangeType) String() string { + switch t { + case RowChangeInsert: + return "ChangeInsert" + case RowChangeUpdate: + return "ChangeUpdate" + case RowChangeDelete: + return "ChangeDelete" + } + + return "" +} + +// RowChange represents a row change, it can be further converted into DML SQL. +type RowChange struct { + sourceTable *cdcmodel.TableName + targetTable *cdcmodel.TableName + + preValues []interface{} + postValues []interface{} + + sourceTableInfo *timodel.TableInfo + targetTableInfo *timodel.TableInfo + + tiSessionCtx sessionctx.Context + + tp RowChangeType + identityInfo *schema.DownstreamTableInfo +} + +// NewRowChange creates a new RowChange. +// preValues stands for values exists before this change, postValues stands for +// values exists after this change. +// These parameters can be nil: +// - targetTable: when same as sourceTable or not applicable +// - preValues: when INSERT +// - postValues: when DELETE +// - targetTableInfo: when same as sourceTableInfo or not applicable +// - tiSessionCtx: will use default sessionCtx which is UTC timezone +// All arguments must not be changed after assigned to RowChange, any +// modification (like convert []byte to string) should be done before +// NewRowChange. +func NewRowChange( + sourceTable *cdcmodel.TableName, + targetTable *cdcmodel.TableName, + preValues []interface{}, + postValues []interface{}, + sourceTableInfo *timodel.TableInfo, + downstreamTableInfo *timodel.TableInfo, + tiCtx sessionctx.Context, +) *RowChange { + ret := &RowChange{ + sourceTable: sourceTable, + preValues: preValues, + postValues: postValues, + sourceTableInfo: sourceTableInfo, + } + + if preValues != nil && len(preValues) != len(sourceTableInfo.Columns) { + log.L().DPanic("preValues length not equal to sourceTableInfo columns", + zap.Int("preValues", len(preValues)), + zap.Int("sourceTableInfo", len(sourceTableInfo.Columns)), + zap.Stringer("sourceTable", sourceTable)) + } + if postValues != nil && len(postValues) != len(sourceTableInfo.Columns) { + log.L().DPanic("postValues length not equal to sourceTableInfo columns", + zap.Int("postValues", len(postValues)), + zap.Int("sourceTableInfo", len(sourceTableInfo.Columns)), + zap.Stringer("sourceTable", sourceTable)) + } + + if targetTable != nil { + ret.targetTable = targetTable + } else { + ret.targetTable = sourceTable + } + + if downstreamTableInfo != nil { + ret.targetTableInfo = downstreamTableInfo + } else { + ret.targetTableInfo = sourceTableInfo + } + + if tiCtx != nil { + ret.tiSessionCtx = tiCtx + } else { + ret.tiSessionCtx = utils.ZeroSessionCtx + } + + ret.calculateType() + + return ret +} + +func (r *RowChange) calculateType() { + switch { + case r.preValues == nil && r.postValues != nil: + r.tp = RowChangeInsert + case r.preValues != nil && r.postValues != nil: + r.tp = RowChangeUpdate + case r.preValues != nil && r.postValues == nil: + r.tp = RowChangeDelete + default: + log.L().DPanic("preValues and postValues can't both be nil", + zap.Stringer("sourceTable", r.sourceTable)) + } +} + +// Type returns the RowChangeType of this RowChange. Caller can future decide +// the DMLType when generate DML from it. +func (r *RowChange) Type() RowChangeType { + return r.tp +} + +// String implements Stringer interface. +func (r *RowChange) String() string { + return fmt.Sprintf("type: %s, source table: %s, target table: %s, preValues: %v, postValues: %v", + r.tp, r.sourceTable, r.targetTable, r.preValues, r.postValues) +} + +// TargetTableID returns a ID string for target table. +func (r *RowChange) TargetTableID() string { + return r.targetTable.QuoteString() +} + +// SetIdentifyInfo can be used when caller has calculated and cached +// identityInfo, to avoid every RowChange lazily initialize it. +func (r *RowChange) SetIdentifyInfo(info *schema.DownstreamTableInfo) { + r.identityInfo = info +} + +func (r *RowChange) lazyInitIdentityInfo() { + if r.identityInfo != nil { + return + } + + r.identityInfo = schema.GetDownStreamTI(r.targetTableInfo, r.sourceTableInfo) +} + +func getColsAndValuesOfIdx( + columns []*timodel.ColumnInfo, + indexColumns *timodel.IndexInfo, + data []interface{}, +) ([]*timodel.ColumnInfo, []interface{}) { + cols := make([]*timodel.ColumnInfo, 0, len(indexColumns.Columns)) + values := make([]interface{}, 0, len(indexColumns.Columns)) + for _, col := range indexColumns.Columns { + cols = append(cols, columns[col.Offset]) + values = append(values, data[col.Offset]) + } + + return cols, values +} + +// whereColumnsAndValues returns columns and values to identify the row, to form +// the WHERE clause. +func (r *RowChange) whereColumnsAndValues() ([]string, []interface{}) { + r.lazyInitIdentityInfo() + + uniqueIndex := r.identityInfo.AbsoluteUKIndexInfo + if uniqueIndex == nil { + uniqueIndex = schema.GetIdentityUKByData(r.identityInfo, r.preValues) + } + + columns, values := r.sourceTableInfo.Columns, r.preValues + if uniqueIndex != nil { + columns, values = getColsAndValuesOfIdx(r.sourceTableInfo.Columns, uniqueIndex, values) + } + + columnNames := make([]string, 0, len(columns)) + for _, column := range columns { + columnNames = append(columnNames, column.Name.O) + } + + failpoint.Inject("DownstreamTrackerWhereCheck", func() { + if r.tp == RowChangeUpdate { + log.L().Info("UpdateWhereColumnsCheck", + zap.String("Columns", fmt.Sprintf("%v", columnNames))) + } else if r.tp == RowChangeDelete { + log.L().Info("DeleteWhereColumnsCheck", + zap.String("Columns", fmt.Sprintf("%v", columnNames))) + } + }) + + return columnNames, values +} + +// genWhere generates WHERE clause for UPDATE and DELETE to identify the row. +// the SQL part is written to `buf` and the args part is returned. +func (r *RowChange) genWhere(buf *strings.Builder) []interface{} { + whereColumns, whereValues := r.whereColumnsAndValues() + + for i, col := range whereColumns { + if i != 0 { + buf.WriteString(" AND ") + } + buf.WriteString(quotes.QuoteName(col)) + if whereValues[i] == nil { + buf.WriteString(" IS ?") + } else { + buf.WriteString(" = ?") + } + } + return whereValues +} + +// valuesHolder gens values holder like (?,?,?). +func valuesHolder(n int) string { + var builder strings.Builder + builder.Grow((n-1)*2 + 3) + builder.WriteByte('(') + for i := 0; i < n; i++ { + if i > 0 { + builder.WriteString(",") + } + builder.WriteString("?") + } + builder.WriteByte(')') + return builder.String() +} + +func (r *RowChange) genDeleteSQL() (string, []interface{}) { + if r.tp != RowChangeDelete && r.tp != RowChangeUpdate { + log.L().DPanic("illegal type for genDeleteSQL", + zap.String("sourceTable", r.sourceTable.String()), + zap.Stringer("changeType", r.tp)) + return "", nil + } + + var buf strings.Builder + buf.Grow(1024) + buf.WriteString("DELETE FROM ") + buf.WriteString(r.targetTable.QuoteString()) + buf.WriteString(" WHERE ") + whereArgs := r.genWhere(&buf) + buf.WriteString(" LIMIT 1") + + return buf.String(), whereArgs +} + +func isGenerated(columns []*timodel.ColumnInfo, name timodel.CIStr) bool { + for _, col := range columns { + if col.Name.L == name.L { + return col.IsGenerated() + } + } + return false +} + +func (r *RowChange) genUpdateSQL() (string, []interface{}) { + if r.tp != RowChangeUpdate { + log.L().DPanic("illegal type for genUpdateSQL", + zap.String("sourceTable", r.sourceTable.String()), + zap.Stringer("changeType", r.tp)) + return "", nil + } + + var buf strings.Builder + buf.Grow(2048) + buf.WriteString("UPDATE ") + buf.WriteString(r.targetTable.QuoteString()) + buf.WriteString(" SET ") + + args := make([]interface{}, 0, len(r.preValues)+len(r.postValues)) + writtenFirstCol := false + for i, col := range r.sourceTableInfo.Columns { + if isGenerated(r.targetTableInfo.Columns, col.Name) { + continue + } + + if writtenFirstCol { + buf.WriteString(", ") + } + writtenFirstCol = true + fmt.Fprintf(&buf, "%s = ?", quotes.QuoteName(col.Name.O)) + args = append(args, r.postValues[i]) + } + + buf.WriteString(" WHERE ") + whereArgs := r.genWhere(&buf) + buf.WriteString(" LIMIT 1") + + args = append(args, whereArgs...) + return buf.String(), args +} + +func (r *RowChange) genInsertSQL(tp DMLType) (string, []interface{}) { + return GenInsertSQL(tp, r) +} + +type DMLType int + +// these constants represent types of row change. +const ( + DMLNull DMLType = iota + DMLInsert + DMLReplace + DMLInsertOnDuplicateUpdate + DMLUpdate + DMLDelete +) + +// String implements fmt.Stringer interface. +func (t DMLType) String() string { + switch t { + case DMLInsert: + return "DMLInsert" + case DMLReplace: + return "DMLReplace" + case DMLUpdate: + return "DMLUpdate" + case DMLInsertOnDuplicateUpdate: + return "DMLInsertOnDuplicateUpdate" + case DMLDelete: + return "DMLDelete" + } + + return "" +} + +// GenSQL generated a DML SQL for this RowChange. +func (r *RowChange) GenSQL(tp DMLType) (string, []interface{}) { + switch tp { + case DMLInsert, DMLReplace, DMLInsertOnDuplicateUpdate: + return r.genInsertSQL(tp) + case DMLUpdate: + return r.genUpdateSQL() + case DMLDelete: + return r.genDeleteSQL() + } + log.L().DPanic("illegal type for GenSQL", + zap.String("sourceTable", r.sourceTable.String()), + zap.Stringer("DMLType", tp)) + return "", nil +} diff --git a/pkg/sqlmodel/row_change_test.go b/pkg/sqlmodel/row_change_test.go new file mode 100644 index 00000000000..5d04a71930c --- /dev/null +++ b/pkg/sqlmodel/row_change_test.go @@ -0,0 +1,345 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "testing" + + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + timodel "github.com/pingcap/tidb/parser/model" + timock "github.com/pingcap/tidb/util/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/dm/pkg/utils" +) + +func mockTableInfo(t *testing.T, sql string) *timodel.TableInfo { + p := parser.New() + se := timock.NewContext() + node, err := p.ParseOneStmt(sql, "", "") + require.NoError(t, err) + ti, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 1) + require.NoError(t, err) + return ti +} + +type dpanicSuite struct { + suite.Suite +} + +func (s *dpanicSuite) SetupSuite() { + err := log.InitLogger(&log.Config{Level: "debug"}) + s.NoError(err) +} + +func TestDpanicSuite(t *testing.T) { + suite.Run(t, new(dpanicSuite)) +} + +func TestNewRowChange(t *testing.T) { + t.Parallel() + + source := &cdcmodel.TableName{Schema: "db", Table: "tbl"} + target := &cdcmodel.TableName{Schema: "db", Table: "tbl_routed"} + sourceTI := mockTableInfo(t, "CREATE TABLE tbl (id INT PRIMARY KEY, name INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tbl_routed (id INT PRIMARY KEY, name INT)") + tiSession := utils.NewSessionCtx(map[string]string{ + "time_zone": "+08:00", + }) + + expected := &RowChange{ + sourceTable: source, + targetTable: target, + preValues: []interface{}{1, 2}, + postValues: []interface{}{1, 3}, + sourceTableInfo: sourceTI, + targetTableInfo: targetTI, + tiSessionCtx: tiSession, + tp: RowChangeUpdate, + identityInfo: nil, + } + + actual := NewRowChange(source, target, []interface{}{1, 2}, []interface{}{1, 3}, sourceTI, targetTI, tiSession) + require.Equal(t, expected, actual) + + actual.lazyInitIdentityInfo() + require.NotNil(t, actual.identityInfo) + + // test some arguments of NewRowChange can be nil + + expected.targetTable = expected.sourceTable + expected.targetTableInfo = expected.sourceTableInfo + expected.tiSessionCtx = utils.ZeroSessionCtx + expected.identityInfo = nil + actual = NewRowChange(source, nil, []interface{}{1, 2}, []interface{}{1, 3}, sourceTI, nil, nil) + require.Equal(t, expected, actual) +} + +func (s *dpanicSuite) TestRowChangeType() { + change := &RowChange{preValues: []interface{}{1}} + change.calculateType() + s.Equal(RowChangeDelete, change.tp) + change = &RowChange{preValues: []interface{}{1}, postValues: []interface{}{2}} + change.calculateType() + s.Equal(RowChangeUpdate, change.tp) + change = &RowChange{postValues: []interface{}{1}} + change.calculateType() + s.Equal(RowChangeInsert, change.tp) + + s.Panics(func() { + change = &RowChange{} + change.calculateType() + }) +} + +func (s *dpanicSuite) TestGenDelete() { + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + + cases := []struct { + sourceCreateSQL string + targetCreateSQL string + preValues []interface{} + + expectedSQL string + expectedArgs []interface{} + }{ + { + "CREATE TABLE tb1 (id INT PRIMARY KEY, name INT)", + "CREATE TABLE tb2 (id INT PRIMARY KEY, name INT, extra VARCHAR(20))", + []interface{}{1, 2}, + + "DELETE FROM `db`.`tb2` WHERE `id` = ? LIMIT 1", + []interface{}{1}, + }, + { + "CREATE TABLE tb1 (c INT, c2 INT UNIQUE)", + "CREATE TABLE tb2 (c INT, c2 INT UNIQUE)", + []interface{}{1, 2}, + + "DELETE FROM `db`.`tb2` WHERE `c2` = ? LIMIT 1", + []interface{}{2}, + }, + // next 2 cases test NULL value + { + "CREATE TABLE tb1 (c INT, c2 INT UNIQUE)", + "CREATE TABLE tb2 (c INT, c2 INT UNIQUE)", + []interface{}{1, nil}, + + "DELETE FROM `db`.`tb2` WHERE `c` = ? AND `c2` IS ? LIMIT 1", + []interface{}{1, nil}, + }, + { + "CREATE TABLE tb1 (c INT, c2 INT)", + "CREATE TABLE tb2 (c INT, c2 INT)", + []interface{}{1, nil}, + + "DELETE FROM `db`.`tb2` WHERE `c` = ? AND `c2` IS ? LIMIT 1", + []interface{}{1, nil}, + }, + // next 2 cases test using downstream table to generate WHERE + { + "CREATE TABLE tb1 (id INT PRIMARY KEY, user_id INT NOT NULL UNIQUE)", + "CREATE TABLE tb2 (new_id INT PRIMARY KEY, id INT, user_id INT NOT NULL UNIQUE)", + []interface{}{1, 2}, + + "DELETE FROM `db`.`tb2` WHERE `user_id` = ? LIMIT 1", + []interface{}{2}, + }, + { + "CREATE TABLE tb1 (id INT PRIMARY KEY, c2 INT)", + "CREATE TABLE tb2 (new_id INT PRIMARY KEY, id INT, c2 INT)", + []interface{}{1, 2}, + + "DELETE FROM `db`.`tb2` WHERE `id` = ? AND `c2` = ? LIMIT 1", + []interface{}{1, 2}, + }, + } + + for _, c := range cases { + sourceTI := mockTableInfo(s.T(), c.sourceCreateSQL) + targetTI := mockTableInfo(s.T(), c.targetCreateSQL) + change := NewRowChange(source, target, c.preValues, nil, sourceTI, targetTI, nil) + sql, args := change.GenSQL(DMLDelete) + s.Equal(c.expectedSQL, sql) + s.Equal(c.expectedArgs, args) + } + + // a RowChangeUpdate can still generate DELETE SQL + sourceTI := mockTableInfo(s.T(), "CREATE TABLE tb1 (id INT PRIMARY KEY, name INT)") + change := NewRowChange(source, nil, []interface{}{1, 2}, []interface{}{3, 4}, sourceTI, nil, nil) + sql, args := change.GenSQL(DMLDelete) + s.Equal("DELETE FROM `db`.`tb1` WHERE `id` = ? LIMIT 1", sql) + s.Equal([]interface{}{1}, args) + + change = NewRowChange(source, nil, nil, []interface{}{3, 4}, sourceTI, nil, nil) + s.Panics(func() { + change.GenSQL(DMLDelete) + }) +} + +func (s *dpanicSuite) TestGenUpdate() { + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + + cases := []struct { + sourceCreateSQL string + targetCreateSQL string + preValues []interface{} + postValues []interface{} + + expectedSQL string + expectedArgs []interface{} + }{ + { + "CREATE TABLE tb1 (id INT PRIMARY KEY, name INT)", + "CREATE TABLE tb2 (id INT PRIMARY KEY, name INT, extra VARCHAR(20))", + []interface{}{1, 2}, + []interface{}{3, 4}, + + "UPDATE `db`.`tb2` SET `id` = ?, `name` = ? WHERE `id` = ? LIMIT 1", + []interface{}{3, 4, 1}, + }, + { + "CREATE TABLE tb1 (id INT UNIQUE, name INT)", + "CREATE TABLE tb2 (id INT UNIQUE, name INT)", + []interface{}{nil, 2}, + []interface{}{3, 4}, + + "UPDATE `db`.`tb2` SET `id` = ?, `name` = ? WHERE `id` IS ? AND `name` = ? LIMIT 1", + []interface{}{3, 4, nil, 2}, + }, + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)", + "CREATE TABLE tb2 (c INT, c2 INT)", + []interface{}{1, 2}, + []interface{}{3, 4}, + + "UPDATE `db`.`tb2` SET `c` = ?, `c2` = ? WHERE `c` = ? AND `c2` = ? LIMIT 1", + []interface{}{3, 4, 1, 2}, + }, + // next 2 cases test generated column + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT AS (c+1))", + "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT AS (c+1))", + []interface{}{1, 2}, + []interface{}{3, 4}, + + "UPDATE `db`.`tb2` SET `c` = ? WHERE `c` = ? LIMIT 1", + []interface{}{3, 1}, + }, + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT AS (c+1))", + "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)", + []interface{}{1, 2}, + []interface{}{3, 4}, + + "UPDATE `db`.`tb2` SET `c` = ?, `c2` = ? WHERE `c` = ? LIMIT 1", + []interface{}{3, 4, 1}, + }, + } + + for _, c := range cases { + sourceTI := mockTableInfo(s.T(), c.sourceCreateSQL) + targetTI := mockTableInfo(s.T(), c.targetCreateSQL) + change := NewRowChange(source, target, c.preValues, c.postValues, sourceTI, targetTI, nil) + sql, args := change.GenSQL(DMLUpdate) + s.Equal(c.expectedSQL, sql) + s.Equal(c.expectedArgs, args) + } + + sourceTI := mockTableInfo(s.T(), "CREATE TABLE tb1 (id INT PRIMARY KEY, name INT)") + change := NewRowChange(source, nil, nil, []interface{}{3, 4}, sourceTI, nil, nil) + s.Panics(func() { + change.GenSQL(DMLUpdate) + }) +} + +func TestGenInsert(t *testing.T) { + t.Parallel() + + source := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + + cases := []struct { + sourceCreateSQL string + targetCreateSQL string + postValues []interface{} + + expectedInsertSQL string + expectedReplaceSQL string + expectedInsertOnDupSQL string + expectedArgs []interface{} + }{ + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)", + "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT, extra VARCHAR(20))", + []interface{}{1, 2}, + + "INSERT INTO `db`.`tb2` (`c`,`c2`) VALUES (?,?)", + "REPLACE INTO `db`.`tb2` (`c`,`c2`) VALUES (?,?)", + "INSERT INTO `db`.`tb2` (`c`,`c2`) VALUES (?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", + []interface{}{1, 2}, + }, + // next 2 cases test generated column + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT AS (c+1))", + "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT AS (c+1))", + []interface{}{1, 2}, + + "INSERT INTO `db`.`tb2` (`c`) VALUES (?)", + "REPLACE INTO `db`.`tb2` (`c`) VALUES (?)", + "INSERT INTO `db`.`tb2` (`c`) VALUES (?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`)", + []interface{}{1}, + }, + { + "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT AS (c+1))", + "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)", + []interface{}{1, 2}, + + "INSERT INTO `db`.`tb2` (`c`,`c2`) VALUES (?,?)", + "REPLACE INTO `db`.`tb2` (`c`,`c2`) VALUES (?,?)", + "INSERT INTO `db`.`tb2` (`c`,`c2`) VALUES (?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", + []interface{}{1, 2}, + }, + } + + for _, c := range cases { + sourceTI := mockTableInfo(t, c.sourceCreateSQL) + targetTI := mockTableInfo(t, c.targetCreateSQL) + change := NewRowChange(source, target, nil, c.postValues, sourceTI, targetTI, nil) + sql, args := change.GenSQL(DMLInsert) + require.Equal(t, c.expectedInsertSQL, sql) + require.Equal(t, c.expectedArgs, args) + sql, args = change.GenSQL(DMLReplace) + require.Equal(t, c.expectedReplaceSQL, sql) + require.Equal(t, c.expectedArgs, args) + sql, args = change.GenSQL(DMLInsertOnDuplicateUpdate) + require.Equal(t, c.expectedInsertOnDupSQL, sql) + require.Equal(t, c.expectedArgs, args) + } +} + +func TestValuesHolder(t *testing.T) { + t.Parallel() + + require.Equal(t, "()", valuesHolder(0)) + require.Equal(t, "(?)", valuesHolder(1)) + require.Equal(t, "(?,?)", valuesHolder(2)) +}