diff --git a/bindings_test.go b/bindings_test.go index 0ad7d5aa8..39187570a 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -861,3 +861,42 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { dbt.mustExec("DROP TABLE binding_test") }) } + +func TestFunctionNullSelect(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + CREATE OR REPLACE FUNCTION NULLPARAMFUNCTION("param1" text, "param2" number, "param3" double) + RETURNS TABLE ("res1" text, "res2" number, "res3" double) + LANGUAGE SQL + AS 'select param1, param2, param3'; + `) + if rows, err := dbt.db.Query("select * from table(NULLPARAMFUNCTION(?, ?, ?))", sql.NullString{}, sql.NullInt64{}, sql.NullFloat64{}); err != nil { + t.Fatal(err) + } else { + if rows.Err() != nil { + t.Fatal(err) + } else { + if !rows.Next() { + t.Fatal() + } else { + var r1 sql.NullString + var r2 sql.NullInt64 + var r3 sql.NullFloat64 + err = rows.Scan(&r1, &r2, &r3) + if err != nil { + t.Fatal(err) + } + if r1.Valid { + t.Fatal(err) + } + if r2.Valid { + t.Fatal(err) + } + if r3.Valid { + t.Fatal(err) + } + } + } + } + }) +} diff --git a/connection.go b/connection.go index a62b85c62..d5a5ceacb 100644 --- a/connection.go +++ b/connection.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "os" + "reflect" "regexp" "strconv" "strings" @@ -398,6 +399,11 @@ func (sc *snowflakeConn) Ping(ctx context.Context) error { // CheckNamedValue determines which types are handled by this driver aside from // the instances captured by driver.Value func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error { + switch reflect.TypeOf(nv.Value) { + case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullInt64{}), + reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullFloat64{}): + return nil + } if supported := supportedArrayBind(nv); !supported { return driver.ErrSkip } diff --git a/converter.go b/converter.go index bc01a5c3b..80dd936e7 100644 --- a/converter.go +++ b/converter.go @@ -4,6 +4,7 @@ package gosnowflake import ( "context" + "database/sql" "database/sql/driver" "encoding/hex" "fmt" @@ -58,13 +59,13 @@ func isInterfaceArrayBinding(t interface{}) bool { // goTypeToSnowflake translates Go data type to Snowflake data type. func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType { switch t := v.(type) { - case int64: + case int64, sql.NullInt64: return fixedType - case float64: + case float64, sql.NullFloat64: return realType - case bool: + case bool, sql.NullBool: return booleanType - case string: + case string, sql.NullString: return textType case []byte: if tsmode == binaryType { @@ -171,6 +172,33 @@ func valueToString(v driver.Value, tsmode snowflakeType) (*string, error) { return &s, nil } } + if ns, ok := v.(sql.NullString); ok { + if !ns.Valid { + return nil, nil + } + return &ns.String, nil + } + if ns, ok := v.(sql.NullInt64); ok { + if !ns.Valid { + return nil, nil + } + s := strconv.FormatInt(ns.Int64, 10) + return &s, nil + } + if ns, ok := v.(sql.NullFloat64); ok { + if !ns.Valid { + return nil, nil + } + s := strconv.FormatFloat(ns.Float64, 'g', -1, 32) + return &s, nil + } + if ns, ok := v.(sql.NullBool); ok { + if !ns.Valid { + return nil, nil + } + s := strconv.FormatBool(ns.Bool) + return &s, nil + } } return nil, fmt.Errorf("unsupported type: %v", v1.Kind()) }