From 1da85122b92427fe3bb3974068ea823af42ea70f Mon Sep 17 00:00:00 2001 From: fwbrasil Date: Fri, 20 Oct 2017 08:59:20 -0700 Subject: [PATCH] spark sql support --- README.md | 31 +++ build.sbt | 34 +++- build/build.sh | 6 +- .../io/getquill/norm/AttachToEntity.scala | 21 +- .../norm/capture/AvoidAliasConflict.scala | 15 +- .../io/getquill/norm/capture/Dealias.scala | 4 +- .../io/getquill/norm/AttachToEntitySpec.scala | 84 +++++++- .../norm/capture/AvoidAliasConflictSpec.scala | 20 +- .../getquill/norm/capture/DealiasSpec.scala | 8 +- .../scala/io/getquill/QuillSparkContext.scala | 79 ++++++++ .../spark/AliasNestedQueryColumns.scala | 32 +++ .../io/getquill/context/spark/Binding.scala | 9 + .../io/getquill/context/spark/Decoders.scala | 16 ++ .../io/getquill/context/spark/Encoders.scala | 36 ++++ .../getquill/context/spark/SparkDialect.scala | 69 +++++++ quill-spark/src/test/resources/.placeholder | 1 + .../spark/AliasNestedQueryColumnsSpec.scala | 68 +++++++ .../context/spark/DepartmentsSparkSpec.scala | 115 +++++++++++ .../context/spark/EncodingSparkSpec.scala | 189 ++++++++++++++++++ .../context/spark/PeopleSparkSpec.scala | 111 ++++++++++ .../context/spark/QuillSparkContextSpec.scala | 39 ++++ .../io/getquill/context/spark/package.scala | 18 ++ .../io/getquill/context/sql/SqlQuery.scala | 6 + 23 files changed, 970 insertions(+), 41 deletions(-) create mode 100644 quill-spark/src/main/scala/io/getquill/QuillSparkContext.scala create mode 100644 quill-spark/src/main/scala/io/getquill/context/spark/AliasNestedQueryColumns.scala create mode 100644 quill-spark/src/main/scala/io/getquill/context/spark/Binding.scala create mode 100644 quill-spark/src/main/scala/io/getquill/context/spark/Decoders.scala create mode 100644 quill-spark/src/main/scala/io/getquill/context/spark/Encoders.scala create mode 100644 quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala create mode 100644 quill-spark/src/test/resources/.placeholder create mode 100644 quill-spark/src/test/scala/io/getquill/context/spark/AliasNestedQueryColumnsSpec.scala create mode 100644 quill-spark/src/test/scala/io/getquill/context/spark/DepartmentsSparkSpec.scala create mode 100644 quill-spark/src/test/scala/io/getquill/context/spark/EncodingSparkSpec.scala create mode 100644 quill-spark/src/test/scala/io/getquill/context/spark/PeopleSparkSpec.scala create mode 100644 quill-spark/src/test/scala/io/getquill/context/spark/QuillSparkContextSpec.scala create mode 100644 quill-spark/src/test/scala/io/getquill/context/spark/package.scala diff --git a/README.md b/README.md index 934cdf4256..b4cded6179 100644 --- a/README.md +++ b/README.md @@ -1271,6 +1271,37 @@ case class MyDao(c: MyContext) extends MySchema { } ``` +## Spark Context + +Quill provides a context that allow users to run queries on top of Spark's SQL engine. Example usage: + +``` +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.Dataset +import io.getquill.QuillSparkContext._ + +// Replace by your spark sql context +implicit val sqlContext = + SparkSession + .builder() + .master("local") + .appName("spark test") + .getOrCreate() + .sqlContext + +import sqlContext.implicits._ + +case class Person(name: String, age: Int) + +val people = List(Person("John", 22)).toQuery + +val q = quote { + people.filter(_.name == "John") +} + +val filtered: Dataset[Person] = run(q) +``` + ## SQL Contexts Example: diff --git a/build.sbt b/build.sbt index 0a05c2eb20..b8e124520b 100644 --- a/build.sbt +++ b/build.sbt @@ -3,19 +3,22 @@ import com.typesafe.sbt.SbtScalariform.ScalariformKeys import scalariform.formatter.preferences._ import sbtrelease.ReleasePlugin +lazy val scalaVersionProperty = Option(System.getProperty("scalaVersion")) + +lazy val modules = Seq[sbt.ClasspathDep[sbt.ProjectReference]]( + `quill-core-jvm`, `quill-core-js`, `quill-sql-jvm`, `quill-sql-js`, + `quill-jdbc`, `quill-finagle-mysql`, `quill-finagle-postgres`, `quill-async`, + `quill-async-mysql`, `quill-async-postgres`, `quill-cassandra`, `quill-orientdb` +) ++ + Seq[sbt.ClasspathDep[sbt.ProjectReference]](`quill-spark`) + .filter(_ => scalaVersionProperty.map(_.startsWith("2.11")).getOrElse(true)) + lazy val `quill` = (project in file(".")) .settings(tutSettings ++ commonSettings) .settings(`tut-settings`:_*) - .dependsOn( - `quill-core-jvm`, `quill-core-js`, `quill-sql-jvm`, `quill-sql-js`, - `quill-jdbc`, `quill-finagle-mysql`, `quill-finagle-postgres`, `quill-async`, - `quill-async-mysql`, `quill-async-postgres`, `quill-cassandra`, `quill-orientdb` - ).aggregate( - `quill-core-jvm`, `quill-core-js`, `quill-sql-jvm`, `quill-sql-js`, - `quill-jdbc`, `quill-finagle-mysql`, `quill-finagle-postgres`, `quill-async`, - `quill-async-mysql`, `quill-async-postgres`, `quill-cassandra`, `quill-orientdb` - ) + .aggregate(modules.map(_.project): _*) + .dependsOn(modules: _*) lazy val superPure = new org.scalajs.sbtplugin.cross.CrossType { def projectDir(crossBase: File, projectType: String): File = @@ -74,6 +77,19 @@ lazy val `quill-jdbc` = ) .dependsOn(`quill-sql-jvm` % "compile->compile;test->test") +lazy val `quill-spark` = + (project in file("quill-spark")) + .settings(commonSettings: _*) + .settings(mimaSettings: _*) + .settings( + crossScalaVersions := Seq("2.11.11"), + fork in Test := true, + libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-sql" % "2.2.0" + ) + ) + .dependsOn(`quill-sql-jvm` % "compile->compile;test->test") + lazy val `quill-finagle-mysql` = (project in file("quill-finagle-mysql")) .settings(commonSettings: _*) diff --git a/build/build.sh b/build/build.sh index c378650756..61e80e5f19 100755 --- a/build/build.sh +++ b/build/build.sh @@ -4,9 +4,9 @@ set -e # Any subsequent(*) commands which fail will cause the shell script to ex chown root ~/.ssh/config chmod 644 ~/.ssh/config -SBT_CMD="sbt clean" -SBT_CMD_2_11=" ++2.11.11 coverage test tut coverageReport coverageAggregate checkUnformattedFiles" -SBT_CMD_2_12=" ++2.12.2 test" +SBT_CMD="sbt" +SBT_CMD_2_11=" -DscalaVersion=2.11.11 ++2.11.11 clean coverage test tut coverageReport coverageAggregate checkUnformattedFiles" +SBT_CMD_2_12=" -DscalaVersion=2.12.3 ++2.12.3 clean test" SBT_PUBLISH=" coverageOff publish" if [[ $SCALA_VERSION == "2.11" ]] diff --git a/quill-core/src/main/scala/io/getquill/norm/AttachToEntity.scala b/quill-core/src/main/scala/io/getquill/norm/AttachToEntity.scala index b1cd463f6f..2d95b58fbd 100644 --- a/quill-core/src/main/scala/io/getquill/norm/AttachToEntity.scala +++ b/quill-core/src/main/scala/io/getquill/norm/AttachToEntity.scala @@ -5,13 +5,22 @@ import io.getquill.ast._ object AttachToEntity { - def apply(f: (Query, Ident) => Query, alias: Option[Ident] = None)(q: Query): Query = + private object IsEntity { + def unapply(q: Ast): Option[Ast] = + q match { + case q: Entity => Some(q) + case q: Infix => Some(q) + case _ => None + } + } + + def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)(q: Ast): Ast = q match { - case Map(a: Entity, b, c) => Map(f(a, b), b, c) - case FlatMap(a: Entity, b, c) => FlatMap(f(a, b), b, c) - case Filter(a: Entity, b, c) => Filter(f(a, b), b, c) - case SortBy(a: Entity, b, c, d) => SortBy(f(a, b), b, c, d) + case Map(IsEntity(a), b, c) => Map(f(a, b), b, c) + case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c) + case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c) + case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d) case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join | _: FlatJoin => f(q, alias.getOrElse(Ident("x"))) @@ -24,7 +33,7 @@ object AttachToEntity { case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a)) case Distinct(a: Query) => Distinct(apply(f, alias)(a)) - case e: Entity => f(e, alias.getOrElse(Ident("x"))) + case IsEntity(q) => f(q, alias.getOrElse(Ident("x"))) case other => fail(s"Can't find an 'Entity' in '$q'") } diff --git a/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala b/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala index da8c20848d..865a3a52a3 100644 --- a/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala +++ b/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala @@ -19,24 +19,21 @@ private case class AvoidAliasConflict(state: collection.Set[Ident]) object Unaliased { - private def isUnaliased(q: Query): Boolean = + private def isUnaliased(q: Ast): Boolean = q match { case Nested(q: Query) => isUnaliased(q) case Take(q: Query, _) => isUnaliased(q) case Drop(q: Query, _) => isUnaliased(q) case Aggregation(_, q: Query) => isUnaliased(q) case Distinct(q: Query) => isUnaliased(q) - case _: Entity => true - case _: Nested | _: Take | _: Drop | _: Aggregation | - _: Distinct | _: FlatMap | _: Map | _: Filter | _: SortBy | - _: GroupBy | _: Union | _: UnionAll | _: Join | _: FlatJoin => - false + case _: Entity | _: Infix => true + case _ => false } - def unapply(q: Ast): Option[Query] = + def unapply(q: Ast): Option[Ast] = q match { - case q: Query if (isUnaliased(q)) => Some(q) - case _ => None + case q if (isUnaliased(q)) => Some(q) + case _ => None } } diff --git a/quill-core/src/main/scala/io/getquill/norm/capture/Dealias.scala b/quill-core/src/main/scala/io/getquill/norm/capture/Dealias.scala index ecbd19394b..772f32b443 100644 --- a/quill-core/src/main/scala/io/getquill/norm/capture/Dealias.scala +++ b/quill-core/src/main/scala/io/getquill/norm/capture/Dealias.scala @@ -36,8 +36,8 @@ case class Dealias(state: Option[Ident]) extends StatefulTransformer[Option[Iden val (bn, _) = apply(b) (UnionAll(an, bn), Dealias(None)) case Join(t, a, b, iA, iB, o) => - val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _)) - val ((bn, iBn, onn), _) = ont.dealias(b, iB, on)((_, _, _)) + val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _)) + val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _)) (Join(t, an, bn, iAn, iBn, onn), Dealias(None)) case FlatJoin(t, a, iA, o) => val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _)) diff --git a/quill-core/src/test/scala/io/getquill/norm/AttachToEntitySpec.scala b/quill-core/src/test/scala/io/getquill/norm/AttachToEntitySpec.scala index b863de4ca5..cdfb8aa873 100644 --- a/quill-core/src/test/scala/io/getquill/norm/AttachToEntitySpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/AttachToEntitySpec.scala @@ -6,11 +6,7 @@ import io.getquill.ast.Constant import io.getquill.ast.Ident import io.getquill.ast.Map import io.getquill.ast.SortBy -import io.getquill.testContext.implicitOrd -import io.getquill.testContext.qr1 -import io.getquill.testContext.qr2 -import io.getquill.testContext.quote -import io.getquill.testContext.unquote +import io.getquill.testContext._ class AttachToEntitySpec extends Spec { @@ -90,6 +86,84 @@ class AttachToEntitySpec extends Spec { } } + val iqr1 = quote { + infix"$qr1".as[Query[TestEntity]] + } + + "attaches clause to the root of the query (infix)" - { + "query is the entity" in { + val n = quote { + iqr1.sortBy(x => 1) + } + attachToEntity(iqr1.ast) mustEqual n.ast + } + "query is a composition" - { + "map" in { + val q = quote { + iqr1.filter(t => t.i == 1).map(t => t.s) + } + val n = quote { + iqr1.sortBy(t => 1).filter(t => t.i == 1).map(t => t.s) + } + attachToEntity(q.ast) mustEqual n.ast + } + "flatMap" in { + val q = quote { + iqr1.filter(t => t.i == 1).flatMap(t => qr2) + } + val n = quote { + iqr1.sortBy(t => 1).filter(t => t.i == 1).flatMap(t => qr2) + } + attachToEntity(q.ast) mustEqual n.ast + } + "filter" in { + val q = quote { + iqr1.filter(t => t.i == 1).filter(t => t.s == "s1") + } + val n = quote { + iqr1.sortBy(t => 1).filter(t => t.i == 1).filter(t => t.s == "s1") + } + attachToEntity(q.ast) mustEqual n.ast + } + "sortBy" in { + val q = quote { + iqr1.sortBy(t => t.s) + } + val n = quote { + iqr1.sortBy(t => 1).sortBy(t => t.s) + } + attachToEntity(q.ast) mustEqual n.ast + } + "take" in { + val q = quote { + iqr1.sortBy(b => b.s).take(1) + } + val n = quote { + iqr1.sortBy(b => 1).sortBy(b => b.s).take(1) + } + attachToEntity(q.ast) mustEqual n.ast + } + "drop" in { + val q = quote { + iqr1.sortBy(b => b.s).drop(1) + } + val n = quote { + iqr1.sortBy(b => 1).sortBy(b => b.s).drop(1) + } + attachToEntity(q.ast) mustEqual n.ast + } + "distinct" in { + val q = quote { + iqr1.sortBy(b => b.s).drop(1).distinct + } + val n = quote { + iqr1.sortBy(b => 1).sortBy(b => b.s).drop(1).distinct + } + attachToEntity(q.ast) mustEqual n.ast + } + } + } + "falls back to the query if it's not possible to flatten it" - { "union" in { val q = quote { diff --git a/quill-core/src/test/scala/io/getquill/norm/capture/AvoidAliasConflictSpec.scala b/quill-core/src/test/scala/io/getquill/norm/capture/AvoidAliasConflictSpec.scala index 5727407d68..0fe23bf764 100644 --- a/quill-core/src/test/scala/io/getquill/norm/capture/AvoidAliasConflictSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/capture/AvoidAliasConflictSpec.scala @@ -1,12 +1,7 @@ package io.getquill.norm.capture import io.getquill.Spec -import io.getquill.testContext.implicitOrd -import io.getquill.testContext.qr1 -import io.getquill.testContext.qr2 -import io.getquill.testContext.qr3 -import io.getquill.testContext.quote -import io.getquill.testContext.unquote +import io.getquill.testContext._ class AvoidAliasConflictSpec extends Spec { @@ -160,6 +155,19 @@ class AvoidAliasConflictSpec extends Spec { } } + "considers infix as unaliased" in { + val i = quote { + infix"$qr1".as[Query[TestEntity]] + } + val q = quote { + i.flatMap(a => qr2.flatMap(a => qr3)) + } + val n = quote { + i.flatMap(a => qr2.flatMap(a1 => qr3)) + } + AvoidAliasConflict(q.ast) mustEqual n.ast + } + "takes in consideration the aliases already defined" - { "flatMap" in { val q = quote { diff --git a/quill-core/src/test/scala/io/getquill/norm/capture/DealiasSpec.scala b/quill-core/src/test/scala/io/getquill/norm/capture/DealiasSpec.scala index 37bf86f293..66a93f2607 100644 --- a/quill-core/src/test/scala/io/getquill/norm/capture/DealiasSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/capture/DealiasSpec.scala @@ -119,7 +119,7 @@ class DealiasSpec extends Spec { Dealias(q.ast) mustEqual n.ast } } - "outer join" - { + "join" - { "left" in { val q = quote { qr1.filter(a => a.s == "s").map(b => b.s).fullJoin(qr1).on((a, b) => a == b.s) @@ -147,6 +147,12 @@ class DealiasSpec extends Spec { } Dealias(q.ast) mustEqual n.ast } + "self join" in { + val q = quote { + qr1.join(qr1).on((a, b) => a.i == b.i) + } + Dealias(q.ast) mustEqual q.ast + } } "entity" in { Dealias(qr1.ast) mustEqual qr1.ast diff --git a/quill-spark/src/main/scala/io/getquill/QuillSparkContext.scala b/quill-spark/src/main/scala/io/getquill/QuillSparkContext.scala new file mode 100644 index 0000000000..d30d796d73 --- /dev/null +++ b/quill-spark/src/main/scala/io/getquill/QuillSparkContext.scala @@ -0,0 +1,79 @@ +package io.getquill + +import scala.util.Success +import scala.util.Try +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{ Encoder => SparkEncoder } +import org.apache.spark.sql.SQLContext +import io.getquill.context.Context +import org.apache.spark.rdd.RDD +import language.implicitConversions +import io.getquill.context.spark.Encoders +import io.getquill.context.spark.Decoders +import io.getquill.context.spark.SparkDialect +import io.getquill.context.spark.Binding +import io.getquill.context.spark.DatasetBinding +import io.getquill.context.spark.ValueBinding +import scala.reflect.ClassTag + +object QuillSparkContext extends QuillSparkContext + +trait QuillSparkContext + extends Context[SparkDialect, Literal] + with Encoders + with Decoders { + + type Result[T] = Dataset[T] + type RunQuerySingleResult[T] = T + type RunQueryResult[T] = T + + type PrepareRow = List[Binding] + type ResultRow = Null + + def close() = {} + + def probe(statement: String): Try[_] = Success(Unit) + + val idiom = SparkDialect + val naming = Literal + + private implicit def datasetEncoder[T] = + (idx: Int, ds: Dataset[T], row: List[Binding]) => + row :+ DatasetBinding(ds) + + case class ToQuery[T](ds: Dataset[T]) { + def toQuery = + quote { + infix"${lift(ds)}".as[Query[T]] + } + } + + implicit def seqToQuery[T: SparkEncoder: ClassTag](t: Seq[T])(implicit spark: SQLContext) = + rddToQuery(spark.sparkContext.parallelize(t)) + + implicit def datasetToQuery[T](ds: Dataset[T]) = ToQuery(ds) + + implicit def rddToQuery[T: SparkEncoder](ds: RDD[T])(implicit spark: SQLContext) = { + import spark.implicits._ + ToQuery(ds.toDS) + } + + def executeQuery[T](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit enc: SparkEncoder[T], spark: SQLContext) = + spark.sql(prepareString(string, prepare)).as[T] + + def executeQuerySingle[T](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit enc: SparkEncoder[T], spark: SQLContext) = + spark.sql(prepareString(string, prepare)).as[T] + + private def prepareString(string: String, prepare: Prepare)(implicit spark: SQLContext) = { + var dsId = 0 + prepare(Nil)._2.foldLeft(string) { + case (string, DatasetBinding(ds)) => + dsId += 1 + val name = s"ds$dsId" + ds.createOrReplaceTempView(name) + string.replaceFirst("\\?", name) + case (string, ValueBinding(value)) => + string.replaceFirst("\\?", value) + } + } +} diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/AliasNestedQueryColumns.scala b/quill-spark/src/main/scala/io/getquill/context/spark/AliasNestedQueryColumns.scala new file mode 100644 index 0000000000..0ac2ac6c42 --- /dev/null +++ b/quill-spark/src/main/scala/io/getquill/context/spark/AliasNestedQueryColumns.scala @@ -0,0 +1,32 @@ +package io.getquill.context.spark + +import io.getquill.context.sql.SqlQuery +import io.getquill.context.sql.FlattenSqlQuery +import io.getquill.context.sql._ +import io.getquill.ast.Ident + +object AliasNestedQueryColumns { + + def apply(q: SqlQuery): SqlQuery = + q match { + case q: FlattenSqlQuery => + val aliased = + q.select.zipWithIndex.map { + case (s @ SelectValue(i: Ident, alias), idx) => s + case (f, idx) => f.copy(alias = f.alias.orElse(Some(s"_${idx + 1}"))) + } + + q.copy(from = q.from.map(apply), select = aliased) + + case SetOperationSqlQuery(a, op, b) => SetOperationSqlQuery(apply(a), op, apply(b)) + case UnaryOperationSqlQuery(op, a) => UnaryOperationSqlQuery(op, apply(a)) + } + + private def apply(f: FromContext): FromContext = + f match { + case QueryContext(a, alias) => QueryContext(apply(a), alias) + case JoinContext(t, a, b, on) => JoinContext(t, apply(a), apply(b), on) + case FlatJoinContext(t, a, on) => FlatJoinContext(t, apply(a), on) + case other => other + } +} \ No newline at end of file diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/Binding.scala b/quill-spark/src/main/scala/io/getquill/context/spark/Binding.scala new file mode 100644 index 0000000000..587f40aa2c --- /dev/null +++ b/quill-spark/src/main/scala/io/getquill/context/spark/Binding.scala @@ -0,0 +1,9 @@ +package io.getquill.context.spark + +import org.apache.spark.sql.Dataset + +sealed trait Binding + +case class DatasetBinding[T](ds: Dataset[T]) extends Binding + +case class ValueBinding(str: String) extends Binding \ No newline at end of file diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/Decoders.scala b/quill-spark/src/main/scala/io/getquill/context/spark/Decoders.scala new file mode 100644 index 0000000000..b94de5246c --- /dev/null +++ b/quill-spark/src/main/scala/io/getquill/context/spark/Decoders.scala @@ -0,0 +1,16 @@ +package io.getquill.context.spark + +import io.getquill.util.Messages +import io.getquill.QuillSparkContext + +trait Decoders { + this: QuillSparkContext => + + type Decoder[T] = BaseDecoder[T] + + implicit def dummyDecoder[T] = + (idx: Int, row: ResultRow) => Messages.fail("quill decoders are not used for spark") + + implicit def mappedDecoder[I, O](implicit mapped: MappedEncoding[I, O], decoder: Decoder[I]): Decoder[O] = + dummyDecoder[O] +} \ No newline at end of file diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/Encoders.scala b/quill-spark/src/main/scala/io/getquill/context/spark/Encoders.scala new file mode 100644 index 0000000000..d5b0c564b2 --- /dev/null +++ b/quill-spark/src/main/scala/io/getquill/context/spark/Encoders.scala @@ -0,0 +1,36 @@ +package io.getquill.context.spark + +import io.getquill.QuillSparkContext + +trait Encoders { + this: QuillSparkContext => + + type Encoder[T] = BaseEncoder[T] + + def encoder[T](f: T => String): Encoder[T] = + (index: Index, value: T, row: PrepareRow) => + row :+ ValueBinding(f(value)) + + private def toStringEncoder[T]: Encoder[T] = encoder((v: T) => s"$v") + + private def quotedToStringEncoder[T]: Encoder[T] = encoder(v => s""""$v"""") + + implicit def mappedEncoder[I, O](implicit mapped: MappedEncoding[I, O], e: Encoder[O]): Encoder[I] = + mappedBaseEncoder(mapped, e) + + implicit def optionEncoder[T](implicit d: Encoder[T]): Encoder[Option[T]] = + (index: Index, value: Option[T], row: PrepareRow) => + value match { + case None => row :+ ValueBinding("null") + case Some(v) => d(index, v, row) + } + + implicit val stringEncoder: Encoder[String] = quotedToStringEncoder + implicit val bigDecimalEncoder: Encoder[BigDecimal] = toStringEncoder + implicit val booleanEncoder: Encoder[Boolean] = toStringEncoder + implicit val byteEncoder: Encoder[Byte] = toStringEncoder + implicit val shortEncoder: Encoder[Short] = toStringEncoder + implicit val intEncoder: Encoder[Int] = toStringEncoder + implicit val longEncoder: Encoder[Long] = toStringEncoder + implicit val doubleEncoder: Encoder[Double] = toStringEncoder +} \ No newline at end of file diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala b/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala new file mode 100644 index 0000000000..21ac3c4650 --- /dev/null +++ b/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala @@ -0,0 +1,69 @@ +package io.getquill.context.spark + +import io.getquill.NamingStrategy +import io.getquill.ast._ +import io.getquill.ast.Property +import io.getquill.ast.Query +import io.getquill.context.sql.SqlQuery +import io.getquill.context.sql.idiom.SqlIdiom +import io.getquill.context.sql.idiom.VerifySqlQuery +import io.getquill.context.sql.norm.SqlNormalize +import io.getquill.idiom.StatementInterpolator.Impl +import io.getquill.idiom.StatementInterpolator.TokenImplicit +import io.getquill.idiom.StatementInterpolator.Tokenizer +import io.getquill.idiom.StatementInterpolator.stringTokenizer +import io.getquill.idiom.StatementInterpolator.tokenTokenizer +import io.getquill.idiom.Token +import io.getquill.util.Messages.fail +import io.getquill.util.Messages.trace + +class SparkDialect extends SqlIdiom { + + def liftingPlaceholder(index: Int): String = "?" + + override def prepareForProbing(string: String) = string + + override def translate(ast: Ast)(implicit naming: NamingStrategy) = { + val normalizedAst = SqlNormalize(ast) + + implicit val tokernizer = defaultTokenizer + + val token = + normalizedAst match { + case q: Query => + val sql = SqlQuery(q) + trace("sql")(sql) + VerifySqlQuery(sql).map(fail) + sql.token + case other => + other.token + } + + (normalizedAst, stmt"$token") + } + + override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + case q => super.sqlQueryTokenizer.token(AliasNestedQueryColumns(q)) + } + + override implicit def propertyTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Property] = { + def path(ast: Ast): Token = + ast match { + case Property(a, b) => + stmt"${path(a)}.${strategy.column(b).token}" + case other => + other.token + } + Tokenizer[Property] { + case p => path(p).token + } + } + + override implicit def valueTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Value] = Tokenizer[Value] { + case Tuple(values) => stmt"(${values.token})" + case other => super.valueTokenizer.token(other) + } +} + +object SparkDialect extends SparkDialect + diff --git a/quill-spark/src/test/resources/.placeholder b/quill-spark/src/test/resources/.placeholder new file mode 100644 index 0000000000..9c558e357c --- /dev/null +++ b/quill-spark/src/test/resources/.placeholder @@ -0,0 +1 @@ +. diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/AliasNestedQueryColumnsSpec.scala b/quill-spark/src/test/scala/io/getquill/context/spark/AliasNestedQueryColumnsSpec.scala new file mode 100644 index 0000000000..5b17f3d028 --- /dev/null +++ b/quill-spark/src/test/scala/io/getquill/context/spark/AliasNestedQueryColumnsSpec.scala @@ -0,0 +1,68 @@ +package io.getquill.context.spark + +import io.getquill.Spec + +case class Test(i: Int, j: Int) + +class AliasNestedQueryColumnsSpec extends Spec { + + import testContext._ + import sqlContext.implicits._ + + val entities = Seq(Test(1, 2)) + + val qr1 = entities.toQuery + val qr2 = entities.toQuery + + "adds tuple alias" - { + "flatten query" in { + val q = quote { + qr1.map(e => (e.i, e.j)) + } + testContext.run(q).collect.toList mustEqual + entities.map(e => (e.i, e.j)) + } + "set operation" in { + val q = quote { + qr1 ++ qr1 + } + testContext.run(q).collect.toList mustEqual + entities ++ entities + } + "unary operation" in { + val q = quote { + qr1.filter(t => qr2.nested.nonEmpty) + } + testContext.run(q).collect.toList mustEqual + entities + } + "nested" - { + "query" in { + val q = quote { + qr1.nested + } + testContext.run(q).collect.toList mustEqual + entities + } + "join" in { + val q = quote { + qr1.join(qr2).on((a, b) => a.i == b.i).map { + case (a, b) => (a.i, b.i) + } + } + testContext.run(q).collect.toList mustEqual + List((1, 1)) + } + "flatJoin" in { + val q = quote { + for { + a <- qr1 + b <- qr2.join(b => b.i == a.i) + } yield (a.i, b.i) + } + testContext.run(q.dynamic).collect.toList mustEqual + List((1, 1)) + } + } + } +} \ No newline at end of file diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/DepartmentsSparkSpec.scala b/quill-spark/src/test/scala/io/getquill/context/spark/DepartmentsSparkSpec.scala new file mode 100644 index 0000000000..f22dffd7be --- /dev/null +++ b/quill-spark/src/test/scala/io/getquill/context/spark/DepartmentsSparkSpec.scala @@ -0,0 +1,115 @@ +package io.getquill.context.spark + +import io.getquill.Spec + +case class Department(dpt: String) +case class Employee(emp: String, dpt: String) +case class Task(emp: String, tsk: String) + +class DepartmentsJdbcSpec extends Spec { + + import testContext._ + import sqlContext.implicits._ + + val departments = Seq( + Department("Product"), + Department("Quality"), + Department("Research"), + Department("Sales") + ).toQuery + + val employees = Seq( + Employee("Alex", "Product"), + Employee("Bert", "Product"), + Employee("Cora", "Research"), + Employee("Drew", "Research"), + Employee("Edna", "Research"), + Employee("Fred", "Sales") + ).toQuery + + val tasks = Seq( + Task("Alex", "build"), + Task("Bert", "build"), + Task("Cora", "abstract"), + Task("Cora", "build"), + Task("Cora", "design"), + Task("Drew", "abstract"), + Task("Drew", "design"), + Task("Edna", "abstract"), + Task("Edna", "call"), + Task("Edna", "design"), + Task("Fred", "call") + ).toQuery + + "Example 8 - nested naive" in { + val q = quote { + (u: String) => + for { + d <- departments if ( + (for { + e <- employees if ( + e.dpt == d.dpt && ( + for { + t <- tasks if (e.emp == t.emp && t.tsk == u) + } yield {} + ).isEmpty + ) + } yield {}).isEmpty + ) + } yield d.dpt + } + testContext.run(q("abstract")).collect().toList mustEqual + List("Research", "Quality") + } + + "Example 9 - nested db" in { + val q = { + val nestedOrg = + quote { + for { + d <- departments + } yield { + (d.dpt, + for { + e <- employees if (d.dpt == e.dpt) + } yield { + (e.emp, + for { + t <- tasks if (e.emp == t.emp) + } yield { + t.tsk + }) + }) + } + } + + def any[T] = + quote { (xs: Query[T]) => (p: T => Boolean) => + (for { + x <- xs if (p(x)) + } yield {}).nonEmpty + } + + def all[T] = + quote { (xs: Query[T]) => (p: T => Boolean) => + !any(xs)(x => !p(x)) + } + + def contains[T] = + quote { (xs: Query[T]) => (u: T) => + any(xs)(x => x == u) + } + + quote { + (u: String) => + for { + (dpt, employees) <- nestedOrg if (all(employees) { case (emp, tasks) => contains(tasks)(u) }) + } yield { + dpt + } + } + } + testContext.run(q("abstract")).collect().toList mustEqual + List("Research", "Quality") + } +} diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/EncodingSparkSpec.scala b/quill-spark/src/test/scala/io/getquill/context/spark/EncodingSparkSpec.scala new file mode 100644 index 0000000000..8e94a079b6 --- /dev/null +++ b/quill-spark/src/test/scala/io/getquill/context/spark/EncodingSparkSpec.scala @@ -0,0 +1,189 @@ +package io.getquill.context.spark + +import io.getquill.Spec + +case class EncodingTestEntity( + v1: String, + v2: BigDecimal, + v3: Boolean, + v4: Byte, + v5: Short, + v6: Int, + v7: Long, + v8: Double, + v9: Array[Byte], + o1: Option[String], + o2: Option[BigDecimal], + o3: Option[Boolean], + o4: Option[Byte], + o5: Option[Short], + o6: Option[Int], + o7: Option[Long], + o8: Option[Double], + o9: Option[Array[Byte]] +) + +class EncodingSparkSpec extends Spec { + + import testContext._ + import sqlContext.implicits._ + + implicit val e = org.apache.spark.sql.Encoders.DATE + + "encodes and decodes types" in { + verify(testContext.run(entities).collect.toList) + } + + "string" in { + val v = "s" + val q = quote { + entities.filter(_.v1 == lift(v)).map(_.v1) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "bigDecimal" in { + val v = BigDecimal(1.1) + val q = quote { + entities.filter(_.v2 == lift(v)).map(_.v2) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "boolean" in { + val v = true + val q = quote { + entities.filter(_.v3 == lift(v)).map(_.v3) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "byte" in { + val v = 11.toByte + val q = quote { + entities.filter(_.v4 == lift(v)).map(_.v4) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "short" in { + val v = 23.toShort + val q = quote { + entities.filter(_.v5 == lift(v)).map(_.v5) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "int" in { + val v = 33 + val q = quote { + entities.filter(_.v6 == lift(v)).map(_.v6) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "long" in { + val v = 431L + val q = quote { + entities.filter(_.v7 == lift(v)).map(_.v7) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + "double" in { + val v = 42d + val q = quote { + entities.filter(_.v8 == lift(v)).map(_.v8) + } + testContext.run(q).collect.toList mustEqual List(v) + } + + val entities = + Seq( + EncodingTestEntity( + "s", + BigDecimal(1.1), + true, + 11.toByte, + 23.toShort, + 33, + 431L, + 42d, + Array(1.toByte, 2.toByte), + Some("s"), + Some(BigDecimal(1.1)), + Some(true), + Some(11.toByte), + Some(23.toShort), + Some(33), + Some(431L), + Some(42d), + Some(Array(1.toByte, 2.toByte)) + ), + EncodingTestEntity( + "", + BigDecimal(0), + false, + 0.toByte, + 0.toShort, + 0, + 0L, + 0D, + Array(), + None, + None, + None, + None, + None, + None, + None, + None, + None + ) + ).toQuery + + def verify(result: List[EncodingTestEntity]) = + result match { + case List(e1, e2) => + + e1.v1 mustEqual "s" + e1.v2 mustEqual BigDecimal(1.1) + e1.v3 mustEqual true + e1.v4 mustEqual 11.toByte + e1.v5 mustEqual 23.toShort + e1.v6 mustEqual 33 + e1.v7 mustEqual 431L + e1.v8 mustEqual 42d + e1.v9.toList mustEqual List(1.toByte, 2.toByte) + + e1.o1 mustEqual Some("s") + e1.o2 mustEqual Some(BigDecimal(1.1)) + e1.o3 mustEqual Some(true) + e1.o4 mustEqual Some(11.toByte) + e1.o5 mustEqual Some(23.toShort) + e1.o6 mustEqual Some(33) + e1.o7 mustEqual Some(431L) + e1.o8 mustEqual Some(42d) + e1.o9.map(_.toList) mustEqual Some(List(1.toByte, 2.toByte)) + + e2.v1 mustEqual "" + e2.v2 mustEqual BigDecimal(0) + e2.v3 mustEqual false + e2.v4 mustEqual 0.toByte + e2.v5 mustEqual 0.toShort + e2.v6 mustEqual 0 + e2.v7 mustEqual 0L + e2.v8 mustEqual 0d + e2.v9.toList mustEqual Nil + + e2.o1 mustEqual None + e2.o2 mustEqual None + e2.o3 mustEqual None + e2.o4 mustEqual None + e2.o5 mustEqual None + e2.o6 mustEqual None + e2.o7 mustEqual None + e2.o8 mustEqual None + e2.o9.map(_.toList) mustEqual None + } +} \ No newline at end of file diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/PeopleSparkSpec.scala b/quill-spark/src/test/scala/io/getquill/context/spark/PeopleSparkSpec.scala new file mode 100644 index 0000000000..1dd79c9621 --- /dev/null +++ b/quill-spark/src/test/scala/io/getquill/context/spark/PeopleSparkSpec.scala @@ -0,0 +1,111 @@ +package io.getquill.context.spark + +import io.getquill.Spec + +case class Person(name: String, age: Int) +case class Couple(her: String, him: String) + +class PeopleJdbcSpec extends Spec { + + val context = io.getquill.context.sql.testContext + + import testContext._ + import sqlContext.implicits._ + + val couples = Seq( + Couple("Alex", "Bert"), + Couple("Cora", "Drew"), + Couple("Edna", "Fred") + ).toQuery + + val people = Seq( + Person("Alex", 60), + Person("Bert", 55), + Person("Cora", 33), + Person("Drew", 31), + Person("Edna", 21), + Person("Fred", 60) + ).toQuery + + "Example 1 - differences" in { + val q = + quote { + for { + c <- couples + w <- people + m <- people if (c.her == w.name && c.him == m.name && w.age > m.age) + } yield { + (w.name, w.age - m.age) + } + } + testContext.run(q).collect.toList mustEqual + List(("Cora", 2), ("Alex", 5)) + } + + "Example 2 - range simple" in { + val rangeSimple = quote { + (a: Int, b: Int) => + for { + u <- people if (a <= u.age && u.age < b) + } yield { + u + } + } + + testContext.run(rangeSimple(30, 40)).collect.toList mustEqual + List(Person("Cora", 33), Person("Drew", 31)) + } + + val satisfies = + quote { + (p: Int => Boolean) => + for { + u <- people if (p(u.age)) + } yield { + u + } + } + + "Example 3 - satisfies" in { + testContext.run(satisfies((x: Int) => 20 <= x && x < 30)).collect.toList mustEqual + List(Person("Edna", 21)) + } + + "Example 4 - satisfies" in { + testContext.run(satisfies((x: Int) => x % 2 == 0)).collect.toList mustEqual + List(Person("Alex", 60), Person("Fred", 60)) + } + + "Example 5 - compose" in { + val q = { + val range = quote { + (a: Int, b: Int) => + for { + u <- people if (a <= u.age && u.age < b) + } yield { + u + } + } + val ageFromName = quote { + (s: String) => + for { + u <- people if (s == u.name) + } yield { + u.age + } + } + quote { + (s: String, t: String) => + for { + a <- ageFromName(s) + b <- ageFromName(t) + r <- range(a, b) + } yield { + r + } + } + } + testContext.run(q("Drew", "Bert")).collect.toList mustEqual + List(Person("Cora", 33), Person("Drew", 31)) + } +} diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/QuillSparkContextSpec.scala b/quill-spark/src/test/scala/io/getquill/context/spark/QuillSparkContextSpec.scala new file mode 100644 index 0000000000..0c8667379d --- /dev/null +++ b/quill-spark/src/test/scala/io/getquill/context/spark/QuillSparkContextSpec.scala @@ -0,0 +1,39 @@ +package io.getquill.context.spark + +import io.getquill.Spec + +class QuillSparkContextSpec extends Spec { + + import sqlContext.implicits._ + import testContext._ + + val entities = Seq(Test(1, 2)) + + "toQuery" - { + + "seq" in { + testContext.run(entities.toQuery).collect.toList mustEqual + entities + } + + "dataset" in { + val q = sqlContext.createDataset(entities).toQuery + testContext.run(q).collect.toList mustEqual + entities + } + + "rdd" in { + val q = sqlContext.sparkContext.parallelize(entities).toQuery + testContext.run(q).collect.toList mustEqual + entities + } + } + + "query single" in { + val q = quote { + entities.toQuery.map(t => t.i).max + } + testContext.run(q).collect.toList mustEqual + List(Some(1)) + } +} \ No newline at end of file diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/package.scala b/quill-spark/src/test/scala/io/getquill/context/spark/package.scala new file mode 100644 index 0000000000..0b58e15e7b --- /dev/null +++ b/quill-spark/src/test/scala/io/getquill/context/spark/package.scala @@ -0,0 +1,18 @@ +package io.getquill.context + +import org.apache.spark.sql.SparkSession +import io.getquill.QuillSparkContext + +package object spark { + + val sparkSession = + SparkSession + .builder() + .master("local") + .appName("spark test") + .getOrCreate() + + implicit val sqlContext = sparkSession.sqlContext + + val testContext = QuillSparkContext +} \ No newline at end of file diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala b/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala index fa9b0fac87..d4586462b7 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala @@ -85,6 +85,12 @@ object SqlQuery { private def flattenContexts(query: Ast): (List[FromContext], Ast) = query match { + case FlatMap(q: Infix, Ident(alias), p: Query) => + val source = this.source(q, alias) + val (nestedContexts, finalFlatMapBody) = flattenContexts(p) + (source +: nestedContexts, finalFlatMapBody) + case FlatMap(q: Infix, Ident(alias), p: Infix) => + fail(s"Infix can't be use as a `flatMap` body. $query") case FlatMap(q: Query, Ident(alias), p: Query) => val source = this.source(q, alias) val (nestedContexts, finalFlatMapBody) = flattenContexts(p)