Skip to content

Commit

Permalink
drainer: rtrim char type column in sql (#1165) (#1167)
Browse files Browse the repository at this point in the history
ref #1164
  • Loading branch information
ti-chi-bot authored May 29, 2022
1 parent 0e44656 commit 1012ec7
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 25 deletions.
19 changes: 14 additions & 5 deletions pkg/loader/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (dml *DML) updateOracleSQL() (sql string, args []interface{}) {
builder.WriteByte(',')
}
arg := dml.Values[name]
fmt.Fprintf(builder, "%s = :%d", escapeName(name), oracleHolderPos)
fmt.Fprintf(builder, "%s = :%d", name, oracleHolderPos)
oracleHolderPos++
args = append(args, arg)
}
Expand Down Expand Up @@ -268,16 +268,25 @@ func (dml *DML) buildOracleWhere(builder *strings.Builder, oracleHolderPos int)
builder.WriteString(" AND ")
}
if wargs[i] == nil || wargs[i] == "" {
builder.WriteString(escapeName(wnames[i]) + " IS NULL")
builder.WriteString(wnames[i] + " IS NULL")
} else {
builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos))
builder.WriteString(fmt.Sprintf("%s = :%d", dml.processOracleColumn(wnames[i]), pOracleHolderPos))
pOracleHolderPos++
args = append(args, wargs[i])
}
}
return
}

func (dml *DML) processOracleColumn(colName string) string {
dataType := dml.info.dataTypeMap[colName]
switch dataType {
case "CHAR", "NCHAR":
return fmt.Sprintf("RTRIM(%s)", colName)
}
return colName
}

func (dml *DML) whereValues(names []string) (values []interface{}) {
valueMap := dml.Values
if dml.Tp == UpdateDMLType {
Expand Down Expand Up @@ -381,9 +390,9 @@ func (dml *DML) oracleDeleteNewValueSQL() (sql string, args []interface{}) {
builder.WriteString(" AND ")
}
if colValues[i] == nil || colValues[i] == "" {
builder.WriteString(escapeName(colNames[i]) + " IS NULL")
builder.WriteString(colNames[i] + " IS NULL")
} else {
builder.WriteString(fmt.Sprintf("%s = :%d", colNames[i], oracleHolderPos))
builder.WriteString(fmt.Sprintf("%s = :%d", dml.processOracleColumn(colNames[i]), oracleHolderPos))
oracleHolderPos++
args = append(args, colValues[i])
}
Expand Down
140 changes: 140 additions & 0 deletions pkg/loader/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,59 @@ func (s *SQLSuite) TestUpdateMarkSQL(c *check.C) {
c.Assert(mock.ExpectationsWereMet(), check.IsNil)
}

func (s *SQLSuite) TestOracleUpdateSQLCharType(c *check.C) {
dml := DML{
Tp: UpdateDMLType,
Database: "db",
Table: "tbl",
Values: map[string]interface{}{
"ID": 123,
"NAME": "pc",
"OFFER": "oo",
"ADDRESS": "aa",
},
OldValues: map[string]interface{}{
"ID": 123,
"NAME": "pingcap",
"OFFER": "o",
"ADDRESS": "a",
},
info: &tableInfo{
columns: []string{"ID", "NAME", "OFFER", "ADDRESS"},
dataTypeMap: map[string]string{
"ID": "VARCHAR2",
"NAME": "VARCHAR2",
"OFFER": "CHAR",
"ADDRESS": "NCHAR",
},
},
UpColumnsInfoMap: map[string]*model.ColumnInfo{
"ID": {
FieldType: types.FieldType{Tp: mysql.TypeInt24}},
"NAME": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
"OFFER": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
"ADDRESS": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
},
DestDBType: OracleDB,
}
sql, args := dml.sql()
c.Assert(
sql, check.Equals,
"UPDATE db.tbl SET ADDRESS = :1,ID = :2,NAME = :3,OFFER = :4 WHERE RTRIM(ADDRESS) = :5 AND ID = :6 AND NAME = :7 AND RTRIM(OFFER) = :8 AND rownum <=1")
c.Assert(args, check.HasLen, 8)
c.Assert(args[0], check.Equals, "aa")
c.Assert(args[1], check.Equals, 123)
c.Assert(args[2], check.Equals, "pc")
c.Assert(args[3], check.Equals, "oo")
c.Assert(args[4], check.Equals, "a")
c.Assert(args[5], check.Equals, 123)
c.Assert(args[6], check.Equals, "pingcap")
c.Assert(args[7], check.Equals, "o")
}

func (s *SQLSuite) TestOracleUpdateSQL(c *check.C) {
dml := DML{
Tp: UpdateDMLType,
Expand Down Expand Up @@ -389,6 +442,49 @@ func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) {
c.Assert(args[2], check.Equals, 123)
}

func (s *SQLSuite) TestOracleDeleteSQLCharType(c *check.C) {
dml := DML{
Tp: DeleteDMLType,
Database: "db",
Table: "tbl",
Values: map[string]interface{}{
"ID": 123,
"NAME": "pc",
"OFFER": "o",
"ADDRESS": "a",
},
info: &tableInfo{
columns: []string{"ID", "NAME", "OFFER", "ADDRESS"},
dataTypeMap: map[string]string{
"ID": "VARCHAR2",
"NAME": "VARCHAR2",
"OFFER": "CHAR",
"ADDRESS": "NCHAR",
},
},
UpColumnsInfoMap: map[string]*model.ColumnInfo{
"ID": {
FieldType: types.FieldType{Tp: mysql.TypeInt24}},
"NAME": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
"OFFER": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
"ADDRESS": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
},
DestDBType: OracleDB,
}
sql, args := dml.sql()
c.Assert(
sql, check.Equals,
"DELETE FROM db.tbl WHERE RTRIM(ADDRESS) = :1 AND ID = :2 AND NAME = :3 AND RTRIM(OFFER) = :4 AND rownum <=1")
c.Assert(args, check.HasLen, 4)
c.Assert(args[0], check.Equals, "a")
c.Assert(args[1], check.Equals, 123)
c.Assert(args[2], check.Equals, "pc")
c.Assert(args[3], check.Equals, "o")
}

func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) {
dml := DML{
Tp: DeleteDMLType,
Expand Down Expand Up @@ -671,3 +767,47 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLEmptyString(c *check.C) {
c.Assert(args[0], check.Equals, 123)
c.Assert(args[1], check.Equals, "456")
}

func (s *SQLSuite) TestOracleDeleteNewValueSQLCharType(c *check.C) {
dml := DML{
Tp: InsertDMLType,
Database: "db",
Table: "tbl",
Values: map[string]interface{}{
"ID": 123,
"ID2": "456",
"NAME": "n",
"C2": "c",
},
info: &tableInfo{
columns: []string{"ID", "ID2", "NAME", "C2"},
dataTypeMap: map[string]string{
"ID": "VARCHAR2",
"ID2": "VARCHAR2",
"NAME": "CHAR",
"C2": "NCHAR",
},
},
UpColumnsInfoMap: map[string]*model.ColumnInfo{
"ID": {
FieldType: types.FieldType{Tp: mysql.TypeInt24}},
"ID2": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
"NAME": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
"C2": {
FieldType: types.FieldType{Tp: mysql.TypeVarString}},
},
DestDBType: OracleDB,
}

sql, args := dml.oracleDeleteNewValueSQL()
c.Assert(
sql, check.Equals,
"DELETE FROM db.tbl WHERE RTRIM(C2) = :1 AND ID = :2 AND ID2 = :3 AND RTRIM(NAME) = :4 AND rownum <=1")
c.Assert(args, check.HasLen, 4)
c.Assert(args[0], check.Equals, "c")
c.Assert(args[1], check.Equals, 123)
c.Assert(args[2], check.Equals, "456")
c.Assert(args[3], check.Equals, "n")
}
27 changes: 16 additions & 11 deletions pkg/loader/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ WHERE table_schema = ? AND table_name = ?
ORDER BY seq_in_index ASC;`

//for oracle db
colsOracleSQL = `SELECT column_name FROM dba_tab_cols WHERE owner = upper(:1) AND table_name = upper(:2) AND virtual_column = 'NO'`
colsOracleSQL = `SELECT column_name, data_type FROM dba_tab_cols WHERE owner = upper(:1) AND table_name = upper(:2) AND virtual_column = 'NO'`
uniqKeyOracleSQL = `select c.constraint_type || i.uniqueness index_type, i.index_name, ic.column_position, ic.column_name
from dba_indexes i
left join dba_constraints c
Expand All @@ -70,6 +70,9 @@ type tableInfo struct {
primaryKey *indexInfo
// include primary key if have
uniqueKeys []indexInfo

//colum name -> data type map used in oracle db
dataTypeMap map[string]string
}

type indexInfo struct {
Expand Down Expand Up @@ -106,7 +109,7 @@ func getTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, e
func getOracleTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, err error) {
info = new(tableInfo)

if info.columns, err = getOracleColsOfTbl(db, schema, table); err != nil {
if info.columns, info.dataTypeMap, err = getOracleColsOfTbl(db, schema, table); err != nil {
return nil, errors.Annotatef(err, "table %s.%s", schema, table)
}

Expand Down Expand Up @@ -303,7 +306,7 @@ func buildColumnList(names []string, destDBType DBType) string {
b.WriteString(",")
}
if destDBType == OracleDB {
b.WriteString(escapeName(name))
b.WriteString(name)
} else {
b.WriteString(quoteName(name))
}
Expand Down Expand Up @@ -349,32 +352,34 @@ func getColsOfTbl(db *gosql.DB, schema, table string) ([]string, error) {
return cols, nil
}

func getOracleColsOfTbl(db *gosql.DB, schema, table string) ([]string, error) {
func getOracleColsOfTbl(db *gosql.DB, schema, table string) ([]string, map[string]string, error) {
rows, err := db.Query(colsOracleSQL, schema, table)
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
defer rows.Close()
cols := make([]string, 0, 1)
dataTypeMap := make(map[string]string)
for rows.Next() {
var name string
err = rows.Scan(&name)
var name, dataType string
err = rows.Scan(&name, &dataType)
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
cols = append(cols, name)
dataTypeMap[name] = dataType
}

if err = rows.Err(); err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}

// if no any columns returns, means the table not exist.
if len(cols) == 0 {
return nil, ErrTableNotExist
return nil, nil, ErrTableNotExist
}

return cols, nil
return cols, dataTypeMap, nil

}

Expand Down
28 changes: 19 additions & 9 deletions pkg/loader/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ func (cs *UtilSuite) TestGetOracleTableInfo(c *check.C) {
c.Assert(err, check.IsNil)
defer db.Close()

columnRows := sqlmock.NewRows([]string{"column_name"}).
AddRow("C1").
AddRow("C2").
AddRow("C3").
AddRow("C4").
AddRow("C5").
AddRow("C6").
AddRow("C7").
AddRow("C8")
columnRows := sqlmock.NewRows([]string{"column_name", "data_type"}).
AddRow("C1", "VARCHAR2").
AddRow("C2", "VARCHAR2").
AddRow("C3", "VARCHAR2").
AddRow("C4", "VARCHAR2").
AddRow("C5", "VARCHAR2").
AddRow("C6", "NUMBER").
AddRow("C7", "CHAR").
AddRow("C8", "NCHAR")
mock.ExpectQuery(regexp.QuoteMeta(colsOracleSQL)).WithArgs("test", "t3").WillReturnRows(columnRows)

indexRows := sqlmock.NewRows([]string{"index_type", "index_name", "column_position", "column_name"}).
Expand All @@ -125,6 +125,16 @@ func (cs *UtilSuite) TestGetOracleTableInfo(c *check.C) {
{name: "T3_C3_C4_UINDEX", columns: []string{"C3", "C4"}},
{name: "T3_C5_C6_UINDEX", columns: []string{"C5", "C6"}},
},
dataTypeMap: map[string]string{
"C1": "VARCHAR2",
"C2": "VARCHAR2",
"C3": "VARCHAR2",
"C4": "VARCHAR2",
"C5": "VARCHAR2",
"C6": "NUMBER",
"C7": "CHAR",
"C8": "NCHAR",
},
})

}
Expand Down

0 comments on commit 1012ec7

Please sign in to comment.