diff --git a/contributors/list b/contributors/list index 641b656ac9..fa0dae613b 100644 --- a/contributors/list +++ b/contributors/list @@ -22,7 +22,6 @@ Chao Wang Chris Duncan Daguang <28806852+DGuang21@users.noreply.github.com> Dale McDiarmid -Dale Mcdiarmid Damir Sayfutdinov Dan Walters Daniel Bershatsky @@ -30,6 +29,7 @@ Danila Migalin Danny.Dunn DarĂ­o Dave Josephsen +Dean Karn Denis Gukov Denis Krivak Denys diff --git a/lib/column/array.go b/lib/column/array.go index 7e9084cac5..e2db182529 100644 --- a/lib/column/array.go +++ b/lib/column/array.go @@ -191,13 +191,23 @@ func appendNullableRowPlain[T any](col *Array, arr []*T) error { func (col *Array) append(elem reflect.Value, level int) error { if level < col.depth { - col.appendOffset(level, uint64(elem.Len())) - for i := 0; i < elem.Len(); i++ { - if err := col.append(elem.Index(i), level+1); err != nil { - return err + switch elem.Kind() { + // reflect.Value.Len() & reflect.Value.Index() is called in `append` method which is only valid for + // Slice, Array and String that make sense here. + case reflect.Slice, reflect.Array, reflect.String: + col.appendOffset(level, uint64(elem.Len())) + for i := 0; i < elem.Len(); i++ { + if err := col.append(elem.Index(i), level+1); err != nil { + return err + } } + return nil + } + return &ColumnConverterError{ + Op: "AppendRow", + To: "Array", + From: fmt.Sprintf("%T", elem), } - return nil } if elem.Kind() == reflect.Ptr && elem.IsNil() { return col.values.AppendRow(nil) diff --git a/lib/column/array_gen.go b/lib/column/array_gen.go index 8b9d11c6ce..eeba8e0fdc 100644 --- a/lib/column/array_gen.go +++ b/lib/column/array_gen.go @@ -22,6 +22,8 @@ package column import ( "database/sql" + "database/sql/driver" + "fmt" "github.com/ClickHouse/ch-go/proto" "github.com/google/uuid" "github.com/paulmach/orb" @@ -156,6 +158,18 @@ func (col *Array) appendRowPlain(v any) error { case []*orb.Ring: return appendNullableRowPlain(col, tv) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Array", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.appendRowPlain(val) + } return col.appendRowDefault(v) } } diff --git a/lib/column/bigint.go b/lib/column/bigint.go index 08c3db632f..e09c96ae1d 100644 --- a/lib/column/bigint.go +++ b/lib/column/bigint.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "encoding/binary" "fmt" "github.com/ClickHouse/ch-go/proto" @@ -97,6 +98,18 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: string(col.chType), @@ -120,6 +133,18 @@ func (col *BigInt) AppendRow(v any) error { case nil: col.append(big.NewInt(0)) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: string(col.chType), diff --git a/lib/column/bool.go b/lib/column/bool.go index f99da947aa..1c69bfff65 100644 --- a/lib/column/bool.go +++ b/lib/column/bool.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -110,6 +111,18 @@ func (col *Bool) Append(v any) (nulls []uint8, err error) { col.Append(v[i]) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Bool", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Bool", @@ -140,6 +153,18 @@ func (col *Bool) AppendRow(v any) error { } case nil: default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Bool", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "Bool", diff --git a/lib/column/codegen/array.tpl b/lib/column/codegen/array.tpl index b3df3fec37..f518c91c5c 100644 --- a/lib/column/codegen/array.tpl +++ b/lib/column/codegen/array.tpl @@ -22,6 +22,7 @@ package column import ( "database/sql" + "database/sql/driver" "github.com/ClickHouse/ch-go/proto" "github.com/google/uuid" "github.com/paulmach/orb" @@ -30,6 +31,7 @@ import ( "net" "net/netip" "time" + "fmt" ) // appendRowPlain is a reflection-free realisation of append for plain arrays. @@ -42,6 +44,18 @@ func (col *Array) appendRowPlain(v any) error { return appendNullableRowPlain(col, tv) {{- end }} default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Array", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.appendRowPlain(val) + } return col.appendRowDefault(v) } } diff --git a/lib/column/codegen/column.tpl b/lib/column/codegen/column.tpl index e4e8141bbc..2ae117db1c 100644 --- a/lib/column/codegen/column.tpl +++ b/lib/column/codegen/column.tpl @@ -31,6 +31,7 @@ import ( "github.com/paulmach/orb" "github.com/shopspring/decimal" "database/sql" + "database/sql/driver" "github.com/ClickHouse/ch-go/proto" ) @@ -315,6 +316,20 @@ func (col *{{ .ChType }}) Append(v any) (nulls []uint8,err error) { } {{- end }} default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "{{ .ChType }}", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "{{ .ChType }}", @@ -382,6 +397,20 @@ func (col *{{ .ChType }}) AppendRow(v any) error { col.col.Append(val) {{- end }} default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "{{ .ChType }}", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().({{ .GoType }})) } else { diff --git a/lib/column/column_gen.go b/lib/column/column_gen.go index 03830f6734..d13781af91 100644 --- a/lib/column/column_gen.go +++ b/lib/column/column_gen.go @@ -22,6 +22,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "github.com/google/uuid" @@ -325,6 +326,20 @@ func (col *Float32) Append(v any) (nulls []uint8, err error) { } } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Float32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "Float32", @@ -348,6 +363,20 @@ func (col *Float32) AppendRow(v any) error { case nil: col.col.Append(0) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Float32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(float32)) } else { @@ -453,6 +482,20 @@ func (col *Float64) Append(v any) (nulls []uint8, err error) { col.AppendRow(v[i]) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Float64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "Float64", @@ -490,6 +533,20 @@ func (col *Float64) AppendRow(v any) error { col.col.Append(0) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Float64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(float64)) } else { @@ -605,6 +662,20 @@ func (col *Int8) Append(v any) (nulls []uint8, err error) { col.col.Append(val) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Int8", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "Int8", @@ -640,6 +711,20 @@ func (col *Int8) AppendRow(v any) error { } col.col.Append(val) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Int8", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(int8)) } else { @@ -745,6 +830,20 @@ func (col *Int16) Append(v any) (nulls []uint8, err error) { col.AppendRow(v[i]) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Int16", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "Int16", @@ -782,6 +881,20 @@ func (col *Int16) AppendRow(v any) error { col.col.Append(0) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Int16", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(int16)) } else { @@ -887,6 +1000,20 @@ func (col *Int32) Append(v any) (nulls []uint8, err error) { col.AppendRow(v[i]) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Int32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "Int32", @@ -924,6 +1051,20 @@ func (col *Int32) AppendRow(v any) error { col.col.Append(0) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Int32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(int32)) } else { @@ -1031,6 +1172,20 @@ func (col *Int64) Append(v any) (nulls []uint8, err error) { col.AppendRow(v[i]) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Int64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "Int64", @@ -1072,6 +1227,20 @@ func (col *Int64) AppendRow(v any) error { case *time.Duration: col.col.Append(int64(*v)) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Int64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(int64)) } else { @@ -1169,6 +1338,20 @@ func (col *UInt8) Append(v any) (nulls []uint8, err error) { } } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "UInt8", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "UInt8", @@ -1198,6 +1381,20 @@ func (col *UInt8) AppendRow(v any) error { } col.col.Append(t) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "UInt8", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(uint8)) } else { @@ -1288,6 +1485,20 @@ func (col *UInt16) Append(v any) (nulls []uint8, err error) { } } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "UInt16", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "UInt16", @@ -1311,6 +1522,20 @@ func (col *UInt16) AppendRow(v any) error { case nil: col.col.Append(0) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "UInt16", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(uint16)) } else { @@ -1401,6 +1626,20 @@ func (col *UInt32) Append(v any) (nulls []uint8, err error) { } } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "UInt32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "UInt32", @@ -1424,6 +1663,20 @@ func (col *UInt32) AppendRow(v any) error { case nil: col.col.Append(0) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "UInt32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(uint32)) } else { @@ -1514,6 +1767,20 @@ func (col *UInt64) Append(v any) (nulls []uint8, err error) { } } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "UInt64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "UInt64", @@ -1537,6 +1804,20 @@ func (col *UInt64) AppendRow(v any) error { case nil: col.col.Append(0) default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "UInt64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if rv := reflect.ValueOf(v); rv.Kind() == col.ScanType().Kind() || rv.CanConvert(col.ScanType()) { col.col.Append(rv.Convert(col.ScanType()).Interface().(uint64)) } else { diff --git a/lib/column/date.go b/lib/column/date.go index 52ff708a6a..bc4f77db49 100644 --- a/lib/column/date.go +++ b/lib/column/date.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -156,6 +157,18 @@ func (col *Date) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Date", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Date", @@ -215,6 +228,18 @@ func (col *Date) AppendRow(v any) error { col.col.Append(datetime) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Date", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } s, ok := v.(fmt.Stringer) if ok { return col.AppendRow(s.String()) diff --git a/lib/column/date32.go b/lib/column/date32.go index 4e3fa1c208..e23429ded3 100644 --- a/lib/column/date32.go +++ b/lib/column/date32.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -146,6 +147,18 @@ func (col *Date32) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Date32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Date32", @@ -205,6 +218,18 @@ func (col *Date32) AppendRow(v any) error { col.col.Append(value) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Date32", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } s, ok := v.(fmt.Stringer) if ok { return col.AppendRow(s.String()) diff --git a/lib/column/datetime.go b/lib/column/datetime.go index 48b4b17b92..d5dfffad22 100644 --- a/lib/column/datetime.go +++ b/lib/column/datetime.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -188,6 +189,18 @@ func (col *DateTime) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "DateTime", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "DateTime", @@ -257,6 +270,18 @@ func (col *DateTime) AppendRow(v any) error { col.col.Append(dateTime) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "DateTime", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } s, ok := v.(fmt.Stringer) if ok { return col.AppendRow(s.String()) diff --git a/lib/column/datetime64.go b/lib/column/datetime64.go index bcba4cc944..f5a5a94877 100644 --- a/lib/column/datetime64.go +++ b/lib/column/datetime64.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "math" @@ -193,6 +194,18 @@ func (col *DateTime64) Append(v any) (nulls []uint8, err error) { col.AppendRow(v[i]) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Datetime64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Datetime64", @@ -251,6 +264,18 @@ func (col *DateTime64) AppendRow(v any) error { case nil: col.col.Append(time.Time{}) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Datetime64", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } s, ok := v.(fmt.Stringer) if ok { return col.AppendRow(s.String()) diff --git a/lib/column/decimal.go b/lib/column/decimal.go index a64f58265f..74b7d75f4d 100644 --- a/lib/column/decimal.go +++ b/lib/column/decimal.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "encoding/binary" "errors" "fmt" @@ -170,6 +171,18 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: string(col.chType), @@ -190,6 +203,18 @@ func (col *Decimal) AppendRow(v any) error { } case nil: default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: string(col.chType), diff --git a/lib/column/enum16.go b/lib/column/enum16.go index 5bbfe0fa3f..c394e7fff3 100644 --- a/lib/column/enum16.go +++ b/lib/column/enum16.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -153,6 +154,24 @@ func (col *Enum16) Append(v any) (nulls []uint8, err error) { nulls[i] = 1 } } + default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Enum16", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ + Op: "Append", + To: "Enum16", + From: fmt.Sprintf("%T", v), + } } return } @@ -214,6 +233,18 @@ func (col *Enum16) AppendRow(elem any) error { case nil: col.col.Append(0) default: + if valuer, ok := elem.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Enum16", + From: fmt.Sprintf("%T", elem), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } if s, ok := elem.(fmt.Stringer); ok { return col.AppendRow(s.String()) } else { diff --git a/lib/column/enum8.go b/lib/column/enum8.go index 9880c6fec0..4aee561ad7 100644 --- a/lib/column/enum8.go +++ b/lib/column/enum8.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -154,6 +155,18 @@ func (col *Enum8) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Enum8", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Enum8", @@ -220,6 +233,19 @@ func (col *Enum8) AppendRow(elem any) error { case nil: col.col.Append(0) default: + if valuer, ok := elem.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Enum8", + From: fmt.Sprintf("%T", elem), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + if s, ok := elem.(fmt.Stringer); ok { return col.AppendRow(s.String()) } else { diff --git a/lib/column/fixed_string.go b/lib/column/fixed_string.go index a836f74860..8ddb0d1ce6 100644 --- a/lib/column/fixed_string.go +++ b/lib/column/fixed_string.go @@ -126,6 +126,18 @@ func (col *FixedString) Append(v any) (nulls []uint8, err error) { col.col.Append(data) nulls = make([]uint8, len(data)/col.col.Size) default: + if s, ok := v.(driver.Valuer); ok { + val, err := s.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "FixedString", + From: fmt.Sprintf("%T", s), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "FixedString", @@ -159,22 +171,12 @@ func (col *FixedString) AppendRow(v any) (err error) { if err != nil { return &ColumnConverterError{ Op: "AppendRow", - To: "String", + To: "FixedString", From: fmt.Sprintf("%T", s), Hint: "could not get driver.Valuer value", } } - - if s, ok := val.(string); ok { - return col.AppendRow(s) - } - - return &ColumnConverterError{ - Op: "AppendRow", - To: "String", - From: fmt.Sprintf("%T", v), - Hint: "driver.Valuer value is not a string", - } + return col.AppendRow(val) } if s, ok := v.(fmt.Stringer); ok { @@ -183,7 +185,7 @@ func (col *FixedString) AppendRow(v any) (err error) { return &ColumnConverterError{ Op: "AppendRow", - To: "String", + To: "FixedString", From: fmt.Sprintf("%T", v), } } diff --git a/lib/column/geo_multi_polygon.go b/lib/column/geo_multi_polygon.go index 9b5ebe7b4f..2839a41c52 100644 --- a/lib/column/geo_multi_polygon.go +++ b/lib/column/geo_multi_polygon.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -91,6 +92,18 @@ func (col *MultiPolygon) Append(v any) (nulls []uint8, err error) { } return col.set.Append(values) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "MultiPolygon", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "MultiPolygon", @@ -106,6 +119,18 @@ func (col *MultiPolygon) AppendRow(v any) error { case *orb.MultiPolygon: return col.set.AppendRow([]orb.Polygon(*v)) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "MultiPolygon", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "MultiPolygon", diff --git a/lib/column/geo_point.go b/lib/column/geo_point.go index 9d3e1d347c..c93a715ace 100644 --- a/lib/column/geo_point.go +++ b/lib/column/geo_point.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -95,6 +96,18 @@ func (col *Point) Append(v any) (nulls []uint8, err error) { }) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Point", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Point", @@ -116,6 +129,18 @@ func (col *Point) AppendRow(v any) error { Y: v.Lat(), }) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Point", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "Point", diff --git a/lib/column/geo_polygon.go b/lib/column/geo_polygon.go index accc14f827..542260815c 100644 --- a/lib/column/geo_polygon.go +++ b/lib/column/geo_polygon.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -91,6 +92,18 @@ func (col *Polygon) Append(v any) (nulls []uint8, err error) { } return col.set.Append(values) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Polygon", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Polygon", @@ -106,6 +119,18 @@ func (col *Polygon) AppendRow(v any) error { case *orb.Polygon: return col.set.AppendRow([]orb.Ring(*v)) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Polygon", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "Polygon", diff --git a/lib/column/geo_ring.go b/lib/column/geo_ring.go index 4580e41c89..0f190a8e25 100644 --- a/lib/column/geo_ring.go +++ b/lib/column/geo_ring.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -91,6 +92,18 @@ func (col *Ring) Append(v any) (nulls []uint8, err error) { } return col.set.Append(values) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "Ring", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "Ring", @@ -106,6 +119,18 @@ func (col *Ring) AppendRow(v any) error { case *orb.Ring: return col.set.AppendRow([]orb.Point(*v)) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "Ring", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "Ring", diff --git a/lib/column/ipv4.go b/lib/column/ipv4.go index f7b78945bd..a15f6d3e92 100644 --- a/lib/column/ipv4.go +++ b/lib/column/ipv4.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "encoding/binary" "fmt" "github.com/ClickHouse/ch-go/proto" @@ -203,6 +204,18 @@ func (col *IPv4) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "IPv4", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "IPv4", @@ -267,6 +280,18 @@ func (col *IPv4) AppendRow(v any) (err error) { col.col.Append(0) } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "IPv4", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "IPv4", diff --git a/lib/column/ipv6.go b/lib/column/ipv6.go index 0002e07308..a67d17abc4 100644 --- a/lib/column/ipv6.go +++ b/lib/column/ipv6.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "net" @@ -231,6 +232,18 @@ func (col *IPv6) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "IPv6", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "IPv6", @@ -315,6 +328,18 @@ func (col *IPv6) AppendRow(v any) (err error) { case nil: col.col.Append([16]byte{}) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "IPv6", + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.Type()), + } + } + return col.AppendRow(val) + } return &ColumnConverterError{ Op: "AppendRow", To: "IPv6", diff --git a/lib/column/map.go b/lib/column/map.go index c25f386e7f..a727d47f41 100644 --- a/lib/column/map.go +++ b/lib/column/map.go @@ -18,6 +18,7 @@ package column import ( + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -111,6 +112,18 @@ func (col *Map) ScanRow(dest any, i int) error { func (col *Map) Append(v any) (nulls []uint8, err error) { value := reflect.Indirect(reflect.ValueOf(v)) if value.Kind() != reflect.Slice { + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.scanType), + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: string(col.chType), @@ -173,6 +186,19 @@ func (col *Map) AppendRow(v any) error { return nil } + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: fmt.Sprintf("could not get driver.Valuer value, try using %s", col.scanType), + } + } + return col.AppendRow(val) + } + return &ColumnConverterError{ Op: "AppendRow", To: string(col.chType), diff --git a/lib/column/string.go b/lib/column/string.go index 9d7ad734bd..5ce480b0e6 100644 --- a/lib/column/string.go +++ b/lib/column/string.go @@ -116,27 +116,17 @@ func (col *String) AppendRow(v any) error { case nil: col.col.Append("") default: - if s, ok := v.(driver.Valuer); ok { - val, err := s.Value() + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() if err != nil { return &ColumnConverterError{ Op: "AppendRow", To: "String", - From: fmt.Sprintf("%T", s), + From: fmt.Sprintf("%T", v), Hint: "could not get driver.Valuer value", } } - - if s, ok := val.(string); ok { - return col.AppendRow(s) - } - - return &ColumnConverterError{ - Op: "AppendRow", - To: "String", - From: fmt.Sprintf("%T", v), - Hint: "driver.Valuer value is not a string", - } + return col.AppendRow(val) } if s, ok := v.(fmt.Stringer); ok { @@ -187,6 +177,19 @@ func (col *String) Append(v any) (nulls []uint8, err error) { col.col.Append(string(v[i])) } default: + + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "String", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: "String", diff --git a/lib/column/tuple.go b/lib/column/tuple.go index 80149f8a8f..95e00db04b 100644 --- a/lib/column/tuple.go +++ b/lib/column/tuple.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "net" "reflect" @@ -491,6 +492,18 @@ func (col *Tuple) Append(v any) (nulls []uint8, err error) { } return nil, nil } + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } return nil, &ColumnConverterError{ Op: "Append", To: string(col.chType), @@ -553,6 +566,19 @@ func (col *Tuple) AppendRow(v any) error { return nil } + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: string(col.chType), + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } + return &ColumnConverterError{ Op: "AppendRow", To: string(col.chType), diff --git a/lib/column/uuid.go b/lib/column/uuid.go index 3ee88eff45..bf2a1c4897 100644 --- a/lib/column/uuid.go +++ b/lib/column/uuid.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -130,6 +131,19 @@ func (col *UUID) Append(v any) (nulls []uint8, err error) { } } default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return nil, &ColumnConverterError{ + Op: "Append", + To: "UUID", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.Append(val) + } + return nil, &ColumnConverterError{ Op: "Append", To: "UUID", @@ -170,6 +184,18 @@ func (col *UUID) AppendRow(v any) error { case nil: col.col.Append(uuid.UUID{}) default: + if valuer, ok := v.(driver.Valuer); ok { + val, err := valuer.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "UUID", + From: fmt.Sprintf("%T", v), + Hint: "could not get driver.Valuer value", + } + } + return col.AppendRow(val) + } if s, ok := v.(fmt.Stringer); ok { return col.AppendRow(s.String()) } diff --git a/tests/array_test.go b/tests/array_test.go index 2d4deb9867..3f8617a18f 100644 --- a/tests/array_test.go +++ b/tests/array_test.go @@ -19,6 +19,8 @@ package tests import ( "context" + "database/sql/driver" + "fmt" "github.com/stretchr/testify/require" "testing" "time" @@ -315,3 +317,59 @@ func TestColumnarArray(t *testing.T) { require.NoError(t, rows.Close()) assert.NoError(t, rows.Err()) } + +type testArraySerializer struct { + val []string +} + +func (c testArraySerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testArraySerializer) Scan(src any) error { + if t, ok := src.([]string); ok { + *c = testArraySerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testArraySerializer", src) +} + +func TestSimpleArrayValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_array_valuer ( + Col1 Array(String) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_array_valuer") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_array_valuer") + require.NoError(t, err) + var ( + col1Data = []string{"A", "b", "c"} + ) + for i := 0; i < 10; i++ { + require.NoError(t, batch.Append(testArraySerializer{val: col1Data})) + require.Equal(t, 1, batch.Rows()) + batch.Flush() + } + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_array_valuer") + require.NoError(t, err) + for rows.Next() { + var ( + col1 []string + ) + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, col1Data, col1) + + } + require.NoError(t, rows.Close()) + require.NoError(t, rows.Err()) +} diff --git a/tests/bigint_test.go b/tests/bigint_test.go index aa7fa4535b..e1c77d4200 100644 --- a/tests/bigint_test.go +++ b/tests/bigint_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "math/big" @@ -306,3 +307,58 @@ func TestBigIntFlush(t *testing.T) { i += 1 } } + +type testBigIntSerializer struct { + val *big.Int +} + +func (c testBigIntSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testBigIntSerializer) Scan(src any) error { + if t, ok := src.(*big.Int); ok { + *c = testBigIntSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testBigIntSerializer", src) +} + +func TestBigIntValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS big_int_flush") + }() + const ddl = ` + CREATE TABLE big_int_flush ( + Col1 UInt128 + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO big_int_flush") + require.NoError(t, err) + vals := [1000]*big.Int{} + for i := 0; i < 1000; i++ { + bigUint128Val := big.NewInt(0) + bigUint128Val.SetString(RandIntString(20), 10) + vals[i] = bigUint128Val + batch.Append(testBigIntSerializer{val: vals[i]}) + require.Equal(t, 1, batch.Rows()) + batch.Flush() + } + require.Equal(t, 0, batch.Rows()) + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM big_int_flush") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 big.Int + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, *vals[i], col1) + i += 1 + } +} diff --git a/tests/bool_test.go b/tests/bool_test.go index ed9242fbab..3e895effce 100644 --- a/tests/bool_test.go +++ b/tests/bool_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "database/sql" + "database/sql/driver" "fmt" "github.com/ClickHouse/clickhouse-go/v2" "github.com/stretchr/testify/assert" @@ -199,3 +200,59 @@ func TestBoolFlush(t *testing.T) { i += 1 } } + +type testBoolSerializer struct { + val bool +} + +func (c testBoolSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testBoolSerializer) Scan(src any) error { + if t, ok := src.(bool); ok { + *c = testBoolSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testBoolSerializer", src) +} + +func TestBoolValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS bool_flush") + }() + const ddl = ` + CREATE TABLE bool_flush ( + Col1 Bool + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO bool_flush") + require.NoError(t, err) + vals := [1000]bool{} + var src = rand.NewSource(time.Now().UnixNano()) + var r = rand.New(src) + + for i := 0; i < 1000; i++ { + vals[i] = r.Intn(2) != 0 + require.NoError(t, batch.Append(testBoolSerializer{val: vals[i]})) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Flush()) + } + require.Equal(t, 0, batch.Rows()) + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM bool_flush") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 bool + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, vals[i], col1) + i += 1 + } +} diff --git a/tests/date32_test.go b/tests/date32_test.go index 408ab5ce05..1792a00c98 100644 --- a/tests/date32_test.go +++ b/tests/date32_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -400,3 +401,62 @@ func TestDate32WithUserLocation(t *testing.T) { assert.Equal(t, "2022-07-01T00:00:00", col1.Format(dateTimeNoZoneFormat)) assert.Equal(t, userLocation.String(), col1.Location().String()) } + +type testDate32Serializer struct { + val time.Time +} + +func (c testDate32Serializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testDate32Serializer) Scan(src any) error { + if t, ok := src.(time.Time); ok { + *c = testDate32Serializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testDate32Serializer", src) +} + +func TestDate32Valuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + require.NoError(t, err) + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS date_32_valuer") + }() + const ddl = ` + CREATE TABLE date_32_valuer ( + Col1 Date32 + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO date_32_valuer") + require.NoError(t, err) + vals := [1000]time.Time{} + var now = time.Now() + + for i := 0; i < 1000; i++ { + vals[i] = now.Add(time.Duration(i) * time.Hour) + batch.Append(testDate32Serializer{val: vals[i]}) + require.Equal(t, 1, batch.Rows()) + batch.Flush() + } + require.Equal(t, 0, batch.Rows()) + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM date_32_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 time.Time + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, vals[i].Format("2016-02-01"), col1.Format("2016-02-01")) + i += 1 + } +} diff --git a/tests/date_test.go b/tests/date_test.go index 2eebb280c8..04a9acdc8d 100644 --- a/tests/date_test.go +++ b/tests/date_test.go @@ -19,6 +19,8 @@ package tests import ( "context" + "database/sql/driver" + "fmt" "github.com/stretchr/testify/require" "testing" "time" @@ -361,3 +363,55 @@ func TestDateWithUserLocation(t *testing.T) { assert.Equal(t, "2022-07-01T00:00:00", col1.Format(dateTimeNoZoneFormat)) assert.Equal(t, userLocation.String(), col1.Location().String()) } + +type testDateSerializer struct { + val time.Time +} + +func (c testDateSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testDateSerializer) Scan(src any) error { + if t, ok := src.(time.Time); ok { + *c = testDateSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testDateSerializer", src) +} + +func TestDateValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS date_valuer") + }() + const ddl = ` + CREATE TABLE date_valuer ( + Col1 Date + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO date_valuer") + require.NoError(t, err) + vals := [1000]time.Time{} + var now = time.Now() + + for i := 0; i < 1000; i++ { + vals[i] = now.Add(time.Duration(i) * time.Hour) + batch.Append(testDateSerializer{val: vals[i]}) + batch.Flush() + } + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM date_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 time.Time + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, vals[i].Format("2016-02-01"), col1.Format("2016-02-01")) + i += 1 + } +} diff --git a/tests/datetime64_test.go b/tests/datetime64_test.go index 2dc11e0c23..d080c51cc2 100644 --- a/tests/datetime64_test.go +++ b/tests/datetime64_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -459,3 +460,55 @@ func TestCustomDateTime64(t *testing.T) { require.NoError(t, row.Scan(&col1)) require.Equal(t, now, time.Time(col1)) } + +type testDateTime64Serializer struct { + val time.Time +} + +func (c testDateTime64Serializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testDateTime64Serializer) Scan(src any) error { + if t, ok := src.(time.Time); ok { + *c = testDateTime64Serializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testDateTime64Serializer", src) +} + +func TestDateTime64Valuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS datetime_64_valuer") + }() + const ddl = ` + CREATE TABLE datetime_64_valuer ( + Col1 DateTime64(3) + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO datetime_64_valuer") + require.NoError(t, err) + vals := [1000]time.Time{} + var now = time.Now() + for i := 0; i < 1000; i++ { + vals[i] = now.Add(time.Duration(i) * time.Hour).Truncate(time.Millisecond) + batch.Append(testDateTime64Serializer{val: vals[i]}) + batch.Flush() + } + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM datetime_64_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 time.Time + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, vals[i].In(time.UTC), col1) + i += 1 + } +} diff --git a/tests/datetime_test.go b/tests/datetime_test.go index bb5593af52..92f896bbfd 100644 --- a/tests/datetime_test.go +++ b/tests/datetime_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -456,3 +457,57 @@ func TestCustomDateTime(t *testing.T) { require.NoError(t, row.Scan(&col1)) require.Equal(t, now, time.Time(col1)) } + +type testDateTimeSerializer struct { + val time.Time +} + +func (c testDateTimeSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testDateTimeSerializer) Scan(src any) error { + if t, ok := src.(time.Time); ok { + *c = testDateTimeSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testDateTimeSerializer", src) +} + +func TestDateTimeValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + defer func() { + conn.Exec(ctx, "DROP TABLE datetime_valuer") + }() + const ddl = ` + CREATE TABLE datetime_valuer ( + Col1 DateTime + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO datetime_valuer") + require.NoError(t, err) + vals := [1000]time.Time{} + var now = time.Now() + for i := 0; i < 1000; i++ { + vals[i] = now.Add(time.Duration(i) * time.Hour).Truncate(time.Second) + batch.Append(testDateTimeSerializer{val: vals[i]}) + require.Equal(t, 1, batch.Rows()) + batch.Flush() + } + require.Equal(t, 0, batch.Rows()) + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM datetime_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 time.Time + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, vals[i].In(time.UTC), col1) + i += 1 + } +} diff --git a/tests/decimal_test.go b/tests/decimal_test.go index 7c911b7f69..2c9d31090a 100644 --- a/tests/decimal_test.go +++ b/tests/decimal_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "testing" @@ -284,3 +285,70 @@ func TestRoundDecimals(t *testing.T) { } } + +type testDecimalSerializer struct { + val decimal.Decimal +} + +func (c testDecimalSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testDecimalSerializer) Scan(src any) error { + if t, ok := src.(decimal.Decimal); ok { + *c = testDecimalSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testDecimalSerializer", src) +} + +func TestDecimalValuer(t *testing.T) { + conn, err := GetNativeConnection(clickhouse.Settings{ + "allow_experimental_bigint_types": 1, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 1, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_decimal ( + Col1 Decimal32(3) + , Col2 Decimal(18,6) + , Col3 Decimal(15,7) + , Col4 Decimal128(8) + , Col5 Decimal256(9) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS test_decimal") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_decimal") + require.NoError(t, err) + require.NoError(t, batch.Append( + testDecimalSerializer{val: decimal.New(25, 4)}, + testDecimalSerializer{val: decimal.New(30, 5)}, + testDecimalSerializer{val: decimal.New(35, 6)}, + testDecimalSerializer{val: decimal.New(135, 7)}, + testDecimalSerializer{val: decimal.New(256, 8)}, + )) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Send()) + var ( + col1 decimal.Decimal + col2 decimal.Decimal + col3 decimal.Decimal + col4 decimal.Decimal + col5 decimal.Decimal + ) + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_decimal").Scan(&col1, &col2, &col3, &col4, &col5)) + assert.True(t, decimal.New(25, 4).Equal(col1)) + assert.True(t, decimal.New(30, 5).Equal(col2)) + assert.True(t, decimal.New(35, 6).Equal(col3)) + assert.True(t, decimal.New(135, 7).Equal(col4)) + assert.True(t, decimal.New(256, 8).Equal(col5)) +} diff --git a/tests/enum_test.go b/tests/enum_test.go index acbaf87a58..9569addd58 100644 --- a/tests/enum_test.go +++ b/tests/enum_test.go @@ -19,6 +19,8 @@ package tests import ( "context" + "database/sql/driver" + "fmt" "github.com/stretchr/testify/require" "testing" @@ -292,3 +294,84 @@ func TestColumnarEnum(t *testing.T) { assert.Equal(t, col6Data, col6) assert.Equal(t, col7Data, col7) } + +type testEnumSerializer struct { + val []string +} + +func (c testEnumSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testEnumSerializer) Scan(src any) error { + if t, ok := src.([]string); ok { + *c = testEnumSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testEnum8Serializer", src) +} + +func TestEnumValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_enum_valuer ( + Col1 Enum ('hello' = 1, 'world' = 2) + , Col2 Enum8 ('click' = 5, 'house' = 25) + , Col3 Enum16('house' = 10, 'value' = 50) + , Col4 Array(Enum8 ('click' = 1, 'house' = 2)) + , Col5 Array(Enum16 ('click' = 1, 'house' = 2)) + , Col6 Array(Nullable(Enum8 ('click' = 1, 'house' = 2))) + , Col7 Array(Nullable(Enum16 ('click' = 1, 'house' = 2))) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS test_enum_valuer") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_enum_valuer") + require.NoError(t, err) + var ( + col1Data = "hello" + col2Data = "click" + col3Data = "house" + col4Data = []string{"click", "house"} + col5Data = []string{"house", "click"} + col6Data = []*string{&col2Data, nil, &col3Data} + col7Data = []*string{&col3Data, nil, &col2Data} + ) + require.NoError(t, batch.Append( + col1Data, + col2Data, + col3Data, + testEnumSerializer{val: col4Data}, + testEnumSerializer{val: col5Data}, + col6Data, + col7Data, + )) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Send()) + var ( + col1 string + col2 string + col3 string + col4 []string + col5 []string + col6 []*string + col7 []*string + ) + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_enum_valuer").Scan( + &col1, &col2, &col3, &col4, + &col5, &col6, &col7, + )) + assert.Equal(t, col1Data, col1) + assert.Equal(t, col2Data, col2) + assert.Equal(t, col3Data, col3) + assert.Equal(t, col4Data, col4) + assert.Equal(t, col5Data, col5) + assert.Equal(t, col6Data, col6) + assert.Equal(t, col7Data, col7) +} diff --git a/tests/fixed_string_test.go b/tests/fixed_string_test.go index cbcad3bb67..782faf1520 100644 --- a/tests/fixed_string_test.go +++ b/tests/fixed_string_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "crypto/rand" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -409,3 +410,61 @@ func TestFixedStringFromDriverValuerType(t *testing.T) { assert.Equal(t, "Value", dest.Col1) assert.Equal(t, testStringSerializer{"Value"}, dest.Col2) } + +type testFixedStringPtrSerializer struct { + val string +} + +func (c testFixedStringPtrSerializer) Value() (driver.Value, error) { + return &c.val, nil +} + +func (c *testFixedStringPtrSerializer) Scan(src any) error { + if t, ok := src.(string); ok { + *c = testFixedStringPtrSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testFixedStringPtrSerializer", src) +} + +func TestFixedStringFromDriverValuerTypeNonStdReturn(t *testing.T) { + conn, err := GetConnection("native", nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + + require.NoError(t, err) + require.NoError(t, conn.Ping(ctx)) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_fixed_string ( + Col1 FixedString(5) + , Col2 FixedString(5) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_fixed_string") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_fixed_string") + require.NoError(t, err) + + s := "Value" + type data struct { + Col1 string `ch:"Col1"` + Col2 testFixedStringPtrSerializer `ch:"Col2"` + } + require.NoError(t, batch.AppendStruct(&data{ + Col1: "Value", + Col2: testFixedStringPtrSerializer{s}, + })) + require.NoError(t, batch.Send()) + + var dest data + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_fixed_string").ScanStruct(&dest)) + assert.Equal(t, "Value", dest.Col1) + assert.Equal(t, testFixedStringPtrSerializer{"Value"}, dest.Col2) +} diff --git a/tests/float64_test.go b/tests/float64_test.go new file mode 100644 index 0000000000..76993f816d --- /dev/null +++ b/tests/float64_test.go @@ -0,0 +1,119 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tests + +import ( + "context" + "database/sql/driver" + "fmt" + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/require" + "testing" +) + +func TestFloat64(t *testing.T) { + ctx := context.Background() + + conn, err := GetNativeConnection(clickhouse.Settings{ + "max_execution_time": 60, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + require.NoError(t, err) + + const ddl = ` + CREATE TABLE IF NOT EXISTS test_float64 ( + Col1 Float64 + , Col2 Float64 + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + defer func() { + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS test_float64")) + }() + + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_float64 (Col1, Col2)") + require.NoError(t, err) + require.NoError(t, batch.Append(1.1, 2.1)) + require.NoError(t, batch.Send()) + + row := conn.QueryRow(ctx, "SELECT Col1, Col2 from test_float64") + require.NoError(t, err) + + var ( + col1 float64 + col2 float64 + ) + require.NoError(t, row.Scan(&col1, &col2)) + require.Equal(t, float64(1.1), col1) + require.Equal(t, float64(2.1), col2) +} + +type testFloat64Serializer struct { + val float64 +} + +func (c testFloat64Serializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testFloat64Serializer) Scan(src any) error { + if t, ok := src.(float64); ok { + *c = testFloat64Serializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testFloat64Serializer", src) +} + +func TestFloat64Valuer(t *testing.T) { + ctx := context.Background() + + conn, err := GetNativeConnection(clickhouse.Settings{ + "max_execution_time": 60, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + require.NoError(t, err) + + const ddl = ` + CREATE TABLE IF NOT EXISTS test_float64_valuer ( + Col1 Float64 + , Col2 Float64 + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + defer func() { + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS test_float64_valuer")) + }() + + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_float64_valuer (Col1, Col2)") + require.NoError(t, err) + require.NoError(t, batch.Append(testFloat64Serializer{val: 1.1}, testFloat64Serializer{val: 2.1})) + require.NoError(t, batch.Send()) + + row := conn.QueryRow(ctx, "SELECT Col1, Col2 from test_float64_valuer") + require.NoError(t, err) + + var ( + col1 float64 + col2 float64 + ) + require.NoError(t, row.Scan(&col1, &col2)) + require.Equal(t, float64(1.1), col1) + require.Equal(t, float64(2.1), col2) +} diff --git a/tests/geo_multipolygon_test.go b/tests/geo_multipolygon_test.go index 19cfe28e43..65f16d24f4 100644 --- a/tests/geo_multipolygon_test.go +++ b/tests/geo_multipolygon_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "math/rand" @@ -197,3 +198,82 @@ func TestGeoMultiPolygonFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testMutiPolygonSerializer struct { + val orb.MultiPolygon +} + +func (c testMutiPolygonSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testMutiPolygonSerializer) Scan(src any) error { + if t, ok := src.(orb.MultiPolygon); ok { + *c = testMutiPolygonSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testMutiPolygonSerializer", src) +} + +func TestGeoMultiPolygonValuer(t *testing.T) { + conn, err := GetNativeConnection(clickhouse.Settings{ + "allow_experimental_geo_types": 1, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 12, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_geo_multipolygon_flush ( + Col1 MultiPolygon + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_geo_multipolygon_flush") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_geo_multipolygon_flush") + require.NoError(t, err) + vals := [1000]orb.MultiPolygon{} + for i := 0; i < 1000; i++ { + vals[i] = orb.MultiPolygon{ + orb.Polygon{ + orb.Ring{ + orb.Point{rand.Float64(), rand.Float64()}, + orb.Point{rand.Float64(), rand.Float64()}, + }, + orb.Ring{ + orb.Point{rand.Float64(), rand.Float64()}, + orb.Point{rand.Float64(), rand.Float64()}, + }, + }, + orb.Polygon{ + orb.Ring{ + orb.Point{rand.Float64(), rand.Float64()}, + orb.Point{rand.Float64(), rand.Float64()}, + }, + orb.Ring{ + orb.Point{rand.Float64(), rand.Float64()}, + orb.Point{rand.Float64(), rand.Float64()}, + }, + }, + } + require.NoError(t, batch.Append(testMutiPolygonSerializer{val: vals[i]})) + require.NoError(t, batch.Flush()) + } + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_geo_multipolygon_flush") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 orb.MultiPolygon + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/geo_point_test.go b/tests/geo_point_test.go index e659b1a360..f3d3333343 100644 --- a/tests/geo_point_test.go +++ b/tests/geo_point_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "math/rand" @@ -114,3 +115,61 @@ func TestGeoPointFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testPointSerializer struct { + val orb.Point +} + +func (c testPointSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testPointSerializer) Scan(src any) error { + if t, ok := src.(orb.Point); ok { + *c = testPointSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testPointSerializer", src) +} + +func TestGeoPointValuer(t *testing.T) { + conn, err := GetNativeConnection(clickhouse.Settings{ + "allow_experimental_geo_types": 1, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 12, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_geo_point_flush ( + Col1 Point + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS test_geo_point_flush") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_geo_point_flush") + require.NoError(t, err) + vals := [1000]orb.Point{} + for i := 0; i < 1000; i++ { + vals[i] = orb.Point{rand.Float64(), rand.Float64()} + require.NoError(t, batch.Append(testPointSerializer{val: vals[i]})) + require.NoError(t, batch.Flush()) + } + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_geo_point_flush") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 orb.Point + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/geo_polygon_test.go b/tests/geo_polygon_test.go index 6d8b064a7f..6ce24cdbd0 100644 --- a/tests/geo_polygon_test.go +++ b/tests/geo_polygon_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "math/rand" @@ -151,3 +152,72 @@ func TestGeoPolygonFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testPolygonSerializer struct { + val orb.Polygon +} + +func (c testPolygonSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testPolygonSerializer) Scan(src any) error { + if t, ok := src.(orb.Polygon); ok { + *c = testPolygonSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testPolygonSerializer", src) +} + +func TestGeoPolygonValuer(t *testing.T) { + conn, err := GetNativeConnection(clickhouse.Settings{ + "allow_experimental_geo_types": 1, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 12, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_geo_polygon_flush ( + Col1 Polygon + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_geo_polygon_flush") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_geo_polygon_flush") + require.NoError(t, err) + vals := [1000]orb.Polygon{} + for i := 0; i < 1000; i++ { + vals[i] = orb.Polygon{ + orb.Ring{ + orb.Point{rand.Float64(), rand.Float64()}, + orb.Point{rand.Float64(), rand.Float64()}, + }, + orb.Ring{ + orb.Point{rand.Float64(), rand.Float64()}, + orb.Point{rand.Float64(), rand.Float64()}, + }, + } + require.NoError(t, batch.Append(testPolygonSerializer{val: vals[i]})) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Flush()) + } + require.Equal(t, 0, batch.Rows()) + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_geo_polygon_flush") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 orb.Polygon + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/geo_ring_test.go b/tests/geo_ring_test.go index 417df3ebee..e550c136f2 100644 --- a/tests/geo_ring_test.go +++ b/tests/geo_ring_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -126,3 +127,66 @@ func TestGeoRingFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testGeoRingSerializer struct { + val orb.Ring +} + +func (c testGeoRingSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testGeoRingSerializer) Scan(src any) error { + if t, ok := src.(orb.Ring); ok { + *c = testGeoRingSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testGeoRingSerializer", src) +} + +func TestGeoRingValuer(t *testing.T) { + conn, err := GetNativeConnection(clickhouse.Settings{ + "allow_experimental_geo_types": 1, + }, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 12, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_geo_ring_valuer ( + Col1 Ring + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_geo_ring_valuer") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_geo_ring_valuer") + require.NoError(t, err) + vals := [1000]orb.Ring{} + for i := 0; i < 1000; i++ { + vals[i] = orb.Ring{ + orb.Point{1, 2}, + orb.Point{1, 2}, + } + require.NoError(t, batch.Append(testGeoRingSerializer{val: vals[i]})) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Flush()) + } + require.Equal(t, 0, batch.Rows()) + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_geo_ring_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 orb.Ring + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/ipv4_test.go b/tests/ipv4_test.go index e732577a78..ea67633b81 100644 --- a/tests/ipv4_test.go +++ b/tests/ipv4_test.go @@ -19,7 +19,9 @@ package tests import ( "context" + "database/sql/driver" "encoding/binary" + "fmt" "github.com/ClickHouse/clickhouse-go/v2/lib/column" "github.com/stretchr/testify/require" "net" @@ -504,3 +506,53 @@ func TestIPv4Flush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testIPv4Serializer struct { + val net.IP +} + +func (c testIPv4Serializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testIPv4Serializer) Scan(src any) error { + if t, ok := src.(net.IP); ok { + *c = testIPv4Serializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testIPv4Serializer", src) +} + +func TestIPv4Valuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, nil) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_ipv4_ring_valuer ( + Col1 IPv4 + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_ipv4_ring_valuer") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_ipv4_ring_valuer") + require.NoError(t, err) + vals := [1000]net.IP{} + for i := 0; i < 1000; i++ { + vals[i] = RandIPv4() + require.NoError(t, batch.Append(testIPv4Serializer{val: vals[i]})) + require.NoError(t, batch.Flush()) + } + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_ipv4_ring_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 net.IP + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1.To4()) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/ipv6_test.go b/tests/ipv6_test.go index 88b63fb198..f8b0f8df25 100644 --- a/tests/ipv6_test.go +++ b/tests/ipv6_test.go @@ -19,6 +19,8 @@ package tests import ( "context" + "database/sql/driver" + "fmt" "github.com/ClickHouse/ch-go/proto" "github.com/ClickHouse/clickhouse-go/v2/lib/column" "github.com/stretchr/testify/require" @@ -466,3 +468,55 @@ func TestIPv6Flush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testIPv6Serializer struct { + val net.IP +} + +func (c testIPv6Serializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testIPv6Serializer) Scan(src any) error { + if t, ok := src.(net.IP); ok { + *c = testIPv6Serializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testIPv6Serializer", src) +} + +func TestIPv6Valuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_ipv6_ring_valuer ( + Col1 IPv6 + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_ipv6_ring_valuer") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_ipv6_ring_valuer") + require.NoError(t, err) + vals := [1000]net.IP{} + for i := 0; i < 1000; i++ { + vals[i] = RandIPv6() + require.NoError(t, batch.Append(testIPv6Serializer{val: vals[i]})) + require.NoError(t, batch.Flush()) + } + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_ipv6_ring_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 net.IP + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/map_test.go b/tests/map_test.go index 6c834d491a..623ab32d57 100644 --- a/tests/map_test.go +++ b/tests/map_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -293,3 +294,63 @@ func TestOrderedMap(t *testing.T) { } require.Equal(t, 1000, i) } + +type testMapSerializer struct { + val map[string]uint64 +} + +func (c testMapSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testMapSerializer) Scan(src any) error { + if t, ok := src.(map[string]uint64); ok { + *c = testMapSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testTupleSerializer", src) +} + +func TestMapValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_map_flush ( + Col1 Map(String, UInt64) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS test_map_flush") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_map_flush") + require.NoError(t, err) + vals := [1000]map[string]uint64{} + for i := 0; i < 1000; i++ { + vals[i] = map[string]uint64{ + "i": uint64(i), + } + require.NoError(t, batch.Append(testMapSerializer{val: vals[i]})) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Flush()) + } + require.Equal(t, 0, batch.Rows()) + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_map_flush") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 map[string]uint64 + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/string_test.go b/tests/string_test.go index 7dc655f75a..366456f638 100644 --- a/tests/string_test.go +++ b/tests/string_test.go @@ -373,3 +373,62 @@ func TestStringFromDriverValuerType(t *testing.T) { assert.Equal(t, "Value", dest.Col1) assert.Equal(t, testStringSerializer{"Value"}, dest.Col2) } + +type testStringPtrSerializer struct { + val string +} + +func (c testStringPtrSerializer) Value() (driver.Value, error) { + return &c.val, nil +} + +func (c *testStringPtrSerializer) Scan(src any) error { + if t, ok := src.(string); ok { + *c = testStringPtrSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testStringPtrSerializer", src) +} + +func TestStringFromDriverValuerTypeNonStdReturn(t *testing.T) { + conn, err := GetConnection("native", nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + + require.NoError(t, err) + require.NoError(t, conn.Ping(ctx)) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_string ( + Col1 String + , Col2 String + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_string") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_string") + require.NoError(t, err) + + type data struct { + Col1 string `ch:"Col1"` + Col2 testStringPtrSerializer `ch:"Col2"` + } + s := "Value" + require.NoError(t, batch.AppendStruct(&data{ + Col1: s, + Col2: testStringPtrSerializer{s}, + })) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Send()) + + var dest data + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_string").ScanStruct(&dest)) + assert.Equal(t, "Value", dest.Col1) + assert.Equal(t, testStringPtrSerializer{s}, dest.Col2) +} diff --git a/tests/tuple_test.go b/tests/tuple_test.go index 3b80ddbac3..c616385d9b 100644 --- a/tests/tuple_test.go +++ b/tests/tuple_test.go @@ -19,6 +19,7 @@ package tests import ( "context" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -455,3 +456,62 @@ func TestTupleFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testTupleSerializer struct { + val map[string]any +} + +func (c testTupleSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testTupleSerializer) Scan(src any) error { + if t, ok := src.(map[string]any); ok { + *c = testTupleSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testTupleSerializer", src) +} + +func TestTupleValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, nil) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_tuple_valuer ( + Col1 Tuple(name String, id Int64) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS test_tuple_valuer") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_tuple_valuer") + require.NoError(t, err) + vals := [1000]map[string]any{} + for i := 0; i < 1000; i++ { + vals[i] = map[string]any{ + "id": int64(i), + "name": RandAsciiString(10), + } + require.NoError(t, batch.Append(testTupleSerializer{val: vals[i]})) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Flush()) + } + require.Equal(t, 0, batch.Rows()) + require.NoError(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_tuple_valuer") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 map[string]any + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +} diff --git a/tests/uuid_test.go b/tests/uuid_test.go index b7ac809240..9bf275ec26 100644 --- a/tests/uuid_test.go +++ b/tests/uuid_test.go @@ -19,6 +19,8 @@ package tests import ( "context" + "database/sql/driver" + "fmt" "testing" "github.com/ClickHouse/clickhouse-go/v2/lib/column" @@ -339,3 +341,57 @@ func TestUUIDFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testUUIDValuer struct { + val uuid.UUID +} + +func (c testUUIDValuer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testUUIDValuer) Scan(src any) error { + if t, ok := src.(string); ok { + *c = testUUIDValuer{val: uuid.MustParse(t)} + return nil + } + return fmt.Errorf("cannot scan %T into testUUIDValuer", src) +} + +func TestUUIDValuer(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS uuid_valuer1") + }() + const ddl = ` + CREATE TABLE uuid_valuer1 ( + Col1 UUID + ) Engine MergeTree() ORDER BY tuple() + ` + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO uuid_valuer1") + require.NoError(t, err) + vals := [1000]uuid.UUID{} + for i := 0; i < 1000; i++ { + vals[i] = uuid.New() + batch.Append(testUUIDValuer{val: vals[i]}) + require.Equal(t, 1, batch.Rows()) + batch.Flush() + } + require.Equal(t, 0, batch.Rows()) + batch.Send() + rows, err := conn.Query(ctx, "SELECT * FROM uuid_valuer1") + require.NoError(t, err) + i := 0 + for rows.Next() { + var col1 uuid.UUID + require.NoError(t, rows.Scan(&col1)) + require.Equal(t, vals[i], col1) + i += 1 + } + require.Equal(t, 1000, i) +}