Skip to content

Commit

Permalink
feat: conditions not supporting composite in
Browse files Browse the repository at this point in the history
  • Loading branch information
kmpm committed Jul 29, 2022
1 parent 92b81c3 commit e5d78d4
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 2 deletions.
1 change: 1 addition & 0 deletions dialect/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ const (
UpdateFromTable
MSSavepoint
GeneratedIdentity
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
)
3 changes: 2 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func New() *Dialect {
feature.TableNotExists |
feature.InsertOnConflict |
feature.SelectExists |
feature.GeneratedIdentity
feature.GeneratedIdentity |
feature.CompositeIn
return d
}

Expand Down
3 changes: 2 additions & 1 deletion dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func New() *Dialect {
feature.DeleteTableAlias |
feature.InsertOnConflict |
feature.TableNotExists |
feature.SelectExists
feature.SelectExists |
feature.CompositeIn
return d
}

Expand Down
93 changes: 93 additions & 0 deletions relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"reflect"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
Expand Down Expand Up @@ -60,6 +61,14 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
q = q.Model(hasManyModel)

var where []byte

if q.db.dialect.Features().Has(feature.CompositeIn) {
return j.manyQueryCompositeIn(where, q)
}
return j.manyQueryMulti(where, q)
}

func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery {
if len(j.Relation.JoinFields) > 1 {
where = append(where, '(')
}
Expand Down Expand Up @@ -88,6 +97,29 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
return q
}

func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery {
where = appendMultiValues(
q.db.Formatter(),
where,
j.JoinModel.rootValue(),
j.JoinModel.parentIndex(),
j.Relation.BaseFields,
j.Relation.JoinFields,
j.JoinModel.Table().SQLAlias,
)

q = q.Where(internal.String(where))

if j.Relation.PolymorphicField != nil {
q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
}

j.applyTo(q)
q = q.Apply(j.hasManyColumns)

return q
}

func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery {
b := make([]byte, 0, 32)

Expand Down Expand Up @@ -312,3 +344,64 @@ func appendChildValues(
}
return b
}

func getColumns(table schema.Safe, fields []*schema.Field) [][]byte {
//Based upon query_base.appendColumns
var list [][]byte
for _, f := range fields {
b := []byte{}

if len(table) > 0 {
b = append(b, table...)
b = append(b, '.')
}
b = append(b, f.SQLName...)
list = append(list, b)
}
return list
}

func appendMultiValues(
fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, table schema.Safe,
) []byte {
// This is a mix of appendChildValues and query_base.appendColumns
if len(joinFields) != len(baseFields) {
panic("asdfasdf")
}

// First get the columns
joins := getColumns(table, joinFields)
// Then values
b = append(b, '(')
seen := make(map[string]struct{})
walk(v, index, func(v reflect.Value) {
start := len(b)
for i, f := range baseFields {
if i > 0 {
b = append(b, " AND "...)
}
if len(baseFields) > 1 {
b = append(b, '(')
}
b = append(b, joins[i]...)
b = append(b, '=')
b = f.AppendValue(fmter, b, v)
if len(baseFields) > 1 {
b = append(b, ')')
}
}

b = append(b, ") OR ("...)

if _, ok := seen[string(b[start:])]; ok {
b = b[:start]
} else {
seen[string(b[start:])] = struct{}{}
}
})
if len(seen) > 0 {
b = b[:len(b)-6] // trim ") OR ("
}
b = append(b, ')')
return b
}

0 comments on commit e5d78d4

Please sign in to comment.