From a958fcbab680b0c5ad7980f369c7b73f7673db87 Mon Sep 17 00:00:00 2001 From: I am Goroot Date: Sun, 30 Jan 2022 17:56:51 +0200 Subject: [PATCH] feat: add CreateTableQuery.DetectForeignKeys --- internal/dbtest/db_test.go | 72 ++++++++++++++++++++++++++++++++++++++ query_table_create.go | 11 ++++++ 2 files changed, 83 insertions(+) diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index dcc0bb1d7..66b7b6e8a 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -227,6 +227,7 @@ func TestDB(t *testing.T) { {testJSONValuer}, {testSelectBool}, {testFKViolation}, + {testDetectForeignKeys}, {testInterfaceAny}, {testInterfaceJSON}, {testScanRawMessage}, @@ -843,6 +844,77 @@ func testFKViolation(t *testing.T, db *bun.DB) { require.Equal(t, 0, n) } +func testDetectForeignKeys(t *testing.T, db *bun.DB) { + type User struct { + ID int `bun:",pk"` + Type string `bun:",pk"` + Name string + } + type Deck struct { + ID int `bun:",pk"` + UserID int + UserType string + User *User `bun:"rel:belongs-to,join:user_id=id,join:user_type=type"` + } + + if db.Dialect().Name() == dialect.SQLite { + _, err := db.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + } + + for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} { + _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) + require.NoError(t, err) + } + + _, err := db.NewCreateTable(). + Model((*User)(nil)). + IfNotExists(). + Exec(ctx) + require.NoError(t, err) + + _, err = db.NewCreateTable(). + Model((*Deck)(nil)). + IfNotExists(). + DetectForeignKeys(). + Exec(ctx) + require.NoError(t, err) + + // Empty deck should violate FK constraint. + _, err = db.NewInsert().Model(new(Deck)).Exec(ctx) + require.Error(t, err) + + // Create a deck that violates the user_id FK contraint + deck := &Deck{UserID: 42} + + _, err = db.NewInsert().Model(deck).Exec(ctx) + require.Error(t, err) + + decks := []*Deck{deck} + _, err = db.NewInsert().Model(&decks).Exec(ctx) + require.Error(t, err) + + n, err := db.NewSelect().Model((*Deck)(nil)).Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, n) + + _, err = db.NewInsert().Model(&User{ID: 1, Type: "admin", Name: "root"}).Exec(ctx) + require.NoError(t, err) + res, err := db.NewInsert().Model(&Deck{UserID: 1, UserType: "admin"}).Exec(ctx) + require.NoError(t, err) + + affected, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), affected) + + //Select with Relation should work + d := Deck{} + err = db.NewSelect().Model(&d).Relation("User").Scan(ctx) + require.NoError(t, err) + require.NotNil(t, d.User) + require.Equal(t, d.User.Name, "root") +} + func testInterfaceAny(t *testing.T, db *bun.DB) { switch db.Dialect().Name() { case dialect.MySQL: diff --git a/query_table_create.go b/query_table_create.go index f2312bc69..9c62c0672 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -102,6 +102,17 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery { return q } +func (q *CreateTableQuery) DetectForeignKeys() *CreateTableQuery { + for _, relation := range q.tableModel.Table().Relations { + q.ForeignKey("(?) REFERENCES ? (?)", + Safe(appendColumns(nil, "", relation.BaseFields)), + relation.JoinTable.SQLName, + Safe(appendColumns(nil, "", relation.JoinFields)), + ) + } + return q +} + //------------------------------------------------------------------------------ func (q *CreateTableQuery) Operation() string {