diff --git a/lib/column/column_gen.go b/lib/column/column_gen.go index 2a56a3e..3c61e41 100644 --- a/lib/column/column_gen.go +++ b/lib/column/column_gen.go @@ -29,7 +29,6 @@ import ( "net" "reflect" "strings" - "time" "github.com/timeplus-io/proton-go-driver/v2/types" ) @@ -120,9 +119,9 @@ func (t Type) Column() (Interface, error) { case "string": return &String{}, nil case "json": - return (&Json{}).parse(false) - case "nullable_json": - return (&Json{}).parse(true) + return (&Json{}).parse(false) + case "nullable_json": + return (&Json{}).parse(true) } switch strType := string(t); { @@ -200,7 +199,7 @@ var ( scanTypeByte = reflect.TypeOf([]byte{}) scanTypeUUID = reflect.TypeOf(uuid.UUID{}) scanTypeDate = reflect.TypeOf(types.Date{}) - scanTypeTime = reflect.TypeOf(time.Time{}) + scanTypeTime = reflect.TypeOf(types.Datetime{}) scanTypeRing = reflect.TypeOf(orb.Ring{}) scanTypePoint = reflect.TypeOf(orb.Point{}) scanTypeSlice = reflect.TypeOf([]interface{}{}) diff --git a/lib/column/datetime.go b/lib/column/datetime.go index 0c098a6..d4e7d44 100644 --- a/lib/column/datetime.go +++ b/lib/column/datetime.go @@ -73,15 +73,15 @@ func (dt *DateTime) Row(i int, ptr bool) interface{} { func (dt *DateTime) ScanRow(dest interface{}, row int) error { switch d := dest.(type) { case *time.Time: - *d = dt.row(row) + *d = dt.row(row).Time case **time.Time: *d = new(time.Time) - **d = dt.row(row) + **d = dt.row(row).Time case *types.Datetime: - *d = types.Datetime{dt.row(row)} + *d = dt.row(row) case **types.Datetime: *d = new(types.Datetime) - **d = types.Datetime{dt.row(row)} + **d = dt.row(row) default: return &ColumnConverterError{ Op: "ScanRow", @@ -195,12 +195,12 @@ func (dt *DateTime) Encode(encoder *binary.Encoder) error { return dt.values.Encode(encoder) } -func (dt *DateTime) row(i int) time.Time { +func (dt *DateTime) row(i int) types.Datetime { v := time.Unix(int64(dt.values[i]), 0) if dt.timezone != nil { v = v.In(dt.timezone) } - return v + return types.Datetime{Time: v} } var _ Interface = (*DateTime)(nil) diff --git a/lib/column/datetime64.go b/lib/column/datetime64.go index de1eb9c..a9825ac 100644 --- a/lib/column/datetime64.go +++ b/lib/column/datetime64.go @@ -87,15 +87,15 @@ func (dt *DateTime64) Row(i int, ptr bool) interface{} { func (dt *DateTime64) ScanRow(dest interface{}, row int) error { switch d := dest.(type) { case *time.Time: - *d = dt.row(row) + *d = dt.row(row).Time case **time.Time: *d = new(time.Time) - **d = dt.row(row) + **d = dt.row(row).Time case *types.Datetime: - *d = types.Datetime{dt.row(row)} + *d = dt.row(row) case **types.Datetime: *d = new(types.Datetime) - **d = types.Datetime{dt.row(row)} + **d = dt.row(row) default: return &ColumnConverterError{ Op: "ScanRow", @@ -213,7 +213,7 @@ func (dt *DateTime64) Encode(encoder *binary.Encoder) error { return dt.values.Encode(encoder) } -func (dt *DateTime64) row(i int) time.Time { +func (dt *DateTime64) row(i int) types.Datetime { var nano int64 if dt.precision < 19 { nano = dt.values[i] * int64(math.Pow10(9-dt.precision)) @@ -226,7 +226,7 @@ func (dt *DateTime64) row(i int) time.Time { if dt.timezone != nil { time = time.In(dt.timezone) } - return time + return types.Datetime{Time: time} } func (dt *DateTime64) timeToInt64(t time.Time) int64 { diff --git a/tests/array_test.go b/tests/array_test.go index 6ccdb73..687d544 100644 --- a/tests/array_test.go +++ b/tests/array_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -56,14 +57,14 @@ func TestArray(t *testing.T) { if err := conn.Exec(ctx, ddl); assert.NoError(t, err) { if batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_array (* except _tp_time)"); assert.NoError(t, err) { var ( - timestamp = time.Now().Truncate(time.Second) + timestamp = types.Datetime{Time: time.Now().Truncate(time.Second)} col1Data = []string{"A", "b", "c"} col2Data = [][]uint32{ {1, 2}, {3, 87}, {33, 3, 847}, } - col3Data = [][][]time.Time{ + col3Data = [][][]types.Datetime{ { { timestamp, @@ -96,7 +97,7 @@ func TestArray(t *testing.T) { var ( col1 []string col2 [][]uint32 - col3 [][][]time.Time + col3 [][][]types.Datetime ) if err := rows.Scan(&col1, &col2, &col3); assert.NoError(t, err) { assert.Equal(t, col1Data, col1) @@ -143,14 +144,14 @@ func TestColumnarArray(t *testing.T) { }() if err := conn.Exec(ctx, ddl); assert.NoError(t, err) { var ( - timestamp = time.Now().Truncate(time.Second) + timestamp = types.Datetime{Time: time.Now().Truncate(time.Second)} col1Data = []string{"A", "b", "c"} col2Data = [][]uint32{ {1, 2}, {3, 87}, {33, 3, 847}, } - col3Data = [][][]time.Time{ + col3Data = [][][]types.Datetime{ { { timestamp, @@ -174,7 +175,7 @@ func TestColumnarArray(t *testing.T) { col1DataColArr [][]string col2DataColArr [][][]uint32 - col3DataColArr [][][][]time.Time + col3DataColArr [][][][]types.Datetime ) for i := 0; i < 10; i++ { @@ -199,7 +200,7 @@ func TestColumnarArray(t *testing.T) { var ( col1 []string col2 [][]uint32 - col3 [][][]time.Time + col3 [][][]types.Datetime ) if err := rows.Scan(&col1, &col2, &col3); assert.NoError(t, err) { assert.Equal(t, col1Data, col1) diff --git a/tests/datetime64_test.go b/tests/datetime64_test.go index cd24c30..ddff25e 100644 --- a/tests/datetime64_test.go +++ b/tests/datetime64_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -63,26 +64,26 @@ func TestDateTime64(t *testing.T) { if err := conn.Exec(ctx, ddl); assert.NoError(t, err) { if batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_datetime64 (* except _tp_time)"); assert.NoError(t, err) { var ( - datetime1 = time.Now().Truncate(time.Millisecond) - datetime2 = time.Now().Truncate(time.Nanosecond) - datetime3 = time.Now().Truncate(time.Second) + datetime1 = types.Datetime{time.Now().Truncate(time.Millisecond)} + datetime2 = types.Datetime{time.Now().Truncate(time.Nanosecond)} + datetime3 = types.Datetime{time.Now().Truncate(time.Second)} ) if err := batch.Append( datetime1, datetime2, datetime3, &datetime1, - []time.Time{datetime1, datetime1}, - []*time.Time{&datetime3, nil, &datetime3}, + []types.Datetime{datetime1, datetime1}, + []*types.Datetime{&datetime3, nil, &datetime3}, ); assert.NoError(t, err) { if err := batch.Send(); assert.NoError(t, err) { var ( - col1 time.Time - col2 time.Time - col3 time.Time - col4 *time.Time - col5 []time.Time - col6 []*time.Time + col1 types.Datetime + col2 types.Datetime + col3 types.Datetime + col4 *types.Datetime + col5 []types.Datetime + col6 []*types.Datetime ) if err := conn.QueryRow(ctx, "SELECT (* except _tp_time) FROM test_datetime64 WHERE _tp_time > earliest_ts() LIMIT 1").Scan(&col1, &col2, &col3, &col4, &col5, &col6); assert.NoError(t, err) { assert.Equal(t, datetime1, col1) @@ -154,12 +155,12 @@ func TestNullableDateTime64(t *testing.T) { if err := batch.Append(datetime1, datetime1, datetime2, datetime2, datetime3, datetime3); assert.NoError(t, err) { if err := batch.Send(); assert.NoError(t, err) { var ( - col1 time.Time - col1Null *time.Time - col2 time.Time - col2Null *time.Time - col3 time.Time - col3Null *time.Time + col1 types.Datetime + col1Null *types.Datetime + col2 types.Datetime + col2Null *types.Datetime + col3 types.Datetime + col3Null *types.Datetime ) if err := conn.QueryRow(ctx, "SELECT (* except _tp_time) FROM test_datetime64 WHERE _tp_time > earliest_ts() LIMIT 1").Scan( &col1, &col1Null, @@ -188,12 +189,12 @@ func TestNullableDateTime64(t *testing.T) { if err := batch.Append(datetime1, nil, datetime2, nil, datetime3, nil); assert.NoError(t, err) { if err := batch.Send(); assert.NoError(t, err) { var ( - col1 time.Time - col1Null *time.Time - col2 time.Time - col2Null *time.Time - col3 time.Time - col3Null *time.Time + col1 types.Datetime + col1Null *types.Datetime + col2 types.Datetime + col2Null *types.Datetime + col3 types.Datetime + col3Null *types.Datetime ) if err := conn.QueryRow(ctx, "SELECT (* except _tp_time) FROM test_datetime64 WHERE _tp_time > earliest_ts() LIMIT 1").Scan( &col1, &col1Null, @@ -259,14 +260,14 @@ func TestColumnarDateTime64(t *testing.T) { if batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_datetime64 (* except _tp_time)"); assert.NoError(t, err) { var ( id []uint64 - col1Data []time.Time - col2Data []*time.Time - col3Data [][]time.Time - col4Data [][]*time.Time + col1Data []types.Datetime + col2Data []*types.Datetime + col3Data [][]types.Datetime + col4Data [][]*types.Datetime ) var ( - datetime1 = time.Now().Truncate(time.Millisecond) - datetime2 = time.Now().Truncate(time.Second) + datetime1 = types.Datetime{Time: time.Now().Truncate(time.Millisecond)} + datetime2 = types.Datetime{Time: time.Now().Truncate(time.Second)} ) for i := 0; i < 1000; i++ { id = append(id, uint64(i)) @@ -276,10 +277,10 @@ func TestColumnarDateTime64(t *testing.T) { } else { col2Data = append(col2Data, nil) } - col3Data = append(col3Data, []time.Time{ + col3Data = append(col3Data, []types.Datetime{ datetime1, datetime2, datetime1, }) - col4Data = append(col4Data, []*time.Time{ + col4Data = append(col4Data, []*types.Datetime{ &datetime2, nil, &datetime1, }) } @@ -302,16 +303,16 @@ func TestColumnarDateTime64(t *testing.T) { } if assert.NoError(t, batch.Send()) { var result struct { - Col1 time.Time - Col2 *time.Time - Col3 []time.Time - Col4 []*time.Time + Col1 types.Datetime + Col2 *types.Datetime + Col3 []types.Datetime + Col4 []*types.Datetime } if err := conn.QueryRow(ctx, "SELECT Col1, Col2, Col3, Col4 FROM test_datetime64 WHERE ID = $1 AND _tp_time > earliest_ts() LIMIT 1", 11).ScanStruct(&result); assert.NoError(t, err) { if assert.Nil(t, result.Col2) { assert.Equal(t, datetime1, result.Col1) - assert.Equal(t, []time.Time{datetime1, datetime2, datetime1}, result.Col3) - assert.Equal(t, []*time.Time{&datetime2, nil, &datetime1}, result.Col4) + assert.Equal(t, []types.Datetime{datetime1, datetime2, datetime1}, result.Col3) + assert.Equal(t, []*types.Datetime{&datetime2, nil, &datetime1}, result.Col4) } } } diff --git a/tests/datetime_test.go b/tests/datetime_test.go index 0d761a1..09e58d3 100644 --- a/tests/datetime_test.go +++ b/tests/datetime_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -58,23 +59,23 @@ func TestDateTime(t *testing.T) { }() if err := conn.Exec(ctx, ddl); assert.NoError(t, err) { if batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_datetime (* except _tp_time)"); assert.NoError(t, err) { - datetime := time.Now().Truncate(time.Second) + datetime := types.Datetime{Time: time.Now().Truncate(time.Second)} if err := batch.Append( datetime, datetime, datetime, &datetime, - []time.Time{datetime, datetime}, - []*time.Time{&datetime, nil, &datetime}, + []types.Datetime{datetime, datetime}, + []*types.Datetime{&datetime, nil, &datetime}, ); assert.NoError(t, err) { if err := batch.Send(); assert.NoError(t, err) { var ( - col1 time.Time - col2 time.Time - col3 time.Time - col4 *time.Time - col5 []time.Time - col6 []*time.Time + col1 types.Datetime + col2 types.Datetime + col3 types.Datetime + col4 *types.Datetime + col5 []types.Datetime + col6 []*types.Datetime ) if err := conn.QueryRow(ctx, "SELECT (* except _tp_time) FROM test_datetime WHERE _tp_time > earliest_ts() LIMIT 1").Scan(&col1, &col2, &col3, &col4, &col5, &col6); assert.NoError(t, err) { assert.Equal(t, datetime, col1) @@ -138,12 +139,12 @@ func TestNullableDateTime(t *testing.T) { if err := batch.Append(datetime, datetime, datetime, datetime, datetime, datetime); assert.NoError(t, err) { if err := batch.Send(); assert.NoError(t, err) { var ( - col1 time.Time - col1Null *time.Time - col2 time.Time - col2Null *time.Time - col3 time.Time - col3Null *time.Time + col1 types.Datetime + col1Null *types.Datetime + col2 types.Datetime + col2Null *types.Datetime + col3 types.Datetime + col3Null *types.Datetime ) if err := conn.QueryRow(ctx, "SELECT (* except _tp_time) FROM test_datetime WHERE _tp_time > earliest_ts() LIMIT 1").Scan( &col1, &col1Null, @@ -168,12 +169,12 @@ func TestNullableDateTime(t *testing.T) { if err := batch.Append(datetime, nil, datetime, nil, datetime, nil); assert.NoError(t, err) { if err := batch.Send(); assert.NoError(t, err) { var ( - col1 time.Time - col1Null *time.Time - col2 time.Time - col2Null *time.Time - col3 time.Time - col3Null *time.Time + col1 types.Datetime + col1Null *types.Datetime + col2 types.Datetime + col2Null *types.Datetime + col3 types.Datetime + col3Null *types.Datetime ) if err := conn.QueryRow(ctx, "SELECT (* except _tp_time) FROM test_datetime WHERE _tp_time > earliest_ts() LIMIT 1").Scan( &col1, &col1Null, @@ -238,14 +239,14 @@ func TestColumnarDateTime(t *testing.T) { if batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_datetime (* except _tp_time)"); assert.NoError(t, err) { var ( id []uint64 - col1Data []time.Time - col2Data []*time.Time - col3Data [][]time.Time - col4Data [][]*time.Time + col1Data []types.Datetime + col2Data []*types.Datetime + col3Data [][]types.Datetime + col4Data [][]*types.Datetime ) var ( - datetime1 = time.Now().Truncate(time.Second) - datetime2 = time.Now().Truncate(time.Second) + datetime1 = types.Datetime{Time: time.Now().Truncate(time.Second)} + datetime2 = types.Datetime{Time: time.Now().Truncate(time.Second)} ) for i := 0; i < 1000; i++ { id = append(id, uint64(i)) @@ -255,10 +256,10 @@ func TestColumnarDateTime(t *testing.T) { } else { col2Data = append(col2Data, nil) } - col3Data = append(col3Data, []time.Time{ + col3Data = append(col3Data, []types.Datetime{ datetime1, datetime2, datetime1, }) - col4Data = append(col4Data, []*time.Time{ + col4Data = append(col4Data, []*types.Datetime{ &datetime2, nil, &datetime1, }) } @@ -281,16 +282,16 @@ func TestColumnarDateTime(t *testing.T) { } if assert.NoError(t, batch.Send()) { var result struct { - Col1 time.Time - Col2 *time.Time - Col3 []time.Time - Col4 []*time.Time + Col1 types.Datetime + Col2 *types.Datetime + Col3 []types.Datetime + Col4 []*types.Datetime } if err := conn.QueryRow(ctx, "SELECT Col1, Col2, Col3, Col4 FROM test_datetime WHERE ID = $1 AND _tp_time > earliest_ts() LIMIT 1", 11).ScanStruct(&result); assert.NoError(t, err) { if assert.Nil(t, result.Col2) { assert.Equal(t, datetime1, result.Col1) - assert.Equal(t, []time.Time{datetime1, datetime2, datetime1}, result.Col3) - assert.Equal(t, []*time.Time{&datetime2, nil, &datetime1}, result.Col4) + assert.Equal(t, []types.Datetime{datetime1, datetime2, datetime1}, result.Col3) + assert.Equal(t, []*types.Datetime{&datetime2, nil, &datetime1}, result.Col4) } } } diff --git a/tests/issues/357_test.go b/tests/issues/357_test.go index 85627dd..7488fa6 100644 --- a/tests/issues/357_test.go +++ b/tests/issues/357_test.go @@ -19,6 +19,7 @@ package issues import ( "database/sql" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -49,7 +50,7 @@ func TestIssue357(t *testing.T) { if err := scope.Commit(); assert.NoError(t, err) { var ( col1 int32 - col2 time.Time + col2 types.Datetime ) if err := conn.QueryRow("SELECT (* except _tp_time) FROM issue_357 WHERE _tp_time > earliest_ts() LIMIT 1").Scan(&col1, &col2); assert.NoError(t, err) { assert.Equal(t, int32(42), col1) diff --git a/tests/marshal_test.go b/tests/marshal_test.go new file mode 100644 index 0000000..4e1b762 --- /dev/null +++ b/tests/marshal_test.go @@ -0,0 +1,116 @@ +package tests + +import ( + "encoding/json" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/timeplus-io/proton-go-driver/v2/types" + "gopkg.in/yaml.v3" + "testing" + "time" +) + +func TestDateTimeMarshal(t *testing.T) { + datetimes := []types.Datetime{{time.Date(2000, 1, 1, 1, 1, 0, 0, time.Local)}, {time.Date(2000, 1, 1, 1, 1, 0, 1, time.Local)}, {time.Date(2000, 1, 1, 1, 1, 0, 10, time.Local)}} + datetimestrs := []string{"2000-01-01 01:01:00", "2000-01-01 01:01:00.000000001", "2000-01-01 01:01:00.00000001"} + for i, datetime := range datetimes { + str := datetimestrs[i] + { + datetimeMap := map[string]types.Datetime{"time": datetime} + s, err := json.Marshal(datetimeMap) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("{\"time\":\"%s\"}", str), string(s)) + } + { + datetimeMap := map[string]types.Datetime{"time": datetime} + s, err := yaml.Marshal(datetimeMap) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("time: \"%s\"\n", str), string(s)) + } + } +} + +func TestDateTimeUnmarshal(t *testing.T) { + datetimes := []types.Datetime{{time.Date(2000, 1, 1, 1, 1, 0, 0, time.Local)}, {time.Date(2000, 1, 1, 1, 1, 0, 1, time.Local)}, {time.Date(2000, 1, 1, 1, 1, 0, 10, time.Local)}} + datetimestrs := []string{"2000-01-01 01:01:00", "2000-01-01 01:01:00.000000001", "2000-01-01 01:01:00.00000001"} + for i, datetime := range datetimes { + str := datetimestrs[i] + { + var actualDatetimeMap map[string]types.Datetime + j := fmt.Sprintf(`{"time": "%s"}`, str) + assert.NoError(t, json.Unmarshal([]byte(j), &actualDatetimeMap)) + assert.Equal(t, datetime, actualDatetimeMap["time"]) + } + { + var actualDatetimeMap map[string]types.Datetime + y := fmt.Sprintf(`"time": "%s"`, str) + assert.NoError(t, yaml.Unmarshal([]byte(y), &actualDatetimeMap)) + assert.Equal(t, datetime, actualDatetimeMap["time"]) + } + { + var actualDatetimeMap map[string]types.Datetime + y := fmt.Sprintf(`"time": '%s'`, str) + assert.NoError(t, yaml.Unmarshal([]byte(y), &actualDatetimeMap)) + assert.Equal(t, datetime, actualDatetimeMap["time"]) + } + { + var actualDatetimeMap map[string]types.Datetime + y := fmt.Sprintf(`"time": %s`, str) + assert.NoError(t, yaml.Unmarshal([]byte(y), &actualDatetimeMap)) + assert.Equal(t, datetime, actualDatetimeMap["time"]) + } + } +} + +func TestDateMarshal(t *testing.T) { + dates := []types.Date{{time.Date(2000, 1, 10, 1, 1, 1, 1, time.Local)}, {time.Date(2077, 1, 1, 1, 1, 0, 1, time.Local)}, {time.Date(1970, 1, 9, 1, 1, 0, 10, time.Local)}} + datestrs := []string{"2000-01-10", "2077-01-01", "1970-01-09"} + for i, date := range dates { + str := datestrs[i] + { + dateMap := map[string]types.Date{"time": date} + s, err := json.Marshal(dateMap) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("{\"time\":\"%s\"}", str), string(s)) + } + { + dateMap := map[string]types.Date{"time": date} + s, err := yaml.Marshal(dateMap) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("time: \"%s\"\n", str), string(s)) + } + } +} + +func TestDateUnmarshal(t *testing.T) { + dates := []types.Date{{time.Date(2000, 1, 10, 1, 1, 1, 1, time.Local)}, {time.Date(2077, 1, 1, 1, 1, 0, 1, time.Local)}, {time.Date(1970, 1, 9, 1, 1, 0, 10, time.Local)}} + datestrs := []string{"2000-01-10", "2077-01-01", "1970-01-09"} + for i, date := range dates { + date.Time = time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, time.Local) + str := datestrs[i] + { + var actualDateMap map[string]types.Date + j := fmt.Sprintf(`{"time": "%s"}`, str) + assert.NoError(t, json.Unmarshal([]byte(j), &actualDateMap)) + assert.Equal(t, date, actualDateMap["time"]) + } + { + var actualDateMap map[string]types.Date + y := fmt.Sprintf(`"time": "%s"`, str) + assert.NoError(t, yaml.Unmarshal([]byte(y), &actualDateMap)) + assert.Equal(t, date, actualDateMap["time"]) + } + { + var actualDateMap map[string]types.Date + y := fmt.Sprintf(`"time": '%s'`, str) + assert.NoError(t, yaml.Unmarshal([]byte(y), &actualDateMap)) + assert.Equal(t, date, actualDateMap["time"]) + } + { + var actualDateMap map[string]types.Date + y := fmt.Sprintf(`"time": %s`, str) + assert.NoError(t, yaml.Unmarshal([]byte(y), &actualDateMap)) + assert.Equal(t, date, actualDateMap["time"]) + } + } +} diff --git a/tests/nullable_array_test.go b/tests/nullable_array_test.go index 1a4ce3e..f900b66 100644 --- a/tests/nullable_array_test.go +++ b/tests/nullable_array_test.go @@ -81,7 +81,7 @@ func TestNullableArray(t *testing.T) { boolTrue = true boolFalse = false decimalVal = decimal.New(25, 0) - datetime = time.Now().Truncate(time.Second) + datetime = types.Datetime{Time: time.Now().Truncate(time.Second)} enum1Val = "click" enum2Val = "house" fixed1Val = "Click" @@ -101,8 +101,8 @@ func TestNullableArray(t *testing.T) { []*uint8{&uint8Val, nil, &uint8Val}, []*types.Date{&dateVal, nil, &dateVal}, []*types.Date{&dateVal, nil, &dateVal}, - []*time.Time{&datetime, nil, &datetime}, - []*time.Time{&datetime, nil, &datetime}, + []*types.Datetime{&datetime, nil, &datetime}, + []*types.Datetime{&datetime, nil, &datetime}, []*decimal.Decimal{&decimalVal, nil, &decimalVal}, []*string{&enum1Val, nil, &enum2Val}, []*string{&enum1Val, nil, &enum2Val}, @@ -121,8 +121,8 @@ func TestNullableArray(t *testing.T) { Col2 []*uint8 Col3 []*types.Date Col4 []*types.Date - Col5 []*time.Time - Col6 []*time.Time + Col5 []*types.Datetime + Col6 []*types.Datetime Col7 []*decimal.Decimal Col8 []*string Col9 []*string @@ -137,8 +137,8 @@ func TestNullableArray(t *testing.T) { assert.Equal(t, []*uint8{&uint8Val, nil, &uint8Val}, result.Col2) assert.Equal(t, []*types.Date{&dateVal, nil, &dateVal}, result.Col3) assert.Equal(t, []*types.Date{&dateVal, nil, &dateVal}, result.Col4) - assert.Equal(t, []*time.Time{&datetime, nil, &datetime}, result.Col5) - assert.Equal(t, []*time.Time{&datetime, nil, &datetime}, result.Col6) + assert.Equal(t, []*types.Datetime{&datetime, nil, &datetime}, result.Col5) + assert.Equal(t, []*types.Datetime{&datetime, nil, &datetime}, result.Col6) if assert.Nil(t, result.Col7[1]) { assert.True(t, decimalVal.Equal(*result.Col7[0])) assert.True(t, decimalVal.Equal(*result.Col7[2])) diff --git a/tests/std/array_test.go b/tests/std/array_test.go index de287c7..f0b20ba 100644 --- a/tests/std/array_test.go +++ b/tests/std/array_test.go @@ -19,6 +19,7 @@ package std import ( "database/sql" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -44,14 +45,14 @@ func TestStdArray(t *testing.T) { } if batch, err := scope.Prepare("INSERT INTO test_array (* except _tp_time)"); assert.NoError(t, err) { var ( - timestamp = time.Now().Truncate(time.Second) + timestamp = types.Datetime{Time: time.Now().Truncate(time.Second)} col1Data = []string{"A", "b", "c"} col2Data = [][]uint32{ {1, 2}, {3, 87}, {33, 3, 847}, } - col3Data = [][][]time.Time{ + col3Data = [][][]types.Datetime{ { { timestamp, @@ -84,7 +85,7 @@ func TestStdArray(t *testing.T) { var ( col1 []string col2 [][]uint32 - col3 [][][]time.Time + col3 [][][]types.Datetime ) if err := rows.Scan(&col1, &col2, &col3); assert.NoError(t, err) { assert.Equal(t, col1Data, col1) diff --git a/tests/std/datetime64_test.go b/tests/std/datetime64_test.go index 9f8dfd7..0777fd2 100644 --- a/tests/std/datetime64_test.go +++ b/tests/std/datetime64_test.go @@ -19,6 +19,7 @@ package std import ( "database/sql" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -47,26 +48,26 @@ func TestStdDateTime64(t *testing.T) { } if batch, err := scope.Prepare("INSERT INTO test_datetime64 (* except _tp_time)"); assert.NoError(t, err) { var ( - datetime1 = time.Now().Truncate(time.Millisecond) - datetime2 = time.Now().Truncate(time.Nanosecond) - datetime3 = time.Now().Truncate(time.Second) + datetime1 = types.Datetime{time.Now().Truncate(time.Millisecond)} + datetime2 = types.Datetime{time.Now().Truncate(time.Nanosecond)} + datetime3 = types.Datetime{time.Now().Truncate(time.Second)} ) if _, err := batch.Exec( datetime1, datetime2, datetime3, &datetime1, - []time.Time{datetime1, datetime1}, - []*time.Time{&datetime3, nil, &datetime3}, + []types.Datetime{datetime1, datetime1}, + []*types.Datetime{&datetime3, nil, &datetime3}, ); assert.NoError(t, err) { if err := scope.Commit(); assert.NoError(t, err) { var ( - col1 time.Time - col2 time.Time - col3 time.Time - col4 *time.Time - col5 []time.Time - col6 []*time.Time + col1 types.Datetime + col2 types.Datetime + col3 types.Datetime + col4 *types.Datetime + col5 []types.Datetime + col6 []*types.Datetime ) if err := conn.QueryRow("SELECT (* except _tp_time) FROM test_datetime64 WHERE _tp_time > earliest_ts() LIMIT 1").Scan(&col1, &col2, &col3, &col4, &col5, &col6); assert.NoError(t, err) { assert.Equal(t, datetime1, col1) diff --git a/tests/std/datetime_test.go b/tests/std/datetime_test.go index 79a9872..6971055 100644 --- a/tests/std/datetime_test.go +++ b/tests/std/datetime_test.go @@ -19,6 +19,7 @@ package std import ( "database/sql" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -46,23 +47,23 @@ func TestStdDateTime(t *testing.T) { return } if batch, err := scope.Prepare("INSERT INTO test_datetime (* except _tp_time)"); assert.NoError(t, err) { - datetime := time.Now().Truncate(time.Second) + datetime := types.Datetime{time.Now().Truncate(time.Second)} if _, err := batch.Exec( datetime, datetime, datetime, &datetime, - []time.Time{datetime, datetime}, - []*time.Time{&datetime, nil, &datetime}, + []types.Datetime{datetime, datetime}, + []*types.Datetime{&datetime, nil, &datetime}, ); assert.NoError(t, err) { if err := scope.Commit(); assert.NoError(t, err) { var ( - col1 time.Time - col2 time.Time - col3 time.Time - col4 *time.Time - col5 []time.Time - col6 []*time.Time + col1 types.Datetime + col2 types.Datetime + col3 types.Datetime + col4 *types.Datetime + col5 []types.Datetime + col6 []*types.Datetime ) if err := conn.QueryRow("SELECT (* except _tp_time) FROM test_datetime WHERE _tp_time > earliest_ts() LIMIT 1").Scan(&col1, &col2, &col3, &col4, &col5, &col6); assert.NoError(t, err) { assert.Equal(t, datetime, col1) diff --git a/tests/std/external_table_test.go b/tests/std/external_table_test.go index 5720df2..a35756f 100644 --- a/tests/std/external_table_test.go +++ b/tests/std/external_table_test.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -59,7 +60,7 @@ func TestStdExternalTable(t *testing.T) { var ( col1 uint8 col2 string - col3 time.Time + col3 types.Datetime ) if err := rows.Scan(&col1, &col2, &col3); assert.NoError(t, err) { t.Logf("row: col1=%d, col2=%s, col3=%s\n", col1, col2, col3) diff --git a/tests/std/lowcardinality_test.go b/tests/std/lowcardinality_test.go index e300809..f228be9 100644 --- a/tests/std/lowcardinality_test.go +++ b/tests/std/lowcardinality_test.go @@ -20,6 +20,7 @@ package std import ( "context" "database/sql" + "github.com/timeplus-io/proton-go-driver/v2/types" "math/rand" "testing" "time" @@ -56,7 +57,7 @@ func TestStdLowCardinality(t *testing.T) { if batch, err := scope.Prepare("INSERT INTO test_lowcardinality (* except _tp_time)"); assert.NoError(t, err) { var ( rnd = rand.Int31() - timestamp = time.Now() + timestamp = types.Datetime{Time: time.Now()} ) for i := 0; i < 10; i++ { var ( @@ -94,7 +95,7 @@ func TestStdLowCardinality(t *testing.T) { var ( col1 string col2 string - col3 time.Time + col3 types.Datetime col4 int32 col5 []string col6 [][]string @@ -104,7 +105,7 @@ func TestStdLowCardinality(t *testing.T) { if err := conn.QueryRow("SELECT (* except _tp_time) FROM test_lowcardinality WHERE _tp_time > earliest_ts() AND Col4 = $1 LIMIT 1", rnd+int32(i)).Scan(&col1, &col2, &col3, &col4, &col5, &col6, &col7, &col8); assert.NoError(t, err) { assert.Equal(t, timestamp.String(), col1) assert.Equal(t, "RU", col2) - assert.Equal(t, timestamp.Add(time.Duration(i)*time.Minute).Truncate(time.Second), col3) + assert.Equal(t, types.Datetime{Time: timestamp.Add(time.Duration(i) * time.Minute).Truncate(time.Second)}, col3) assert.Equal(t, rnd+int32(i), col4) assert.Equal(t, []string{"A", "B", "C"}, col5) assert.Equal(t, [][]string{ diff --git a/tests/tuple_test.go b/tests/tuple_test.go index f664fbc..3aba7a1 100644 --- a/tests/tuple_test.go +++ b/tests/tuple_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "fmt" + "github.com/timeplus-io/proton-go-driver/v2/types" "testing" "time" @@ -65,8 +66,8 @@ func TestTuple(t *testing.T) { if batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_tuple (* except _tp_time)"); assert.NoError(t, err) { var ( col1Data = []interface{}{"A", int64(42)} - col2Data = []interface{}{"B", int8(1), time.Now().Truncate(time.Second)} - col3Data = []interface{}{time.Now().Truncate(time.Second), "CH", map[string]string{ + col2Data = []interface{}{"B", int8(1), types.Datetime{Time: time.Now().Truncate(time.Second)}} + col3Data = []interface{}{types.Datetime{Time: time.Now().Truncate(time.Second)}, "CH", map[string]string{ "key": "value", }} col4Data = [][][]interface{}{ @@ -147,7 +148,7 @@ func TestColumnarTuple(t *testing.T) { col1Data = [][]interface{}{} col2Data = [][]interface{}{} col3Data = [][]interface{}{} - timestamp = time.Now().Truncate(time.Second) + timestamp = types.Datetime{time.Now().Truncate(time.Second)} ) for i := 0; i < 1000; i++ { id = append(id, uint64(i)) diff --git a/types/date.go b/types/date.go index f848dd5..865b240 100644 --- a/types/date.go +++ b/types/date.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "gopkg.in/yaml.v3" "time" ) @@ -9,6 +10,28 @@ type Date struct { time.Time } +const dateFormat = "2006-01-02" + +// MarshalJSON implements json.Marshaler func (d Date) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", d.Format("2006-01-02"))), nil + return []byte(fmt.Sprintf("\"%s\"", d.Format(dateFormat))), nil +} + +// UnmarshalJSON implements json.Unmarshaler +func (d *Date) UnmarshalJSON(data []byte) error { + t, err := time.ParseInLocation(`"`+dateFormat+`"`, string(data), time.Local) + d.Time = t + return err +} + +// MarshalYAML implements yaml.Marshaler +func (d Date) MarshalYAML() (interface{}, error) { + return fmt.Sprintf("%s", d.Format(dateFormat)), nil +} + +// UnmarshalYAML implements yaml.Unmarshaler +func (d *Date) UnmarshalYAML(value *yaml.Node) error { + t, err := time.ParseInLocation(dateFormat, value.Value, time.Local) + d.Time = t + return err } diff --git a/types/datetime.go b/types/datetime.go index 58654bc..f461f97 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "gopkg.in/yaml.v3" "time" ) @@ -9,6 +10,28 @@ type Datetime struct { time.Time } +const timeFormat = "2006-01-02 15:04:05.999999999" + +// MarshalJSON implements json.Marshaler func (dt Datetime) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", dt.Format("2006-01-02 15:04:05"))), nil + return []byte(fmt.Sprintf("\"%s\"", dt.Format(timeFormat))), nil +} + +// UnmarshalJSON implements json.Unmarshaler +func (dt *Datetime) UnmarshalJSON(data []byte) error { + t, err := time.ParseInLocation(`"`+timeFormat+`"`, string(data), time.Local) + dt.Time = t + return err +} + +// MarshalYAML implements yaml.Marshaler +func (dt Datetime) MarshalYAML() (interface{}, error) { + return fmt.Sprintf("%s", dt.Format(timeFormat)), nil +} + +// UnmarshalYAML implements yaml.Unmarshaler +func (dt *Datetime) UnmarshalYAML(value *yaml.Node) error { + t, err := time.ParseInLocation(timeFormat, value.Value, time.Local) + dt.Time = t + return err }