Skip to content

Commit

Permalink
feat: improve nil ptr values handling
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Mar 2, 2022
1 parent 01af843 commit b398e6b
Show file tree
Hide file tree
Showing 29 changed files with 73 additions and 19 deletions.
1 change: 0 additions & 1 deletion dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ func fieldSQLType(field *schema.Field) string {
if v, ok := field.Tag.Option("composite"); ok {
return v
}

if _, ok := field.Tag.Option("hstore"); ok {
return "hstore"
}
Expand Down
15 changes: 14 additions & 1 deletion internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ func TestQuery(t *testing.T) {
},
func(db *bun.DB) schema.QueryAppender {
type Model struct {
Raw *json.RawMessage `bun:",nullzero"`
Raw *json.RawMessage
}
return db.NewInsert().Model(new(Model))
},
Expand All @@ -682,6 +682,19 @@ func TestQuery(t *testing.T) {
}
return db.NewInsert().Model(&Model{ID: 123, Slice: make([]Item, 0)})
},
func(db *bun.DB) schema.QueryAppender {
type Model struct {
Time *time.Time
}
return db.NewInsert().Model(new(Model))
},
func(db *bun.DB) schema.QueryAppender {
type Model struct {
Time *time.Time
}
tm := time.Unix(0, 0)
return db.NewInsert().Model(&Model{Time: &tm})
},
}

timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`time`) VALUES (DEFAULT) RETURNING `time`
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`time`) VALUES ('1970-01-01 00:00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-mariadb-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO `models` (`raw`) VALUES ('null')
INSERT INTO `models` (`raw`) VALUES (DEFAULT) RETURNING `raw`
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES (DEFAULT)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES ('1970-01-01 00:00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-mssql2019-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO "models" ("raw") VALUES ('null')
INSERT INTO "models" ("raw") VALUES (DEFAULT)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`time`) VALUES (DEFAULT)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`time`) VALUES ('1970-01-01 00:00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-mysql5-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO `models` (`raw`) VALUES ('null')
INSERT INTO `models` (`raw`) VALUES (DEFAULT)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`time`) VALUES (DEFAULT)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`time`) VALUES ('1970-01-01 00:00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-mysql8-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO `models` (`raw`) VALUES ('null')
INSERT INTO `models` (`raw`) VALUES (DEFAULT)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES (DEFAULT) RETURNING "time"
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES ('1970-01-01 00:00:00+00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-pg-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO "models" ("raw") VALUES ('null')
INSERT INTO "models" ("raw") VALUES (DEFAULT) RETURNING "raw"
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES (DEFAULT) RETURNING "time"
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES ('1970-01-01 00:00:00+00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-pgx-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO "models" ("raw") VALUES ('null')
INSERT INTO "models" ("raw") VALUES (DEFAULT) RETURNING "raw"
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-109
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES (NULL) RETURNING "time"
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-110
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("time") VALUES ('1970-01-01 00:00:00+00:00')
2 changes: 1 addition & 1 deletion internal/dbtest/testdata/snapshots/TestQuery-sqlite-59
Original file line number Diff line number Diff line change
@@ -1 +1 @@
INSERT INTO "models" ("raw") VALUES ('null')
INSERT INTO "models" ("raw") VALUES (NULL) RETURNING "raw"
2 changes: 1 addition & 1 deletion query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (q *whereBaseQuery) appendWhere(
field := q.tableModel.Table().SoftDeleteField
b = append(b, field.SQLName...)

if field.NullZero {
if field.IsPtr || field.NullZero {
if q.flags.Has(deletedFlag) {
b = append(b, " IS NOT NULL"...)
} else {
Expand Down
10 changes: 6 additions & 4 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (q *InsertQuery) appendStructValues(
switch {
case isTemplate:
b = append(b, '?')
case f.NullZero && f.HasZeroValue(strct):
case (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)):
if q.db.features.Has(feature.DefaultPlaceholder) {
b = append(b, "DEFAULT"...)
} else if f.SQLDefault != "" {
Expand Down Expand Up @@ -397,9 +397,11 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) {
q.addReturningField(f)
continue
}
if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) {
q.addReturningField(f)
continue
if f.NotNull && f.SQLDefault == "" {
if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) {
q.addReturningField(f)
continue
}
}
fields = append(fields, f)
}
Expand Down
4 changes: 3 additions & 1 deletion schema/append_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
return appendBytesValue
case timeType:
return appendTimeValue
case timePtrType:
return PtrAppender(appendTimeValue)
case ipType:
return appendIPValue
case ipNetType:
Expand Down Expand Up @@ -135,7 +137,7 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
return ifaceAppenderFunc
case reflect.Ptr:
if typ.Implements(jsonMarshalerType) {
return AppendJSONValue
return nilAwareAppender(AppendJSONValue)
}
if fn := Appender(dialect, typ.Elem()); fn != nil {
return PtrAppender(fn)
Expand Down
28 changes: 25 additions & 3 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

type Field struct {
StructField reflect.StructField
IsPtr bool

Tag tagparser.Tag
IndirectType reflect.Type
Expand Down Expand Up @@ -51,15 +52,36 @@ func (f *Field) Value(strct reflect.Value) reflect.Value {
return fieldByIndexAlloc(strct, f.Index)
}

func (f *Field) HasNilValue(v reflect.Value) bool {
if len(f.Index) == 1 {
return v.Field(f.Index[0]).IsNil()
}

for _, index := range f.Index {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return true
}
v = v.Elem()
}
v = v.Field(index)
}
return v.IsNil()
}

func (f *Field) HasZeroValue(v reflect.Value) bool {
for _, idx := range f.Index {
if len(f.Index) == 1 {
return f.IsZero(v.Field(f.Index[0]))
}

for _, index := range f.Index {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return true
}
v = v.Elem()
}
v = v.Field(idx)
v = v.Field(index)
}
return f.IsZero(v)
}
Expand All @@ -70,7 +92,7 @@ func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []by
return dialect.AppendNull(b)
}

if f.NullZero && f.IsZero(fv) {
if (f.IsPtr && fv.IsNil()) || (f.NullZero && f.IsZero(fv)) {
return dialect.AppendNull(b)
}
if f.Append == nil {
Expand Down
3 changes: 2 additions & 1 deletion schema/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import (

var (
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
timePtrType = reflect.TypeOf((*time.Time)(nil))
timeType = timePtrType.Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
Expand Down
1 change: 1 addition & 0 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie

field := &Field{
StructField: f,
IsPtr: f.Type.Kind() == reflect.Ptr,

Tag: tag,
IndirectType: indirectType(f.Type),
Expand Down

0 comments on commit b398e6b

Please sign in to comment.