From 8c4d8be3aa4e64582698b37fd21434b8960dddc0 Mon Sep 17 00:00:00 2001 From: Sander van Harmelen Date: Sat, 6 Jan 2024 07:57:46 +0100 Subject: [PATCH] feat: add Join to UpdateQuery (#908) * feat: add Join to UpdateQuery * chore: implement LSP suggestion (could apply De Morgan's law) Improves readability (IMHO) so thought to implement it. --- internal/dbtest/query_test.go | 7 ++++ .../testdata/snapshots/TestQuery-mariadb-163 | 1 + .../snapshots/TestQuery-mssql2019-163 | 1 + .../testdata/snapshots/TestQuery-mysql5-163 | 1 + .../testdata/snapshots/TestQuery-mysql8-163 | 1 + .../testdata/snapshots/TestQuery-pg-163 | 1 + .../testdata/snapshots/TestQuery-pgx-163 | 1 + .../testdata/snapshots/TestQuery-sqlite-163 | 1 + query_merge.go | 2 +- query_update.go | 35 +++++++++++++++++++ 10 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mariadb-163 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mssql2019-163 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql5-163 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql8-163 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pg-163 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pgx-163 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-sqlite-163 diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index f5785e491..81f9abe99 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -1001,6 +1001,13 @@ func TestQuery(t *testing.T) { func(db *bun.DB) schema.QueryAppender { return db.NewUpdate().Model(&Model{42, ""}).OmitZero() }, + func(db *bun.DB) schema.QueryAppender { + return db.NewUpdate(). + Model((*Story)(nil)). + Set("name = ?", "new-name"). + Join("JOIN user ON user.id = story.user_id"). + Where("user.id = ?", 1) + }, } timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mariadb-163 b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-163 new file mode 100644 index 000000000..825cb3842 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-163 @@ -0,0 +1 @@ +UPDATE `stories` AS `story` SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-163 b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-163 new file mode 100644 index 000000000..22bbf8252 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-163 @@ -0,0 +1 @@ +UPDATE "stories" SET name = N'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-163 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-163 new file mode 100644 index 000000000..825cb3842 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-163 @@ -0,0 +1 @@ +UPDATE `stories` AS `story` SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-163 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-163 new file mode 100644 index 000000000..825cb3842 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-163 @@ -0,0 +1 @@ +UPDATE `stories` AS `story` SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-163 b/internal/dbtest/testdata/snapshots/TestQuery-pg-163 new file mode 100644 index 000000000..f2c526a68 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-163 @@ -0,0 +1 @@ +UPDATE "stories" AS "story" SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-163 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-163 new file mode 100644 index 000000000..f2c526a68 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-163 @@ -0,0 +1 @@ +UPDATE "stories" AS "story" SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-163 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-163 new file mode 100644 index 000000000..f2c526a68 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-163 @@ -0,0 +1 @@ +UPDATE "stories" AS "story" SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1) diff --git a/query_merge.go b/query_merge.go index 706dc20ae..626752b8a 100644 --- a/query_merge.go +++ b/query_merge.go @@ -29,7 +29,7 @@ func NewMergeQuery(db *DB) *MergeQuery { conn: db.DB, }, } - if !(q.db.dialect.Name() == dialect.MSSQL || q.db.dialect.Name() == dialect.PG) { + if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG { q.err = errors.New("bun: merge not supported for current dialect") } return q diff --git a/query_update.go b/query_update.go index 146d695b8..e56ba20d1 100644 --- a/query_update.go +++ b/query_update.go @@ -20,6 +20,7 @@ type UpdateQuery struct { setQuery idxHintsQuery + joins []joinQuery omitZero bool } @@ -133,6 +134,33 @@ func (q *UpdateQuery) OmitZero() *UpdateQuery { //------------------------------------------------------------------------------ +func (q *UpdateQuery) Join(join string, args ...interface{}) *UpdateQuery { + q.joins = append(q.joins, joinQuery{ + join: schema.SafeQuery(join, args), + }) + return q +} + +func (q *UpdateQuery) JoinOn(cond string, args ...interface{}) *UpdateQuery { + return q.joinOn(cond, args, " AND ") +} + +func (q *UpdateQuery) JoinOnOr(cond string, args ...interface{}) *UpdateQuery { + return q.joinOn(cond, args, " OR ") +} + +func (q *UpdateQuery) joinOn(cond string, args []interface{}, sep string) *UpdateQuery { + if len(q.joins) == 0 { + q.err = errors.New("bun: query has no joins") + return q + } + j := &q.joins[len(q.joins)-1] + j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep)) + return q +} + +//------------------------------------------------------------------------------ + func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery { q.addWhereCols(cols) return q @@ -230,6 +258,13 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e } } + for _, j := range q.joins { + b, err = j.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + if q.hasFeature(feature.Output) && q.hasReturning() { b = append(b, " OUTPUT "...) b, err = q.appendOutput(fmter, b)