diff --git a/bind_uploader.go b/bind_uploader.go index bcc26b3ce..414bbb83f 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -240,7 +240,7 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, if t == nullType || t == unSupportedType { t = textType // if null or not supported, pass to GS as text } - bindValues[strconv.Itoa(idx)] = execBindParameter{ + bindValues[bindingName(binding, idx)] = execBindParameter{ Type: t.String(), Value: val, } @@ -250,6 +250,13 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, return bindValues, nil } +func bindingName(nv driver.NamedValue, idx int) string { + if nv.Name != "" { + return nv.Name + } + return strconv.Itoa(idx) +} + func arrayBindValueCount(bindValues []driver.NamedValue) int { if !isArrayBind(bindValues) { return 0 diff --git a/bindings_test.go b/bindings_test.go index 975ddb40e..e424e40dc 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -930,3 +930,137 @@ func TestFunctionParameters(t *testing.T) { } }) } + +func TestVariousBindingModesForNonNullValues(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input interface{} + }{ + {"textAndString", "text", "string"}, + {"numberAndInteger", "number", 123}, + {"floatAndFloat", "float", 123.01}, + {"booleanAndBoolean", "boolean", true}, + //{"dateAndTime", "date", time.Now()}, // does not bind in any binding mode + {"datetimeAndTime", "datetime", time.Now()}, + //{"timeAndTime", "time", time.Now()}, // does not bind in any binding mode + {"timestampAndTime", "timestamp", time.Now()}, + {"timestamp_ntzAndTime", "timestamp_ntz", time.Now()}, + {"timestamp_ltzAndTime", "timestamp_ltz", time.Now()}, + {"timestamp_tzAndTime", "timestamp_tz", time.Now()}, + } + + bindingModes := []struct { + desc string + insert string + query string + transform func(interface{}) interface{} + }{ + { + desc: "?", + insert: "INSERT INTO BINDING_MODES VALUES(?)", + query: "SELECT * FROM BINDING_MODES WHERE param1 = ?", + transform: func(v interface{}) interface{} { return v }, + }, + { + desc: ":1", + insert: "INSERT INTO BINDING_MODES VALUES(:1)", + query: "SELECT * FROM BINDING_MODES WHERE param1 = :1", + transform: func(v interface{}) interface{} { return v }, + }, + { + desc: ":id", + insert: "INSERT INTO BINDING_MODES VALUES(:param)", + query: "SELECT * FROM BINDING_MODES WHERE param1 = :param", + transform: func(v interface{}) interface{} { return sql.Named("param", v) }, + }, + } + + runTests(t, dsn, func(dbt *DBTest) { + for _, tc := range testcases { + for _, bindingMode := range bindingModes { + t.Run(tc.testDesc+" "+bindingMode.desc, func(t *testing.T) { + query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) + dbt.mustExec(query) + if _, err := dbt.db.Exec(bindingMode.insert, bindingMode.transform(tc.input)); err != nil { + t.Fatal(err) + } + rows, err := dbt.db.Query(bindingMode.query, bindingMode.transform(tc.input)) + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("Expected to return a row") + } + }) + } + } + }) +} + +func TestVariousBindingModesForNullValues(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input interface{} + }{ + {"textAndNullString", "text", sql.NullString{}}, + {"numberAndNullInt64", "number", sql.NullInt64{}}, + {"floatAndNullFloat64", "float", sql.NullFloat64{}}, + {"booleanAndAndNullBool", "boolean", sql.NullBool{}}, + {"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}}, + {"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}}, + {"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}}, + {"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}}, + {"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}}, + {"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}}, + {"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}}, + } + + bindingModes := []struct { + desc string + insert string + query string + transform func(interface{}) interface{} + }{ + { + desc: "?", + insert: "INSERT INTO BINDING_MODES VALUES(?)", + query: "SELECT * FROM BINDING_MODES WHERE param1 IS NULL", + transform: func(v interface{}) interface{} { return v }, + }, + { + desc: ":1", + insert: "INSERT INTO BINDING_MODES VALUES(:1)", + query: "SELECT * FROM BINDING_MODES WHERE param1 IS NULL", + transform: func(v interface{}) interface{} { return v }, + }, + { + desc: ":id", + insert: "INSERT INTO BINDING_MODES VALUES(:param)", + query: "SELECT * FROM BINDING_MODES WHERE param1 IS NULL", + transform: func(v interface{}) interface{} { return sql.Named("param", v) }, + }, + } + + runTests(t, dsn, func(dbt *DBTest) { + for _, tc := range testcases { + for _, bindingMode := range bindingModes { + t.Run(tc.testDesc+" "+bindingMode.desc, func(t *testing.T) { + query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) + dbt.mustExec(query) + if _, err := dbt.db.Exec(bindingMode.insert, bindingMode.transform(tc.input)); err != nil { + t.Fatal(err) + } + rows, err := dbt.db.Query(bindingMode.query, bindingMode.transform(tc.input)) + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("Expected to return a row") + } + }) + } + } + }) +}