From 2b3623dd665d873911fd20ca707016929921e862 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 10 Sep 2021 17:33:20 +0300 Subject: [PATCH] fix: improve zero checker for ptr values --- internal/dbtest/db_test.go | 26 ++++++++++++++++++++++++++ schema/table.go | 11 ++++++++--- schema/zerochecker.go | 4 ---- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 4067c7a3a..976a70c27 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -223,6 +223,7 @@ func TestDB(t *testing.T) { {testInterfaceAny}, {testInterfaceJSON}, {testScanBytes}, + {testPointers}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -847,3 +848,28 @@ func testScanBytes(t *testing.T, db *bun.DB) { require.Equal(t, models, models1) } + +func testPointers(t *testing.T, db *bun.DB) { + type Model struct { + ID *int64 `bun:",allowzero,default:0"` + Str *string + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + id := int64(1) + str := "hello" + models := []Model{ + {}, + {ID: &id, Str: &str}, + } + _, err = db.NewInsert().Model(&models).Exec(ctx) + require.NoError(t, err) + + var models2 []Model + err = db.NewSelect().Model(&models2).Order("id ASC").Scan(ctx) + require.NoError(t, err) +} diff --git a/schema/table.go b/schema/table.go index 7498a2bc8..8bed5ed38 100644 --- a/schema/table.go +++ b/schema/table.go @@ -193,10 +193,15 @@ func (t *Table) initFields() { } } if len(t.PKs) == 1 { - switch t.PKs[0].IndirectType.Kind() { + pk := t.PKs[0] + if pk.SQLDefault != "" { + return + } + + switch pk.IndirectType.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - t.PKs[0].AutoIncrement = true + pk.AutoIncrement = true } } } @@ -359,7 +364,7 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType) field.Append = t.dialect.FieldAppender(field) field.Scan = FieldScanner(t.dialect, field) - field.IsZero = FieldZeroChecker(field) + field.IsZero = zeroChecker(field.StructField.Type) if v, ok := tag.Options["alt"]; ok { t.FieldMap[v] = field diff --git a/schema/zerochecker.go b/schema/zerochecker.go index 95efeee6b..f088b8c2c 100644 --- a/schema/zerochecker.go +++ b/schema/zerochecker.go @@ -13,10 +13,6 @@ type isZeroer interface { type IsZeroerFunc func(reflect.Value) bool -func FieldZeroChecker(field *Field) IsZeroerFunc { - return zeroChecker(field.IndirectType) -} - func zeroChecker(typ reflect.Type) IsZeroerFunc { if typ.Implements(isZeroerType) { return isZeroInterface