diff --git a/bind_uploader.go b/bind_uploader.go index bcc26b3ce..414bbb83f 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -240,7 +240,7 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, if t == nullType || t == unSupportedType { t = textType // if null or not supported, pass to GS as text } - bindValues[strconv.Itoa(idx)] = execBindParameter{ + bindValues[bindingName(binding, idx)] = execBindParameter{ Type: t.String(), Value: val, } @@ -250,6 +250,13 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, return bindValues, nil } +func bindingName(nv driver.NamedValue, idx int) string { + if nv.Name != "" { + return nv.Name + } + return strconv.Itoa(idx) +} + func arrayBindValueCount(bindValues []driver.NamedValue) int { if !isArrayBind(bindValues) { return 0 diff --git a/bindings_test.go b/bindings_test.go index 975ddb40e..79d4be32e 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -301,7 +301,7 @@ func TestBindingInterface(t *testing.T) { if !rows.Next() { dbt.Error("failed to query") } - var v1, v2, v3, v4, v5, v6 interface{} + var v1, v2, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } @@ -327,7 +327,7 @@ func TestBindingInterfaceString(t *testing.T) { if !rows.Next() { dbt.Error("failed to query") } - var v1, v2, v3, v4, v5, v6 interface{} + var v1, v2, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } @@ -348,7 +348,7 @@ func TestBindingInterfaceString(t *testing.T) { } func TestBulkArrayBindingInterfaceNil(t *testing.T) { - nilArray := make([]interface{}, 1) + nilArray := make([]any, 1) runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec(createTableSQL) @@ -413,22 +413,22 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) { } func TestBulkArrayBindingInterface(t *testing.T) { - intArray := make([]interface{}, 3) + intArray := make([]any, 3) intArray[0] = int32(100) intArray[1] = int32(200) - fltArray := make([]interface{}, 3) + fltArray := make([]any, 3) fltArray[0] = float64(0.1) fltArray[2] = float64(5.678) - boolArray := make([]interface{}, 3) + boolArray := make([]any, 3) boolArray[1] = false boolArray[2] = true - strArray := make([]interface{}, 3) + strArray := make([]any, 3) strArray[2] = "test3" - byteArray := make([]interface{}, 3) + byteArray := make([]any, 3) byteArray[0] = []byte{0x01, 0x02, 0x03} byteArray[2] = []byte{0x07, 0x08, 0x09} @@ -504,23 +504,23 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { if err != nil { t.Error(err) } - ntzArray := make([]interface{}, 3) + ntzArray := make([]any, 3) ntzArray[0] = now ntzArray[1] = now.Add(1) - ltzArray := make([]interface{}, 3) + ltzArray := make([]any, 3) ltzArray[1] = now.Add(2).In(loc) ltzArray[2] = now.Add(3).In(loc) - tzArray := make([]interface{}, 3) + tzArray := make([]any, 3) tzArray[0] = tz.Add(4).In(loc) tzArray[2] = tz.Add(5).In(loc) - dtArray := make([]interface{}, 3) + dtArray := make([]any, 3) dtArray[0] = tz.Add(6).In(loc) dtArray[1] = now.Add(7).In(loc) - tmArray := make([]interface{}, 3) + tmArray := make([]any, 3) tmArray[1] = now.Add(8).In(loc) tmArray[2] = now.Add(9).In(loc) @@ -810,8 +810,8 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { numRows := endNum - startNum // Define the integer and string arrays - intArr := make([]interface{}, numRows) - stringArr := make([]interface{}, numRows) + intArr := make([]any, numRows) + stringArr := make([]any, numRows) for i := startNum; i < endNum; i++ { intArr[i-startNum] = i stringArr[i-startNum] = fmt.Sprint(i) @@ -867,7 +867,7 @@ func TestFunctionParameters(t *testing.T) { testcases := []struct { testDesc string paramType string - input interface{} + input any nullResult bool }{ {"textAndNullStringResultInNull", "text", sql.NullString{}, true}, @@ -912,7 +912,7 @@ func TestFunctionParameters(t *testing.T) { if !rows.Next() { t.Fatal() } else { - var r1 interface{} + var r1 any err = rows.Scan(&r1) if err != nil { t.Fatal(err) @@ -930,3 +930,80 @@ func TestFunctionParameters(t *testing.T) { } }) } + +func TestVariousBindingModes(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input any + isNil bool + }{ + {"textAndString", "text", "string", false}, + {"numberAndInteger", "number", 123, false}, + {"floatAndFloat", "float", 123.01, false}, + {"booleanAndBoolean", "boolean", true, false}, + {"dateAndTime", "date", time.Now().Truncate(24 * time.Hour), false}, + {"datetimeAndTime", "datetime", time.Now(), false}, + {"timeAndTime", "time", "12:34:56", false}, + {"timestampAndTime", "timestamp", time.Now(), false}, + {"timestamp_ntzAndTime", "timestamp_ntz", time.Now(), false}, + {"timestamp_ltzAndTime", "timestamp_ltz", time.Now(), false}, + {"timestamp_tzAndTime", "timestamp_tz", time.Now(), false}, + {"textAndNullString", "text", sql.NullString{}, true}, + {"numberAndNullInt64", "number", sql.NullInt64{}, true}, + {"floatAndNullFloat64", "float", sql.NullFloat64{}, true}, + {"booleanAndAndNullBool", "boolean", sql.NullBool{}, true}, + {"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, + {"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, + {"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, + {"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, + } + + bindingModes := []struct { + param string + query string + transform func(any) any + }{ + { + param: "?", + transform: func(v any) any { return v }, + }, + { + param: ":1", + transform: func(v any) any { return v }, + }, + { + param: ":param", + transform: func(v any) any { return sql.Named("param", v) }, + }, + } + + runTests(t, dsn, func(dbt *DBTest) { + for _, tc := range testcases { + for _, bindingMode := range bindingModes { + t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) { + query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) + dbt.mustExec(query) + if _, err := dbt.db.Exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil { + t.Fatal(err) + } + if tc.isNil { + query = "SELECT * FROM BINDING_MODES WHERE param1 IS NULL" + } else { + query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param) + } + rows, err := dbt.db.Query(query, bindingMode.transform(tc.input)) + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("Expected to return a row") + } + }) + } + } + }) +}