From 5be346ff68e6e3b28868990b0b05c96123617291 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 10 Jun 2024 12:00:47 +0300 Subject: [PATCH] pgdialect: add Range and MultiRange --- dialect/pgdialect/append.go | 313 +------------ dialect/pgdialect/array.go | 582 +++++++++++++++++++++++- dialect/pgdialect/array_parser.go | 148 +++--- dialect/pgdialect/array_parser_test.go | 40 +- dialect/pgdialect/array_scan.go | 301 ------------ dialect/pgdialect/dialect.go | 10 +- dialect/pgdialect/go.mod | 9 +- dialect/pgdialect/go.sum | 15 + dialect/pgdialect/hstore_parser.go | 154 +++---- dialect/pgdialect/hstore_parser_test.go | 47 +- dialect/pgdialect/hstore_scan.go | 25 +- dialect/pgdialect/range.go | 240 ++++++++++ dialect/pgdialect/sqltype.go | 9 +- dialect/pgdialect/stream_parser.go | 60 --- example/migrate/main.go | 2 +- extra/bundebug/debug.go | 6 +- internal/dbtest/db_test.go | 14 +- internal/dbtest/pg_test.go | 24 + internal/parser/parser.go | 50 +- query_select.go | 4 +- schema/dialect.go | 2 +- schema/table.go | 1 + 22 files changed, 1078 insertions(+), 978 deletions(-) create mode 100644 dialect/pgdialect/range.go delete mode 100644 dialect/pgdialect/stream_parser.go diff --git a/dialect/pgdialect/append.go b/dialect/pgdialect/append.go index 7e9491abc..c95fa86e7 100644 --- a/dialect/pgdialect/append.go +++ b/dialect/pgdialect/append.go @@ -2,12 +2,9 @@ package pgdialect import ( "database/sql/driver" - "encoding/hex" "fmt" "reflect" - "strconv" "time" - "unicode/utf8" "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/schema" @@ -32,316 +29,10 @@ var ( sliceTimeType = reflect.TypeOf([]time.Time(nil)) ) -func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { - switch v := v.(type) { - case int64: - return strconv.AppendInt(b, v, 10) - case float64: - return dialect.AppendFloat64(b, v) - case bool: - return dialect.AppendBool(b, v) - case []byte: - return arrayAppendBytes(b, v) - case string: - return arrayAppendString(b, v) - case time.Time: - return fmter.Dialect().AppendTime(b, v) - default: - err := fmt.Errorf("pgdialect: can't append %T", v) - return dialect.AppendError(b, err) - } -} - -func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - return arrayAppendString(b, v.String()) -} - -func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - return arrayAppendBytes(b, v.Bytes()) -} - -func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - iface, err := v.Interface().(driver.Valuer).Value() - if err != nil { - return dialect.AppendError(b, err) - } - return arrayAppend(fmter, b, iface) -} - -//------------------------------------------------------------------------------ - -func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc { - kind := typ.Kind() - - switch kind { - case reflect.Ptr: - if fn := d.arrayAppender(typ.Elem()); fn != nil { - return schema.PtrAppender(fn) - } - case reflect.Slice, reflect.Array: - // ok: - default: - return nil - } - - elemType := typ.Elem() - - if kind == reflect.Slice { - switch elemType { - case stringType: - return appendStringSliceValue - case intType: - return appendIntSliceValue - case int64Type: - return appendInt64SliceValue - case float64Type: - return appendFloat64SliceValue - case timeType: - return appendTimeSliceValue - } - } - - appendElem := d.arrayElemAppender(elemType) - if appendElem == nil { - panic(fmt.Errorf("pgdialect: %s is not supported", typ)) - } - - return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - kind := v.Kind() - switch kind { - case reflect.Ptr, reflect.Slice: - if v.IsNil() { - return dialect.AppendNull(b) - } - } - - if kind == reflect.Ptr { - v = v.Elem() - } - - b = append(b, '\'') - - b = append(b, '{') - ln := v.Len() - for i := 0; i < ln; i++ { - elem := v.Index(i) - b = appendElem(fmter, b, elem) - b = append(b, ',') - } - if v.Len() > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b - } -} - -func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc { - if typ.Implements(driverValuerType) { - return arrayAppendDriverValue - } - switch typ.Kind() { - case reflect.String: - return arrayAppendStringValue - case reflect.Slice: - if typ.Elem().Kind() == reflect.Uint8 { - return arrayAppendBytesValue - } - } - return schema.Appender(d, typ) -} - -func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ss := v.Convert(sliceStringType).Interface().([]string) - return appendStringSlice(b, ss) +func appendTime(buf []byte, tm time.Time) []byte { + return tm.UTC().AppendFormat(buf, "2006-01-02 15:04:05.999999-07:00") } -func appendStringSlice(b []byte, ss []string) []byte { - if ss == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, s := range ss { - b = arrayAppendString(b, s) - b = append(b, ',') - } - if len(ss) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ints := v.Convert(sliceIntType).Interface().([]int) - return appendIntSlice(b, ints) -} - -func appendIntSlice(b []byte, ints []int) []byte { - if ints == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, n := range ints { - b = strconv.AppendInt(b, int64(n), 10) - b = append(b, ',') - } - if len(ints) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ints := v.Convert(sliceInt64Type).Interface().([]int64) - return appendInt64Slice(b, ints) -} - -func appendInt64Slice(b []byte, ints []int64) []byte { - if ints == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, n := range ints { - b = strconv.AppendInt(b, n, 10) - b = append(b, ',') - } - if len(ints) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - floats := v.Convert(sliceFloat64Type).Interface().([]float64) - return appendFloat64Slice(b, floats) -} - -func appendFloat64Slice(b []byte, floats []float64) []byte { - if floats == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, n := range floats { - b = dialect.AppendFloat64(b, n) - b = append(b, ',') - } - if len(floats) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -//------------------------------------------------------------------------------ - -func arrayAppendBytes(b []byte, bs []byte) []byte { - if bs == nil { - return dialect.AppendNull(b) - } - - b = append(b, `"\\x`...) - - s := len(b) - b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) - hex.Encode(b[s:], bs) - - b = append(b, '"') - - return b -} - -func arrayAppendString(b []byte, s string) []byte { - b = append(b, '"') - for _, r := range s { - switch r { - case 0: - // ignore - case '\'': - b = append(b, "''"...) - case '"': - b = append(b, '\\', '"') - case '\\': - b = append(b, '\\', '\\') - default: - if r < utf8.RuneSelf { - b = append(b, byte(r)) - break - } - l := len(b) - if cap(b)-l < utf8.UTFMax { - b = append(b, make([]byte, utf8.UTFMax)...) - } - n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) - b = b[:l+n] - } - } - b = append(b, '"') - return b -} - -func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ts := v.Convert(sliceTimeType).Interface().([]time.Time) - return appendTimeSlice(fmter, b, ts) -} - -func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte { - if ts == nil { - return dialect.AppendNull(b) - } - b = append(b, '\'') - b = append(b, '{') - for _, t := range ts { - b = append(b, '"') - b = t.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00") - b = append(b, '"') - b = append(b, ',') - } - if len(ts) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - b = append(b, '\'') - return b -} - -//------------------------------------------------------------------------------ - var mapStringStringType = reflect.TypeOf(map[string]string(nil)) func (d *Dialect) hstoreAppender(typ reflect.Type) schema.AppenderFunc { diff --git a/dialect/pgdialect/array.go b/dialect/pgdialect/array.go index 281cff733..46b55659b 100644 --- a/dialect/pgdialect/array.go +++ b/dialect/pgdialect/array.go @@ -2,9 +2,16 @@ package pgdialect import ( "database/sql" + "database/sql/driver" + "encoding/hex" "fmt" "reflect" + "strconv" + "time" + "unicode/utf8" + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) @@ -20,7 +27,7 @@ type ArrayValue struct { // // For struct fields you can use array tag: // -// Emails []string `bun:",array"` +// Emails []string `bun:",array"` func Array(vi interface{}) *ArrayValue { v := reflect.ValueOf(vi) if !v.IsValid() { @@ -63,3 +70,576 @@ func (a *ArrayValue) Value() interface{} { } return nil } + +//------------------------------------------------------------------------------ + +func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + + switch kind { + case reflect.Ptr: + if fn := d.arrayAppender(typ.Elem()); fn != nil { + return schema.PtrAppender(fn) + } + case reflect.Slice, reflect.Array: + // continue below + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendStringSliceValue + case intType: + return appendIntSliceValue + case int64Type: + return appendInt64SliceValue + case float64Type: + return appendFloat64SliceValue + case timeType: + return appendTimeSliceValue + } + } + + appendElem := d.arrayElemAppender(elemType) + if appendElem == nil { + panic(fmt.Errorf("pgdialect: %s is not supported", typ)) + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return dialect.AppendNull(b) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + b = append(b, "'{"...) + + ln := v.Len() + for i := 0; i < ln; i++ { + elem := v.Index(i) + if i > 0 { + b = append(b, ',') + } + b = appendElem(fmter, b, elem) + } + + b = append(b, "}'"...) + + return b + } +} + +func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc { + if typ.Implements(driverValuerType) { + return arrayAppendDriverValue + } + switch typ.Kind() { + case reflect.String: + return arrayAppendStringValue + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return arrayAppendBytesValue + } + } + return schema.Appender(d, typ) +} + +func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case int64: + return strconv.AppendInt(b, v, 10) + case float64: + return dialect.AppendFloat64(b, v) + case bool: + return dialect.AppendBool(b, v) + case []byte: + return arrayAppendBytes(b, v) + case string: + return arrayAppendString(b, v) + case time.Time: + return fmter.Dialect().AppendTime(b, v) + default: + err := fmt.Errorf("pgdialect: can't append %T", v) + return dialect.AppendError(b, err) + } +} + +func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendString(b, v.String()) +} + +func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendBytes(b, v.Bytes()) +} + +func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + iface, err := v.Interface().(driver.Valuer).Value() + if err != nil { + return dialect.AppendError(b, err) + } + return arrayAppend(fmter, b, iface) +} + +func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendStringSlice(b, ss) +} + +func appendStringSlice(b []byte, ss []string) []byte { + if ss == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, s := range ss { + b = arrayAppendString(b, s) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendIntSlice(b, ints) +} + +func appendIntSlice(b []byte, ints []int) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendInt64Slice(b, ints) +} + +func appendInt64Slice(b []byte, ints []int64) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendFloat64Slice(b, floats) +} + +func appendFloat64Slice(b []byte, floats []float64) []byte { + if floats == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range floats { + b = dialect.AppendFloat64(b, n) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ts := v.Convert(sliceTimeType).Interface().([]time.Time) + return appendTimeSlice(fmter, b, ts) +} + +func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte { + if ts == nil { + return dialect.AppendNull(b) + } + b = append(b, '\'') + b = append(b, '{') + for _, t := range ts { + b = append(b, '"') + b = appendTime(b, t) + b = append(b, '"') + b = append(b, ',') + } + if len(ts) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + b = append(b, '\'') + return b +} + +//------------------------------------------------------------------------------ + +func arrayScanner(typ reflect.Type) schema.ScannerFunc { + kind := typ.Kind() + + switch kind { + case reflect.Ptr: + if fn := arrayScanner(typ.Elem()); fn != nil { + return schema.PtrScanner(fn) + } + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return scanStringSliceValue + case intType: + return scanIntSliceValue + case int64Type: + return scanInt64SliceValue + case float64Type: + return scanFloat64SliceValue + } + } + + scanElem := schema.Scanner(elemType) + return func(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + kind := dest.Kind() + + if src == nil { + if kind != reflect.Slice || !dest.IsNil() { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if kind == reflect.Slice { + if dest.IsNil() { + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + } else if dest.Len() > 0 { + dest.Set(dest.Slice(0, 0)) + } + } + + b, err := toBytes(src) + if err != nil { + return err + } + + p := newArrayParser(b) + nextValue := internal.MakeSliceNextElemFunc(dest) + for p.Next() { + elem := p.Elem() + elemValue := nextValue() + if err := scanElem(elemValue, elem); err != nil { + return fmt.Errorf("scanElem failed: %w", err) + } + } + return p.Err() + } +} + +func scanStringSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeStringSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeStringSlice(src interface{}) ([]string, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]string, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + slice = append(slice, string(elem)) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func scanIntSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeIntSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeIntSlice(src interface{}) ([]int, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.Atoi(bytesToString(elem)) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func scanInt64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeInt64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeInt64Slice(src interface{}) ([]int64, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int64, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseInt(bytesToString(elem), 10, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := scanFloat64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func scanFloat64Slice(src interface{}) ([]float64, error) { + if src == -1 { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]float64, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseFloat(bytesToString(elem), 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return stringToBytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +//------------------------------------------------------------------------------ + +func arrayAppendBytes(b []byte, bs []byte) []byte { + if bs == nil { + return dialect.AppendNull(b) + } + + b = append(b, `"\\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) + + b = append(b, '"') + + return b +} + +func arrayAppendString(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} diff --git a/dialect/pgdialect/array_parser.go b/dialect/pgdialect/array_parser.go index a8358337e..462f8d91d 100644 --- a/dialect/pgdialect/array_parser.go +++ b/dialect/pgdialect/array_parser.go @@ -2,132 +2,92 @@ package pgdialect import ( "bytes" - "encoding/hex" "fmt" "io" ) type arrayParser struct { - *streamParser - err error + p pgparser + + elem []byte + err error } func newArrayParser(b []byte) *arrayParser { - p := &arrayParser{ - streamParser: newStreamParser(b, 1), - } + p := new(arrayParser) + if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { - p.err = fmt.Errorf("bun: can't parse array: %q", b) + p.err = fmt.Errorf("pgdialect: can't parse array: %q", b) + return p } + + p.p.Reset(b[1 : len(b)-1]) return p } -func (p *arrayParser) NextElem() ([]byte, error) { +func (p *arrayParser) Next() bool { if p.err != nil { - return nil, p.err + return false } + p.err = p.readNext() + return p.err == nil +} + +func (p *arrayParser) Err() error { + if p.err != io.EOF { + return p.err + } + return nil +} - c, err := p.readByte() - if err != nil { - return nil, err +func (p *arrayParser) Elem() []byte { + return p.elem +} + +func (p *arrayParser) readNext() error { + ch := p.p.Read() + if ch == 0 { + return io.EOF } - switch c { + switch ch { case '}': - return nil, io.EOF + return io.EOF case '"': - b, err := p.readSubstring() + b, err := p.p.ReadSubstring(ch) if err != nil { - return nil, err - } - - if p.peek() == ',' { - p.skipNext() + return err } - return b, nil - default: - b := p.readSimple() - if bytes.Equal(b, []byte("NULL")) { - b = nil + if p.p.Peek() == ',' { + p.p.Advance() } - if p.peek() == ',' { - p.skipNext() + p.elem = b + return nil + case '[', '(': + rng, err := p.p.ReadRange(ch) + if err != nil { + return err } - return b, nil - } -} - -func (p *arrayParser) readSimple() []byte { - p.unreadByte() - - if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { - b := p.b[p.i : p.i+i] - p.i += i - return b - } - - b := p.b[p.i : len(p.b)-1] - p.i = len(p.b) - 1 - return b -} - -func (p *arrayParser) readSubstring() ([]byte, error) { - c, err := p.readByte() - if err != nil { - return nil, err - } - - p.buf = p.buf[:0] - for { - if c == '"' { - break + if p.p.Peek() == ',' { + p.p.Advance() } - next, err := p.readByte() - if err != nil { - return nil, err + p.elem = rng + return nil + default: + lit := p.p.ReadLiteral(ch) + if bytes.Equal(lit, []byte("NULL")) { + lit = nil } - if c == '\\' { - switch next { - case '\\', '"': - p.buf = append(p.buf, next) - - c, err = p.readByte() - if err != nil { - return nil, err - } - default: - p.buf = append(p.buf, '\\') - c = next - } - continue + if p.p.Peek() == ',' { + p.p.Advance() } - if c == '\'' && next == '\'' { - p.buf = append(p.buf, next) - c, err = p.readByte() - if err != nil { - return nil, err - } - continue - } - - p.buf = append(p.buf, c) - c = next - } - if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { - data := p.buf[2:] - buf := make([]byte, hex.DecodedLen(len(data))) - n, err := hex.Decode(buf, data) - if err != nil { - return nil, err - } - return buf[:n], nil + p.elem = lit + return nil } - - return p.buf, nil } diff --git a/dialect/pgdialect/array_parser_test.go b/dialect/pgdialect/array_parser_test.go index 8e0ba8e4a..f2752434d 100644 --- a/dialect/pgdialect/array_parser_test.go +++ b/dialect/pgdialect/array_parser_test.go @@ -1,8 +1,10 @@ package pgdialect import ( - "io" + "fmt" "testing" + + "github.com/stretchr/testify/require" ) func TestArrayParser(t *testing.T) { @@ -21,35 +23,21 @@ func TestArrayParser(t *testing.T) { {"{1,NULL}", []string{"1", ""}}, {`{"1","2"}`, []string{"1", "2"}}, {`{"{1}","{2}"}`, []string{"{1}", "{2}"}}, + {`{[1,2),[3,4)}`, []string{"[1,2)", "[3,4)"}}, } - for testi, test := range tests { - p := newArrayParser([]byte(test.s)) + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + p := newArrayParser([]byte(test.s)) - var got []string - for { - s, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) + got := make([]string, 0) + for p.Next() { + elem := p.Elem() + got = append(got, string(elem)) } - got = append(got, string(s)) - } - - if len(got) != len(test.els) { - t.Fatalf( - "test #%d got %d elements, wanted %d (got=%#v wanted=%#v)", - testi, len(got), len(test.els), got, test.els) - } - for i, el := range got { - if el != test.els[i] { - t.Fatalf( - "test #%d el #%d does not match: %s != %s (got=%#v wanted=%#v)", - testi, i, el, test.els[i], got, test.els) - } - } + require.NoError(t, p.Err()) + require.Equal(t, test.els, got) + }) } } diff --git a/dialect/pgdialect/array_scan.go b/dialect/pgdialect/array_scan.go index a8ff29715..6b8abda3d 100644 --- a/dialect/pgdialect/array_scan.go +++ b/dialect/pgdialect/array_scan.go @@ -1,302 +1 @@ package pgdialect - -import ( - "fmt" - "io" - "reflect" - "strconv" - - "github.com/uptrace/bun/internal" - "github.com/uptrace/bun/schema" -) - -func arrayScanner(typ reflect.Type) schema.ScannerFunc { - kind := typ.Kind() - - switch kind { - case reflect.Ptr: - if fn := arrayScanner(typ.Elem()); fn != nil { - return schema.PtrScanner(fn) - } - case reflect.Slice, reflect.Array: - // ok: - default: - return nil - } - - elemType := typ.Elem() - - if kind == reflect.Slice { - switch elemType { - case stringType: - return scanStringSliceValue - case intType: - return scanIntSliceValue - case int64Type: - return scanInt64SliceValue - case float64Type: - return scanFloat64SliceValue - } - } - - scanElem := schema.Scanner(elemType) - return func(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - kind := dest.Kind() - - if src == nil { - if kind != reflect.Slice || !dest.IsNil() { - dest.Set(reflect.Zero(dest.Type())) - } - return nil - } - - if kind == reflect.Slice { - if dest.IsNil() { - dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) - } else if dest.Len() > 0 { - dest.Set(dest.Slice(0, 0)) - } - } - - b, err := toBytes(src) - if err != nil { - return err - } - - p := newArrayParser(b) - nextValue := internal.MakeSliceNextElemFunc(dest) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return err - } - - elemValue := nextValue() - if err := scanElem(elemValue, elem); err != nil { - return err - } - } - - return nil - } -} - -func scanStringSliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := decodeStringSlice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeStringSlice(src interface{}) ([]string, error) { - if src == nil { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]string, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - slice = append(slice, string(elem)) - } - - return slice, nil -} - -func scanIntSliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := decodeIntSlice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeIntSlice(src interface{}) ([]int, error) { - if src == nil { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]int, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := strconv.Atoi(bytesToString(elem)) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanInt64SliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := decodeInt64Slice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeInt64Slice(src interface{}) ([]int64, error) { - if src == nil { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]int64, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := strconv.ParseInt(bytesToString(elem), 10, 64) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := scanFloat64Slice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func scanFloat64Slice(src interface{}) ([]float64, error) { - if src == -1 { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]float64, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := strconv.ParseFloat(bytesToString(elem), 64) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func toBytes(src interface{}) ([]byte, error) { - switch src := src.(type) { - case string: - return stringToBytes(src), nil - case []byte: - return src, nil - default: - return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) - } -} diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index f100e682c..358971f61 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -89,9 +89,17 @@ func (d *Dialect) onField(field *schema.Field) { if field.Tag.HasOption("array") || strings.HasSuffix(field.UserSQLType, "[]") { field.Append = d.arrayAppender(field.StructField.Type) field.Scan = arrayScanner(field.StructField.Type) + return } - if field.DiscoveredSQLType == sqltype.HSTORE { + if field.Tag.HasOption("multirange") { + field.Append = d.arrayAppender(field.StructField.Type) + field.Scan = arrayScanner(field.StructField.Type) + return + } + + switch field.DiscoveredSQLType { + case sqltype.HSTORE: field.Append = d.hstoreAppender(field.StructField.Type) field.Scan = hstoreScanner(field.StructField.Type) } diff --git a/dialect/pgdialect/go.mod b/dialect/pgdialect/go.mod index 478515879..2cdb938a7 100644 --- a/dialect/pgdialect/go.mod +++ b/dialect/pgdialect/go.mod @@ -6,12 +6,19 @@ toolchain go1.22.1 replace github.com/uptrace/bun => ../.. -require github.com/uptrace/bun v1.2.1 +require ( + github.com/stretchr/testify v1.8.1 + github.com/uptrace/bun v1.2.1 +) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect golang.org/x/sys v0.18.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/dialect/pgdialect/go.sum b/dialect/pgdialect/go.sum index 6ecea8d8d..3fb829687 100644 --- a/dialect/pgdialect/go.sum +++ b/dialect/pgdialect/go.sum @@ -1,9 +1,20 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= @@ -14,5 +25,9 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/dialect/pgdialect/hstore_parser.go b/dialect/pgdialect/hstore_parser.go index 7a18b50b1..df9918219 100644 --- a/dialect/pgdialect/hstore_parser.go +++ b/dialect/pgdialect/hstore_parser.go @@ -3,140 +3,98 @@ package pgdialect import ( "bytes" "fmt" + "io" ) type hstoreParser struct { - *streamParser - err error + p pgparser + + key string + value string + err error } func newHStoreParser(b []byte) *hstoreParser { - p := &hstoreParser{ - streamParser: newStreamParser(b, 0), - } + p := new(hstoreParser) if len(b) < 6 || b[0] != '"' { - p.err = fmt.Errorf("bun: can't parse hstore: %q", b) + p.err = fmt.Errorf("pgdialect: can't parse hstore: %q", b) + return p } + p.p.Reset(b) return p } -func (p *hstoreParser) NextKey() (string, error) { +func (p *hstoreParser) Next() bool { if p.err != nil { - return "", p.err + return false } + p.err = p.readNext() + return p.err == nil +} - err := p.skipByte('"') - if err != nil { - return "", err +func (p *hstoreParser) Err() error { + if p.err != io.EOF { + return p.err } + return nil +} - key, err := p.readSubstring() - if err != nil { - return "", err - } +func (p *hstoreParser) Key() string { + return p.key +} - const separator = "=>" +func (p *hstoreParser) Value() string { + return p.value +} - for i := range separator { - err = p.skipByte(separator[i]) - if err != nil { - return "", err - } +func (p *hstoreParser) readNext() error { + if !p.p.Valid() { + return io.EOF } - return string(key), nil -} + if err := p.p.Skip('"'); err != nil { + return err + } -func (p *hstoreParser) NextValue() (string, error) { - if p.err != nil { - return "", p.err + key, err := p.p.ReadUnescapedSubstring('"') + if err != nil { + return err + } + p.key = string(key) + + if err := p.p.SkipPrefix([]byte("=>")); err != nil { + return err } - c, err := p.readByte() + ch, err := p.p.ReadByte() if err != nil { - return "", err + return err } - switch c { + switch ch { case '"': - value, err := p.readSubstring() + value, err := p.p.ReadUnescapedSubstring(ch) if err != nil { - return "", err - } - - if p.peek() == ',' { - p.skipNext() - } - - if p.peek() == ' ' { - p.skipNext() + return err } - - return string(value), nil + p.skipComma() + p.value = string(value) + return nil default: - value := p.readSimple() + value := p.p.ReadLiteral(ch) if bytes.Equal(value, []byte("NULL")) { value = nil } - - if p.peek() == ',' { - p.skipNext() - } - - return string(value), nil - } -} - -func (p *hstoreParser) readSimple() []byte { - p.unreadByte() - - if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { - b := p.b[p.i : p.i+i] - p.i += i - return b + p.skipComma() + return nil } - - b := p.b[p.i:len(p.b)] - p.i = len(p.b) - return b } -func (p *hstoreParser) readSubstring() ([]byte, error) { - c, err := p.readByte() - if err != nil { - return nil, err +func (p *hstoreParser) skipComma() { + if p.p.Peek() == ',' { + p.p.Advance() } - - p.buf = p.buf[:0] - for { - if c == '"' { - break - } - - next, err := p.readByte() - if err != nil { - return nil, err - } - - if c == '\\' { - switch next { - case '\\', '"': - p.buf = append(p.buf, next) - - c, err = p.readByte() - if err != nil { - return nil, err - } - default: - p.buf = append(p.buf, '\\') - c = next - } - continue - } - - p.buf = append(p.buf, c) - c = next + if p.p.Peek() == ' ' { + p.p.Advance() } - - return p.buf, nil } diff --git a/dialect/pgdialect/hstore_parser_test.go b/dialect/pgdialect/hstore_parser_test.go index aeb8c2e15..2323611c2 100644 --- a/dialect/pgdialect/hstore_parser_test.go +++ b/dialect/pgdialect/hstore_parser_test.go @@ -1,8 +1,10 @@ package pgdialect import ( - "io" + "fmt" "testing" + + "github.com/stretchr/testify/require" ) func TestHStoreParser(t *testing.T) { @@ -21,42 +23,17 @@ func TestHStoreParser(t *testing.T) { {`"{1}"=>"{2}", "{3}"=>"{4}"`, map[string]string{"{1}": "{2}", "{3}": "{4}"}}, } - for testi, test := range tests { - p := newHStoreParser([]byte(test.s)) - - got := make(map[string]string) - for { - key, err := p.NextKey() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) - } + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + p := newHStoreParser([]byte(test.s)) - value, err := p.NextValue() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) + got := make(map[string]string) + for p.Next() { + got[p.Key()] = p.Value() } - got[key] = value - } - - if len(got) != len(test.m) { - t.Fatalf( - "test #%d got %d elements, wanted %d (got=%#v wanted=%#v)", - testi, len(got), len(test.m), got, test.m) - } - - for k, v := range got { - if v != test.m[k] { - t.Fatalf( - "test #%d key #%s does not match: %s != %s (got=%#v wanted=%#v)", - testi, k, v, test.m[k], got, test.m) - } - } + require.NoError(t, p.Err()) + require.Equal(t, test.m, got) + }) } } diff --git a/dialect/pgdialect/hstore_scan.go b/dialect/pgdialect/hstore_scan.go index b10b06b8d..62ab89a3a 100644 --- a/dialect/pgdialect/hstore_scan.go +++ b/dialect/pgdialect/hstore_scan.go @@ -2,7 +2,6 @@ package pgdialect import ( "fmt" - "io" "reflect" "github.com/uptrace/bun/schema" @@ -58,25 +57,11 @@ func decodeMapStringString(src interface{}) (map[string]string, error) { m := make(map[string]string) p := newHStoreParser(b) - for { - key, err := p.NextKey() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - value, err := p.NextValue() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - m[key] = value + for p.Next() { + m[p.Key()] = p.Value() + } + if err := p.Err(); err != nil { + return nil, err } - return m, nil } diff --git a/dialect/pgdialect/range.go b/dialect/pgdialect/range.go new file mode 100644 index 000000000..b942a068e --- /dev/null +++ b/dialect/pgdialect/range.go @@ -0,0 +1,240 @@ +package pgdialect + +import ( + "bytes" + "database/sql" + "encoding/hex" + "fmt" + "io" + "time" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" + "github.com/uptrace/bun/schema" +) + +type MultiRange[T any] []Range[T] + +type Range[T any] struct { + Lower, Upper T + LowerBound, UpperBound RangeBound +} + +type RangeBound byte + +const ( + RangeBoundInclusiveLeft RangeBound = '[' + RangeBoundInclusiveRight RangeBound = ']' + RangeBoundExclusiveLeft RangeBound = '(' + RangeBoundExclusiveRight RangeBound = ')' +) + +func NewRange[T any](lower, upper T) Range[T] { + return Range[T]{ + Lower: lower, + Upper: upper, + LowerBound: RangeBoundInclusiveLeft, + UpperBound: RangeBoundExclusiveRight, + } +} + +var _ sql.Scanner = (*Range[any])(nil) + +func (r *Range[T]) Scan(anySrc any) (err error) { + src := anySrc.([]byte) + + if len(src) == 0 { + return io.ErrUnexpectedEOF + } + r.LowerBound = RangeBound(src[0]) + src = src[1:] + + src, err = scanElem(&r.Lower, src) + if err != nil { + return err + } + + if len(src) == 0 { + return io.ErrUnexpectedEOF + } + if ch := src[0]; ch != ',' { + return fmt.Errorf("got %q, wanted %q", ch, ',') + } + src = src[1:] + + src, err = scanElem(&r.Upper, src) + if err != nil { + return err + } + + if len(src) == 0 { + return io.ErrUnexpectedEOF + } + r.UpperBound = RangeBound(src[0]) + src = src[1:] + + if len(src) > 0 { + return fmt.Errorf("unread data: %q", src) + } + return nil +} + +var _ schema.QueryAppender = (*Range[any])(nil) + +func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) { + buf = append(buf, byte(r.LowerBound)) + buf = appendElem(buf, r.Lower) + buf = append(buf, ',') + buf = appendElem(buf, r.Upper) + buf = append(buf, byte(r.UpperBound)) + return buf, nil +} + +func appendElem(buf []byte, val any) []byte { + switch val := val.(type) { + case time.Time: + buf = append(buf, '"') + buf = appendTime(buf, val) + buf = append(buf, '"') + return buf + default: + panic(fmt.Errorf("unsupported range type: %T", val)) + } +} + +func scanElem(ptr any, src []byte) ([]byte, error) { + switch ptr := ptr.(type) { + case *time.Time: + src, str, err := readStringLiteral(src) + if err != nil { + return nil, err + } + + tm, err := internal.ParseTime(internal.String(str)) + if err != nil { + return nil, err + } + *ptr = tm + + return src, nil + default: + panic(fmt.Errorf("unsupported range type: %T", ptr)) + } +} + +func readStringLiteral(src []byte) ([]byte, []byte, error) { + p := newParser(src) + + if err := p.Skip('"'); err != nil { + return nil, nil, err + } + + str, err := p.ReadSubstring('"') + if err != nil { + return nil, nil, err + } + + src = p.Remaining() + return src, str, nil +} + +//------------------------------------------------------------------------------ + +type pgparser struct { + parser.Parser + buf []byte +} + +func newParser(b []byte) *pgparser { + p := new(pgparser) + p.Reset(b) + return p +} + +func (p *pgparser) ReadLiteral(ch byte) []byte { + p.Unread() + lit, _ := p.ReadSep(',') + return lit +} + +func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) { + return p.readSubstring(ch, false) +} + +func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) { + return p.readSubstring(ch, true) +} + +func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) { + ch, err := p.ReadByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if ch == '"' { + break + } + + next, err := p.ReadByte() + if err != nil { + return nil, err + } + + if ch == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + ch, err = p.ReadByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + ch = next + } + continue + } + + if escaped && ch == '\'' && next == '\'' { + p.buf = append(p.buf, next) + ch, err = p.ReadByte() + if err != nil { + return nil, err + } + continue + } + + p.buf = append(p.buf, ch) + ch = next + } + + if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { + data := p.buf[2:] + buf := make([]byte, hex.DecodedLen(len(data))) + n, err := hex.Decode(buf, data) + if err != nil { + return nil, err + } + return buf[:n], nil + } + + return p.buf, nil +} + +func (p *pgparser) ReadRange(ch byte) ([]byte, error) { + p.buf = p.buf[:0] + p.buf = append(p.buf, ch) + + for p.Valid() { + ch = p.Read() + p.buf = append(p.buf, ch) + if ch == ']' || ch == ')' { + break + } + } + + return p.buf, nil +} diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index dadea5c1c..40802e51d 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -32,8 +32,7 @@ const ( pgTypeText = "TEXT" // variable length string without limit // JSON Types - pgTypeJSON = "JSON" // text representation of json data - pgTypeJSONB = "JSONB" // binary representation of json data + pgTypeJSON = "JSON" // text representation of json data // Binary Data Types pgTypeBytea = "BYTEA" // binary string @@ -83,7 +82,7 @@ func sqlType(typ reflect.Type) string { case ipNetType: return pgTypeCidr case jsonRawMessageType: - return pgTypeJSONB + return sqltype.JSONB } sqlType := schema.DiscoverSQLType(typ) @@ -95,14 +94,14 @@ func sqlType(typ reflect.Type) string { switch typ.Kind() { case reflect.Map, reflect.Struct: if sqlType == sqltype.VarChar { - return pgTypeJSONB + return sqltype.JSONB } return sqlType case reflect.Array, reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { return pgTypeBytea } - return pgTypeJSONB + return sqltype.JSONB } return sqlType diff --git a/dialect/pgdialect/stream_parser.go b/dialect/pgdialect/stream_parser.go deleted file mode 100644 index 7b9a15f62..000000000 --- a/dialect/pgdialect/stream_parser.go +++ /dev/null @@ -1,60 +0,0 @@ -package pgdialect - -import ( - "fmt" - "io" -) - -type streamParser struct { - b []byte - i int - - buf []byte -} - -func newStreamParser(b []byte, start int) *streamParser { - return &streamParser{ - b: b, - i: start, - } -} - -func (p *streamParser) valid() bool { - return p.i < len(p.b) -} - -func (p *streamParser) skipByte(skip byte) error { - c, err := p.readByte() - if err != nil { - return err - } - if c == skip { - return nil - } - p.unreadByte() - return fmt.Errorf("got %q, wanted %q", c, skip) -} - -func (p *streamParser) readByte() (byte, error) { - if p.valid() { - c := p.b[p.i] - p.i++ - return c, nil - } - return 0, io.EOF -} - -func (p *streamParser) unreadByte() { - p.i-- -} - -func (p *streamParser) peek() byte { - if p.valid() { - return p.b[p.i] - } - return 0 -} - -func (p *streamParser) skipNext() { - p.i++ -} diff --git a/example/migrate/main.go b/example/migrate/main.go index f763f9ade..8d7fb74ce 100644 --- a/example/migrate/main.go +++ b/example/migrate/main.go @@ -27,7 +27,7 @@ func main() { db := bun.NewDB(sqldb, sqlitedialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) app := &cli.App{ diff --git a/extra/bundebug/debug.go b/extra/bundebug/debug.go index 5e04b03e4..70519b685 100644 --- a/extra/bundebug/debug.go +++ b/extra/bundebug/debug.go @@ -41,9 +41,9 @@ func WithWriter(w io.Writer) Option { // FromEnv configures the hook using the environment variable value. // For example, WithEnv("BUNDEBUG"): -// - BUNDEBUG=0 - disables the hook. -// - BUNDEBUG=1 - enables the hook. -// - BUNDEBUG=2 - enables the hook and verbose mode. +// - BUNDEBUG=0 - disables the hook. +// - BUNDEBUG=1 - enables the hook. +// - BUNDEBUG=2 - enables the hook and verbose mode. func FromEnv(keys ...string) Option { if len(keys) == 0 { keys = []string{"BUNDEBUG"} diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 43fee3601..4f103be52 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -67,7 +67,7 @@ func pg(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, pgdialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) @@ -90,7 +90,7 @@ func pgx(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, pgdialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) @@ -113,7 +113,7 @@ func mysql8(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, mysqldialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) @@ -136,7 +136,7 @@ func mysql5(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, mysqldialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) @@ -159,7 +159,7 @@ func mariadb(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, mysqldialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) @@ -177,7 +177,7 @@ func sqlite(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, sqlitedialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) @@ -200,7 +200,7 @@ func mssql2019(tb testing.TB) *bun.DB { db := bun.NewDB(sqldb, mssqldialect.New()) db.AddQueryHook(bundebug.NewQueryHook( bundebug.WithEnabled(false), - bundebug.FromEnv(""), + bundebug.FromEnv(), )) require.Equal(tb, "DB", db.String()) diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index cc4031033..103d68170 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -837,3 +837,27 @@ func TestPostgresCustomTypeBytes(t *testing.T) { err = db.NewSelect().Model(out).Scan(ctx) require.NoError(t, err) } + +func TestPostgresMultiRange(t *testing.T) { + type Model struct { + ID int64 `bun:",pk,autoincrement"` + Value pgdialect.MultiRange[time.Time] `bun:",multirange,type:tstzmultirange"` + } + + ctx := context.Background() + + db := pg(t) + t.Cleanup(func() { db.Close() }) + + mustResetModel(t, ctx, db, (*Model)(nil)) + + r1 := pgdialect.NewRange(time.Unix(1000, 0), time.Unix(2000, 0)) + r2 := pgdialect.NewRange(time.Unix(5000, 0), time.Unix(6000, 0)) + in := &Model{Value: pgdialect.MultiRange[time.Time]{r1, r2}} + _, err := db.NewInsert().Model(in).Exec(ctx) + require.NoError(t, err) + + out := new(Model) + err = db.NewSelect().Model(out).Scan(ctx) + require.NoError(t, err) +} diff --git a/internal/parser/parser.go b/internal/parser/parser.go index cdfc0be16..1f2704478 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -2,6 +2,8 @@ package parser import ( "bytes" + "fmt" + "io" "strconv" "github.com/uptrace/bun/internal" @@ -22,23 +24,43 @@ func NewString(s string) *Parser { return New(internal.Bytes(s)) } +func (p *Parser) Reset(b []byte) { + p.b = b + p.i = 0 +} + func (p *Parser) Valid() bool { return p.i < len(p.b) } -func (p *Parser) Bytes() []byte { +func (p *Parser) Remaining() []byte { return p.b[p.i:] } +func (p *Parser) ReadByte() (byte, error) { + if p.Valid() { + ch := p.b[p.i] + p.Advance() + return ch, nil + } + return 0, io.ErrUnexpectedEOF +} + func (p *Parser) Read() byte { if p.Valid() { - c := p.b[p.i] + ch := p.b[p.i] p.Advance() - return c + return ch } return 0 } +func (p *Parser) Unread() { + if p.i > 0 { + p.i-- + } +} + func (p *Parser) Peek() byte { if p.Valid() { return p.b[p.i] @@ -50,19 +72,25 @@ func (p *Parser) Advance() { p.i++ } -func (p *Parser) Skip(skip byte) bool { - if p.Peek() == skip { +func (p *Parser) Skip(skip byte) error { + ch := p.Peek() + if ch == skip { p.Advance() - return true + return nil } - return false + return fmt.Errorf("got %q, wanted %q", ch, skip) } -func (p *Parser) SkipBytes(skip []byte) bool { - if len(skip) > len(p.b[p.i:]) { - return false +func (p *Parser) SkipPrefix(skip []byte) error { + if !bytes.HasPrefix(p.b[p.i:], skip) { + return fmt.Errorf("got %q, wanted prefix %q", p.b, skip) } - if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { + p.i += len(skip) + return nil +} + +func (p *Parser) CutPrefix(skip []byte) bool { + if !bytes.HasPrefix(p.b[p.i:], skip) { return false } p.i += len(skip) diff --git a/query_select.go b/query_select.go index c0e145110..932cd48be 100644 --- a/query_select.go +++ b/query_select.go @@ -564,8 +564,8 @@ func (q *SelectQuery) appendQuery( return nil, err } - for _, j := range q.joins { - b, err = j.AppendQuery(fmter, b) + for _, join := range q.joins { + b, err = join.AppendQuery(fmter, b) if err != nil { return nil, err } diff --git a/schema/dialect.go b/schema/dialect.go index 8814313f7..330293444 100644 --- a/schema/dialect.go +++ b/schema/dialect.go @@ -118,7 +118,7 @@ func (BaseDialect) AppendJSON(b, jsonb []byte) []byte { case '\000': continue case '\\': - if p.SkipBytes([]byte("u0000")) { + if p.CutPrefix([]byte("u0000")) { b = append(b, `\\u0000`...) } else { b = append(b, '\\') diff --git a/schema/table.go b/schema/table.go index 0a23156a2..e0ab61082 100644 --- a/schema/table.go +++ b/schema/table.go @@ -918,6 +918,7 @@ func isKnownFieldOption(name string) bool { "array", "hstore", "composite", + "multirange", "json_use_number", "msgpack", "notnull",