diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 486868f49..4cf864c75 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -223,7 +223,7 @@ func (dml *DML) updateOracleSQL() (sql string, args []interface{}) { builder.WriteByte(',') } arg := dml.Values[name] - fmt.Fprintf(builder, "%s = :%d", escapeName(name), oracleHolderPos) + fmt.Fprintf(builder, "%s = :%d", name, oracleHolderPos) oracleHolderPos++ args = append(args, arg) } @@ -268,9 +268,9 @@ func (dml *DML) buildOracleWhere(builder *strings.Builder, oracleHolderPos int) builder.WriteString(" AND ") } if wargs[i] == nil || wargs[i] == "" { - builder.WriteString(escapeName(wnames[i]) + " IS NULL") + builder.WriteString(wnames[i] + " IS NULL") } else { - builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos)) + builder.WriteString(fmt.Sprintf("%s = :%d", dml.processOracleColumn(wnames[i]), pOracleHolderPos)) pOracleHolderPos++ args = append(args, wargs[i]) } @@ -278,6 +278,15 @@ func (dml *DML) buildOracleWhere(builder *strings.Builder, oracleHolderPos int) return } +func (dml *DML) processOracleColumn(colName string) string { + dataType := dml.info.dataTypeMap[colName] + switch dataType { + case "CHAR", "NCHAR": + return fmt.Sprintf("RTRIM(%s)", colName) + } + return colName +} + func (dml *DML) whereValues(names []string) (values []interface{}) { valueMap := dml.Values if dml.Tp == UpdateDMLType { @@ -381,9 +390,9 @@ func (dml *DML) oracleDeleteNewValueSQL() (sql string, args []interface{}) { builder.WriteString(" AND ") } if colValues[i] == nil || colValues[i] == "" { - builder.WriteString(escapeName(colNames[i]) + " IS NULL") + builder.WriteString(colNames[i] + " IS NULL") } else { - builder.WriteString(fmt.Sprintf("%s = :%d", colNames[i], oracleHolderPos)) + builder.WriteString(fmt.Sprintf("%s = :%d", dml.processOracleColumn(colNames[i]), oracleHolderPos)) oracleHolderPos++ args = append(args, colValues[i]) } diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go index 810655446..cafa383da 100644 --- a/pkg/loader/model_test.go +++ b/pkg/loader/model_test.go @@ -271,6 +271,59 @@ func (s *SQLSuite) TestUpdateMarkSQL(c *check.C) { c.Assert(mock.ExpectationsWereMet(), check.IsNil) } +func (s *SQLSuite) TestOracleUpdateSQLCharType(c *check.C) { + dml := DML{ + Tp: UpdateDMLType, + Database: "db", + Table: "tbl", + Values: map[string]interface{}{ + "ID": 123, + "NAME": "pc", + "OFFER": "oo", + "ADDRESS": "aa", + }, + OldValues: map[string]interface{}{ + "ID": 123, + "NAME": "pingcap", + "OFFER": "o", + "ADDRESS": "a", + }, + info: &tableInfo{ + columns: []string{"ID", "NAME", "OFFER", "ADDRESS"}, + dataTypeMap: map[string]string{ + "ID": "VARCHAR2", + "NAME": "VARCHAR2", + "OFFER": "CHAR", + "ADDRESS": "NCHAR", + }, + }, + UpColumnsInfoMap: map[string]*model.ColumnInfo{ + "ID": { + FieldType: types.FieldType{Tp: mysql.TypeInt24}}, + "NAME": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + "OFFER": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + "ADDRESS": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + }, + DestDBType: OracleDB, + } + sql, args := dml.sql() + c.Assert( + sql, check.Equals, + "UPDATE db.tbl SET ADDRESS = :1,ID = :2,NAME = :3,OFFER = :4 WHERE RTRIM(ADDRESS) = :5 AND ID = :6 AND NAME = :7 AND RTRIM(OFFER) = :8 AND rownum <=1") + c.Assert(args, check.HasLen, 8) + c.Assert(args[0], check.Equals, "aa") + c.Assert(args[1], check.Equals, 123) + c.Assert(args[2], check.Equals, "pc") + c.Assert(args[3], check.Equals, "oo") + c.Assert(args[4], check.Equals, "a") + c.Assert(args[5], check.Equals, 123) + c.Assert(args[6], check.Equals, "pingcap") + c.Assert(args[7], check.Equals, "o") +} + func (s *SQLSuite) TestOracleUpdateSQL(c *check.C) { dml := DML{ Tp: UpdateDMLType, @@ -389,6 +442,49 @@ func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) { c.Assert(args[2], check.Equals, 123) } +func (s *SQLSuite) TestOracleDeleteSQLCharType(c *check.C) { + dml := DML{ + Tp: DeleteDMLType, + Database: "db", + Table: "tbl", + Values: map[string]interface{}{ + "ID": 123, + "NAME": "pc", + "OFFER": "o", + "ADDRESS": "a", + }, + info: &tableInfo{ + columns: []string{"ID", "NAME", "OFFER", "ADDRESS"}, + dataTypeMap: map[string]string{ + "ID": "VARCHAR2", + "NAME": "VARCHAR2", + "OFFER": "CHAR", + "ADDRESS": "NCHAR", + }, + }, + UpColumnsInfoMap: map[string]*model.ColumnInfo{ + "ID": { + FieldType: types.FieldType{Tp: mysql.TypeInt24}}, + "NAME": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + "OFFER": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + "ADDRESS": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + }, + DestDBType: OracleDB, + } + sql, args := dml.sql() + c.Assert( + sql, check.Equals, + "DELETE FROM db.tbl WHERE RTRIM(ADDRESS) = :1 AND ID = :2 AND NAME = :3 AND RTRIM(OFFER) = :4 AND rownum <=1") + c.Assert(args, check.HasLen, 4) + c.Assert(args[0], check.Equals, "a") + c.Assert(args[1], check.Equals, 123) + c.Assert(args[2], check.Equals, "pc") + c.Assert(args[3], check.Equals, "o") +} + func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) { dml := DML{ Tp: DeleteDMLType, @@ -671,3 +767,47 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLEmptyString(c *check.C) { c.Assert(args[0], check.Equals, 123) c.Assert(args[1], check.Equals, "456") } + +func (s *SQLSuite) TestOracleDeleteNewValueSQLCharType(c *check.C) { + dml := DML{ + Tp: InsertDMLType, + Database: "db", + Table: "tbl", + Values: map[string]interface{}{ + "ID": 123, + "ID2": "456", + "NAME": "n", + "C2": "c", + }, + info: &tableInfo{ + columns: []string{"ID", "ID2", "NAME", "C2"}, + dataTypeMap: map[string]string{ + "ID": "VARCHAR2", + "ID2": "VARCHAR2", + "NAME": "CHAR", + "C2": "NCHAR", + }, + }, + UpColumnsInfoMap: map[string]*model.ColumnInfo{ + "ID": { + FieldType: types.FieldType{Tp: mysql.TypeInt24}}, + "ID2": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + "NAME": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + "C2": { + FieldType: types.FieldType{Tp: mysql.TypeVarString}}, + }, + DestDBType: OracleDB, + } + + sql, args := dml.oracleDeleteNewValueSQL() + c.Assert( + sql, check.Equals, + "DELETE FROM db.tbl WHERE RTRIM(C2) = :1 AND ID = :2 AND ID2 = :3 AND RTRIM(NAME) = :4 AND rownum <=1") + c.Assert(args, check.HasLen, 4) + c.Assert(args[0], check.Equals, "c") + c.Assert(args[1], check.Equals, 123) + c.Assert(args[2], check.Equals, "456") + c.Assert(args[3], check.Equals, "n") +} diff --git a/pkg/loader/util.go b/pkg/loader/util.go index b2b1c5a4f..81aad3e02 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -51,7 +51,7 @@ WHERE table_schema = ? AND table_name = ? ORDER BY seq_in_index ASC;` //for oracle db - colsOracleSQL = `SELECT column_name FROM dba_tab_cols WHERE owner = upper(:1) AND table_name = upper(:2) AND virtual_column = 'NO'` + colsOracleSQL = `SELECT column_name, data_type FROM dba_tab_cols WHERE owner = upper(:1) AND table_name = upper(:2) AND virtual_column = 'NO'` uniqKeyOracleSQL = `select c.constraint_type || i.uniqueness index_type, i.index_name, ic.column_position, ic.column_name from dba_indexes i left join dba_constraints c @@ -70,6 +70,9 @@ type tableInfo struct { primaryKey *indexInfo // include primary key if have uniqueKeys []indexInfo + + //colum name -> data type map used in oracle db + dataTypeMap map[string]string } type indexInfo struct { @@ -106,7 +109,7 @@ func getTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, e func getOracleTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, err error) { info = new(tableInfo) - if info.columns, err = getOracleColsOfTbl(db, schema, table); err != nil { + if info.columns, info.dataTypeMap, err = getOracleColsOfTbl(db, schema, table); err != nil { return nil, errors.Annotatef(err, "table %s.%s", schema, table) } @@ -303,7 +306,7 @@ func buildColumnList(names []string, destDBType DBType) string { b.WriteString(",") } if destDBType == OracleDB { - b.WriteString(escapeName(name)) + b.WriteString(name) } else { b.WriteString(quoteName(name)) } @@ -349,32 +352,34 @@ func getColsOfTbl(db *gosql.DB, schema, table string) ([]string, error) { return cols, nil } -func getOracleColsOfTbl(db *gosql.DB, schema, table string) ([]string, error) { +func getOracleColsOfTbl(db *gosql.DB, schema, table string) ([]string, map[string]string, error) { rows, err := db.Query(colsOracleSQL, schema, table) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } defer rows.Close() cols := make([]string, 0, 1) + dataTypeMap := make(map[string]string) for rows.Next() { - var name string - err = rows.Scan(&name) + var name, dataType string + err = rows.Scan(&name, &dataType) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } cols = append(cols, name) + dataTypeMap[name] = dataType } if err = rows.Err(); err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } // if no any columns returns, means the table not exist. if len(cols) == 0 { - return nil, ErrTableNotExist + return nil, nil, ErrTableNotExist } - return cols, nil + return cols, dataTypeMap, nil } diff --git a/pkg/loader/util_test.go b/pkg/loader/util_test.go index 963ecfd0d..cf9f22434 100644 --- a/pkg/loader/util_test.go +++ b/pkg/loader/util_test.go @@ -90,15 +90,15 @@ func (cs *UtilSuite) TestGetOracleTableInfo(c *check.C) { c.Assert(err, check.IsNil) defer db.Close() - columnRows := sqlmock.NewRows([]string{"column_name"}). - AddRow("C1"). - AddRow("C2"). - AddRow("C3"). - AddRow("C4"). - AddRow("C5"). - AddRow("C6"). - AddRow("C7"). - AddRow("C8") + columnRows := sqlmock.NewRows([]string{"column_name", "data_type"}). + AddRow("C1", "VARCHAR2"). + AddRow("C2", "VARCHAR2"). + AddRow("C3", "VARCHAR2"). + AddRow("C4", "VARCHAR2"). + AddRow("C5", "VARCHAR2"). + AddRow("C6", "NUMBER"). + AddRow("C7", "CHAR"). + AddRow("C8", "NCHAR") mock.ExpectQuery(regexp.QuoteMeta(colsOracleSQL)).WithArgs("test", "t3").WillReturnRows(columnRows) indexRows := sqlmock.NewRows([]string{"index_type", "index_name", "column_position", "column_name"}). @@ -125,6 +125,16 @@ func (cs *UtilSuite) TestGetOracleTableInfo(c *check.C) { {name: "T3_C3_C4_UINDEX", columns: []string{"C3", "C4"}}, {name: "T3_C5_C6_UINDEX", columns: []string{"C5", "C6"}}, }, + dataTypeMap: map[string]string{ + "C1": "VARCHAR2", + "C2": "VARCHAR2", + "C3": "VARCHAR2", + "C4": "VARCHAR2", + "C5": "VARCHAR2", + "C6": "NUMBER", + "C7": "CHAR", + "C8": "NCHAR", + }, }) }