Skip to content

Commit

Permalink
WIP: TypedSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Sep 3, 2014
1 parent e2c901b commit 6e1eaf3
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 3 deletions.
2 changes: 2 additions & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ object Catalyst {

object SQL {
lazy val settings = Seq(
addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full),
libraryDependencies += "ch.epfl.lamp" %% "scala-records" % "0.1",
initialCommands in console :=
"""
|import org.apache.spark.sql.catalyst.analysis._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,16 @@ import org.apache.spark.sql.catalyst.types._
/**
* Provides experimental support for generating catalyst schemas for scala objects.
*/
object ScalaReflection {
import scala.reflect.runtime.universe._
object ScalaReflection extends ScalaReflection {
val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
}

trait ScalaReflection {

/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe

import universe._

case class Schema(dataType: DataType, nullable: Boolean)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
with SQLConf
with ExpressionConversions
with UDFRegistration
with Serializable {
with Serializable
with TypedSQL {

self =>

Expand Down
200 changes: 200 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._

import scala.language.experimental.macros
import scala.language.existentials

import records._
import Macros.RecordMacros

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}

/**
* A collection of Scala macros for working with SQL in a type-safe way.
*/
private[sql] object SQLMacros {
import scala.reflect.macros._

def sqlImpl(c: Context)(args: c.Expr[Any]*) =
new Macros[c.type](c).sql(args)

case class Schema(dataType: DataType, nullable: Boolean)

class Macros[C <: Context](val c: C) extends ScalaReflection {
val universe: c.universe.type = c.universe

import c.universe._

val rowTpe = tq"_root_.org.apache.spark.sql.catalyst.expressions.Row"

val rMacros = new RecordMacros[c.type](c)

trait InterpolatedItem {
def placeholderName: String
def registerCode: Tree
def localRegister(catalog: Catalog, registry: FunctionRegistry)
}

case class InterpolatedUDF(index: Int, expr: c.Expr[Any], returnType: DataType)
extends InterpolatedItem{

val placeholderName = s"func$index"

def registerCode = q"""registerFunction($placeholderName, $expr)"""

def localRegister(catalog: Catalog, registry: FunctionRegistry) = {
registry.registerFunction(
placeholderName, (_: Seq[Expression]) => ScalaUdf(null, returnType, Nil))
}
}

case class InterpolatedTable(index: Int, expr: c.Expr[Any], schema: StructType)
extends InterpolatedItem{

val placeholderName = s"table$index"

def registerCode = q"""$expr.registerTempTable($placeholderName)"""

def localRegister(catalog: Catalog, registry: FunctionRegistry) = {
catalog.registerTable(None, placeholderName, LocalRelation(schema.toAttributes :_*))
}
}

case class RecSchema(name: String, index: Int, cType: DataType, tpe: Type)

def sql(args: Seq[c.Expr[Any]]) = {

val q"""
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree

//rawParts.map(_.toString).foreach(println)

val parts =
rawParts.map(
_.toString.stripPrefix("\"")
.replaceAll("\\\\", "")
.stripSuffix("\""))

val interpolatedArguments = args.zipWithIndex.map { case (arg, i) =>
// println(arg + " " + arg.actualType)
arg.actualType match {
case TypeRef(_, _, Seq(schemaType)) =>
InterpolatedTable(i, arg, schemaFor(schemaType).dataType.asInstanceOf[StructType])
case TypeRef(_, _, Seq(inputType, outputType)) =>
InterpolatedUDF(i, arg, schemaFor(outputType).dataType)
}
}

val query = parts(0) + args.indices.map { i =>
interpolatedArguments(i).placeholderName + parts(i + 1)
}.mkString("")

val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
val record = genRecord(q"row", fields)

val tree = q"""
..${interpolatedArguments.map(_.registerCode)}
val result = sql($query)
result.map(row => $record)
"""

c.Expr(tree)
}

// TODO: Handle nullable fields
def genRecord(row: Tree, fields: Seq[(String, DataType)]) = {
case class ImplSchema(name: String, tpe: Type, impl: Tree)

val implSchemas = for {
((name, dataType),i) <- fields.zipWithIndex
} yield {
val tpe = c.typeCheck(genGetField(q"null: $rowTpe", i, dataType)).tpe
val tree = genGetField(row, i, dataType)

ImplSchema(name, tpe, tree)
}

val schema = implSchemas.map(f => (f.name, f.tpe))

val (spFlds, objFields) = implSchemas.partition(s =>
rMacros.specializedTypes.contains(s.tpe))

val spImplsByTpe = {
val grouped = spFlds.groupBy(_.tpe)
grouped.mapValues { _.map(s => s.name -> s.impl).toMap }
}

val dataObjImpl = {
val impls = objFields.map(s => s.name -> s.impl).toMap
val lookupTree = rMacros.genLookup(q"fieldName", impls, mayCache = false)
q"($lookupTree).asInstanceOf[T]"
}

rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) {
case tpe if spImplsByTpe.contains(tpe) =>
rMacros.genLookup(q"fieldName", spImplsByTpe(tpe), mayCache = false)
}
}

/**
* Generate a tree that retrieves a given field for a given type.
* Constructs a nested record if necessary
*/
def genGetField(row: Tree, index: Int, t: DataType): Tree = t match {
case t: PrimitiveType =>
val methodName = newTermName("get" + primitiveForType(t))
q"$row.$methodName($index)"
case StructType(structFields) =>
val fields = structFields.map(f => (f.name, f.dataType))
genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields)
case _ =>
c.abort(NoPosition, s"Query returns currently unhandled field type: $t")
}
}

// TODO: Duplicated from codegen PR...
protected def primitiveForType(dt: PrimitiveType) = dt match {
case IntegerType => "Int"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case StringType => "String"
}
}

trait TypedSQL {
self: SQLContext =>

/**
* :: Experimental ::
* Adds a string interpolator that allows users to run Spark SQL Queries that return type-safe
* results.
*
* This features is experimental and the return types of this interpolation may change in future
* releases.
*/
@Experimental
implicit class SQLInterpolation(val strCtx: StringContext) {
// TODO: Handle functions...
def sql(args: Any*): Any = macro SQLMacros.sqlImpl
}
}
76 changes: 76 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalatest.FunSuite

import org.apache.spark.sql.test.TestSQLContext

case class Person(name: String, age: Int)

case class Car(owner: Person, model: String)

class TypedSqlSuite extends FunSuite {
import TestSQLContext._

val people = sparkContext.parallelize(
Person("Michael", 30) ::
Person("Bob", 40) :: Nil)

val cars = sparkContext.parallelize(
Car(Person("Michael", 30), "GrandAm") :: Nil)

test("typed query") {
val results = sql"SELECT name FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
}

test("int results") {
val results = sql"SELECT * FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
assert(results.first().age == 30)
}

test("nested results") {
val results = sql"SELECT * FROM $cars"
assert(results.first().owner.name == "Michael")
}

test("join query") {
val results = sql"""SELECT a.name FROM $people a JOIN $people b ON a.age = b.age"""

assert(results.first().name == "Michael")
}

test("lambda udf") {
def addOne = (_: Int) + 1
val result = sql"SELECT $addOne(1) as two, $addOne(2) as three".first
assert(result.two === 2)
assert(result.three === 3)
}

test("with quotes") {
assert(sql"SELECT 'test' as str".first.str == "test")
}

ignore("function udf") {
// This does not even get to the macro code.
// def addOne(i: Int) = i + 1
// assert(sql"SELECT $addOne(1) as two".first.two === 2)
}
}

0 comments on commit 6e1eaf3

Please sign in to comment.