From 66b44f7c0edc205927fb8be96aaf263b31828fa1 Mon Sep 17 00:00:00 2001 From: oGi4i Date: Tue, 14 Jun 2022 15:56:41 +0300 Subject: [PATCH] feat(pgdialect): add hstore support --- dialect/pgdialect/append.go | 55 +++++++++ dialect/pgdialect/array_parser.go | 36 +----- dialect/pgdialect/dialect.go | 5 + dialect/pgdialect/hstore.go | 73 ++++++++++++ dialect/pgdialect/hstore_parser.go | 142 ++++++++++++++++++++++++ dialect/pgdialect/hstore_parser_test.go | 62 +++++++++++ dialect/pgdialect/hstore_scan.go | 82 ++++++++++++++ dialect/pgdialect/sqltype.go | 6 +- dialect/pgdialect/stream_parser.go | 60 ++++++++++ dialect/sqltype/sqltype.go | 1 + internal/dbtest/pg_test.go | 71 ++++++++++++ schema/scan.go | 5 + 12 files changed, 561 insertions(+), 37 deletions(-) create mode 100644 dialect/pgdialect/hstore.go create mode 100644 dialect/pgdialect/hstore_parser.go create mode 100644 dialect/pgdialect/hstore_parser_test.go create mode 100644 dialect/pgdialect/hstore_scan.go create mode 100644 dialect/pgdialect/stream_parser.go diff --git a/dialect/pgdialect/append.go b/dialect/pgdialect/append.go index d5e0d0a57..a60bf5de2 100644 --- a/dialect/pgdialect/append.go +++ b/dialect/pgdialect/append.go @@ -307,3 +307,58 @@ func arrayAppendString(b []byte, s string) []byte { b = append(b, '"') return b } + +//------------------------------------------------------------------------------ + +var mapStringStringType = reflect.TypeOf(map[string]string(nil)) + +func (d *Dialect) hstoreAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + + switch kind { + case reflect.Ptr: + if fn := d.hstoreAppender(typ.Elem()); fn != nil { + return schema.PtrAppender(fn) + } + case reflect.Map: + // ok: + default: + return nil + } + + if typ.Key() == stringType && typ.Elem() == stringType { + return appendMapStringStringValue + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + err := fmt.Errorf("bun: Hstore(unsupported %s)", v.Type()) + return dialect.AppendError(b, err) + } +} + +func appendMapStringString(b []byte, m map[string]string) []byte { + if m == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + for key, value := range m { + b = arrayAppendString(b, key) + b = append(b, '=', '>') + b = arrayAppendString(b, value) + b = append(b, ',') + } + if len(m) > 0 { + b = b[:len(b)-1] // Strip trailing comma. + } + + b = append(b, '\'') + + return b +} + +func appendMapStringStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + m := v.Convert(mapStringStringType).Interface().(map[string]string) + return appendMapStringString(b, m) +} diff --git a/dialect/pgdialect/array_parser.go b/dialect/pgdialect/array_parser.go index 0dff754f8..a8358337e 100644 --- a/dialect/pgdialect/array_parser.go +++ b/dialect/pgdialect/array_parser.go @@ -8,17 +8,13 @@ import ( ) type arrayParser struct { - b []byte - i int - - buf []byte + *streamParser err error } func newArrayParser(b []byte) *arrayParser { p := &arrayParser{ - b: b, - i: 1, + streamParser: newStreamParser(b, 1), } if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { p.err = fmt.Errorf("bun: can't parse array: %q", b) @@ -135,31 +131,3 @@ func (p *arrayParser) readSubstring() ([]byte, error) { return p.buf, nil } - -func (p *arrayParser) valid() bool { - return p.i < len(p.b) -} - -func (p *arrayParser) readByte() (byte, error) { - if p.valid() { - c := p.b[p.i] - p.i++ - return c, nil - } - return 0, io.EOF -} - -func (p *arrayParser) unreadByte() { - p.i-- -} - -func (p *arrayParser) peek() byte { - if p.valid() { - return p.b[p.i] - } - return 0 -} - -func (p *arrayParser) skipNext() { - p.i++ -} diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index 852132b7f..1b64ea753 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -88,6 +88,11 @@ func (d *Dialect) onField(field *schema.Field) { field.Append = d.arrayAppender(field.StructField.Type) field.Scan = arrayScanner(field.StructField.Type) } + + if field.DiscoveredSQLType == sqltype.HSTORE { + field.Append = d.hstoreAppender(field.StructField.Type) + field.Scan = hstoreScanner(field.StructField.Type) + } } func (d *Dialect) IdentQuote() byte { diff --git a/dialect/pgdialect/hstore.go b/dialect/pgdialect/hstore.go new file mode 100644 index 000000000..029f7cb6d --- /dev/null +++ b/dialect/pgdialect/hstore.go @@ -0,0 +1,73 @@ +package pgdialect + +import ( + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type HStoreValue struct { + v reflect.Value + + append schema.AppenderFunc + scan schema.ScannerFunc +} + +// HStore accepts a map[string]string and returns a wrapper for working with PostgreSQL +// hstore data type. +// +// For struct fields you can use hstore tag: +// +// Attrs map[string]string `bun:",hstore"` +func HStore(vi interface{}) *HStoreValue { + v := reflect.ValueOf(vi) + if !v.IsValid() { + panic(fmt.Errorf("bun: HStore(nil)")) + } + + typ := v.Type() + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() != reflect.Map { + panic(fmt.Errorf("bun: Hstore(unsupported %s)", typ)) + } + + return &HStoreValue{ + v: v, + + append: pgDialect.hstoreAppender(v.Type()), + scan: hstoreScanner(v.Type()), + } +} + +var ( + _ schema.QueryAppender = (*HStoreValue)(nil) + _ sql.Scanner = (*HStoreValue)(nil) +) + +func (h *HStoreValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) { + if h.append == nil { + panic(fmt.Errorf("bun: HStore(unsupported %s)", h.v.Type())) + } + return h.append(fmter, b, h.v), nil +} + +func (h *HStoreValue) Scan(src interface{}) error { + if h.scan == nil { + return fmt.Errorf("bun: HStore(unsupported %s)", h.v.Type()) + } + if h.v.Kind() != reflect.Ptr { + return fmt.Errorf("bun: HStore(non-pointer %s)", h.v.Type()) + } + return h.scan(h.v.Elem(), src) +} + +func (h *HStoreValue) Value() interface{} { + if h.v.IsValid() { + return h.v.Interface() + } + return nil +} diff --git a/dialect/pgdialect/hstore_parser.go b/dialect/pgdialect/hstore_parser.go new file mode 100644 index 000000000..7a18b50b1 --- /dev/null +++ b/dialect/pgdialect/hstore_parser.go @@ -0,0 +1,142 @@ +package pgdialect + +import ( + "bytes" + "fmt" +) + +type hstoreParser struct { + *streamParser + err error +} + +func newHStoreParser(b []byte) *hstoreParser { + p := &hstoreParser{ + streamParser: newStreamParser(b, 0), + } + if len(b) < 6 || b[0] != '"' { + p.err = fmt.Errorf("bun: can't parse hstore: %q", b) + } + return p +} + +func (p *hstoreParser) NextKey() (string, error) { + if p.err != nil { + return "", p.err + } + + err := p.skipByte('"') + if err != nil { + return "", err + } + + key, err := p.readSubstring() + if err != nil { + return "", err + } + + const separator = "=>" + + for i := range separator { + err = p.skipByte(separator[i]) + if err != nil { + return "", err + } + } + + return string(key), nil +} + +func (p *hstoreParser) NextValue() (string, error) { + if p.err != nil { + return "", p.err + } + + c, err := p.readByte() + if err != nil { + return "", err + } + + switch c { + case '"': + value, err := p.readSubstring() + if err != nil { + return "", err + } + + if p.peek() == ',' { + p.skipNext() + } + + if p.peek() == ' ' { + p.skipNext() + } + + return string(value), nil + default: + value := p.readSimple() + if bytes.Equal(value, []byte("NULL")) { + value = nil + } + + if p.peek() == ',' { + p.skipNext() + } + + return string(value), nil + } +} + +func (p *hstoreParser) readSimple() []byte { + p.unreadByte() + + if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { + b := p.b[p.i : p.i+i] + p.i += i + return b + } + + b := p.b[p.i:len(p.b)] + p.i = len(p.b) + return b +} + +func (p *hstoreParser) readSubstring() ([]byte, error) { + c, err := p.readByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if c == '"' { + break + } + + next, err := p.readByte() + if err != nil { + return nil, err + } + + if c == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + c, err = p.readByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + c = next + } + continue + } + + p.buf = append(p.buf, c) + c = next + } + + return p.buf, nil +} diff --git a/dialect/pgdialect/hstore_parser_test.go b/dialect/pgdialect/hstore_parser_test.go new file mode 100644 index 000000000..aeb8c2e15 --- /dev/null +++ b/dialect/pgdialect/hstore_parser_test.go @@ -0,0 +1,62 @@ +package pgdialect + +import ( + "io" + "testing" +) + +func TestHStoreParser(t *testing.T) { + tests := []struct { + s string + m map[string]string + }{ + {`""=>""`, map[string]string{"": ""}}, + {`"\\"=>"\\"`, map[string]string{`\`: `\`}}, + {`"'"=>"'"`, map[string]string{"'": "'"}}, + {`"'\"{}"=>"'\"{}"`, map[string]string{`'"{}`: `'"{}`}}, + + {`"1"=>"2", "3"=>"4"`, map[string]string{"1": "2", "3": "4"}}, + {`"1"=>NULL`, map[string]string{"1": ""}}, + {`"1"=>"NULL"`, map[string]string{"1": "NULL"}}, + {`"{1}"=>"{2}", "{3}"=>"{4}"`, map[string]string{"{1}": "{2}", "{3}": "{4}"}}, + } + + for testi, test := range tests { + p := newHStoreParser([]byte(test.s)) + + got := make(map[string]string) + for { + key, err := p.NextKey() + if err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + + value, err := p.NextValue() + if err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + + got[key] = value + } + + if len(got) != len(test.m) { + t.Fatalf( + "test #%d got %d elements, wanted %d (got=%#v wanted=%#v)", + testi, len(got), len(test.m), got, test.m) + } + + for k, v := range got { + if v != test.m[k] { + t.Fatalf( + "test #%d key #%s does not match: %s != %s (got=%#v wanted=%#v)", + testi, k, v, test.m[k], got, test.m) + } + } + } +} diff --git a/dialect/pgdialect/hstore_scan.go b/dialect/pgdialect/hstore_scan.go new file mode 100644 index 000000000..b10b06b8d --- /dev/null +++ b/dialect/pgdialect/hstore_scan.go @@ -0,0 +1,82 @@ +package pgdialect + +import ( + "fmt" + "io" + "reflect" + + "github.com/uptrace/bun/schema" +) + +func hstoreScanner(typ reflect.Type) schema.ScannerFunc { + kind := typ.Kind() + + switch kind { + case reflect.Ptr: + if fn := hstoreScanner(typ.Elem()); fn != nil { + return schema.PtrScanner(fn) + } + case reflect.Map: + // ok: + default: + return nil + } + + if typ.Key() == stringType && typ.Elem() == stringType { + return scanMapStringStringValue + } + return func(dest reflect.Value, src interface{}) error { + return fmt.Errorf("bun: Hstore(unsupported %s)", dest.Type()) + } +} + +func scanMapStringStringValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + m, err := decodeMapStringString(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(m)) + return nil +} + +func decodeMapStringString(src interface{}) (map[string]string, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + m := make(map[string]string) + + p := newHStoreParser(b) + for { + key, err := p.NextKey() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + value, err := p.NextValue() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + m[key] = value + } + + return m, nil +} diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index bfef89fa1..6c6294d71 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -53,11 +53,11 @@ func fieldSQLType(field *schema.Field) string { if v, ok := field.Tag.Option("composite"); ok { return v } - if _, ok := field.Tag.Option("hstore"); ok { - return "hstore" + if field.Tag.HasOption("hstore") { + return sqltype.HSTORE } - if _, ok := field.Tag.Options["array"]; ok { + if field.Tag.HasOption("array") { switch field.IndirectType.Kind() { case reflect.Slice, reflect.Array: sqlType := sqlType(field.IndirectType.Elem()) diff --git a/dialect/pgdialect/stream_parser.go b/dialect/pgdialect/stream_parser.go new file mode 100644 index 000000000..7b9a15f62 --- /dev/null +++ b/dialect/pgdialect/stream_parser.go @@ -0,0 +1,60 @@ +package pgdialect + +import ( + "fmt" + "io" +) + +type streamParser struct { + b []byte + i int + + buf []byte +} + +func newStreamParser(b []byte, start int) *streamParser { + return &streamParser{ + b: b, + i: start, + } +} + +func (p *streamParser) valid() bool { + return p.i < len(p.b) +} + +func (p *streamParser) skipByte(skip byte) error { + c, err := p.readByte() + if err != nil { + return err + } + if c == skip { + return nil + } + p.unreadByte() + return fmt.Errorf("got %q, wanted %q", c, skip) +} + +func (p *streamParser) readByte() (byte, error) { + if p.valid() { + c := p.b[p.i] + p.i++ + return c, nil + } + return 0, io.EOF +} + +func (p *streamParser) unreadByte() { + p.i-- +} + +func (p *streamParser) peek() byte { + if p.valid() { + return p.b[p.i] + } + return 0 +} + +func (p *streamParser) skipNext() { + p.i++ +} diff --git a/dialect/sqltype/sqltype.go b/dialect/sqltype/sqltype.go index f58b2f1d1..1031fd352 100644 --- a/dialect/sqltype/sqltype.go +++ b/dialect/sqltype/sqltype.go @@ -12,4 +12,5 @@ const ( Timestamp = "TIMESTAMP" JSON = "JSON" JSONB = "JSONB" + HSTORE = "HSTORE" ) diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index 6d6076e27..cbdb5affd 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -540,3 +540,74 @@ func TestPostgresUUID(t *testing.T) { require.NoError(t, err) require.NotZero(t, model.ID) } + +func TestPostgresHStore(t *testing.T) { + type Model struct { + ID int64 `bun:",pk,autoincrement"` + Attrs1 map[string]string `bun:",hstore"` + Attrs2 *map[string]string `bun:",hstore"` + Attrs3 *map[string]string `bun:",hstore"` + } + + db := pg(t) + defer db.Close() + + _, err := db.Exec(`CREATE EXTENSION IF NOT EXISTS HSTORE;`) + require.NoError(t, err) + + _, err = db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx) + require.NoError(t, err) + + _, err = db.NewCreateTable().Model((*Model)(nil)).Exec(ctx) + require.NoError(t, err) + + model1 := &Model{ + ID: 123, + Attrs1: map[string]string{"one": "two", "three": "four"}, + Attrs2: &map[string]string{"two": "three", "four": "five"}, + } + _, err = db.NewInsert().Model(model1).Exec(ctx) + require.NoError(t, err) + + model2 := new(Model) + err = db.NewSelect().Model(model2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, model1, model2) + + attrs1 := make(map[string]string) + err = db.NewSelect().Model((*Model)(nil)). + Column("attrs1"). + Scan(ctx, pgdialect.HStore(&attrs1)) + require.NoError(t, err) + require.Equal(t, map[string]string{"one": "two", "three": "four"}, attrs1) + + attrs2 := make(map[string]string) + err = db.NewSelect().Model((*Model)(nil)). + Column("attrs2"). + Scan(ctx, pgdialect.HStore(&attrs2)) + require.NoError(t, err) + require.Equal(t, map[string]string{"two": "three", "four": "five"}, attrs2) + + var attrs3 map[string]string + err = db.NewSelect().Model((*Model)(nil)). + Column("attrs3"). + Scan(ctx, pgdialect.HStore(&attrs3)) + require.NoError(t, err) + require.Nil(t, attrs3) +} + +func TestPostgresHStoreQuote(t *testing.T) { + db := pg(t) + defer db.Close() + + _, err := db.Exec(`CREATE EXTENSION IF NOT EXISTS HSTORE;`) + require.NoError(t, err) + + wanted := map[string]string{"'": "'", "''": "''", "'''": "'''", "\"": "\""} + m := make(map[string]string) + err = db.NewSelect(). + ColumnExpr("?::hstore", pgdialect.HStore(wanted)). + Scan(ctx, pgdialect.HStore(&m)) + require.NoError(t, err) + require.Equal(t, wanted, m) +} diff --git a/schema/scan.go b/schema/scan.go index 069b14e44..96b31caf3 100644 --- a/schema/scan.go +++ b/schema/scan.go @@ -449,6 +449,11 @@ func PtrScanner(fn ScannerFunc) ScannerFunc { if dest.IsNil() { dest.Set(reflect.New(dest.Type().Elem())) } + + if dest.Kind() == reflect.Map { + return fn(dest, src) + } + return fn(dest.Elem(), src) } }