Skip to content

Commit

Permalink
feat: add Join to UpdateQuery (#908)
Browse files Browse the repository at this point in the history
* feat: add Join to UpdateQuery

* chore: implement LSP suggestion (could apply De Morgan's law)

Improves readability (IMHO) so thought to implement it.
  • Loading branch information
svanharmelen authored Jan 6, 2024
1 parent 8a43835 commit 8c4d8be
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 1 deletion.
7 changes: 7 additions & 0 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,13 @@ func TestQuery(t *testing.T) {
func(db *bun.DB) schema.QueryAppender {
return db.NewUpdate().Model(&Model{42, ""}).OmitZero()
},
func(db *bun.DB) schema.QueryAppender {
return db.NewUpdate().
Model((*Story)(nil)).
Set("name = ?", "new-name").
Join("JOIN user ON user.id = story.user_id").
Where("user.id = ?", 1)
},
}

timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE `stories` AS `story` SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE "stories" SET name = N'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE `stories` AS `story` SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE `stories` AS `story` SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE "stories" AS "story" SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE "stories" AS "story" SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-163
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UPDATE "stories" AS "story" SET name = 'new-name' JOIN user ON user.id = story.user_id WHERE (user.id = 1)
2 changes: 1 addition & 1 deletion query_merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewMergeQuery(db *DB) *MergeQuery {
conn: db.DB,
},
}
if !(q.db.dialect.Name() == dialect.MSSQL || q.db.dialect.Name() == dialect.PG) {
if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG {
q.err = errors.New("bun: merge not supported for current dialect")
}
return q
Expand Down
35 changes: 35 additions & 0 deletions query_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type UpdateQuery struct {
setQuery
idxHintsQuery

joins []joinQuery
omitZero bool
}

Expand Down Expand Up @@ -133,6 +134,33 @@ func (q *UpdateQuery) OmitZero() *UpdateQuery {

//------------------------------------------------------------------------------

func (q *UpdateQuery) Join(join string, args ...interface{}) *UpdateQuery {
q.joins = append(q.joins, joinQuery{
join: schema.SafeQuery(join, args),
})
return q
}

func (q *UpdateQuery) JoinOn(cond string, args ...interface{}) *UpdateQuery {
return q.joinOn(cond, args, " AND ")
}

func (q *UpdateQuery) JoinOnOr(cond string, args ...interface{}) *UpdateQuery {
return q.joinOn(cond, args, " OR ")
}

func (q *UpdateQuery) joinOn(cond string, args []interface{}, sep string) *UpdateQuery {
if len(q.joins) == 0 {
q.err = errors.New("bun: query has no joins")
return q
}
j := &q.joins[len(q.joins)-1]
j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep))
return q
}

//------------------------------------------------------------------------------

func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery {
q.addWhereCols(cols)
return q
Expand Down Expand Up @@ -230,6 +258,13 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
}
}

for _, j := range q.joins {
b, err = j.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}

if q.hasFeature(feature.Output) && q.hasReturning() {
b = append(b, " OUTPUT "...)
b, err = q.appendOutput(fmter, b)
Expand Down

0 comments on commit 8c4d8be

Please sign in to comment.