Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-645253 Handle binding named parameters #850

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter,
if t == nullType || t == unSupportedType {
t = textType // if null or not supported, pass to GS as text
}
bindValues[strconv.Itoa(idx)] = execBindParameter{
bindValues[bindingName(binding, idx)] = execBindParameter{
Type: t.String(),
Value: val,
}
Expand All @@ -250,6 +250,13 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter,
return bindValues, nil
}

func bindingName(nv driver.NamedValue, idx int) string {
if nv.Name != "" {
return nv.Name
}
return strconv.Itoa(idx)
}

func arrayBindValueCount(bindValues []driver.NamedValue) int {
if !isArrayBind(bindValues) {
return 0
Expand Down
111 changes: 94 additions & 17 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func TestBindingInterface(t *testing.T) {
if !rows.Next() {
dbt.Error("failed to query")
}
var v1, v2, v3, v4, v5, v6 interface{}
var v1, v2, v3, v4, v5, v6 any
if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil {
dbt.Errorf("failed to scan: %#v", err)
}
Expand All @@ -327,7 +327,7 @@ func TestBindingInterfaceString(t *testing.T) {
if !rows.Next() {
dbt.Error("failed to query")
}
var v1, v2, v3, v4, v5, v6 interface{}
var v1, v2, v3, v4, v5, v6 any
if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil {
dbt.Errorf("failed to scan: %#v", err)
}
Expand All @@ -348,7 +348,7 @@ func TestBindingInterfaceString(t *testing.T) {
}

func TestBulkArrayBindingInterfaceNil(t *testing.T) {
nilArray := make([]interface{}, 1)
nilArray := make([]any, 1)

runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec(createTableSQL)
Expand Down Expand Up @@ -413,22 +413,22 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) {
}

func TestBulkArrayBindingInterface(t *testing.T) {
intArray := make([]interface{}, 3)
intArray := make([]any, 3)
intArray[0] = int32(100)
intArray[1] = int32(200)

fltArray := make([]interface{}, 3)
fltArray := make([]any, 3)
fltArray[0] = float64(0.1)
fltArray[2] = float64(5.678)

boolArray := make([]interface{}, 3)
boolArray := make([]any, 3)
boolArray[1] = false
boolArray[2] = true

strArray := make([]interface{}, 3)
strArray := make([]any, 3)
strArray[2] = "test3"

byteArray := make([]interface{}, 3)
byteArray := make([]any, 3)
byteArray[0] = []byte{0x01, 0x02, 0x03}
byteArray[2] = []byte{0x07, 0x08, 0x09}

Expand Down Expand Up @@ -504,23 +504,23 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) {
if err != nil {
t.Error(err)
}
ntzArray := make([]interface{}, 3)
ntzArray := make([]any, 3)
ntzArray[0] = now
ntzArray[1] = now.Add(1)

ltzArray := make([]interface{}, 3)
ltzArray := make([]any, 3)
ltzArray[1] = now.Add(2).In(loc)
ltzArray[2] = now.Add(3).In(loc)

tzArray := make([]interface{}, 3)
tzArray := make([]any, 3)
tzArray[0] = tz.Add(4).In(loc)
tzArray[2] = tz.Add(5).In(loc)

dtArray := make([]interface{}, 3)
dtArray := make([]any, 3)
dtArray[0] = tz.Add(6).In(loc)
dtArray[1] = now.Add(7).In(loc)

tmArray := make([]interface{}, 3)
tmArray := make([]any, 3)
tmArray[1] = now.Add(8).In(loc)
tmArray[2] = now.Add(9).In(loc)

Expand Down Expand Up @@ -810,8 +810,8 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) {
numRows := endNum - startNum

// Define the integer and string arrays
intArr := make([]interface{}, numRows)
stringArr := make([]interface{}, numRows)
intArr := make([]any, numRows)
stringArr := make([]any, numRows)
for i := startNum; i < endNum; i++ {
intArr[i-startNum] = i
stringArr[i-startNum] = fmt.Sprint(i)
Expand Down Expand Up @@ -867,7 +867,7 @@ func TestFunctionParameters(t *testing.T) {
testcases := []struct {
testDesc string
paramType string
input interface{}
input any
nullResult bool
}{
{"textAndNullStringResultInNull", "text", sql.NullString{}, true},
Expand Down Expand Up @@ -912,7 +912,7 @@ func TestFunctionParameters(t *testing.T) {
if !rows.Next() {
t.Fatal()
} else {
var r1 interface{}
var r1 any
err = rows.Scan(&r1)
if err != nil {
t.Fatal(err)
Expand All @@ -930,3 +930,80 @@ func TestFunctionParameters(t *testing.T) {
}
})
}

func TestVariousBindingModes(t *testing.T) {
testcases := []struct {
testDesc string
paramType string
input any
isNil bool
}{
{"textAndString", "text", "string", false},
{"numberAndInteger", "number", 123, false},
{"floatAndFloat", "float", 123.01, false},
{"booleanAndBoolean", "boolean", true, false},
{"dateAndTime", "date", time.Now().Truncate(24 * time.Hour), false},
{"datetimeAndTime", "datetime", time.Now(), false},
{"timeAndTime", "time", "12:34:56", false},
{"timestampAndTime", "timestamp", time.Now(), false},
{"timestamp_ntzAndTime", "timestamp_ntz", time.Now(), false},
{"timestamp_ltzAndTime", "timestamp_ltz", time.Now(), false},
{"timestamp_tzAndTime", "timestamp_tz", time.Now(), false},
{"textAndNullString", "text", sql.NullString{}, true},
{"numberAndNullInt64", "number", sql.NullInt64{}, true},
{"floatAndNullFloat64", "float", sql.NullFloat64{}, true},
{"booleanAndAndNullBool", "boolean", sql.NullBool{}, true},
{"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}, true},
{"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true},
{"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}, true},
{"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true},
{"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true},
{"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true},
{"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true},
}

bindingModes := []struct {
param string
query string
transform func(any) any
}{
{
param: "?",
transform: func(v any) any { return v },
},
{
param: ":1",
transform: func(v any) any { return v },
},
{
param: ":param",
transform: func(v any) any { return sql.Named("param", v) },
},
}

runTests(t, dsn, func(dbt *DBTest) {
for _, tc := range testcases {
for _, bindingMode := range bindingModes {
t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) {
query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType)
sfc-gh-dprzybysz marked this conversation as resolved.
Show resolved Hide resolved
dbt.mustExec(query)
if _, err := dbt.db.Exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil {
t.Fatal(err)
}
if tc.isNil {
query = "SELECT * FROM BINDING_MODES WHERE param1 IS NULL"
} else {
query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param)
}
rows, err := dbt.db.Query(query, bindingMode.transform(tc.input))
if err != nil {
t.Fatal(err)
}
if !rows.Next() {
t.Fatal("Expected to return a row")
}
})
}
}
})
}