Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check, diff: auto discover ansi-quotes #381

Merged
merged 12 commits into from
Aug 25, 2020
46 changes: 17 additions & 29 deletions pkg/check/table_structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,17 @@ func (o *incompatibilityOption) String() string {
// In generally we need to check definitions of columns, constraints and table options.
// Because of the early TiDB engineering design, we did not have a complete list of check items, which are all based on experience now.
type TablesChecker struct {
db *sql.DB
dbinfo *dbutil.DBConfig
tables map[string][]string // schema => []table; if []table is empty, query tables from db
enableANSIQuotes bool
db *sql.DB
dbinfo *dbutil.DBConfig
tables map[string][]string // schema => []table; if []table is empty, query tables from db
}

// NewTablesChecker returns a Checker
func NewTablesChecker(db *sql.DB, dbinfo *dbutil.DBConfig, tables map[string][]string, enableANSIQuotes bool) Checker {
func NewTablesChecker(db *sql.DB, dbinfo *dbutil.DBConfig, tables map[string][]string) Checker {
return &TablesChecker{
db: db,
dbinfo: dbinfo,
tables: tables,
enableANSIQuotes: enableANSIQuotes,
db: db,
dbinfo: dbinfo,
tables: tables,
}
}

Expand Down Expand Up @@ -161,11 +159,7 @@ func (c *TablesChecker) Name() string {
}

func (c *TablesChecker) checkCreateSQL(statement string) []*incompatibilityOption {
sqlMode := ""
if c.enableANSIQuotes {
sqlMode = "ANSI_QUOTES"
}
parser2, err := dbutil.GetParser(sqlMode)
parser2, err := dbutil.GetParserForDB(c.db)
if err != nil {
return []*incompatibilityOption{
{
Expand Down Expand Up @@ -291,18 +285,16 @@ type ShardingTablesCheck struct {
tables map[string]map[string][]string // instance => {schema: [table1, table2, ...]}
mapping map[string]*column.Mapping
checkAutoIncrementPrimaryKey bool
enableANSIQuotes bool
}

// NewShardingTablesCheck returns a Checker
func NewShardingTablesCheck(name string, dbs map[string]*sql.DB, tables map[string]map[string][]string, mapping map[string]*column.Mapping, checkAutoIncrementPrimaryKey bool, enableANSIQuotes bool) Checker {
func NewShardingTablesCheck(name string, dbs map[string]*sql.DB, tables map[string]map[string][]string, mapping map[string]*column.Mapping, checkAutoIncrementPrimaryKey bool) Checker {
return &ShardingTablesCheck{
name: name,
dbs: dbs,
tables: tables,
mapping: mapping,
checkAutoIncrementPrimaryKey: checkAutoIncrementPrimaryKey,
enableANSIQuotes: enableANSIQuotes,
}
}

Expand All @@ -320,24 +312,20 @@ func (c *ShardingTablesCheck) Check(ctx context.Context) *Result {
tableName string
)

sqlMode := ""
if c.enableANSIQuotes {
sqlMode = "ANSI_QUOTES"
}
parser2, err := dbutil.GetParser(sqlMode)
if err != nil {
markCheckError(r, err)
r.Extra = fmt.Sprintf("fail to get parser")
return r
}

for instance, schemas := range c.tables {
db, ok := c.dbs[instance]
if !ok {
markCheckError(r, errors.NotFoundf("client for instance %s", instance))
return r
}

parser2, err := dbutil.GetParserForDB(db)
if err != nil {
markCheckError(r, err)
r.Extra = fmt.Sprintf("fail to get parser for instance %s on sharding %s", instance, c.name)
return r
}

for schema, tables := range schemas {
for _, table := range tables {
statement, err := dbutil.GetCreateTableSQL(ctx, db, schema, table)
Expand All @@ -347,7 +335,7 @@ func (c *ShardingTablesCheck) Check(ctx context.Context) *Result {
return r
}

info, err := dbutil.GetTableInfoBySQL(statement, sqlMode)
info, err := dbutil.GetTableInfoBySQL(statement, parser2)
if err != nil {
markCheckError(r, err)
r.Extra = fmt.Sprintf("instance %s on sharding %s", instance, c.name)
Expand Down
74 changes: 70 additions & 4 deletions pkg/dbutil/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ type DBConfig struct {
Schema string `toml:"schema" json:"schema"`

Snapshot string `toml:"snapshot" json:"snapshot"`

SQLMode string `toml:"sql-mode" json:"sql-mode"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do other places still need this (sql_mode may including others rather than ANSI_QUOTES)?

But we may remove it now and add it back if needed later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess other places don't require this SQLMode otherwise it won't compile. and in #280 this field is used for parser only

}

// String returns native format of database configuration
Expand Down Expand Up @@ -606,6 +604,61 @@ func GetDBVersion(ctx context.Context, db *sql.DB) (string, error) {
return "", ErrVersionNotFound
}

// GetSessionVariable gets server's session variable, although argument is *sql.DB, (session) system variables may be
// set through DSN
func GetSessionVariable(db *sql.DB, variable string) (value string, err error) {
query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable)
rows, err := db.Query(query)

if err != nil {
return "", errors.Trace(err)
}
defer rows.Close()

// Show an example.
/*
mysql> SHOW VARIABLES LIKE "binlog_format";
+---------------+-------+
| Variable_name | Value |
+---------------+-------+
| binlog_format | ROW |
+---------------+-------+
*/

for rows.Next() {
err = rows.Scan(&variable, &value)
if err != nil {
return "", errors.Trace(err)
}
}

if rows.Err() != nil {
return "", errors.Trace(err)
}

return value, nil
}

// GetSQLMode returns sql_mode.
func GetSQLMode(db *sql.DB) (tmysql.SQLMode, error) {
sqlMode, err := GetSessionVariable(db, "sql_mode")
if err != nil {
return tmysql.ModeNone, err
}

mode, err := tmysql.GetSQLMode(sqlMode)
return mode, errors.Trace(err)
}

// HasAnsiQuotesMode checks whether database has `ANSI_QUOTES` set
func HasAnsiQuotesMode(db *sql.DB) (bool, error) {
mode, err := GetSQLMode(db)
if err != nil {
return false, err
}
return mode.HasANSIQuotesMode(), nil
}

// IsTiDB returns true if this database is tidb
func IsTiDB(ctx context.Context, db *sql.DB) (bool, error) {
version, err := GetDBVersion(ctx, db)
Expand Down Expand Up @@ -766,8 +819,8 @@ func DeleteRows(ctx context.Context, db *sql.DB, schemaName string, tableName st
return DeleteRows(ctx, db, schemaName, tableName, where, args)
}

// GetParser gets parser according to sql mode
func GetParser(sqlModeStr string) (*parser.Parser, error) {
// getParser gets parser according to sql mode
func getParser(sqlModeStr string) (*parser.Parser, error) {
if len(sqlModeStr) == 0 {
return parser.New(), nil
}
Expand All @@ -780,3 +833,16 @@ func GetParser(sqlModeStr string) (*parser.Parser, error) {
parser2.SetSQLMode(sqlMode)
return parser2, nil
}

// GetParserForDB discovers ANSI_QUOTES in db's session variables and returns a proper parser
func GetParserForDB(db *sql.DB) (*parser.Parser, error) {
ansiQuotes, err := HasAnsiQuotesMode(db)
if err != nil {
return nil, err
}
sqlMode := ""
if ansiQuotes {
sqlMode = "ANSI_QUOTES"
}
return getParser(sqlMode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set sql_mode for all modes from GetSQLMode rather than only ANSI_QUOTES?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f70d144

Okay, going to align DM's parser to this behaviour

}
2 changes: 1 addition & 1 deletion pkg/dbutil/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (s *testDBSuite) TestGetParser(c *C) {
}

for _, testCase := range testCases {
parser, err := GetParser(testCase.sqlModeStr)
parser, err := getParser(testCase.sqlModeStr)
if testCase.hasErr {
c.Assert(err, NotNil)
} else {
Expand Down
3 changes: 2 additions & 1 deletion pkg/dbutil/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package dbutil

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser"
)

func (*testDBSuite) TestIndex(c *C) {
Expand Down Expand Up @@ -81,7 +82,7 @@ func (*testDBSuite) TestIndex(c *C) {
}

for _, testCase := range testCases {
tableInfo, err := GetTableInfoBySQL(testCase.sql, "")
tableInfo, err := GetTableInfoBySQL(testCase.sql, parser.New())
c.Assert(err, IsNil)

indices := FindAllIndex(tableInfo)
Expand Down
14 changes: 7 additions & 7 deletions pkg/dbutil/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/tidb/ddl"
Expand All @@ -28,22 +29,21 @@ import (
)

// GetTableInfo returns table information.
func GetTableInfo(ctx context.Context, db *sql.DB, schemaName string, tableName string, sqlMode string) (*model.TableInfo, error) {
func GetTableInfo(ctx context.Context, db *sql.DB, schemaName string, tableName string) (*model.TableInfo, error) {
createTableSQL, err := GetCreateTableSQL(ctx, db, schemaName, tableName)
if err != nil {
return nil, errors.Trace(err)
}

return GetTableInfoBySQL(createTableSQL, sqlMode)
}

// GetTableInfoBySQL returns table information by given create table sql.
func GetTableInfoBySQL(createTableSQL string, sqlMode string) (table *model.TableInfo, err error) {
parser2, err := GetParser(sqlMode)
parser2, err := GetParserForDB(db)
if err != nil {
return nil, errors.Trace(err)
}
return GetTableInfoBySQL(createTableSQL, parser2)
}

// GetTableInfoBySQL returns table information by given create table sql.
func GetTableInfoBySQL(createTableSQL string, parser2 *parser.Parser) (table *model.TableInfo, err error) {
stmt, err := parser2.ParseOneStmt(createTableSQL, "", "")
if err != nil {
return nil, errors.Trace(err)
Expand Down
11 changes: 7 additions & 4 deletions pkg/dbutil/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/parser"
"github.com/pingcap/parser/mysql"
)

Expand Down Expand Up @@ -78,7 +79,7 @@ func (*testDBSuite) TestTable(c *C) {
}

for _, testCase := range testCases {
tableInfo, err := GetTableInfoBySQL(testCase.sql, "")
tableInfo, err := GetTableInfoBySQL(testCase.sql, parser.New())
c.Assert(err, IsNil)
for i, column := range tableInfo.Columns {
c.Assert(testCase.columns[i], Equals, column.Name.O)
Expand All @@ -95,15 +96,17 @@ func (*testDBSuite) TestTable(c *C) {

func (*testDBSuite) TestTableStructEqual(c *C) {
createTableSQL1 := "CREATE TABLE `test`.`atest` (`id` int(24), `name` varchar(24), `birthday` datetime, `update_time` time, `money` decimal(20,2), primary key(`id`))"
tableInfo1, err := GetTableInfoBySQL(createTableSQL1, "")
tableInfo1, err := GetTableInfoBySQL(createTableSQL1, parser.New())
c.Assert(err, IsNil)

createTableSQL2 := "CREATE TABLE `test`.`atest` (`id` int(24) NOT NULL, `name` varchar(24), `birthday` datetime, `update_time` time, `money` decimal(20,2), primary key(`id`))"
tableInfo2, err := GetTableInfoBySQL(createTableSQL2, "")
tableInfo2, err := GetTableInfoBySQL(createTableSQL2, parser.New())
c.Assert(err, IsNil)

createTableSQL3 := `CREATE TABLE "test"."atest" ("id" int(24), "name" varchar(24), "birthday" datetime, "update_time" time, "money" decimal(20,2), unique key("id"))`
tableInfo3, err := GetTableInfoBySQL(createTableSQL3, "ANSI_QUOTES")
p := parser.New()
p.SetSQLMode(mysql.ModeANSIQuotes)
tableInfo3, err := GetTableInfoBySQL(createTableSQL3, p)
c.Assert(err, IsNil)

equal, _ := EqualTableInfo(tableInfo1, tableInfo2)
Expand Down
2 changes: 1 addition & 1 deletion pkg/diff/chunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (*testChunkSuite) TestSplitRange(c *C) {
// only work on tidb, so don't assert err here
_, _ = conn.ExecContext(ctx, "ANALYZE TABLE `test`.`test_chunk`")

tableInfo, err := dbutil.GetTableInfo(ctx, conn, "test", "test_chunk", "")
tableInfo, err := dbutil.GetTableInfo(ctx, conn, "test", "test_chunk")
c.Assert(err, IsNil)

tableInstance := &TableInstance{
Expand Down
5 changes: 2 additions & 3 deletions pkg/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type TableInstance struct {
Table string `json:"table"`
InstanceID string `json:"instance-id"`
info *model.TableInfo
SQLMode string `json:"sql-mode"`
}

// TableDiff saves config for diff table
Expand Down Expand Up @@ -212,14 +211,14 @@ func (t *TableDiff) adjustConfig() {
}

func (t *TableDiff) getTableInfo(ctx context.Context) error {
tableInfo, err := dbutil.GetTableInfo(ctx, t.TargetTable.Conn, t.TargetTable.Schema, t.TargetTable.Table, t.TargetTable.SQLMode)
tableInfo, err := dbutil.GetTableInfo(ctx, t.TargetTable.Conn, t.TargetTable.Schema, t.TargetTable.Table)
if err != nil {
return errors.Trace(err)
}
t.TargetTable.info = ignoreColumns(tableInfo, t.IgnoreColumns)

for _, sourceTable := range t.SourceTables {
tableInfo, err := dbutil.GetTableInfo(ctx, sourceTable.Conn, sourceTable.Schema, sourceTable.Table, sourceTable.SQLMode)
tableInfo, err := dbutil.GetTableInfo(ctx, sourceTable.Conn, sourceTable.Schema, sourceTable.Table)
if err != nil {
return errors.Trace(err)
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/diff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
_ "github.com/go-sql-driver/mysql"
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser"
"github.com/pingcap/tidb-tools/pkg/dbutil"
"github.com/pingcap/tidb-tools/pkg/importer"
)
Expand All @@ -37,7 +38,7 @@ type testDiffSuite struct{}

func (*testDiffSuite) TestGenerateSQLs(c *C) {
createTableSQL := "CREATE TABLE `diff_test`.`atest` (`id` int(24), `name` varchar(24), `birthday` datetime, `update_time` time, `money` decimal(20,2), `id_gen` int(11) GENERATED ALWAYS AS ((`id` + 1)) VIRTUAL, primary key(`id`, `name`))"
tableInfo, err := dbutil.GetTableInfoBySQL(createTableSQL, "")
tableInfo, err := dbutil.GetTableInfoBySQL(createTableSQL, parser.New())
c.Assert(err, IsNil)

rowsData := map[string]*dbutil.ColumnData{
Expand All @@ -56,7 +57,7 @@ func (*testDiffSuite) TestGenerateSQLs(c *C) {

// test the unique key
createTableSQL2 := "CREATE TABLE `diff_test`.`atest` (`id` int(24), `name` varchar(24), `birthday` datetime, `update_time` time, `money` decimal(20,2), unique key(`id`, `name`))"
tableInfo2, err := dbutil.GetTableInfoBySQL(createTableSQL2, "")
tableInfo2, err := dbutil.GetTableInfoBySQL(createTableSQL2, parser.New())
c.Assert(err, IsNil)
replaceSQL = generateDML("replace", rowsData, tableInfo2, "diff_test")
deleteSQL = generateDML("delete", rowsData, tableInfo2, "diff_test")
Expand Down Expand Up @@ -188,10 +189,10 @@ func testStructEqual(ctx context.Context, conn *sql.DB, c *C) {
_, err = conn.ExecContext(ctx, testCase.createTargetTable)
c.Assert(err, IsNil)

sourceInfo, err := dbutil.GetTableInfoBySQL(testCase.createSourceTable, "")
sourceInfo, err := dbutil.GetTableInfoBySQL(testCase.createSourceTable, parser.New())
c.Assert(err, IsNil)

targetInfo, err := dbutil.GetTableInfoBySQL(testCase.createTargetTable, "")
targetInfo, err := dbutil.GetTableInfoBySQL(testCase.createTargetTable, parser.New())
c.Assert(err, IsNil)

tableDiff := createTableDiff(conn, "diff_test", []string{sourceInfo.Name.O}, targetInfo.Name.O)
Expand Down
3 changes: 2 additions & 1 deletion pkg/diff/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"container/heap"

. "github.com/pingcap/check"
"github.com/pingcap/parser"
"github.com/pingcap/tidb-tools/pkg/dbutil"
)

Expand All @@ -26,7 +27,7 @@ type testMergerSuite struct{}

func (s *testMergerSuite) TestMerge(c *C) {
createTableSQL := "create table test.test(id int(24), name varchar(24), age int(24), primary key(id, name));"
tableInfo, err := dbutil.GetTableInfoBySQL(createTableSQL, "")
tableInfo, err := dbutil.GetTableInfoBySQL(createTableSQL, parser.New())
c.Assert(err, IsNil)

_, orderKeyCols := dbutil.SelectUniqueOrderKey(tableInfo)
Expand Down
Loading