Skip to content

Commit

Permalink
fix(pgdialect): fix bytea[] handling
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 14, 2021
1 parent 60ffe29 commit a5ca013
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 10 deletions.
8 changes: 4 additions & 4 deletions dialect/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ func AppendString(b []byte, s string) []byte {
return b
}

func AppendBytes(b []byte, bytes []byte) []byte {
if bytes == nil {
func AppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return AppendNull(b)
}

b = append(b, `'\x`...)

s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...)
hex.Encode(b[s:], bytes)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)

b = append(b, '\'')

Expand Down
34 changes: 30 additions & 4 deletions dialect/pgdialect/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgdialect

import (
"database/sql/driver"
"encoding/hex"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -64,7 +65,7 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
case bool:
return dialect.AppendBool(b, v)
case []byte:
return dialect.AppendBytes(b, v)
return arrayAppendBytes(b, v)
case string:
return arrayAppendString(b, v)
case time.Time:
Expand All @@ -76,19 +77,28 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
}

func arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Kind() == reflect.String {
return arrayAppendStringValue
}
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.String:
return arrayAppendStringValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return arrayAppendBytesValue
}
}
return schema.Appender(typ, customAppender)
}

func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}

func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendBytes(b, v.Bytes())
}

func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
iface, err := v.Interface().(driver.Valuer).Value()
if err != nil {
Expand Down Expand Up @@ -280,6 +290,22 @@ func appendFloat64Slice(b []byte, floats []float64) []byte {

//------------------------------------------------------------------------------

func arrayAppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
}

b = append(b, `"\\x`...)

s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)

b = append(b, '"')

return b
}

func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
for _, r := range s {
Expand Down
11 changes: 11 additions & 0 deletions dialect/pgdialect/array_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgdialect

import (
"bytes"
"encoding/hex"
"fmt"
"io"
)
Expand Down Expand Up @@ -114,6 +115,16 @@ func (p *arrayParser) readSubstring() ([]byte, error) {
c = next
}

if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 {
data := p.buf[2:]
buf := make([]byte, hex.DecodedLen(len(data)))
n, err := hex.Decode(buf, data)
if err != nil {
return nil, err
}
return buf[:n], nil
}

return p.buf, nil
}

Expand Down
48 changes: 47 additions & 1 deletion internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package dbtest_test

import (
"database/sql"
"database/sql/driver"
"fmt"
"net"
"reflect"
"testing"
Expand All @@ -16,13 +18,14 @@ import (

func TestPGArray(t *testing.T) {
type Model struct {
ID int
ID int64
Array1 []string `bun:",array"`
Array2 *[]string `bun:",array"`
Array3 *[]string `bun:",array"`
}

db := pg(t)
defer db.Close()

_, err := db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -57,6 +60,49 @@ func TestPGArray(t *testing.T) {
require.Nil(t, strs)
}

type Hash [32]byte

func (h *Hash) Scan(src interface{}) error {
srcB, ok := src.([]byte)
if !ok {
return fmt.Errorf("can't scan %T into Hash", src)
}
if len(srcB) != len(h) {
return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), len(h))
}
copy(h[:], srcB)
return nil
}

func (h Hash) Value() (driver.Value, error) {
return h[:], nil
}

func TestPGArrayValuer(t *testing.T) {
type Model struct {
ID int64
Array []Hash `bun:",array"`
}

db := pg(t)
defer db.Close()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

model1 := &Model{
ID: 123,
Array: []Hash{Hash{}},
}
_, err = db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)

model2 := new(Model)
err = db.NewSelect().Model(model2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, model1, model2)
}

type Recipe struct {
bun.BaseModel `bun:"?tenant.recipes"`

Expand Down
6 changes: 5 additions & 1 deletion migrate/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ func init() {
}
`

const sqlTemplate = `SELECT 1
const sqlTemplate = `SET statement_timeout = 0;
--bun:split
SELECT 1
--bun:split
Expand Down

0 comments on commit a5ca013

Please sign in to comment.