Skip to content

Commit

Permalink
fix: check for nils when appeding driver.Value
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Feb 22, 2022
1 parent d3f93b4 commit 7bb1640
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
MSSQL_USER: sa
MSSQL_PASSWORD: passWORD1
ports:
- 14339:1433
- 1433:1433
options: >-
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S tcp:localhost,1433 -U sa -P passWORD1 -Q
'select 1' -b -o /dev/null" --health-interval=10s --health-timeout=5s --health-retries=5
Expand All @@ -90,4 +90,4 @@ jobs:
MYSQL: user:pass@/test
MYSQL5: user:pass@tcp(localhost:53306)/test
MARIADB: user:pass@tcp(localhost:13306)/test
MSSQL2019: sqlserver://sa:passWORD1@localhost:14339?database=test
MSSQL2019: sqlserver://sa:passWORD1@localhost:1433?database=master
28 changes: 28 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ func TestDB(t *testing.T) {
{testEmbedModelValue},
{testEmbedModelPointer},
{testJSONMarshaler},
{testNilDriverValue},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -1284,3 +1285,30 @@ func testJSONMarshaler(t *testing.T, db *bun.DB) {
require.NoError(t, err)
require.Equal(t, "bar", m2.Field.Foo)
}

type DriverValue struct {
s string
}

var _ driver.Valuer = (*DriverValue)(nil)

func (v *DriverValue) Value() (driver.Value, error) {
return v.s, nil
}

func testNilDriverValue(t *testing.T, db *bun.DB) {
type Model struct {
Value *DriverValue `bun:"type:varchar(100)"`
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

_, err = db.NewInsert().Model(&Model{}).Exec(ctx)
require.NoError(t, err)

_, err = db.NewInsert().Model(&Model{Value: &DriverValue{s: "hello"}}).Exec(ctx)
require.NoError(t, err)
}
19 changes: 17 additions & 2 deletions schema/append_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,21 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
return appendJSONRawMessageValue
}

kind := typ.Kind()

if typ.Implements(queryAppenderType) {
if kind == reflect.Ptr {
return nilAwareAppender(appendQueryAppenderValue)
}
return appendQueryAppenderValue
}
if typ.Implements(driverValuerType) {
if kind == reflect.Ptr {
return nilAwareAppender(appendDriverValue)
}
return appendDriverValue
}

kind := typ.Kind()

if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(queryAppenderType) {
Expand Down Expand Up @@ -156,6 +162,15 @@ func ifaceAppenderFunc(fmter Formatter, b []byte, v reflect.Value) []byte {
return appender(fmter, b, elem)
}

func nilAwareAppender(fn AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
return fn(fmter, b, v)
}
}

func PtrAppender(fn AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
Expand Down

0 comments on commit 7bb1640

Please sign in to comment.