Skip to content

Commit

Permalink
fix: "model does not have column" error (#850)
Browse files Browse the repository at this point in the history
This fixes the case of two belongs-to fields with different names
but with the same type.
  • Loading branch information
martoche authored Sep 10, 2023
1 parent fd7b609 commit 16367aa
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 7 deletions.
155 changes: 155 additions & 0 deletions internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ func TestORM(t *testing.T) {
{testRelationExcludeAll},
{testM2MRelationExcludeColumn},
{testRelationBelongsToSelf},
{testRelationsCycle},
{testCompositeHasMany},
{testRelationsDifferentFieldsWithSameType},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -372,6 +374,65 @@ func testRelationBelongsToSelf(t *testing.T, db *bun.DB) {
}, models)
}

func testRelationsCycle(t *testing.T, db *bun.DB) {
type Child struct {
Id int64 `bun:",pk"`
SiblingId int64
Sibling *Child `bun:"rel:belongs-to"`
}

type Parent struct {
Id int64 `bun:",pk"`
GrandParentId int64
FirstChildId int64
FirstChild *Child `bun:"rel:belongs-to"`
SecondChildId int64
SecondChild *Child `bun:"rel:belongs-to"`
}

type GrandParent struct {
Id int64 `bun:",pk"`
Parents []*Parent `bun:"rel:has-many"`
}

err := db.ResetModel(ctx, (*GrandParent)(nil))
require.NoError(t, err)
err = db.ResetModel(ctx, (*Parent)(nil))
require.NoError(t, err)
err = db.ResetModel(ctx, (*Child)(nil))
require.NoError(t, err)

_, err = db.NewInsert().Model(&GrandParent{Id: 1}).Exec(ctx)
require.NoError(t, err)
_, err = db.NewInsert().Model(&Parent{Id: 1, GrandParentId: 1, FirstChildId: 1, SecondChildId: 2}).Exec(ctx)
require.NoError(t, err)
_, err = db.NewInsert().Model(&Child{Id: 1}).Exec(ctx)
require.NoError(t, err)
_, err = db.NewInsert().Model(&Child{Id: 2, SiblingId: 1}).Exec(ctx)
require.NoError(t, err)

var grandParent GrandParent
err = db.NewSelect().
Model(&grandParent).
Relation("Parents.FirstChild").
Relation("Parents.SecondChild.Sibling").
Where("grand_parent.id = ?", 1).
Scan(ctx)

require.NoError(t, err)
require.Equal(t, GrandParent{
Id: 1,
Parents: []*Parent{{
Id: 1,
GrandParentId: 1,
FirstChildId: 1,
FirstChild: &Child{Id: 1, SiblingId: 0},
SecondChildId: 2,
SecondChild: &Child{Id: 2, SiblingId: 1, Sibling: &Child{Id: 1, SiblingId: 0}},
}},
}, grandParent)
}

func testM2MRelationExcludeColumn(t *testing.T, db *bun.DB) {
type Item struct {
ID int64 `bun:",pk,autoincrement"`
Expand Down Expand Up @@ -442,6 +503,100 @@ func testCompositeHasMany(t *testing.T, db *bun.DB) {
require.Equal(t, 2, len(department.Employees))
}

func testRelationsDifferentFieldsWithSameType(t *testing.T, db *bun.DB) {
type Country struct {
Id int `bun:",pk"`
}

type Address struct {
Id int `bun:",pk"`
CountryId int
Country *Country `bun:"rel:has-one,join:country_id=id"`
}

type CompanyAddress struct {
Id int `bun:",pk"`
CompanyId int
BillingAddressId int
BillingAddress *Address `bun:"rel:belongs-to,join:billing_address_id=id"`
ShippingAddressId int
ShippingAddress *Address `bun:"rel:belongs-to,join:shipping_address_id=id"`
CountryId int
Country *Country `bun:"rel:has-one,join:country_id=id"`
}

type Company struct {
Id int `bun:",pk"`
ParentCompanyId int
CompanyAddress *CompanyAddress `bun:"rel:has-one,join:id=company_id"`
}

type ParentCompany struct {
Id int `bun:",pk"`
Companies []*Company `bun:"rel:has-many,join:id=parent_company_id"`
}

models := []interface{}{
(*Country)(nil),
(*Address)(nil),
(*CompanyAddress)(nil),
(*Company)(nil),
(*ParentCompany)(nil),
}
for _, model := range models {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
require.NoError(t, err)
_, err = db.NewCreateTable().Model(model).Exec(ctx)
require.NoError(t, err)
}

models = []interface{}{
&Country{Id: 1},
&Address{Id: 1, CountryId: 1},
&CompanyAddress{Id: 1, CompanyId: 1, BillingAddressId: 1, ShippingAddressId: 1, CountryId: 1},
&Company{Id: 1, ParentCompanyId: 1},
&ParentCompany{Id: 1},
}
for _, model := range models {
res, err := db.NewInsert().Model(model).Exec(ctx)
require.NoError(t, err)

n, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), n)
}

var parentCompany ParentCompany
err := db.NewSelect().
Model(&parentCompany).
Relation("Companies.CompanyAddress.Country").
Relation("Companies.CompanyAddress.BillingAddress.Country").
Relation("Companies.CompanyAddress.ShippingAddress.Country").
Where("parent_company.id = ?", 1).
Scan(ctx)

require.NoError(t, err)
require.Equal(t, ParentCompany{
Id: 1,
Companies: []*Company{
{
Id: 1,
ParentCompanyId: 1,
CompanyAddress: &CompanyAddress{
Id: 1,
CompanyId: 1,
BillingAddressId: 1,
BillingAddress: &Address{Id: 1, CountryId: 1, Country: &Country{Id: 1}},
ShippingAddressId: 1,
ShippingAddress: &Address{Id: 1, CountryId: 1, Country: &Country{Id: 1}},
CountryId: 1,
Country: &Country{Id: 1},
},
},
},
}, parentCompany)
}

type Genre struct {
ID int `bun:",pk"`
Name string
Expand Down
37 changes: 30 additions & 7 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -806,18 +807,38 @@ func (t *Table) m2mRelation(field *Field) *Relation {
return rel
}

func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
if seen == nil {
seen = map[reflect.Type]struct{}{t.Type: {}}
type seenKey struct {
Table reflect.Type
FieldIndex string
}

type seenMap map[seenKey]struct{}

func NewSeenKey(table reflect.Type, fieldIndex []int) (key seenKey) {
key.Table = table
for _, index := range fieldIndex {
key.FieldIndex += strconv.Itoa(index) + "-"
}
return key
}

if _, ok := seen[field.IndirectType]; ok {
return
func (s seenMap) Clone() seenMap {
t := make(seenMap)
for k, v := range s {
t[k] = v
}
return t
}

func (t *Table) inlineFields(field *Field, seen seenMap) {
if seen == nil {
seen = make(seenMap)
}
seen[field.IndirectType] = struct{}{}

joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
key := NewSeenKey(joinTable.Type, f.Index)

f = f.Clone()
f.GoName = field.GoName + "_" + f.GoName
f.Name = field.Name + "__" + f.Name
Expand All @@ -834,7 +855,9 @@ func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
continue
}

if _, ok := seen[f.IndirectType]; !ok {
if _, ok := seen[key]; !ok {
seen = seen.Clone()
seen[key] = struct{}{}
t.inlineFields(f, seen)
}
}
Expand Down

0 comments on commit 16367aa

Please sign in to comment.