Skip to content

Commit

Permalink
Merge pull request #1412 from deusaquilus2/fix_v
Browse files Browse the repository at this point in the history
Fix variable shadowing issue in action metas
  • Loading branch information
deusaquilus authored Apr 22, 2019
2 parents 1a4c214 + 8d355a8 commit 2c3555f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ quill-sql/io/
MyTest.scala
MySparkTest.scala
MyTestJdbc.scala
MyTestSql.scala
MySqlTest.scala
quill-core/src/main/resources/logback.xml
quill-jdbc/src/main/resources/logback.xml
log.txt*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
package io.getquill.norm.capture

import io.getquill.ast._
import io.getquill.ast.Entity
import io.getquill.ast.Filter
import io.getquill.ast.FlatMap
import io.getquill.ast.Ident
import io.getquill.ast.Join
import io.getquill.ast.Map
import io.getquill.ast.Query
import io.getquill.ast.SortBy
import io.getquill.ast.StatefulTransformer
import io.getquill.ast.{ Entity, Filter, FlatJoin, FlatMap, GroupBy, Ident, Join, Map, Query, SortBy, StatefulTransformer, _ }
import io.getquill.norm.BetaReduction
import io.getquill.ast.FlatJoin
import io.getquill.ast.GroupBy

private case class AvoidAliasConflict(state: collection.Set[Ident])
private[getquill] case class AvoidAliasConflict(state: collection.Set[Ident])
extends StatefulTransformer[collection.Set[Ident]] {

object Unaliased {
Expand Down Expand Up @@ -99,12 +88,59 @@ private case class AvoidAliasConflict(state: collection.Set[Ident])
else
loop(x, 1)
}

/**
* Sometimes we need to change the variables in a function because they will might conflict with some variable
* further up in the macro. Right now, this only happens when you do something like this:
* <code>
* val q = quote { (v: Foo) => query[Foo].insert(v) }
* run(q(lift(v)))
* </code>
* Since 'v' is used by actionMeta in order to map keys to values for insertion, using it as a function argument
* messes up the output SQL like so:
* <code>
* INSERT INTO MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?)
* </code>
* Therefore, we need to have a method to remove such conflicting variables from Function ASTs
*/
private def applyFunction(f: Function): Function = {
val (newBody, _, newParams) =
f.params.foldLeft((f.body, state, List[Ident]())) {
case ((body, state, newParams), param) => {
val fresh = freshIdent(param)
val pr = BetaReduction(body, param -> fresh)
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
(prr, t.state, newParams :+ fresh)
}
}
Function(newParams, newBody)
}

private def applyForeach(f: Foreach): Foreach = {
val fresh = freshIdent(f.alias)
val pr = BetaReduction(f.body, f.alias -> fresh)
val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
Foreach(f.query, fresh, prr)
}
}

private[capture] object AvoidAliasConflict {
private[getquill] object AvoidAliasConflict {

def apply(q: Query): Query =
AvoidAliasConflict(collection.Set[Ident]())(q) match {
case (q, _) => q
}

/**
* Make sure query parameters do not collide with paramters of a AST function. Do this
* by walkning through the function's subtree and transforming and queries encountered.
*/
def sanitizeVariables(f: Function, dangerousVariables: Set[Ident]): Function = {
AvoidAliasConflict(dangerousVariables).applyFunction(f)
}

/** Same is `sanitizeVariables` but for Foreach **/
def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
AvoidAliasConflict(dangerousVariables).applyForeach(f)
}
}
20 changes: 17 additions & 3 deletions quill-core/src/main/scala/io/getquill/quotation/Parsing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.getquill.norm.BetaReduction
import io.getquill.util.Messages.RichContext
import io.getquill.util.Interleave
import io.getquill.dsl.CoreDsl
import io.getquill.norm.capture.AvoidAliasConflict

import scala.annotation.tailrec
import scala.collection.immutable.StringOps
Expand All @@ -17,6 +18,10 @@ trait Parsing {

import c.universe.{ Ident => _, Constant => _, Function => _, If => _, Block => _, _ }

// Variables that need to be sanitized out in various places due to internal conflicts with the way
// macros hard handeled in MetaDsl
private[getquill] val dangerousVariables = Set("v").map(Ident(_))

case class Parser[T](p: PartialFunction[Tree, T])(implicit ct: ClassTag[T]) {

def apply(tree: Tree) =
Expand Down Expand Up @@ -316,8 +321,15 @@ trait Parsing {
case q"new { def apply[..$t1](...$params) = $body }" =>
c.fail("Anonymous classes aren't supported for function declaration anymore. Use a method with a type parameter instead. " +
"For instance, replace `val q = quote { new { def apply[T](q: Query[T]) = ... } }` by `def q[T] = quote { (q: Query[T] => ... }`")
case q"(..$params) => $body" =>
Function(params.map(identParser(_)), astParser(body))
case q"(..$params) => $body" => {
val subtree = Function(params.map(identParser(_)), astParser(body))
// If there are actions inside the subtree, we need to do some additional sanitizations
// of the variables so that their content will not collide with code that we have generated.
if (CollectAst.byType[Action](subtree).nonEmpty)
AvoidAliasConflict.sanitizeVariables(subtree, dangerousVariables)
else
subtree
}
}

val identParser: Parser[Ident] = Parser[Ident] {
Expand Down Expand Up @@ -699,7 +711,9 @@ trait Parsing {
case q"$action.returning[$r](($alias) => $body)" =>
Returning(astParser(action), identParser(alias), astParser(body))
case q"$query.foreach[$t1, $t2](($alias) => $body)($f)" if (is[CoreDsl#Query[Any]](query)) =>
Foreach(astParser(query), identParser(alias), astParser(body))
// If there are actions inside the subtree, we need to do some additional sanitizations
// of the variables so that their content will not collide with code that we have generated.
AvoidAliasConflict.sanitizeVariables(Foreach(astParser(query), identParser(alias), astParser(body)), dangerousVariables)
}

private val assignmentParser: Parser[Assignment] = Parser[Assignment] {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.getquill.context.sql.idiom

import io.getquill.Spec
import io.getquill.context.mirror.Row
import io.getquill.context.sql.testContext
import io.getquill.context.sql.testContext._

Expand Down Expand Up @@ -844,6 +845,37 @@ class SqlIdiomSpec extends Spec {
}
"action" - {
"insert" - {
"not affected by variable name" - {
"simple" in {
val q = quote { (v: TestEntity) =>
query[TestEntity].insert(v)
}
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(q(lift(v))).string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)"
}
"returning" in {
val q = quote { (v: TestEntity) =>
query[TestEntity].insert(v)
}
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(q(lift(v)).returning(v => v.i)).string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)"
}
"foreach" in {
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(
liftQuery(List(v)).foreach(v => query[TestEntity].insert(v))
).groups mustEqual List(("INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)", List(Row(v.productIterator.toList: _*))))
}
"foreach returning" in {
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(
liftQuery(List(v)).foreach(v => query[TestEntity].insert(v).returning(v => v.i))
).groups mustEqual
List(("INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)", "i",
List(Row(v.productIterator.toList.filter(m => !m.isInstanceOf[Int]): _*))
))
}
}
"simple" in {
val q = quote {
qr1.insert(_.i -> 1, _.s -> "s")
Expand Down

0 comments on commit 2c3555f

Please sign in to comment.