From 62e154755a8b845d2cd208c5db1f1f61e52d0c91 Mon Sep 17 00:00:00 2001 From: fwbrasil Date: Thu, 28 Sep 2017 00:25:53 -0700 Subject: [PATCH] fix nested joins --- .../norm/FlattenOptionOperation.scala | 8 +-- .../io/getquill/norm/SymbolicReduction.scala | 9 +-- .../norm/FlattenOptionOperationSpec.scala | 15 ++++ .../getquill/norm/SymbolicReductionSpec.scala | 19 +++++- .../io/getquill/context/sql/SqlQuery.scala | 9 ++- .../context/sql/norm/ExpandJoin.scala | 68 ++++++++++++------- .../sql/norm/ExpandNestedQueries.scala | 2 +- .../getquill/context/sql/SqlQuerySpec.scala | 19 +++++- .../sql/idiom/VerifySqlQuerySpec.scala | 4 +- .../context/sql/norm/ExpandJoinSpec.scala | 8 +-- .../sql/norm/RenamePropertiesSpec.scala | 4 +- 11 files changed, 119 insertions(+), 46 deletions(-) diff --git a/quill-core/src/main/scala/io/getquill/norm/FlattenOptionOperation.scala b/quill-core/src/main/scala/io/getquill/norm/FlattenOptionOperation.scala index 3e27fd181d..d4baa3e333 100644 --- a/quill-core/src/main/scala/io/getquill/norm/FlattenOptionOperation.scala +++ b/quill-core/src/main/scala/io/getquill/norm/FlattenOptionOperation.scala @@ -9,13 +9,13 @@ object FlattenOptionOperation extends StatelessTransformer { case OptionMap(ast, alias, body) => apply(BetaReduction(body, alias -> ast)) case OptionForall(ast, alias, body) => - val isEmpty = apply(BinaryOperation(ast, EqualityOperator.`==`, NullValue): Ast) - val exists = apply(BetaReduction(body, alias -> ast)) - BinaryOperation(isEmpty, BooleanOperator.`||`, exists) + val isEmpty = BinaryOperation(ast, EqualityOperator.`==`, NullValue) + val exists = BetaReduction(body, alias -> ast) + apply(BinaryOperation(isEmpty, BooleanOperator.`||`, exists): Ast) case OptionExists(ast, alias, body) => apply(BetaReduction(body, alias -> ast)) case OptionContains(ast, body) => - BinaryOperation(ast, EqualityOperator.`==`, body) + apply(BinaryOperation(ast, EqualityOperator.`==`, body): Ast) case other => super.apply(other) } diff --git a/quill-core/src/main/scala/io/getquill/norm/SymbolicReduction.scala b/quill-core/src/main/scala/io/getquill/norm/SymbolicReduction.scala index 46b22e46b4..d5d739f287 100644 --- a/quill-core/src/main/scala/io/getquill/norm/SymbolicReduction.scala +++ b/quill-core/src/main/scala/io/getquill/norm/SymbolicReduction.scala @@ -8,6 +8,7 @@ import io.getquill.ast.UnionAll import io.getquill.ast.Join import io.getquill.ast.Ident import io.getquill.ast.Property +import io.getquill.ast.InnerJoin object SymbolicReduction { @@ -38,19 +39,19 @@ object SymbolicReduction { // a.filter(b => c).join(d).on((e, f) => g) => // a.join(d).on((e, f) => g).filter(x => c[b := x._1]) - case Join(tpe, Filter(a, b, c), d, e, f, g) => + case Join(InnerJoin, Filter(a, b, c), d, e, f, g) => val x = Ident("x") val x1 = Property(x, "_1") val cr = BetaReduction(c, b -> x1) - Some(Filter(Join(tpe, a, d, e, f, g), x, cr)) + Some(Filter(Join(InnerJoin, a, d, e, f, g), x, cr)) // a.join(b.filter(c => d)).on((e, f) => g) => // a.join(b).on((e, f) => g).filter(x => d[c := x._2]) - case Join(tpe, a, Filter(b, c, d), e, f, g) => + case Join(InnerJoin, a, Filter(b, c, d), e, f, g) => val x = Ident("x") val x2 = Property(x, "_2") val dr = BetaReduction(d, c -> x2) - Some(Filter(Join(tpe, a, b, e, f, g), x, dr)) + Some(Filter(Join(InnerJoin, a, b, e, f, g), x, dr)) case other => None } diff --git a/quill-core/src/test/scala/io/getquill/norm/FlattenOptionOperationSpec.scala b/quill-core/src/test/scala/io/getquill/norm/FlattenOptionOperationSpec.scala index 84430f01f1..64a8a98f31 100644 --- a/quill-core/src/test/scala/io/getquill/norm/FlattenOptionOperationSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/FlattenOptionOperationSpec.scala @@ -37,6 +37,21 @@ class FlattenOptionOperationSpec extends Spec { BinaryOperation(Property(Ident("o"), "i"), EqualityOperator.`!=`, Constant(1)) ) } + "map + forall + binop" in { + val q = quote { + (o: Option[TestEntity]) => o.map(_.i).forall(i => i != 1) && true + } + FlattenOptionOperation(q.ast.body: Ast) mustEqual + BinaryOperation( + BinaryOperation( + BinaryOperation(Property(Ident("o"), "i"), EqualityOperator.`==`, NullValue), + BooleanOperator.`||`, + BinaryOperation(Property(Ident("o"), "i"), EqualityOperator.`!=`, Constant(1)) + ), + BooleanOperator.`&&`, + Constant(true) + ) + } "exists" in { val q = quote { (o: Option[Int]) => o.exists(i => i > 1) diff --git a/quill-core/src/test/scala/io/getquill/norm/SymbolicReductionSpec.scala b/quill-core/src/test/scala/io/getquill/norm/SymbolicReductionSpec.scala index 287577316f..9545da94b2 100644 --- a/quill-core/src/test/scala/io/getquill/norm/SymbolicReductionSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/SymbolicReductionSpec.scala @@ -60,7 +60,7 @@ class SymbolicReductionSpec extends Spec { SymbolicReduction.unapply(q.ast) mustEqual Some(n.ast) } - "a.filter(b => c).join(d).on((e, f) => g) => a.join(d).on((e, f) => g).filter(x => c[b := x._1])" in { + "a.filter(b => c).innerJoin(d).on((e, f) => g) => a.innerJoin(d).on((e, f) => g).filter(x => c[b := x._1])" in { val q = quote { qr1.filter(a => a.i == 1).join(qr2).on((a, b) => a.i == b.i) } @@ -70,7 +70,7 @@ class SymbolicReductionSpec extends Spec { SymbolicReduction.unapply(q.ast) mustEqual Some(n.ast) } - "a.join(b.filter(c => d)).on((e, f) => g) => a.join(b).on((e, f) => g).filter(x => d[c := x._2])" in { + "a.innerJoin(b.filter(c => d)).on((e, f) => g) => a.innerJoin(b).on((e, f) => g).filter(x => d[c := x._2])" in { val q = quote { qr1.join(qr2.filter(b => b.i == 1)).on((a, b) => a.i == b.i) } @@ -79,4 +79,19 @@ class SymbolicReductionSpec extends Spec { } SymbolicReduction.unapply(q.ast) mustEqual Some(n.ast) } + + "doesn't reduce non-inner-joins since they aren't commutative" - { + "a.filter.*join(b)" in { + val q = quote { + qr1.filter(a => a.i == 1).leftJoin(qr2).on((a, b) => a.i == b.i) + } + SymbolicReduction.unapply(q.ast) mustEqual None + } + "a.*join(b.filter)" in { + val q = quote { + qr1.rightJoin(qr2.filter(b => b.i == 1)).on((a, b) => a.i == b.i) + } + SymbolicReduction.unapply(q.ast) mustEqual None + } + } } diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala b/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala index 0ae870a5eb..5302f2215d 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala @@ -104,8 +104,13 @@ object SqlQuery { def base(q: Ast, alias: String) = { def nest(ctx: FromContext) = FlattenSqlQuery(from = sources :+ ctx, select = select(alias)) q match { - case Map(_: GroupBy, _, _) => nest(source(q, alias)) - case Nested(q) => nest(QueryContext(apply(q), alias)) + case Map(_: GroupBy, _, _) => nest(source(q, alias)) + case Nested(q) => nest(QueryContext(apply(q), alias)) + case Join(tpe, a, b, iA, iB, on) => + FlattenSqlQuery( + from = source(q, alias) :: Nil, + select = SelectValue(iA, None) :: SelectValue(iB, None) :: Nil + ) case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias) case q if (sources == Nil) => flatten(sources, q, alias) case other => nest(source(q, alias)) diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandJoin.scala b/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandJoin.scala index b389456e19..09408ea491 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandJoin.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandJoin.scala @@ -1,5 +1,7 @@ package io.getquill.context.sql.norm +import io.getquill.ast.Ast +import io.getquill.ast.Filter import io.getquill.ast.Ident import io.getquill.ast.Join import io.getquill.ast.Map @@ -12,32 +14,50 @@ object ExpandJoin extends StatelessTransformer { override def apply(q: Query) = q match { - case q @ Join(_, _, _, Ident(a), Ident(b), _) => - val (qr, tuple) = expandedTuple(q) - Map(qr, Ident(s"$a$b"), tuple) + case Filter(Expand(ar, at), b, c) => + val id = ident(at) + val cr = BetaReduction(c, b -> at) + Map(Filter(ar, id, cr), id, at) + case Expand(qr, map) => + Map(qr, ident(map), map) case other => super.apply(other) } - private def expandedTuple(q: Join): (Join, Tuple) = - q match { - - case Join(t, a: Join, b: Join, tA, tB, o) => - val (ar, at) = expandedTuple(a) - val (br, bt) = expandedTuple(b) - val or = BetaReduction(o, tA -> at, tB -> bt) - (Join(t, ar, br, tA, tB, or), Tuple(List(at, bt))) - - case Join(t, a: Join, b, tA, tB, o) => - val (ar, at) = expandedTuple(a) - val or = BetaReduction(o, tA -> at) - (Join(t, ar, b, tA, tB, or), Tuple(List(at, tB))) - - case Join(t, a, b: Join, tA, tB, o) => - val (br, bt) = expandedTuple(b) - val or = BetaReduction(o, tB -> bt) - (Join(t, a, br, tA, tB, or), Tuple(List(tA, bt))) - - case q @ Join(t, a, b, tA, tB, on) => - (q, Tuple(List(tA, tB))) + object Expand { + def unapply(q: Ast): Option[(Ast, Ast)] = + q match { + case Join(t, Expand(ar, at), Expand(br, bt), tA, tB, o) => + val or = BetaReduction(o, tA -> at, tB -> bt) + Some((Join(t, ar, br, tA, tB, or), Tuple(List(at, bt)))) + + case Join(t, Expand(ar, at), b, tA, tB, o) => + val or = BetaReduction(o, tA -> at) + Some((Join(t, ar, b, tA, tB, or), Tuple(List(at, tB)))) + + case Join(t, a, Expand(br, bt), tA, tB, o) => + val or = BetaReduction(o, tB -> bt) + Some((Join(t, a, br, tA, tB, or), Tuple(List(tA, bt)))) + + case q @ Join(t, a, b, tA, tB, on) => + Some((q, Tuple(List(tA, tB)))) + + case Filter(Expand(ar, at), b, c) => + val id = ident(at) + val cr = BetaReduction(c, b -> at) + Some((Filter(ar, id, cr), id)) + + case _ => None + } + } + + private def ident(ast: Ast): Ident = + ast match { + case Tuple(values) => + values.map(ident).foldLeft(Ident("")) { + case (Ident(a), Ident(b)) => + Ident(s"$a$b") + } + case i: Ident => i + case other => Ident(other.toString) } } diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala b/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala index 086d11a31d..a7cfc4b93d 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala @@ -93,7 +93,7 @@ object ExpandNestedQueries { } } - references.toList match { + references.toList.sortBy(_.ast.toString).toList match { case Nil => select case refs => refs.map(expandReference) } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala index 8dd8bef1df..63dc1c172c 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala @@ -43,6 +43,23 @@ class SqlQuerySpec extends Spec { "SELECT a.i, b.i FROM TestEntity a LEFT JOIN TestEntity2 b ON a.i = b.i WHERE (b.i IS NULL) OR (b.i = 1)" } + "nested join" in { + val q = quote { + qr1.leftJoin(qr2).on { + case (a, b) => + a.i == b.i + }.filter { + case (a, b) => + b.map(_.l).contains(3L) + }.leftJoin(qr3).on { + case ((a, b), c) => + b.map(_.i).contains(a.i) && b.map(_.i).contains(c.i) + } + } + testContext.run(q).string mustEqual + "SELECT x01x11.s, x01x11.i, x01x11.l, x01x11.o, x01x11.s, x01x11.i, x01x11.l, x01x11.o, x12.s, x12.i, x12.l, x12.o FROM (SELECT x01.s s, x01.i i, x01.o o, x01.l l, x11.s s, x11.i i, x11.l l, x11.o o FROM TestEntity x01 LEFT JOIN TestEntity2 x11 ON x01.i = x11.i WHERE x11.l = 3) x01x11 LEFT JOIN TestEntity3 x12 ON (x01x11.i = x01x11.i) AND (x01x11.i = x12.i)" + } + "flat outer join" in { val q = quote { for { @@ -269,7 +286,7 @@ class SqlQuerySpec extends Spec { } } testContext.run(q).string mustEqual - "SELECT t.i, SUM(t.i) FROM (SELECT b.i i, a.i i FROM TestEntity a INNER JOIN TestEntity2 b ON a.s = b.s) t GROUP BY t.i" + "SELECT t.i, SUM(t.i) FROM (SELECT a.i i, b.i i FROM TestEntity a INNER JOIN TestEntity2 b ON a.s = b.s) t GROUP BY t.i" } } "invalid groupby criteria" in { diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/VerifySqlQuerySpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/VerifySqlQuerySpec.scala index 23c288f4aa..d1a857405c 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/VerifySqlQuerySpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/VerifySqlQuerySpec.scala @@ -30,8 +30,8 @@ class VerifySqlQuerySpec extends Spec { case (a, b) => b.isDefined } } - VerifySqlQuery(SqlQuery(q.ast)).toString mustEqual - "Some(The monad composition can't be expressed using applicative joins. Faulty expression: 'x01._2.isDefined'. Free variables: 'List(x01)'., Faulty expression: 'x01'. Free variables: 'List(x01)'.)" + VerifySqlQuery(SqlQuery(q.ast)).toString mustEqual + "Some(The monad composition can't be expressed using applicative joins. Faulty expression: 'x01._2.isDefined'. Free variables: 'List(x01)'.)" } "invalid flatJoin on" in { diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandJoinSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandJoinSpec.scala index e5686905ad..312597c26c 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandJoinSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandJoinSpec.scala @@ -23,28 +23,28 @@ class ExpandJoinSpec extends Spec { qr1.join(qr2).on((a, b) => a.s == b.s).join(qr3).on((c, d) => c._1.s == d.s) } ExpandJoin(q.ast).toString mustEqual - """querySchema("TestEntity").join(querySchema("TestEntity2")).on((a, b) => a.s == b.s).join(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(cd => ((a, b), d))""" + """querySchema("TestEntity").join(querySchema("TestEntity2")).on((a, b) => a.s == b.s).join(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(abd => ((a, b), d))""" } "left" in { val q = quote { qr1.leftJoin(qr2).on((a, b) => a.s == b.s).leftJoin(qr3).on((c, d) => c._1.s == d.s) } ExpandJoin(q.ast).toString mustEqual - """querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(cd => ((a, b), d))""" + """querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3")).on((c, d) => a.s == d.s).map(abd => ((a, b), d))""" } "right" in { val q = quote { qr1.leftJoin(qr2.leftJoin(qr3).on((a, b) => a.s == b.s)).on((c, d) => c.s == d._1.s) } ExpandJoin(q.ast).toString mustEqual - """querySchema("TestEntity").leftJoin(querySchema("TestEntity2").leftJoin(querySchema("TestEntity3")).on((a, b) => a.s == b.s)).on((c, d) => c.s == a.s).map(cd => (c, (a, b)))""" + """querySchema("TestEntity").leftJoin(querySchema("TestEntity2").leftJoin(querySchema("TestEntity3")).on((a, b) => a.s == b.s)).on((c, d) => c.s == a.s).map(cab => (c, (a, b)))""" } "both" in { val q = quote { qr1.leftJoin(qr2).on((a, b) => a.s == b.s).leftJoin(qr3.leftJoin(qr2).on((c, d) => c.s == d.s)).on((e, f) => e._1.s == f._1.s) } ExpandJoin(q.ast).toString mustEqual - """querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3").leftJoin(querySchema("TestEntity2")).on((c, d) => c.s == d.s)).on((e, f) => a.s == c.s).map(ef => ((a, b), (c, d)))""" + """querySchema("TestEntity").leftJoin(querySchema("TestEntity2")).on((a, b) => a.s == b.s).leftJoin(querySchema("TestEntity3").leftJoin(querySchema("TestEntity2")).on((c, d) => c.s == d.s)).on((e, f) => a.s == c.s).map(abcd => ((a, b), (c, d)))""" } } } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala index 5d7b19e182..44b9f67464 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala @@ -199,14 +199,14 @@ class RenamePropertiesSpec extends Spec { e.leftJoin(f).on((a, b) => a.s == b.s).map(t => t._1.s) } testContext.run(q).string mustEqual - "SELECT a.field_s FROM test_entity a LEFT JOIN TestEntity t ON a.field_s = t.s WHERE t.i = 1" + "SELECT a.field_s FROM test_entity a LEFT JOIN (SELECT t.s FROM TestEntity t WHERE t.i = 1) t ON a.field_s = t.s" } "right" in { val q = quote { f.rightJoin(e).on((a, b) => a.s == b.s).map(t => t._2.s) } testContext.run(q).string mustEqual - "SELECT b.field_s FROM TestEntity t RIGHT JOIN test_entity b ON t.s = b.field_s WHERE t.i = 1" + "SELECT b.field_s FROM (SELECT t.s FROM TestEntity t WHERE t.i = 1) t RIGHT JOIN test_entity b ON t.s = b.field_s" } "flat inner" in { val q = quote {