Skip to content

Commit

Permalink
Bugfix/issue 34 col name conflict (slingdata-io#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
yokofly authored Oct 25, 2024
1 parent 31be7b2 commit aa5ebc5
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 117 deletions.
127 changes: 84 additions & 43 deletions core/dbio/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1481,36 +1481,70 @@ func (conn *BaseConn) GetTableColumns(table *Table, fields ...string) (columns i

// if fields provided, check if exists in table
colMap := map[string]map[string]any{}
for _, rec := range colData.Records() {
colName := cast.ToString(rec["column_name"])
colMap[strings.ToLower(colName)] = rec
caseSensitive := conn.GetType().DBNameCaseSensitive()

if caseSensitive {
for _, rec := range colData.Records() {
colName := cast.ToString(rec["column_name"])
colMap[colName] = rec
}
} else {
for _, rec := range colData.Records() {
colName := cast.ToString(rec["column_name"])
colMap[strings.ToLower(colName)] = rec
}
}

var colTypes []ColumnType

// if fields provided, filter, keep order
if len(fields) > 0 {
for _, field := range fields {
rec, ok := colMap[strings.ToLower(field)]
if !ok {
err = g.Error(
"provided field '%s' not found in table %s",
strings.ToLower(field), table.FullName(),
)
return
}
if caseSensitive {
for _, field := range fields {
rec, ok := colMap[(field)]
if !ok {
err = g.Error(
"provided field '%s' not found in table %s",
(field), table.FullName(),
)
return
}

if conn.Type == dbio.TypeDbSnowflake {
rec["data_type"], rec["precision"], rec["scale"] = parseSnowflakeDataType(rec)
if conn.Type == dbio.TypeDbSnowflake {
rec["data_type"], rec["precision"], rec["scale"] = parseSnowflakeDataType(rec)
}

colTypes = append(colTypes, ColumnType{
Name: cast.ToString(rec["column_name"]),
DatabaseTypeName: cast.ToString(rec["data_type"]),
Precision: cast.ToInt(rec["precision"]),
Scale: cast.ToInt(rec["scale"]),
Sourced: true,
})
}
} else {
for _, field := range fields {
rec, ok := colMap[strings.ToLower(field)]
if !ok {
err = g.Error(
"provided field '%s' not found in table %s",
strings.ToLower(field), table.FullName(),
)
return
}

colTypes = append(colTypes, ColumnType{
Name: cast.ToString(rec["column_name"]),
DatabaseTypeName: cast.ToString(rec["data_type"]),
Precision: cast.ToInt(rec["precision"]),
Scale: cast.ToInt(rec["scale"]),
Sourced: true,
})
if conn.Type == dbio.TypeDbSnowflake {
rec["data_type"], rec["precision"], rec["scale"] = parseSnowflakeDataType(rec)
}

colTypes = append(colTypes, ColumnType{
Name: cast.ToString(rec["column_name"]),
DatabaseTypeName: cast.ToString(rec["data_type"]),
Precision: cast.ToInt(rec["precision"]),
Scale: cast.ToInt(rec["scale"]),
Sourced: true,
})
}
}
} else {
colTypes = lo.Map(colData.Records(), func(rec map[string]interface{}, i int) ColumnType {
Expand Down Expand Up @@ -2052,31 +2086,38 @@ func (conn *BaseConn) CastColumnsForSelect(srcColumns iop.Columns, tgtColumns io
// ValidateColumnNames verifies that source fields are present in the target table
// It will return quoted field names as `newColNames`, the same length as `colNames`
func (conn *BaseConn) ValidateColumnNames(tgtCols iop.Columns, colNames []string, quote bool) (newCols iop.Columns, err error) {

tgtFields := map[string]string{}
for _, colName := range tgtCols.Names() {
colName = conn.Self().Unquote(colName)
if quote {
tgtFields[strings.ToLower(colName)] = conn.Self().Quote(colName)
} else {
tgtFields[strings.ToLower(colName)] = colName
}
}

mismatches := []string{}
for _, colName := range colNames {
newCol := tgtCols.GetColumn(colName)
if newCol == nil || newCol.Name == "" {
// src field is missing in tgt field
mismatches = append(mismatches, g.F("source field '%s' is missing in target table", colName))
continue
caseSensitive := conn.GetType().DBNameCaseSensitive()
if caseSensitive {
for _, colName := range colNames {
newCol := tgtCols.GetColumnWithOriginalCase(colName)
if newCol == nil || newCol.Name == "" {
// src field is missing in tgt field
mismatches = append(mismatches, g.F("source field '%s' is missing in target table", colName))
continue
}
if quote {
newCol.Name = conn.Self().Quote(newCol.Name)
} else {
newCol.Name = conn.Self().Unquote(newCol.Name)
}
newCols = append(newCols, *newCol)
}
if quote {
newCol.Name = conn.Self().Quote(newCol.Name)
} else {
newCol.Name = conn.Self().Unquote(newCol.Name)
} else {
for _, colName := range colNames {
newCol := tgtCols.GetColumn(colName)
if newCol == nil || newCol.Name == "" {
// src field is missing in tgt field
mismatches = append(mismatches, g.F("source field '%s' is missing in target table", colName))
continue
}
if quote {
newCol.Name = conn.Self().Quote(newCol.Name)
} else {
newCol.Name = conn.Self().Unquote(newCol.Name)
}
newCols = append(newCols, *newCol)
}
newCols = append(newCols, *newCol)
}

if len(mismatches) > 0 {
Expand Down
Loading

0 comments on commit aa5ebc5

Please sign in to comment.