Skip to content

Commit

Permalink
addedd support for sql.Null types
Browse files Browse the repository at this point in the history
  • Loading branch information
m-pavel committed Jun 23, 2023
1 parent a832856 commit ba2f991
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
39 changes: 39 additions & 0 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,42 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) {
dbt.mustExec("DROP TABLE binding_test")
})
}

func TestFunctionNullSelect(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec(`
CREATE OR REPLACE FUNCTION NULLPARAMFUNCTION("param1" text, "param2" number, "param3" double)
RETURNS TABLE ("res1" text, "res2" number, "res3" double)
LANGUAGE SQL
AS 'select param1, param2, param3';
`)
if rows, err := dbt.db.Query("select * from table(NULLPARAMFUNCTION(?, ?, ?))", sql.NullString{}, sql.NullInt64{}, sql.NullFloat64{}); err != nil {
t.Fatal(err)
} else {
if rows.Err() != nil {
t.Fatal(err)
} else {
if !rows.Next() {
t.Fatal()
} else {
var r1 sql.NullString
var r2 sql.NullInt64
var r3 sql.NullFloat64
err = rows.Scan(&r1, &r2, &r3)
if err != nil {
t.Fatal(err)
}
if r1.Valid {
t.Fatal(err)
}
if r2.Valid {
t.Fatal(err)
}
if r3.Valid {
t.Fatal(err)
}
}
}
}
})
}
6 changes: 6 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/http"
"net/url"
"os"
"reflect"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -398,6 +399,11 @@ func (sc *snowflakeConn) Ping(ctx context.Context) error {
// CheckNamedValue determines which types are handled by this driver aside from
// the instances captured by driver.Value
func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error {
switch reflect.TypeOf(nv.Value) {
case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullInt64{}),
reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullFloat64{}):
return nil
}
if supported := supportedArrayBind(nv); !supported {
return driver.ErrSkip
}
Expand Down
36 changes: 32 additions & 4 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package gosnowflake

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
Expand Down Expand Up @@ -58,13 +59,13 @@ func isInterfaceArrayBinding(t interface{}) bool {
// goTypeToSnowflake translates Go data type to Snowflake data type.
func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType {
switch t := v.(type) {
case int64:
case int64, sql.NullInt64:
return fixedType
case float64:
case float64, sql.NullFloat64:
return realType
case bool:
case bool, sql.NullBool:
return booleanType
case string:
case string, sql.NullString:
return textType
case []byte:
if tsmode == binaryType {
Expand Down Expand Up @@ -171,6 +172,33 @@ func valueToString(v driver.Value, tsmode snowflakeType) (*string, error) {
return &s, nil
}
}
if ns, ok := v.(sql.NullString); ok {
if !ns.Valid {
return nil, nil
}
return &ns.String, nil
}
if ns, ok := v.(sql.NullInt64); ok {
if !ns.Valid {
return nil, nil
}
s := strconv.FormatInt(ns.Int64, 10)
return &s, nil
}
if ns, ok := v.(sql.NullFloat64); ok {
if !ns.Valid {
return nil, nil
}
s := strconv.FormatFloat(ns.Float64, 'g', -1, 32)
return &s, nil
}
if ns, ok := v.(sql.NullBool); ok {
if !ns.Valid {
return nil, nil
}
s := strconv.FormatBool(ns.Bool)
return &s, nil
}
}
return nil, fmt.Errorf("unsupported type: %v", v1.Kind())
}
Expand Down

0 comments on commit ba2f991

Please sign in to comment.