From 7536401e78109b04c49a9404ddc504930f018d77 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Thu, 6 Jun 2024 08:47:00 +0200 Subject: [PATCH] Introduce bindingValue --- bind_uploader.go | 16 ++++++--- converter.go | 60 +++++++++++++++---------------- converter_test.go | 66 +++++++++++++++++------------------ structured_type_read_test.go | 2 +- structured_type_write_test.go | 2 +- 5 files changed, 76 insertions(+), 70 deletions(-) diff --git a/bind_uploader.go b/bind_uploader.go index 7f6d75bd9..a2680d8f7 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -36,6 +36,12 @@ type bindingSchema struct { Fields []fieldMetadata `json:"fields"` } +type bindingValue struct { + value *string + format string + schema *bindingSchema +} + func (bu *bindUploader) upload(bindings []driver.NamedValue) (*execResponse, error) { bindingRows, err := bu.buildRowsAsBytes(bindings) if err != nil { @@ -237,13 +243,13 @@ func getBindValues(bindings []driver.NamedValue, params map[string]*string) (map } } else { var val interface{} - var schema *bindingSchema - fmt := "" + var bv bindingValue if t == sliceType { // retrieve array binding data t, val = snowflakeArrayToString(&binding, false) } else { - val, fmt, schema, err = valueToString(binding.Value, tsmode, params) + bv, err = valueToString(binding.Value, tsmode, params) + val = bv.value if err != nil { return nil, err } @@ -256,8 +262,8 @@ func getBindValues(bindings []driver.NamedValue, params map[string]*string) (map bindValues[bindingName(binding, idx)] = execBindParameter{ Type: t.String(), Value: val, - Format: fmt, - Schema: schema, + Format: bv.format, + Schema: bv.schema, } idx++ } diff --git a/converter.go b/converter.go index b7470eae9..351f2eb10 100644 --- a/converter.go +++ b/converter.go @@ -210,70 +210,70 @@ func snowflakeTypeToGoForMaps[K comparable](ctx context.Context, valueMetadata f // valueToString converts arbitrary golang type to a string. This is mainly used in binding data with placeholders // in queries. -func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (*string, string, *bindingSchema, error) { +func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (bindingValue, error) { logger.Debugf("TYPE: %v, %v", reflect.TypeOf(v), reflect.ValueOf(v)) if v == nil { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } v1 := reflect.ValueOf(v) switch v1.Kind() { case reflect.Bool: s := strconv.FormatBool(v1.Bool()) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case reflect.Int64: s := strconv.FormatInt(v1.Int(), 10) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case reflect.Float64: s := strconv.FormatFloat(v1.Float(), 'g', -1, 32) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case reflect.String: s := v1.String() - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case reflect.Slice, reflect.Map: if v1.IsNil() { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } if bd, ok := v.([]byte); ok { if tsmode == binaryType { s := hex.EncodeToString(bd) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil } } // TODO: is this good enough? s := v1.String() - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case reflect.Struct: switch typedVal := v.(type) { case time.Time: return timeTypeValueToString(typedVal, tsmode) case sql.NullTime: if !typedVal.Valid { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } return timeTypeValueToString(typedVal.Time, tsmode) case sql.NullBool: if !typedVal.Valid { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } s := strconv.FormatBool(typedVal.Bool) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case sql.NullInt64: if !typedVal.Valid { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } s := strconv.FormatInt(typedVal.Int64, 10) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case sql.NullFloat64: if !typedVal.Valid { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } s := strconv.FormatFloat(typedVal.Float64, 'g', -1, 32) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case sql.NullString: if !typedVal.Valid { - return nil, "", nil, nil + return bindingValue{nil, "", nil}, nil } - return &typedVal.String, "", nil, nil + return bindingValue{&typedVal.String, "", nil}, nil } } if sow, ok := v.(StructuredObjectWriter); ok { @@ -281,11 +281,11 @@ func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*stri sowc.init(params) err := sow.Write(sowc) if err != nil { - return nil, "", nil, err + return bindingValue{nil, "", nil}, err } jsonBytes, err := json.Marshal(sowc.values) if err != nil { - return nil, "", nil, err + return bindingValue{nil, "", nil}, err } jsonString := string(jsonBytes) schema := bindingSchema{ @@ -293,46 +293,46 @@ func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*stri Nullable: true, Fields: sowc.toFields(), } - return &jsonString, "json", &schema, nil + return bindingValue{&jsonString, "json", &schema}, nil } else if typ, ok := v.(reflect.Type); ok { sowc, err := buildSowcFromType(params, typ) if err != nil { - return nil, "", nil, err + return bindingValue{nil, "", nil}, err } schema := bindingSchema{ Typ: "object", Nullable: true, Fields: sowc.toFields(), } - return nil, "json", &schema, nil + return bindingValue{nil, "json", &schema}, nil } - return nil, "", nil, fmt.Errorf("unsupported type: %v", v1.Kind()) + return bindingValue{nil, "", nil}, fmt.Errorf("unsupported type: %v", v1.Kind()) } -func timeTypeValueToString(tm time.Time, tsmode snowflakeType) (*string, string, *bindingSchema, error) { +func timeTypeValueToString(tm time.Time, tsmode snowflakeType) (bindingValue, error) { switch tsmode { case dateType: _, offset := tm.Zone() tm = tm.Add(time.Second * time.Duration(offset)) s := strconv.FormatInt(tm.Unix()*1000, 10) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case timeType: s := fmt.Sprintf("%d", (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case timestampNtzType, timestampLtzType: unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10) m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) unixTime.Mul(unixTime, m) tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10) s := unixTime.Add(unixTime, tmNanos).String() - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil case timestampTzType: _, offset := tm.Zone() s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440) - return &s, "", nil, nil + return bindingValue{&s, "", nil}, nil } - return nil, "", nil, fmt.Errorf("unsupported time type: %v", tsmode) + return bindingValue{nil, "", nil}, fmt.Errorf("unsupported time type: %v", tsmode) } // extractTimestamp extracts the internal timestamp data to epoch time in seconds and milliseconds diff --git a/converter_test.go b/converter_test.go index bcf558fc2..619308ebb 100644 --- a/converter_test.go +++ b/converter_test.go @@ -209,7 +209,7 @@ func (o *testValueToStringStructuredObject) Write(sowc StructuredObjectWriterCon func TestValueToString(t *testing.T) { v := cmplx.Sqrt(-5 + 12i) // should never happen as Go sql package must have already validated. - _, _, _, err := valueToString(v, nullType, nil) + _, err := valueToString(v, nullType, nil) if err == nil { t.Errorf("should raise error: %v", v) } @@ -226,46 +226,46 @@ func TestValueToString(t *testing.T) { expectedFloat64 := "1.1" expectedString := "teststring" - s, fmt, schema, err := valueToString(localTime, timestampLtzType, nil) + bv, err := valueToString(localTime, timestampLtzType, nil) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, *s, expectedUnixTime) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedUnixTime) - s, fmt, schema, err = valueToString(utcTime, timestampLtzType, nil) + bv, err = valueToString(utcTime, timestampLtzType, nil) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, *s, expectedUnixTime) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedUnixTime) - s, fmt, schema, err = valueToString(sql.NullBool{Bool: true, Valid: true}, timestampLtzType, nil) + bv, err = valueToString(sql.NullBool{Bool: true, Valid: true}, timestampLtzType, nil) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, *s, expectedBool) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedBool) - s, fmt, schema, err = valueToString(sql.NullInt64{Int64: 1, Valid: true}, timestampLtzType, nil) + bv, err = valueToString(sql.NullInt64{Int64: 1, Valid: true}, timestampLtzType, nil) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, *s, expectedInt64) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedInt64) - s, fmt, schema, err = valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, timestampLtzType, nil) + bv, err = valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, timestampLtzType, nil) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, *s, expectedFloat64) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedFloat64) - s, fmt, schema, err = valueToString(sql.NullString{String: "teststring", Valid: true}, timestampLtzType, nil) + bv, err = valueToString(sql.NullString{String: "teststring", Valid: true}, timestampLtzType, nil) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, *s, expectedString) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedString) - s, fmt, schema, err = valueToString(&testValueToStringStructuredObject{s: "some string", i: 123, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)}, timestampLtzType, params) + bv, err = valueToString(&testValueToStringStructuredObject{s: "some string", i: 123, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)}, timestampLtzType, params) assertNilF(t, err) - assertEqualE(t, fmt, "json") - assertDeepEqualE(t, *schema, bindingSchema{ + assertEqualE(t, bv.format, "json") + assertDeepEqualE(t, *bv.schema, bindingSchema{ Typ: "object", Nullable: true, Fields: []fieldMetadata{ @@ -290,7 +290,7 @@ func TestValueToString(t *testing.T) { }, }, }) - assertEqualIgnoringWhitespaceE(t, *s, `{"date": "2024-05-24", "i": 123, "s": "some string"}`) + assertEqualIgnoringWhitespaceE(t, *bv.value, `{"date": "2024-05-24", "i": 123, "s": "some string"}`) } func TestExtractTimestamp(t *testing.T) { @@ -2388,11 +2388,11 @@ func TestTimeTypeValueToString(t *testing.T) { for _, tc := range testcases { t.Run(tc.out, func(t *testing.T) { - output, fmt, schema, err := timeTypeValueToString(tc.in, tc.tsmode) + bv, err := timeTypeValueToString(tc.in, tc.tsmode) assertNilF(t, err) - assertEmptyStringE(t, fmt) - assertNilE(t, schema) - assertEqualE(t, tc.out, *output) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, tc.out, *bv.value) }) } } diff --git a/structured_type_read_test.go b/structured_type_read_test.go index bca461b86..e59a3a8d2 100644 --- a/structured_type_read_test.go +++ b/structured_type_read_test.go @@ -230,7 +230,7 @@ func TestObjectWithAllTypes(t *testing.T) { assertTrueE(t, res.ltz.Equal(time.Date(2021, time.July, 21, 11, 22, 33, 0, warsawTz))) assertTrueE(t, res.tz.Equal(time.Date(2022, time.August, 31, 13, 43, 22, 0, warsawTz))) assertTrueE(t, res.ntz.Equal(time.Date(2023, time.May, 22, 1, 17, 19, 0, time.UTC))) - assertEqualE(t, res.so, simpleObject{s: "child", i: 9}) + assertDeepEqualE(t, res.so, &simpleObject{s: "child", i: 9}) assertDeepEqualE(t, res.sArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.f64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.someMap, map[string]bool{"x": true, "y": false}) diff --git a/structured_type_write_test.go b/structured_type_write_test.go index 896718c67..f8676de27 100644 --- a/structured_type_write_test.go +++ b/structured_type_write_test.go @@ -132,7 +132,7 @@ func TestBindingObjectWithSchema(t *testing.T) { assertNilF(t, err) skipStructuredTypesTestsOnGHActions(t) runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMP, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER)))") + dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }()