Skip to content

Commit

Permalink
Introduce bindingValue
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Jun 10, 2024
1 parent 484a24b commit 7536401
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 70 deletions.
16 changes: 11 additions & 5 deletions bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ type bindingSchema struct {
Fields []fieldMetadata `json:"fields"`
}

type bindingValue struct {
value *string
format string
schema *bindingSchema
}

func (bu *bindUploader) upload(bindings []driver.NamedValue) (*execResponse, error) {
bindingRows, err := bu.buildRowsAsBytes(bindings)
if err != nil {
Expand Down Expand Up @@ -237,13 +243,13 @@ func getBindValues(bindings []driver.NamedValue, params map[string]*string) (map
}
} else {
var val interface{}
var schema *bindingSchema
fmt := ""
var bv bindingValue
if t == sliceType {
// retrieve array binding data
t, val = snowflakeArrayToString(&binding, false)
} else {
val, fmt, schema, err = valueToString(binding.Value, tsmode, params)
bv, err = valueToString(binding.Value, tsmode, params)
val = bv.value
if err != nil {
return nil, err
}
Expand All @@ -256,8 +262,8 @@ func getBindValues(bindings []driver.NamedValue, params map[string]*string) (map
bindValues[bindingName(binding, idx)] = execBindParameter{
Type: t.String(),
Value: val,
Format: fmt,
Schema: schema,
Format: bv.format,
Schema: bv.schema,
}
idx++
}
Expand Down
60 changes: 30 additions & 30 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,129 +210,129 @@ func snowflakeTypeToGoForMaps[K comparable](ctx context.Context, valueMetadata f

// valueToString converts arbitrary golang type to a string. This is mainly used in binding data with placeholders
// in queries.
func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (*string, string, *bindingSchema, error) {
func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (bindingValue, error) {
logger.Debugf("TYPE: %v, %v", reflect.TypeOf(v), reflect.ValueOf(v))
if v == nil {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
v1 := reflect.ValueOf(v)
switch v1.Kind() {
case reflect.Bool:
s := strconv.FormatBool(v1.Bool())
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case reflect.Int64:
s := strconv.FormatInt(v1.Int(), 10)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case reflect.Float64:
s := strconv.FormatFloat(v1.Float(), 'g', -1, 32)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case reflect.String:
s := v1.String()
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case reflect.Slice, reflect.Map:
if v1.IsNil() {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
if bd, ok := v.([]byte); ok {
if tsmode == binaryType {
s := hex.EncodeToString(bd)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
}
}
// TODO: is this good enough?
s := v1.String()
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case reflect.Struct:
switch typedVal := v.(type) {
case time.Time:
return timeTypeValueToString(typedVal, tsmode)
case sql.NullTime:
if !typedVal.Valid {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
return timeTypeValueToString(typedVal.Time, tsmode)
case sql.NullBool:
if !typedVal.Valid {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
s := strconv.FormatBool(typedVal.Bool)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case sql.NullInt64:
if !typedVal.Valid {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
s := strconv.FormatInt(typedVal.Int64, 10)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case sql.NullFloat64:
if !typedVal.Valid {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
s := strconv.FormatFloat(typedVal.Float64, 'g', -1, 32)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case sql.NullString:
if !typedVal.Valid {
return nil, "", nil, nil
return bindingValue{nil, "", nil}, nil
}
return &typedVal.String, "", nil, nil
return bindingValue{&typedVal.String, "", nil}, nil
}
}
if sow, ok := v.(StructuredObjectWriter); ok {
sowc := &structuredObjectWriterContext{}
sowc.init(params)
err := sow.Write(sowc)
if err != nil {
return nil, "", nil, err
return bindingValue{nil, "", nil}, err
}
jsonBytes, err := json.Marshal(sowc.values)
if err != nil {
return nil, "", nil, err
return bindingValue{nil, "", nil}, err
}
jsonString := string(jsonBytes)
schema := bindingSchema{
Typ: "object",
Nullable: true,
Fields: sowc.toFields(),
}
return &jsonString, "json", &schema, nil
return bindingValue{&jsonString, "json", &schema}, nil
} else if typ, ok := v.(reflect.Type); ok {
sowc, err := buildSowcFromType(params, typ)
if err != nil {
return nil, "", nil, err
return bindingValue{nil, "", nil}, err
}
schema := bindingSchema{
Typ: "object",
Nullable: true,
Fields: sowc.toFields(),
}
return nil, "json", &schema, nil
return bindingValue{nil, "json", &schema}, nil
}
return nil, "", nil, fmt.Errorf("unsupported type: %v", v1.Kind())
return bindingValue{nil, "", nil}, fmt.Errorf("unsupported type: %v", v1.Kind())
}

func timeTypeValueToString(tm time.Time, tsmode snowflakeType) (*string, string, *bindingSchema, error) {
func timeTypeValueToString(tm time.Time, tsmode snowflakeType) (bindingValue, error) {
switch tsmode {
case dateType:
_, offset := tm.Zone()
tm = tm.Add(time.Second * time.Duration(offset))
s := strconv.FormatInt(tm.Unix()*1000, 10)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case timeType:
s := fmt.Sprintf("%d",
(tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond())
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case timestampNtzType, timestampLtzType:
unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10)
m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10)
unixTime.Mul(unixTime, m)
tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10)
s := unixTime.Add(unixTime, tmNanos).String()
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
case timestampTzType:
_, offset := tm.Zone()
s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440)
return &s, "", nil, nil
return bindingValue{&s, "", nil}, nil
}
return nil, "", nil, fmt.Errorf("unsupported time type: %v", tsmode)
return bindingValue{nil, "", nil}, fmt.Errorf("unsupported time type: %v", tsmode)
}

// extractTimestamp extracts the internal timestamp data to epoch time in seconds and milliseconds
Expand Down
66 changes: 33 additions & 33 deletions converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (o *testValueToStringStructuredObject) Write(sowc StructuredObjectWriterCon

func TestValueToString(t *testing.T) {
v := cmplx.Sqrt(-5 + 12i) // should never happen as Go sql package must have already validated.
_, _, _, err := valueToString(v, nullType, nil)
_, err := valueToString(v, nullType, nil)
if err == nil {
t.Errorf("should raise error: %v", v)
}
Expand All @@ -226,46 +226,46 @@ func TestValueToString(t *testing.T) {
expectedFloat64 := "1.1"
expectedString := "teststring"

s, fmt, schema, err := valueToString(localTime, timestampLtzType, nil)
bv, err := valueToString(localTime, timestampLtzType, nil)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, *s, expectedUnixTime)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, *bv.value, expectedUnixTime)

s, fmt, schema, err = valueToString(utcTime, timestampLtzType, nil)
bv, err = valueToString(utcTime, timestampLtzType, nil)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, *s, expectedUnixTime)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, *bv.value, expectedUnixTime)

s, fmt, schema, err = valueToString(sql.NullBool{Bool: true, Valid: true}, timestampLtzType, nil)
bv, err = valueToString(sql.NullBool{Bool: true, Valid: true}, timestampLtzType, nil)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, *s, expectedBool)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, *bv.value, expectedBool)

s, fmt, schema, err = valueToString(sql.NullInt64{Int64: 1, Valid: true}, timestampLtzType, nil)
bv, err = valueToString(sql.NullInt64{Int64: 1, Valid: true}, timestampLtzType, nil)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, *s, expectedInt64)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, *bv.value, expectedInt64)

s, fmt, schema, err = valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, timestampLtzType, nil)
bv, err = valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, timestampLtzType, nil)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, *s, expectedFloat64)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, *bv.value, expectedFloat64)

s, fmt, schema, err = valueToString(sql.NullString{String: "teststring", Valid: true}, timestampLtzType, nil)
bv, err = valueToString(sql.NullString{String: "teststring", Valid: true}, timestampLtzType, nil)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, *s, expectedString)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, *bv.value, expectedString)

s, fmt, schema, err = valueToString(&testValueToStringStructuredObject{s: "some string", i: 123, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)}, timestampLtzType, params)
bv, err = valueToString(&testValueToStringStructuredObject{s: "some string", i: 123, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)}, timestampLtzType, params)
assertNilF(t, err)
assertEqualE(t, fmt, "json")
assertDeepEqualE(t, *schema, bindingSchema{
assertEqualE(t, bv.format, "json")
assertDeepEqualE(t, *bv.schema, bindingSchema{
Typ: "object",
Nullable: true,
Fields: []fieldMetadata{
Expand All @@ -290,7 +290,7 @@ func TestValueToString(t *testing.T) {
},
},
})
assertEqualIgnoringWhitespaceE(t, *s, `{"date": "2024-05-24", "i": 123, "s": "some string"}`)
assertEqualIgnoringWhitespaceE(t, *bv.value, `{"date": "2024-05-24", "i": 123, "s": "some string"}`)
}

func TestExtractTimestamp(t *testing.T) {
Expand Down Expand Up @@ -2388,11 +2388,11 @@ func TestTimeTypeValueToString(t *testing.T) {

for _, tc := range testcases {
t.Run(tc.out, func(t *testing.T) {
output, fmt, schema, err := timeTypeValueToString(tc.in, tc.tsmode)
bv, err := timeTypeValueToString(tc.in, tc.tsmode)
assertNilF(t, err)
assertEmptyStringE(t, fmt)
assertNilE(t, schema)
assertEqualE(t, tc.out, *output)
assertEmptyStringE(t, bv.format)
assertNilE(t, bv.schema)
assertEqualE(t, tc.out, *bv.value)
})
}
}
2 changes: 1 addition & 1 deletion structured_type_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func TestObjectWithAllTypes(t *testing.T) {
assertTrueE(t, res.ltz.Equal(time.Date(2021, time.July, 21, 11, 22, 33, 0, warsawTz)))
assertTrueE(t, res.tz.Equal(time.Date(2022, time.August, 31, 13, 43, 22, 0, warsawTz)))
assertTrueE(t, res.ntz.Equal(time.Date(2023, time.May, 22, 1, 17, 19, 0, time.UTC)))
assertEqualE(t, res.so, simpleObject{s: "child", i: 9})
assertDeepEqualE(t, res.so, &simpleObject{s: "child", i: 9})
assertDeepEqualE(t, res.sArr, []string{"x", "y", "z"})
assertDeepEqualE(t, res.f64Arr, []float64{1.1, 2.2, 3.3})
assertDeepEqualE(t, res.someMap, map[string]bool{"x": true, "y": false})
Expand Down
2 changes: 1 addition & 1 deletion structured_type_write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestBindingObjectWithSchema(t *testing.T) {
assertNilF(t, err)
skipStructuredTypesTestsOnGHActions(t)
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMP, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER)))")
dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER)))")
defer func() {
dbt.mustExec("DROP TABLE IF EXISTS test_object_binding")
}()
Expand Down

0 comments on commit 7536401

Please sign in to comment.