Skip to content

Commit

Permalink
chore: rework inlining and embedding (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco authored Jan 29, 2024
1 parent 19a8264 commit 9052fc4
Show file tree
Hide file tree
Showing 11 changed files with 489 additions and 369 deletions.
5 changes: 0 additions & 5 deletions bun.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ type (

BeforeScanRowHook = schema.BeforeScanRowHook
AfterScanRowHook = schema.AfterScanRowHook

// DEPRECATED. Use BeforeScanRowHook instead.
BeforeScanHook = schema.BeforeScanHook
// DEPRECATED. Use AfterScanRowHook instead.
AfterScanHook = schema.AfterScanHook
)

type BeforeSelectHook interface {
Expand Down
22 changes: 0 additions & 22 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ func TestDB(t *testing.T) {
{testJSONMarshaler},
{testNilDriverValue},
{testRunInTxAndSavepoint},
{testEmbedTypeField},
{testDriverValuerReturnsItself},
{testNoPanicWhenReturningNullColumns},
}
Expand Down Expand Up @@ -1440,27 +1439,6 @@ func testEmbedModelPointer(t *testing.T, db *bun.DB) {
require.Equal(t, *m1, m2)
}

func testEmbedTypeField(t *testing.T, db *bun.DB) {
type Embed string
type Model struct {
Embed
}

ctx := context.Background()
mustResetModel(t, ctx, db, (*Model)(nil))

m1 := &Model{
Embed: Embed("foo"),
}
_, err := db.NewInsert().Model(m1).Exec(ctx)
require.NoError(t, err)

var m2 Model
err = db.NewSelect().Model(&m2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, *m1, m2)
}

type JSONField struct {
Foo string `json:"foo"`
}
Expand Down
2 changes: 1 addition & 1 deletion internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ func TestQuery(t *testing.T) {
func(db *bun.DB) schema.QueryAppender {
type ID string
type Model struct {
ID
ID ID
}
return db.NewInsert().Model(&Model{ID: ID("embed")})
},
Expand Down
6 changes: 3 additions & 3 deletions model_table_has_many.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ func (m *hasManyModel) Scan(src interface{}) error {
column := m.columns[m.scanIndex]
m.scanIndex++

field, err := m.table.Field(column)
if err != nil {
return err
field := m.table.LookupField(column)
if field == nil {
return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column)
}

if err := field.ScanValue(m.strct, src); err != nil {
Expand Down
20 changes: 1 addition & 19 deletions model_table_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ func (m *structTableModel) BeforeScanRow(ctx context.Context) error {
if m.table.HasBeforeScanRowHook() {
return m.strct.Addr().Interface().(schema.BeforeScanRowHook).BeforeScanRow(ctx)
}
if m.table.HasBeforeScanHook() {
return m.strct.Addr().Interface().(schema.BeforeScanHook).BeforeScan(ctx)
}
return nil
}

Expand All @@ -144,21 +141,6 @@ func (m *structTableModel) AfterScanRow(ctx context.Context) error {
return firstErr
}

if m.table.HasAfterScanHook() {
firstErr := m.strct.Addr().Interface().(schema.AfterScanHook).AfterScan(ctx)

for _, j := range m.joins {
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil {
firstErr = err
}
}
}

return firstErr
}

return nil
}

Expand Down Expand Up @@ -325,7 +307,7 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err
}
}

if field, ok := m.table.FieldMap[column]; ok {
if field := m.table.LookupField(column); field != nil {
if src == nil && m.isNil() {
return true, nil
}
Expand Down
4 changes: 3 additions & 1 deletion relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte
return append(b, j.BaseModel.Table().SQLAlias...)
}

func (j *relationJoin) appendSoftDelete(fmter schema.Formatter, b []byte, flags internal.Flag) []byte {
func (j *relationJoin) appendSoftDelete(
fmter schema.Formatter, b []byte, flags internal.Flag,
) []byte {
b = append(b, '.')

field := j.JoinModel.Table().SoftDeleteField
Expand Down
35 changes: 16 additions & 19 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ func (f *Field) String() string {
return f.Name
}

func (f *Field) WithIndex(path []int) *Field {
if len(path) == 0 {
return f
}
clone := *f
clone.Index = makeIndex(path, f.Index)
return &clone
}

func (f *Field) Clone() *Field {
cp := *f
cp.Index = cp.Index[:len(f.Index):len(f.Index)]
Expand Down Expand Up @@ -103,13 +112,6 @@ func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []by
return f.Append(fmter, b, fv)
}

func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
if f.Scan == nil {
return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
}
return f.Scan(fv, src)
}

func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
if src == nil {
if fv, ok := fieldByIndex(strct, f.Index); ok {
Expand All @@ -122,18 +124,13 @@ func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
return f.ScanWithCheck(fv, src)
}

func (f *Field) SkipUpdate() bool {
return f.Tag.HasOption("skipupdate")
func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
if f.Scan == nil {
return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
}
return f.Scan(fv, src)
}

func indexEqual(ind1, ind2 []int) bool {
if len(ind1) != len(ind2) {
return false
}
for i, ind := range ind1 {
if ind != ind2[i] {
return false
}
}
return true
func (f *Field) SkipUpdate() bool {
return f.Tag.HasOption("skipupdate")
}
16 changes: 0 additions & 16 deletions schema/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,6 @@ var beforeAppendModelHookType = reflect.TypeOf((*BeforeAppendModelHook)(nil)).El

//------------------------------------------------------------------------------

type BeforeScanHook interface {
BeforeScan(context.Context) error
}

var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()

//------------------------------------------------------------------------------

type AfterScanHook interface {
AfterScan(context.Context) error
}

var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()

//------------------------------------------------------------------------------

type BeforeScanRowHook interface {
BeforeScanRow(context.Context) error
}
Expand Down
Loading

0 comments on commit 9052fc4

Please sign in to comment.