Skip to content

Commit

Permalink
fix nested joins
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Oct 1, 2017
1 parent 1a52c2b commit 62e1547
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ object FlattenOptionOperation extends StatelessTransformer {
case OptionMap(ast, alias, body) =>
apply(BetaReduction(body, alias -> ast))
case OptionForall(ast, alias, body) =>
val isEmpty = apply(BinaryOperation(ast, EqualityOperator.`==`, NullValue): Ast)
val exists = apply(BetaReduction(body, alias -> ast))
BinaryOperation(isEmpty, BooleanOperator.`||`, exists)
val isEmpty = BinaryOperation(ast, EqualityOperator.`==`, NullValue)
val exists = BetaReduction(body, alias -> ast)
apply(BinaryOperation(isEmpty, BooleanOperator.`||`, exists): Ast)
case OptionExists(ast, alias, body) =>
apply(BetaReduction(body, alias -> ast))
case OptionContains(ast, body) =>
BinaryOperation(ast, EqualityOperator.`==`, body)
apply(BinaryOperation(ast, EqualityOperator.`==`, body): Ast)
case other =>
super.apply(other)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.getquill.ast.UnionAll
import io.getquill.ast.Join
import io.getquill.ast.Ident
import io.getquill.ast.Property
import io.getquill.ast.InnerJoin

object SymbolicReduction {

Expand Down Expand Up @@ -38,19 +39,19 @@ object SymbolicReduction {

// a.filter(b => c).join(d).on((e, f) => g) =>
// a.join(d).on((e, f) => g).filter(x => c[b := x._1])
case Join(tpe, Filter(a, b, c), d, e, f, g) =>
case Join(InnerJoin, Filter(a, b, c), d, e, f, g) =>
val x = Ident("x")
val x1 = Property(x, "_1")
val cr = BetaReduction(c, b -> x1)
Some(Filter(Join(tpe, a, d, e, f, g), x, cr))
Some(Filter(Join(InnerJoin, a, d, e, f, g), x, cr))

// a.join(b.filter(c => d)).on((e, f) => g) =>
// a.join(b).on((e, f) => g).filter(x => d[c := x._2])
case Join(tpe, a, Filter(b, c, d), e, f, g) =>
case Join(InnerJoin, a, Filter(b, c, d), e, f, g) =>
val x = Ident("x")
val x2 = Property(x, "_2")
val dr = BetaReduction(d, c -> x2)
Some(Filter(Join(tpe, a, b, e, f, g), x, dr))
Some(Filter(Join(InnerJoin, a, b, e, f, g), x, dr))

case other => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ class FlattenOptionOperationSpec extends Spec {
BinaryOperation(Property(Ident("o"), "i"), EqualityOperator.`!=`, Constant(1))
)
}
"map + forall + binop" in {
val q = quote {
(o: Option[TestEntity]) => o.map(_.i).forall(i => i != 1) && true
}
FlattenOptionOperation(q.ast.body: Ast) mustEqual
BinaryOperation(
BinaryOperation(
BinaryOperation(Property(Ident("o"), "i"), EqualityOperator.`==`, NullValue),
BooleanOperator.`||`,
BinaryOperation(Property(Ident("o"), "i"), EqualityOperator.`!=`, Constant(1))
),
BooleanOperator.`&&`,
Constant(true)
)
}
"exists" in {
val q = quote {
(o: Option[Int]) => o.exists(i => i > 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class SymbolicReductionSpec extends Spec {
SymbolicReduction.unapply(q.ast) mustEqual Some(n.ast)
}

"a.filter(b => c).join(d).on((e, f) => g) => a.join(d).on((e, f) => g).filter(x => c[b := x._1])" in {
"a.filter(b => c).innerJoin(d).on((e, f) => g) => a.innerJoin(d).on((e, f) => g).filter(x => c[b := x._1])" in {
val q = quote {
qr1.filter(a => a.i == 1).join(qr2).on((a, b) => a.i == b.i)
}
Expand All @@ -70,7 +70,7 @@ class SymbolicReductionSpec extends Spec {
SymbolicReduction.unapply(q.ast) mustEqual Some(n.ast)
}

"a.join(b.filter(c => d)).on((e, f) => g) => a.join(b).on((e, f) => g).filter(x => d[c := x._2])" in {
"a.innerJoin(b.filter(c => d)).on((e, f) => g) => a.innerJoin(b).on((e, f) => g).filter(x => d[c := x._2])" in {
val q = quote {
qr1.join(qr2.filter(b => b.i == 1)).on((a, b) => a.i == b.i)
}
Expand All @@ -79,4 +79,19 @@ class SymbolicReductionSpec extends Spec {
}
SymbolicReduction.unapply(q.ast) mustEqual Some(n.ast)
}

"doesn't reduce non-inner-joins since they aren't commutative" - {
"a.filter.*join(b)" in {
val q = quote {
qr1.filter(a => a.i == 1).leftJoin(qr2).on((a, b) => a.i == b.i)
}
SymbolicReduction.unapply(q.ast) mustEqual None
}
"a.*join(b.filter)" in {
val q = quote {
qr1.rightJoin(qr2.filter(b => b.i == 1)).on((a, b) => a.i == b.i)
}
SymbolicReduction.unapply(q.ast) mustEqual None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ object SqlQuery {
def base(q: Ast, alias: String) = {
def nest(ctx: FromContext) = FlattenSqlQuery(from = sources :+ ctx, select = select(alias))
q match {
case Map(_: GroupBy, _, _) => nest(source(q, alias))
case Nested(q) => nest(QueryContext(apply(q), alias))
case Map(_: GroupBy, _, _) => nest(source(q, alias))
case Nested(q) => nest(QueryContext(apply(q), alias))
case Join(tpe, a, b, iA, iB, on) =>
FlattenSqlQuery(
from = source(q, alias) :: Nil,
select = SelectValue(iA, None) :: SelectValue(iB, None) :: Nil
)
case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias)
case q if (sources == Nil) => flatten(sources, q, alias)
case other => nest(source(q, alias))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.getquill.context.sql.norm

import io.getquill.ast.Ast
import io.getquill.ast.Filter
import io.getquill.ast.Ident
import io.getquill.ast.Join
import io.getquill.ast.Map
Expand All @@ -12,32 +14,50 @@ object ExpandJoin extends StatelessTransformer {

override def apply(q: Query) =
q match {
case q @ Join(_, _, _, Ident(a), Ident(b), _) =>
val (qr, tuple) = expandedTuple(q)
Map(qr, Ident(s"$a$b"), tuple)
case Filter(Expand(ar, at), b, c) =>
val id = ident(at)
val cr = BetaReduction(c, b -> at)
Map(Filter(ar, id, cr), id, at)
case Expand(qr, map) =>
Map(qr, ident(map), map)
case other => super.apply(other)
}

private def expandedTuple(q: Join): (Join, Tuple) =
q match {

case Join(t, a: Join, b: Join, tA, tB, o) =>
val (ar, at) = expandedTuple(a)
val (br, bt) = expandedTuple(b)
val or = BetaReduction(o, tA -> at, tB -> bt)
(Join(t, ar, br, tA, tB, or), Tuple(List(at, bt)))

case Join(t, a: Join, b, tA, tB, o) =>
val (ar, at) = expandedTuple(a)
val or = BetaReduction(o, tA -> at)
(Join(t, ar, b, tA, tB, or), Tuple(List(at, tB)))

case Join(t, a, b: Join, tA, tB, o) =>
val (br, bt) = expandedTuple(b)
val or = BetaReduction(o, tB -> bt)
(Join(t, a, br, tA, tB, or), Tuple(List(tA, bt)))

case q @ Join(t, a, b, tA, tB, on) =>
(q, Tuple(List(tA, tB)))
object Expand {
def unapply(q: Ast): Option[(Ast, Ast)] =
q match {
case Join(t, Expand(ar, at), Expand(br, bt), tA, tB, o) =>
val or = BetaReduction(o, tA -> at, tB -> bt)
Some((Join(t, ar, br, tA, tB, or), Tuple(List(at, bt))))

case Join(t, Expand(ar, at), b, tA, tB, o) =>
val or = BetaReduction(o, tA -> at)
Some((Join(t, ar, b, tA, tB, or), Tuple(List(at, tB))))

case Join(t, a, Expand(br, bt), tA, tB, o) =>
val or = BetaReduction(o, tB -> bt)
Some((Join(t, a, br, tA, tB, or), Tuple(List(tA, bt))))

case q @ Join(t, a, b, tA, tB, on) =>
Some((q, Tuple(List(tA, tB))))

case Filter(Expand(ar, at), b, c) =>
val id = ident(at)
val cr = BetaReduction(c, b -> at)
Some((Filter(ar, id, cr), id))

case _ => None
}
}

private def ident(ast: Ast): Ident =
ast match {
case Tuple(values) =>
values.map(ident).foldLeft(Ident("")) {
case (Ident(a), Ident(b)) =>
Ident(s"$a$b")
}
case i: Ident => i
case other => Ident(other.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object ExpandNestedQueries {
}
}

references.toList match {
references.toList.sortBy(_.ast.toString).toList match {
case Nil => select
case refs => refs.map(expandReference)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ class SqlQuerySpec extends Spec {
"SELECT a.i, b.i FROM TestEntity a LEFT JOIN TestEntity2 b ON a.i = b.i WHERE (b.i IS NULL) OR (b.i = 1)"
}

"nested join" in {
val q = quote {
qr1.leftJoin(qr2).on {
case (a, b) =>
a.i == b.i
}.filter {
case (a, b) =>
b.map(_.l).contains(3L)
}.leftJoin(qr3).on {
case ((a, b), c) =>
b.map(_.i).contains(a.i) && b.map(_.i).contains(c.i)
}
}
testContext.run(q).string mustEqual
"SELECT x01x11.s, x01x11.i, x01x11.l, x01x11.o, x01x11.s, x01x11.i, x01x11.l, x01x11.o, x12.s, x12.i, x12.l, x12.o FROM (SELECT x01.s s, x01.i i, x01.o o, x01.l l, x11.s s, x11.i i, x11.l l, x11.o o FROM TestEntity x01 LEFT JOIN TestEntity2 x11 ON x01.i = x11.i WHERE x11.l = 3) x01x11 LEFT JOIN TestEntity3 x12 ON (x01x11.i = x01x11.i) AND (x01x11.i = x12.i)"
}

"flat outer join" in {
val q = quote {
for {
Expand Down Expand Up @@ -269,7 +286,7 @@ class SqlQuerySpec extends Spec {
}
}
testContext.run(q).string mustEqual
"SELECT t.i, SUM(t.i) FROM (SELECT b.i i, a.i i FROM TestEntity a INNER JOIN TestEntity2 b ON a.s = b.s) t GROUP BY t.i"
"SELECT t.i, SUM(t.i) FROM (SELECT a.i i, b.i i FROM TestEntity a INNER JOIN TestEntity2 b ON a.s = b.s) t GROUP BY t.i"
}
}
"invalid groupby criteria" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class VerifySqlQuerySpec extends Spec {
case (a, b) => b.isDefined
}
}
VerifySqlQuery(SqlQuery(q.ast)).toString mustEqual
"Some(The monad composition can't be expressed using applicative joins. Faulty expression: 'x01._2.isDefined'. Free variables: 'List(x01)'., Faulty expression: 'x01'. Free variables: 'List(x01)'.)"
VerifySqlQuery(SqlQuery(q.ast)).toString mustEqual
"Some(The monad composition can't be expressed using applicative joins. Faulty expression: 'x01._2.isDefined'. Free variables: 'List(x01)'.)"
}

"invalid flatJoin on" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,28 @@ class ExpandJoinSpec extends Spec {
qr1.join(qr2).on((a, b) => a.s == b.s).join(qr3).on((c, d) => c._1.s == d.s)
}
ExpandJoin(q.ast).toString mustEqual
"""querySchema("TestEntity").join(querySchema("TestEntity2")).on((a, b) => a.s == b.s).join(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(cd => ((a, b), d))"""
"""querySchema("TestEntity").join(querySchema("TestEntity2")).on((a, b) => a.s == b.s).join(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(abd => ((a, b), d))"""
}
"left" in {
val q = quote {
qr1.leftJoin(qr2).on((a, b) => a.s == b.s).leftJoin(qr3).on((c, d) => c._1.s == d.s)
}
ExpandJoin(q.ast).toString mustEqual
"""querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(cd => ((a, b), d))"""
"""querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(abd => ((a, b), d))"""
}
"right" in {
val q = quote {
qr1.leftJoin(qr2.leftJoin(qr3).on((a, b) => a.s == b.s)).on((c, d) => c.s == d._1.s)
}
ExpandJoin(q.ast).toString mustEqual
"""querySchema("TestEntity").leftJoin(querySchema("TestEntity2").leftJoin(querySchema("TestEntity3")).on((a, b) => a.s == b.s)).on((c, d) => c.s == a.s).map(cd => (c, (a, b)))"""
"""querySchema("TestEntity").leftJoin(querySchema("TestEntity2").leftJoin(querySchema("TestEntity3")).on((a, b) => a.s == b.s)).on((c, d) => c.s == a.s).map(cab => (c, (a, b)))"""
}
"both" in {
val q = quote {
qr1.leftJoin(qr2).on((a, b) => a.s == b.s).leftJoin(qr3.leftJoin(qr2).on((c, d) => c.s == d.s)).on((e, f) => e._1.s == f._1.s)
}
ExpandJoin(q.ast).toString mustEqual
"""querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3").leftJoin(querySchema("TestEntity2")).on((c, d) => c.s == d.s)).on((e, f) => a.s == c.s).map(ef => ((a, b), (c, d)))"""
"""querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3").leftJoin(querySchema("TestEntity2")).on((c, d) => c.s == d.s)).on((e, f) => a.s == c.s).map(abcd => ((a, b), (c, d)))"""
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,14 @@ class RenamePropertiesSpec extends Spec {
e.leftJoin(f).on((a, b) => a.s == b.s).map(t => t._1.s)
}
testContext.run(q).string mustEqual
"SELECT a.field_s FROM test_entity a LEFT JOIN TestEntity t ON a.field_s = t.s WHERE t.i = 1"
"SELECT a.field_s FROM test_entity a LEFT JOIN (SELECT t.s FROM TestEntity t WHERE t.i = 1) t ON a.field_s = t.s"
}
"right" in {
val q = quote {
f.rightJoin(e).on((a, b) => a.s == b.s).map(t => t._2.s)
}
testContext.run(q).string mustEqual
"SELECT b.field_s FROM TestEntity t RIGHT JOIN test_entity b ON t.s = b.field_s WHERE t.i = 1"
"SELECT b.field_s FROM (SELECT t.s FROM TestEntity t WHERE t.i = 1) t RIGHT JOIN test_entity b ON t.s = b.field_s"
}
"flat inner" in {
val q = quote {
Expand Down

0 comments on commit 62e1547

Please sign in to comment.