Skip to content

Commit

Permalink
feat: add CreateTableQuery.DetectForeignKeys
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgoroot committed Jan 30, 2022
1 parent 3cb01ac commit a958fcb
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
72 changes: 72 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ func TestDB(t *testing.T) {
{testJSONValuer},
{testSelectBool},
{testFKViolation},
{testDetectForeignKeys},
{testInterfaceAny},
{testInterfaceJSON},
{testScanRawMessage},
Expand Down Expand Up @@ -843,6 +844,77 @@ func testFKViolation(t *testing.T, db *bun.DB) {
require.Equal(t, 0, n)
}

func testDetectForeignKeys(t *testing.T, db *bun.DB) {
type User struct {
ID int `bun:",pk"`
Type string `bun:",pk"`
Name string
}
type Deck struct {
ID int `bun:",pk"`
UserID int
UserType string
User *User `bun:"rel:belongs-to,join:user_id=id,join:user_type=type"`
}

if db.Dialect().Name() == dialect.SQLite {
_, err := db.Exec("PRAGMA foreign_keys = ON;")
require.NoError(t, err)
}

for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
require.NoError(t, err)
}

_, err := db.NewCreateTable().
Model((*User)(nil)).
IfNotExists().
Exec(ctx)
require.NoError(t, err)

_, err = db.NewCreateTable().
Model((*Deck)(nil)).
IfNotExists().
DetectForeignKeys().
Exec(ctx)
require.NoError(t, err)

// Empty deck should violate FK constraint.
_, err = db.NewInsert().Model(new(Deck)).Exec(ctx)
require.Error(t, err)

// Create a deck that violates the user_id FK contraint
deck := &Deck{UserID: 42}

_, err = db.NewInsert().Model(deck).Exec(ctx)
require.Error(t, err)

decks := []*Deck{deck}
_, err = db.NewInsert().Model(&decks).Exec(ctx)
require.Error(t, err)

n, err := db.NewSelect().Model((*Deck)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)

_, err = db.NewInsert().Model(&User{ID: 1, Type: "admin", Name: "root"}).Exec(ctx)
require.NoError(t, err)
res, err := db.NewInsert().Model(&Deck{UserID: 1, UserType: "admin"}).Exec(ctx)
require.NoError(t, err)

affected, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)

//Select with Relation should work
d := Deck{}
err = db.NewSelect().Model(&d).Relation("User").Scan(ctx)
require.NoError(t, err)
require.NotNil(t, d.User)
require.Equal(t, d.User.Name, "root")
}

func testInterfaceAny(t *testing.T, db *bun.DB) {
switch db.Dialect().Name() {
case dialect.MySQL:
Expand Down
11 changes: 11 additions & 0 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery {
return q
}

func (q *CreateTableQuery) DetectForeignKeys() *CreateTableQuery {
for _, relation := range q.tableModel.Table().Relations {
q.ForeignKey("(?) REFERENCES ? (?)",
Safe(appendColumns(nil, "", relation.BaseFields)),
relation.JoinTable.SQLName,
Safe(appendColumns(nil, "", relation.JoinFields)),
)
}
return q
}

//------------------------------------------------------------------------------

func (q *CreateTableQuery) Operation() string {
Expand Down

0 comments on commit a958fcb

Please sign in to comment.