diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 91da2a1d4..902845a2d 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -236,6 +236,8 @@ func TestDB(t *testing.T) { {testUpsert}, {testMultiUpdate}, {testTxScanAndCount}, + {testEmbedModelValue}, + {testEmbedModelPointer}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -1029,3 +1031,59 @@ func testTxScanAndCount(t *testing.T, db *bun.DB) { require.NoError(t, err) } } + +func testEmbedModelValue(t *testing.T, db *bun.DB) { + type Embed struct { + Foo string + Bar string + } + type Model struct { + X Embed `bun:"embed:x_"` + Y Embed `bun:"embed:y_"` + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + m1 := &Model{ + X: Embed{Foo: "x.foo", Bar: "x.bar"}, + Y: Embed{Foo: "y.foo", Bar: "y.bar"}, + } + _, err = db.NewInsert().Model(m1).Exec(ctx) + require.NoError(t, err) + + var m2 Model + err = db.NewSelect().Model(&m2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, *m1, m2) +} + +func testEmbedModelPointer(t *testing.T, db *bun.DB) { + type Embed struct { + Foo string + Bar string + } + type Model struct { + X *Embed `bun:"embed:x_"` + Y *Embed `bun:"embed:y_"` + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + m1 := &Model{ + X: &Embed{Foo: "x.foo", Bar: "x.bar"}, + Y: &Embed{Foo: "y.foo", Bar: "y.bar"}, + } + _, err = db.NewInsert().Model(m1).Exec(ctx) + require.NoError(t, err) + + var m2 Model + err = db.NewSelect().Model(&m2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, *m1, m2) +} diff --git a/schema/table.go b/schema/table.go index 96856a997..88b8d8e25 100644 --- a/schema/table.go +++ b/schema/table.go @@ -203,7 +203,7 @@ func (t *Table) fieldByGoName(name string) *Field { func (t *Table) initFields() { t.Fields = make([]*Field, 0, t.Type.NumField()) t.FieldMap = make(map[string]*Field, t.Type.NumField()) - t.addFields(t.Type, nil) + t.addFields(t.Type, "", nil) if len(t.PKs) == 0 { for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { @@ -230,7 +230,7 @@ func (t *Table) initFields() { } } -func (t *Table) addFields(typ reflect.Type, baseIndex []int) { +func (t *Table) addFields(typ reflect.Type, prefix string, index []int) { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) unexported := f.PkgPath != "" @@ -242,10 +242,6 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - // Make a copy so the slice is not shared between fields. - index := make([]int, len(baseIndex)) - copy(index, baseIndex) - if f.Anonymous { if f.Name == "BaseModel" && f.Type == baseModelType { if len(index) == 0 { @@ -258,7 +254,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { if fieldType.Kind() != reflect.Struct { continue } - t.addFields(fieldType, append(index, f.Index...)) + t.addFields(fieldType, "", withIndex(index, f.Index)) tag := tagparser.Parse(f.Tag.Get("bun")) if _, inherit := tag.Options["inherit"]; inherit { @@ -274,7 +270,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - if field := t.newField(f, index); field != nil { + if field := t.newField(f, prefix, index); field != nil { t.addField(field) } } @@ -315,10 +311,20 @@ func (t *Table) processBaseModelField(f reflect.StructField) { } //nolint -func (t *Table) newField(f reflect.StructField, index []int) *Field { - sqlName := internal.Underscore(f.Name) +func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Field { tag := tagparser.Parse(f.Tag.Get("bun")) + if prefix, ok := tag.Option("embed"); ok { + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct", + t.TypeName, f.Name, fieldType.Kind())) + } + t.addFields(fieldType, prefix, withIndex(index, f.Index)) + return nil + } + + sqlName := internal.Underscore(f.Name) if tag.Name != "" && tag.Name != sqlName { if isKnownFieldOption(tag.Name) { internal.Warn.Printf( @@ -328,10 +334,10 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } sqlName = tag.Name } - if s, ok := tag.Option("column"); ok { sqlName = s } + sqlName = prefix + sqlName for name := range tag.Options { if !isKnownFieldOption(name) { @@ -339,7 +345,7 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } } - index = append(index, f.Index...) + index = withIndex(index, f.Index) if field := t.fieldWithLock(sqlName); field != nil { if indexEqual(field.Index, index) { return field @@ -795,7 +801,7 @@ func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { f.GoName = field.GoName + "_" + f.GoName f.Name = field.Name + "__" + f.Name f.SQLName = t.quoteIdent(f.Name) - f.Index = appendNew(field.Index, f.Index...) + f.Index = withIndex(field.Index, f.Index) t.fieldsMapMu.Lock() if _, ok := t.FieldMap[f.Name]; !ok { @@ -853,13 +859,6 @@ func (t *Table) quoteIdent(s string) Safe { return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) } -func appendNew(dst []int, src ...int) []int { - cp := make([]int, len(dst)+len(src)) - copy(cp, dst) - copy(cp[len(dst):], src) - return cp -} - func isKnownTableOption(name string) bool { switch name { case "table", "alias", "select": @@ -991,3 +990,10 @@ func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time return field.ScanWithCheck(fv, tm) } } + +func withIndex(a, b []int) []int { + dest := make([]int, 0, len(a)+len(b)) + dest = append(dest, a...) + dest = append(dest, b...) + return dest +}