Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix variable shadowing issue in action metas #1412

Merged
merged 1 commit into from
Apr 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ quill-sql/io/
MyTest.scala
MySparkTest.scala
MyTestJdbc.scala
MyTestSql.scala
MySqlTest.scala
quill-core/src/main/resources/logback.xml
quill-jdbc/src/main/resources/logback.xml
log.txt*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
package io.getquill.norm.capture

import io.getquill.ast._
import io.getquill.ast.Entity
import io.getquill.ast.Filter
import io.getquill.ast.FlatMap
import io.getquill.ast.Ident
import io.getquill.ast.Join
import io.getquill.ast.Map
import io.getquill.ast.Query
import io.getquill.ast.SortBy
import io.getquill.ast.StatefulTransformer
import io.getquill.ast.{ Entity, Filter, FlatJoin, FlatMap, GroupBy, Ident, Join, Map, Query, SortBy, StatefulTransformer, _ }
import io.getquill.norm.BetaReduction
import io.getquill.ast.FlatJoin
import io.getquill.ast.GroupBy

private case class AvoidAliasConflict(state: collection.Set[Ident])
private[getquill] case class AvoidAliasConflict(state: collection.Set[Ident])
extends StatefulTransformer[collection.Set[Ident]] {

object Unaliased {
Expand Down Expand Up @@ -99,12 +88,59 @@ private case class AvoidAliasConflict(state: collection.Set[Ident])
else
loop(x, 1)
}

/**
* Sometimes we need to change the variables in a function because they will might conflict with some variable
* further up in the macro. Right now, this only happens when you do something like this:
* <code>
* val q = quote { (v: Foo) => query[Foo].insert(v) }
* run(q(lift(v)))
* </code>
* Since 'v' is used by actionMeta in order to map keys to values for insertion, using it as a function argument
* messes up the output SQL like so:
* <code>
* INSERT INTO MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?)
* </code>
* Therefore, we need to have a method to remove such conflicting variables from Function ASTs
*/
private def applyFunction(f: Function): Function = {
val (newBody, _, newParams) =
f.params.foldLeft((f.body, state, List[Ident]())) {
case ((body, state, newParams), param) => {
val fresh = freshIdent(param)
val pr = BetaReduction(body, param -> fresh)
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
(prr, t.state, newParams :+ fresh)
}
}
Function(newParams, newBody)
}

private def applyForeach(f: Foreach): Foreach = {
val fresh = freshIdent(f.alias)
val pr = BetaReduction(f.body, f.alias -> fresh)
val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
Foreach(f.query, fresh, prr)
}
}

private[capture] object AvoidAliasConflict {
private[getquill] object AvoidAliasConflict {

def apply(q: Query): Query =
AvoidAliasConflict(collection.Set[Ident]())(q) match {
case (q, _) => q
}

/**
* Make sure query parameters do not collide with paramters of a AST function. Do this
* by walkning through the function's subtree and transforming and queries encountered.
*/
def sanitizeVariables(f: Function, dangerousVariables: Set[Ident]): Function = {
AvoidAliasConflict(dangerousVariables).applyFunction(f)
}

/** Same is `sanitizeVariables` but for Foreach **/
def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
AvoidAliasConflict(dangerousVariables).applyForeach(f)
}
}
20 changes: 17 additions & 3 deletions quill-core/src/main/scala/io/getquill/quotation/Parsing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.getquill.norm.BetaReduction
import io.getquill.util.Messages.RichContext
import io.getquill.util.Interleave
import io.getquill.dsl.CoreDsl
import io.getquill.norm.capture.AvoidAliasConflict

import scala.annotation.tailrec
import scala.collection.immutable.StringOps
Expand All @@ -17,6 +18,10 @@ trait Parsing {

import c.universe.{ Ident => _, Constant => _, Function => _, If => _, Block => _, _ }

// Variables that need to be sanitized out in various places due to internal conflicts with the way
// macros hard handeled in MetaDsl
private[getquill] val dangerousVariables = Set("v").map(Ident(_))

case class Parser[T](p: PartialFunction[Tree, T])(implicit ct: ClassTag[T]) {

def apply(tree: Tree) =
Expand Down Expand Up @@ -316,8 +321,15 @@ trait Parsing {
case q"new { def apply[..$t1](...$params) = $body }" =>
c.fail("Anonymous classes aren't supported for function declaration anymore. Use a method with a type parameter instead. " +
"For instance, replace `val q = quote { new { def apply[T](q: Query[T]) = ... } }` by `def q[T] = quote { (q: Query[T] => ... }`")
case q"(..$params) => $body" =>
Function(params.map(identParser(_)), astParser(body))
case q"(..$params) => $body" => {
val subtree = Function(params.map(identParser(_)), astParser(body))
// If there are actions inside the subtree, we need to do some additional sanitizations
// of the variables so that their content will not collide with code that we have generated.
if (CollectAst.byType[Action](subtree).nonEmpty)
AvoidAliasConflict.sanitizeVariables(subtree, dangerousVariables)
else
subtree
}
}

val identParser: Parser[Ident] = Parser[Ident] {
Expand Down Expand Up @@ -699,7 +711,9 @@ trait Parsing {
case q"$action.returning[$r](($alias) => $body)" =>
Returning(astParser(action), identParser(alias), astParser(body))
case q"$query.foreach[$t1, $t2](($alias) => $body)($f)" if (is[CoreDsl#Query[Any]](query)) =>
Foreach(astParser(query), identParser(alias), astParser(body))
// If there are actions inside the subtree, we need to do some additional sanitizations
// of the variables so that their content will not collide with code that we have generated.
AvoidAliasConflict.sanitizeVariables(Foreach(astParser(query), identParser(alias), astParser(body)), dangerousVariables)
}

private val assignmentParser: Parser[Assignment] = Parser[Assignment] {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.getquill.context.sql.idiom

import io.getquill.Spec
import io.getquill.context.mirror.Row
import io.getquill.context.sql.testContext
import io.getquill.context.sql.testContext._

Expand Down Expand Up @@ -844,6 +845,37 @@ class SqlIdiomSpec extends Spec {
}
"action" - {
"insert" - {
"not affected by variable name" - {
"simple" in {
val q = quote { (v: TestEntity) =>
query[TestEntity].insert(v)
}
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(q(lift(v))).string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)"
}
"returning" in {
val q = quote { (v: TestEntity) =>
query[TestEntity].insert(v)
}
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(q(lift(v)).returning(v => v.i)).string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)"
}
"foreach" in {
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(
liftQuery(List(v)).foreach(v => query[TestEntity].insert(v))
).groups mustEqual List(("INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)", List(Row(v.productIterator.toList: _*))))
}
"foreach returning" in {
val v = TestEntity("s", 1, 2L, Some(1))
testContext.run(
liftQuery(List(v)).foreach(v => query[TestEntity].insert(v).returning(v => v.i))
).groups mustEqual
List(("INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)", "i",
List(Row(v.productIterator.toList.filter(m => !m.isInstanceOf[Int]): _*))
))
}
}
"simple" in {
val q = quote {
qr1.insert(_.i -> 1, _.s -> "s")
Expand Down