Skip to content

Commit

Permalink
fix property renaming for nested and infix queries
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Sep 9, 2017
1 parent ec83092 commit d35c195
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 24 deletions.
65 changes: 43 additions & 22 deletions quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,42 @@ object RenameProperties extends StatelessTransformer {
case (q, schema) => q
}

private def applySchema(q: Ast): (Ast, Ast) =
q match {
case q: Action => applySchema(q)
case q: Query => applySchema(q)
case q =>
CollectAst.byType[Entity](q) match {
case schema :: Nil => (q, schema)
case _ => (q, Tuple(List.empty))
}
}

private def applySchema(q: Action): (Action, Ast) =
q match {
case Insert(q: Query, assignments) => applySchema(q, assignments, Insert)
case Update(q: Query, assignments) => applySchema(q, assignments, Update)
case Delete(q: Query) =>
case Insert(q, assignments) => applySchema(q, assignments, Insert)
case Update(q, assignments) => applySchema(q, assignments, Update)
case Delete(q) =>
applySchema(q) match {
case (q, schema) => (Delete(q), schema)
}
case Returning(action: Action, alias, body) =>
case Returning(action, alias, body) =>
applySchema(action) match {
case (action, schema) =>
val replace = replacements(alias, schema)
val bodyr = BetaReduction(body, replace: _*)
(Returning(action, alias, bodyr), schema)
}
case q => (q, Tuple(List.empty))
case Foreach(q, alias, body) =>
applySchema(q) match {
case (q, schema) =>
val replace = replacements(alias, schema)
val bodyr = BetaReduction(body, replace: _*)
(Foreach(q, alias, bodyr), schema)
}
}

private def applySchema(q: Query, a: List[Assignment], f: (Query, List[Assignment]) => Action): (Action, Ast) =
private def applySchema(q: Ast, a: List[Assignment], f: (Ast, List[Assignment]) => Action): (Action, Ast) =
applySchema(q) match {
case (q, schema) =>
val ar = a.map {
Expand All @@ -47,17 +64,17 @@ object RenameProperties extends StatelessTransformer {

private def applySchema(q: Query): (Query, Ast) =
q match {
case e: Entity => (e, e)
case Map(q: Query, x, p) => applySchema(q, x, p, Map)
case Filter(q: Query, x, p) => applySchema(q, x, p, Filter)
case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy)
case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _))
case Take(q: Query, n) => applySchema(q, Take(_, n))
case Drop(q: Query, n) => applySchema(q, Drop(_, n))
case Distinct(q: Query) => applySchema(q, Distinct)

case FlatMap(q: Query, x, p) =>
case e: Entity => (e, e)
case Map(q, x, p) => applySchema(q, x, p, Map)
case Filter(q, x, p) => applySchema(q, x, p, Filter)
case SortBy(q, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
case GroupBy(q, x, p) => applySchema(q, x, p, GroupBy)
case Aggregation(op, q) => applySchema(q, Aggregation(op, _))
case Take(q, n) => applySchema(q, Take(_, n))
case Drop(q, n) => applySchema(q, Drop(_, n))
case Distinct(q) => applySchema(q, Distinct)

case FlatMap(q, x, p) =>
applySchema(q, x, p, FlatMap) match {
case (FlatMap(q, x, p: Query), oldSchema) =>
val (pr, newSchema) = applySchema(p)
Expand All @@ -66,7 +83,7 @@ object RenameProperties extends StatelessTransformer {
(flatMap, Tuple(List.empty))
}

case Join(typ, a: Query, b: Query, iA, iB, on) =>
case Join(typ, a, b, iA, iB, on) =>
(applySchema(a), applySchema(b)) match {
case ((a, schemaA), (b, schemaB)) =>
val replaceA = replacements(iA, schemaA)
Expand All @@ -75,25 +92,29 @@ object RenameProperties extends StatelessTransformer {
(Join(typ, a, b, iA, iB, onr), Tuple(List(schemaA, schemaB)))
}

case FlatJoin(typ, a: Query, iA, on) =>
case FlatJoin(typ, a, iA, on) =>
applySchema(a) match {
case (a, schemaA) =>
val replaceA = replacements(iA, schemaA)
val onr = BetaReduction(on, replaceA: _*)
(FlatJoin(typ, a, iA, onr), schemaA)
}

case q =>
case Nested(q) =>
val (qr, schema) = applySchema(q)
(Nested(qr), schema)

case _: Union | _: UnionAll =>
(q, Tuple(List.empty))
}

private def applySchema(ast: Query, f: Ast => Query): (Query, Ast) =
private def applySchema(ast: Ast, f: Ast => Query): (Query, Ast) =
applySchema(ast) match {
case (ast, schema) =>
(f(ast), schema)
}

private def applySchema[T](q: Query, x: Ident, p: Ast, f: (Ast, Ident, Ast) => T): (T, Ast) =
private def applySchema[T](q: Ast, x: Ident, p: Ast, f: (Ast, Ident, Ast) => T): (T, Ast) =
applySchema(q) match {
case (q, schema) =>
val replace = replacements(x, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RenamePropertiesSpec extends Spec {
val q = quote {
e.insert(lift(TestEntity("s", 1, 1L, None))).returning(_.i)
}
val mirror = testContext.run(q.dynamic)
val mirror = testContext.run(q)
mirror.returningColumn mustEqual "field_i"
}
}
Expand All @@ -66,7 +66,7 @@ class RenamePropertiesSpec extends Spec {
val q = quote {
e.flatMap(t => qr2.map(u => t)).map(t => t.s)
}
testContext.run(q.dynamic).string mustEqual
testContext.run(q).string mustEqual
"SELECT t.field_s FROM test_entity t, TestEntity2 u"
}
"with filter" in {
Expand Down Expand Up @@ -239,6 +239,40 @@ class RenamePropertiesSpec extends Spec {
"SELECT b.field_i, b.field_s FROM TestEntity2 a RIGHT JOIN test_entity b ON a.s = b.field_s"
}
}

"nested" - {
"body" in {
val q = quote {
e.nested
}
testContext.run(q).string mustEqual
"SELECT x.field_s, x.field_i, x.l, x.o FROM (SELECT x.field_s, x.field_i, x.l, x.o FROM test_entity x) x"
}
"transitive" in {
val q = quote {
e.nested.map(t => t.s)
}
testContext.run(q).string mustEqual
"SELECT t.field_s FROM (SELECT x.field_s FROM test_entity x) t"
}
}

"infix" - {
"body" in {
val q = quote {
infix"$e".as[Query[TestEntity]]
}
testContext.run(q).string mustEqual
"SELECT x.field_s, x.field_i, x.l, x.o FROM (SELECT x.* FROM test_entity x) x"
}
"transitive" in {
val q = quote {
infix"$e".as[Query[TestEntity]].map(t => t.s)
}
testContext.run(q).string mustEqual
"SELECT t.field_s FROM (SELECT x.* FROM test_entity x) t"
}
}
}

"respects the schema definition for embeddeds" - {
Expand Down

0 comments on commit d35c195

Please sign in to comment.