Skip to content

Commit

Permalink
Merge pull request #1437 from getquill/equal-equal
Browse files Browse the repository at this point in the history
Allow == for Option[T] and/or T columns
  • Loading branch information
deusaquilus authored May 22, 2019
2 parents e3f0dbb + fae9a36 commit 1d57e6d
Show file tree
Hide file tree
Showing 7 changed files with 443 additions and 51 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ MySparkTest.scala
MyTestJdbc.scala
MyJdbcTest.scala
MySqlTest.scala
MyTest.scala
quill-core/src/main/resources/logback.xml
quill-jdbc/src/main/resources/logback.xml
log.txt*
Expand Down
9 changes: 8 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ lazy val `quill-core` =
.settings(libraryDependencies ++= Seq(
"com.typesafe" % "config" % "1.3.4",
"com.typesafe.scala-logging" %% "scala-logging" % "3.9.0",
"org.scala-lang" % "scala-reflect" % scalaVersion.value
"org.scala-lang" % "scala-reflect" % scalaVersion.value,

"org.scala-lang" % "scala-library" % "2.11.11",
"org.scala-lang" % "scala-compiler" % "2.11.11"
))
.jsSettings(
libraryDependencies += "org.scala-js" %%% "scalajs-java-time" % "0.2.5",
Expand Down Expand Up @@ -474,6 +477,10 @@ lazy val basicSettings = Seq(
scalaVersion := "2.11.12",
crossScalaVersions := Seq("2.11.12","2.12.7"),
libraryDependencies ++= Seq(
"org.scala-lang" % "scala-library" % "2.11.11",
"org.scala-lang" % "scala-compiler" % "2.11.11",
"org.scala-lang" % "scala-reflect" % "2.11.11",

"org.scalamacros" %% "resetallattrs" % "1.0.0",
"org.scalatest" %%% "scalatest" % "3.0.7" % Test,
"ch.qos.logback" % "logback-classic" % "1.2.3" % Test,
Expand Down
30 changes: 30 additions & 0 deletions quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,36 @@ private[dsl] trait QueryDsl {
)
}

object extras extends LowPriorityExtras {
implicit class NumericOptionOps[A: Numeric](a: Option[A]) {
def ===[B: Numeric](b: Option[B]): Boolean = a.exists(av => b.exists(bv => av == bv))
def ===[B: Numeric](b: B): Boolean = a.exists(av => av == b)
def =!=[B: Numeric](b: Option[B]): Boolean = a.exists(av => b.exists(bv => av != bv))
def =!=[B: Numeric](b: B): Boolean = a.exists(av => av != b)
}
implicit class NumericRegOps[A: Numeric](a: A) {
def ===[B: Numeric](b: Option[B]): Boolean = b.exists(bv => bv == a)
def ===[B: Numeric](b: B): Boolean = a == b
def =!=[B: Numeric](b: Option[B]): Boolean = b.exists(bv => bv != a)
def =!=[B: Numeric](b: B): Boolean = a != b
}
}

trait LowPriorityExtras {
implicit class OptionOps[T](a: Option[T]) {
def ===(b: Option[T]): Boolean = a.exists(av => b.exists(bv => av == bv))
def ===(b: T): Boolean = a.exists(av => av == b)
def =!=(b: Option[T]): Boolean = a.exists(av => b.exists(bv => av != bv))
def =!=(b: T): Boolean = a.exists(av => av != b)
}
implicit class RegOps[T](a: T) {
def ===(b: Option[T]): Boolean = b.exists(bv => bv == a)
def ===(b: T): Boolean = a == b
def =!=(b: Option[T]): Boolean = b.exists(bv => bv != a)
def =!=(b: T): Boolean = a != b
}
}

sealed trait Query[+T] {

def map[R](f: T => R): Query[R]
Expand Down
100 changes: 87 additions & 13 deletions quill-core/src/main/scala/io/getquill/quotation/Parsing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.getquill.norm.capture.AvoidAliasConflict
import scala.annotation.tailrec
import scala.collection.immutable.StringOps
import scala.reflect.macros.TypecheckException
import io.getquill.ast.Implicits._

trait Parsing {
this: Quotation =>
Expand Down Expand Up @@ -501,24 +502,44 @@ trait Parsing {
case q"${ astParser(a) }.apply[..$t](...$values)" => FunctionApply(a, values.flatten.map(astParser(_)))
}

private def rejectOptions(a: Tree, b: Tree) = {
if ((!is[Null](a) && is[Option[_]](a)) || (!is[Null](b) && is[Option[_]](b)))
c.abort(a.pos, "Can't compare `Option` values since databases have different behavior for null comparison. Use `Option` methods like `forall` and `exists` instead.")
def withInnerTypechecks(left: Tree, right: Tree)(equality: BinaryOperator): Operation = {
val (leftIsOptional, rightIsOptional) = checkInnerTypes(left, right)
val a = astParser(left)
val b = astParser(right)
val comparison = BinaryOperation(a, equality, b)
(leftIsOptional, rightIsOptional) match {
case (true, true) => OptionIsDefined(a) +&&+ OptionIsDefined(b) +&&+ comparison
case (true, false) => OptionIsDefined(a) +&&+ comparison
case (false, true) => OptionIsDefined(b) +&&+ comparison
case (false, false) => comparison
}
}

val equalityOperationParser: Parser[Operation] = Parser[Operation] {
case q"$a.==($b)" =>
checkTypes(a, b)
rejectOptions(a, b)
BinaryOperation(astParser(a), EqualityOperator.`==`, astParser(b))
withInnerTypechecks(a, b)(EqualityOperator.`==`)
case q"$a.equals($b)" =>
checkTypes(a, b)
rejectOptions(a, b)
BinaryOperation(astParser(a), EqualityOperator.`==`, astParser(b))
withInnerTypechecks(a, b)(EqualityOperator.`==`)
case q"$a.!=($b)" =>
checkTypes(a, b)
rejectOptions(a, b)
BinaryOperation(astParser(a), EqualityOperator.`!=`, astParser(b))
withInnerTypechecks(a, b)(EqualityOperator.`!=`)

case q"$pack.extras.NumericOptionOps[$t]($a)($imp).===[$q]($b)($imp2)" =>
withInnerTypechecks(a, b)(EqualityOperator.`==`)
case q"$pack.extras.NumericRegOps[$t]($a)($imp).===[$q]($b)($imp2)" =>
withInnerTypechecks(a, b)(EqualityOperator.`==`)
case q"$pack.extras.NumericOptionOps[$t]($a)($imp).=!=[$q]($b)($imp2)" =>
withInnerTypechecks(a, b)(EqualityOperator.`!=`)
case q"$pack.extras.NumericRegOps[$t]($a)($imp).=!=[$q]($b)($imp2)" =>
withInnerTypechecks(a, b)(EqualityOperator.`!=`)

case q"$pack.extras.OptionOps[$t]($a).===($b)" =>
withInnerTypechecks(a, b)(EqualityOperator.`==`)
case q"$pack.extras.RegOps[$t]($a).===($b)" =>
withInnerTypechecks(a, b)(EqualityOperator.`==`)
case q"$pack.extras.OptionOps[$t]($a).=!=($b)" =>
withInnerTypechecks(a, b)(EqualityOperator.`!=`)
case q"$pack.extras.RegOps[$t]($a).=!=($b)" =>
withInnerTypechecks(a, b)(EqualityOperator.`!=`)
}

val booleanOperationParser: Parser[Operation] =
Expand Down Expand Up @@ -594,6 +615,15 @@ trait Parsing {
private def isTypeTuple(tpe: Type) =
tpe.typeSymbol.fullName startsWith "scala.Tuple"

/**
* Need special handling to check if a type is null since need to check if it's Option, Some or None. Don't want
* to use `<:<` since that would also match things like `Nothing` and `Null`.
*/
def isOptionType(tpe: Type) = {
val era = tpe.erasure
era =:= typeOf[Option[Any]] || era =:= typeOf[Some[Any]] || era =:= typeOf[None.type]
}

/**
* Recursively traverse an `Option[T]` or `Option[Option[T]]`, or `Option[Option[Option[T]]]` etc...
* until we find the `T`
Expand All @@ -604,7 +634,9 @@ trait Parsing {
case TypeRef(_, cls, List(arg)) if (cls.isClass && cls.asClass.fullName == "scala.Option") =>
innerOptionParam(arg)
// If it's not a ref-type but an Option, convert to a ref-type and reprocess
case _ if (tpe <:< typeOf[Option[Any]]) =>
// also since Nothing is a subtype of everything need to know to stop searching once Nothing
// has been reached.
case _ if (isOptionType(tpe) && !(tpe =:= typeOf[Nothing])) =>
innerOptionParam(tpe.baseType(typeOf[Option[Any]].typeSymbol))
// Otherwise we have gotten to the actual type inside the nesting. Check what it is.
case other => other
Expand Down Expand Up @@ -754,6 +786,48 @@ trait Parsing {
private def parseConflictAssigns(targets: List[Tree]) =
OnConflict.Update(targets.map(assignmentParser(_)))

/**
* Type-check two trees, if one of them has optionals, go into the optionals to find the root types
* in each of them. Then compare the types that are inside. If they are not compareable, abort the build.
* Otherwise return type of which side (or both) has the optional. In order to do the actual comparison,
* the 'weak conformance' operator is used and a subclass is allowed on either side of the `==`. Weak
* conformance is necessary so that Longs can be compared to Ints etc...
*/
private def checkInnerTypes(lhs: Tree, rhs: Tree): (Boolean, Boolean) = {
val leftType = typecheckUnquoted(lhs).tpe
val rightType = typecheckUnquoted(rhs).tpe
val leftInner = innerOptionParam(leftType)
val rightInner = innerOptionParam(rightType)
val leftIsOptional = isOptionType(leftType) && !is[Nothing](lhs)
val rightIsOptional = isOptionType(rightType) && !is[Nothing](rhs)

if (rightInner.`weak_<:<`(leftInner) ||
rightInner.widen.`weak_<:<`(leftInner.widen) ||
leftInner.`weak_<:<`(rightInner) ||
leftInner.widen.`weak_<:<`(rightInner.widen)) {
(leftIsOptional, rightIsOptional)
} else {
if (leftIsOptional || rightIsOptional)
c.abort(lhs.pos, s"${leftType.widen} == ${rightType.widen} is not allowed since ${leftInner.widen}, ${rightInner.widen} are different types.")
else
c.abort(lhs.pos, s"${leftType.widen} == ${rightType.widen} is not allowed since they are different types.")
}
}

private def typecheckUnquoted(tree: Tree): Tree = {
def unquoted(maybeQuoted: Tree) =
is[CoreDsl#Quoted[Any]](maybeQuoted) match {
case false => maybeQuoted
case true => q"unquote($maybeQuoted)"
}
val t = TypeName(c.freshName("T"))
try
c.typecheck(unquoted(tree), c.TYPEmode)
catch {
case t: TypecheckException => c.abort(tree.pos, t.msg)
}
}

private def checkTypes(lhs: Tree, rhs: Tree): Unit = {
def unquoted(tree: Tree) =
is[CoreDsl#Quoted[Any]](tree) match {
Expand Down
15 changes: 15 additions & 0 deletions quill-core/src/test/scala/io/getquill/MoreAstOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.getquill
import io.getquill.ast._

object MoreAstOps {
implicit class AstOpsExt2(body: Ast) {
def +++(other: Constant) =
if (other.v.isInstanceOf[String])
BinaryOperation(body, StringOperator.`+`, other)
else
BinaryOperation(body, NumericOperator.`+`, other)

def +>+(other: Ast) = BinaryOperation(body, NumericOperator.`>`, other)
def +!=+(other: Ast) = BinaryOperation(body, EqualityOperator.`!=`, other)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,10 @@ import io.getquill.testContext._
import io.getquill.ast.NumericOperator
import io.getquill.ast.Implicits._
import io.getquill.norm.ConcatBehavior.{ AnsiConcat, NonAnsiConcat }
import io.getquill.MoreAstOps._

class FlattenOptionOperationSpec extends Spec {

implicit class AstOpsExt2(body: Ast) {
def +++(other: Constant) =
if (other.v.isInstanceOf[String])
BinaryOperation(body, StringOperator.`+`, other)
else
BinaryOperation(body, NumericOperator.`+`, other)

def +>+(other: Ast) = BinaryOperation(body, NumericOperator.`>`, other)
def +!=+(other: Ast) = BinaryOperation(body, EqualityOperator.`!=`, other)
}

def o = Ident("o")
def c1 = Constant(1)
def cFoo = Constant("foo")
Expand Down
Loading

0 comments on commit 1d57e6d

Please sign in to comment.