Skip to content

Commit

Permalink
fix: improve zero checker for ptr values
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 10, 2021
1 parent d7de8d3 commit 2b3623d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
26 changes: 26 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ func TestDB(t *testing.T) {
{testInterfaceAny},
{testInterfaceJSON},
{testScanBytes},
{testPointers},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -847,3 +848,28 @@ func testScanBytes(t *testing.T, db *bun.DB) {

require.Equal(t, models, models1)
}

func testPointers(t *testing.T, db *bun.DB) {
type Model struct {
ID *int64 `bun:",allowzero,default:0"`
Str *string
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

id := int64(1)
str := "hello"
models := []Model{
{},
{ID: &id, Str: &str},
}
_, err = db.NewInsert().Model(&models).Exec(ctx)
require.NoError(t, err)

var models2 []Model
err = db.NewSelect().Model(&models2).Order("id ASC").Scan(ctx)
require.NoError(t, err)
}
11 changes: 8 additions & 3 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,15 @@ func (t *Table) initFields() {
}
}
if len(t.PKs) == 1 {
switch t.PKs[0].IndirectType.Kind() {
pk := t.PKs[0]
if pk.SQLDefault != "" {
return
}

switch pk.IndirectType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
t.PKs[0].AutoIncrement = true
pk.AutoIncrement = true
}
}
}
Expand Down Expand Up @@ -359,7 +364,7 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
field.Append = t.dialect.FieldAppender(field)
field.Scan = FieldScanner(t.dialect, field)
field.IsZero = FieldZeroChecker(field)
field.IsZero = zeroChecker(field.StructField.Type)

if v, ok := tag.Options["alt"]; ok {
t.FieldMap[v] = field
Expand Down
4 changes: 0 additions & 4 deletions schema/zerochecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ type isZeroer interface {

type IsZeroerFunc func(reflect.Value) bool

func FieldZeroChecker(field *Field) IsZeroerFunc {
return zeroChecker(field.IndirectType)
}

func zeroChecker(typ reflect.Type) IsZeroerFunc {
if typ.Implements(isZeroerType) {
return isZeroInterface
Expand Down

0 comments on commit 2b3623d

Please sign in to comment.