Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ianoc/case class tuple converters #1131

Merged
merged 7 commits into from
Dec 16, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object ScaldingBuild extends Build {
val scalaCheckVersion = "1.11.5"
val hadoopVersion = "1.2.1"
val algebirdVersion = "0.8.2"
val bijectionVersion = "0.7.0"
val bijectionVersion = "0.7.1"
val chillVersion = "0.5.1"
val slf4jVersion = "1.6.6"
val parquetVersion = "1.6.0rc4"
Expand Down Expand Up @@ -200,6 +200,7 @@ object ScaldingBuild extends Build {
scaldingJson,
scaldingJdbc,
scaldingHadoopTest,
scaldingMacros,
maple
)

Expand Down Expand Up @@ -401,6 +402,16 @@ object ScaldingBuild extends Build {
}
).dependsOn(scaldingCore)

lazy val scaldingMacros = module("macros").settings(
libraryDependencies <++= (scalaVersion) { scalaVersion => Seq(
"org.scala-lang" % "scala-library" % scalaVersion,
"org.scala-lang" % "scala-reflect" % scalaVersion,
"com.twitter" %% "bijection-macros" % bijectionVersion
) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % "2.0.1") else Seq())
},
addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)
).dependsOn(scaldingCore, scaldingHadoopTest)

// This one uses a different naming convention
lazy val maple = Project(
id = "maple",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.twitter.scalding.macros

import scala.language.experimental.macros

import com.twitter.scalding._
import com.twitter.scalding.macros.impl.MacroImpl
import com.twitter.bijection.macros.IsCaseClass

object MacroImplicits {
/**
* This method provides proof that the given type is a case class.
*/
implicit def materializeCaseClassTupleSetter[T: IsCaseClass]: TupleSetter[T] = macro MacroImpl.caseClassTupleSetterImpl[T]
implicit def materializeCaseClassTupleConverter[T: IsCaseClass]: TupleConverter[T] = macro MacroImpl.caseClassTupleConverterImpl[T]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.twitter.scalding.macros

import scala.language.experimental.macros

import com.twitter.scalding._
import com.twitter.scalding.macros.impl.MacroImpl
import com.twitter.bijection.macros.IsCaseClass

object Macros {
// These only work for simple types inside the case class
// Nested case classes are allowed, but only: Int, Boolean, String, Long, Short, Float, Double of other types are allowed
def caseClassTupleSetter[T: IsCaseClass]: TupleSetter[T] = macro MacroImpl.caseClassTupleSetterImpl[T]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should document here that these only work when the case class is "simple" (primitive + string) or nested simple.

def caseClassTupleConverter[T: IsCaseClass]: TupleConverter[T] = macro MacroImpl.caseClassTupleConverterImpl[T]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to add a macro for:

def fieldsOf[T: IsCaseClass]: Fields
that returns a Fields object with the dotted names and the types set.

The goal is to use this with TypedDelimited and generate the types needed so cascading can parse the strings without storing strings in the tuples.

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package com.twitter.scalding.macros.impl

import scala.collection.mutable.{ Map => MMap }
import scala.language.experimental.macros
import scala.reflect.macros.Context
import scala.reflect.runtime.universe._
import scala.util.{ Try => BasicTry }

import cascading.tuple.{ Tuple, TupleEntry }

import com.twitter.scalding._
import com.twitter.bijection.macros.{ IsCaseClass, MacroGenerated }
import com.twitter.bijection.macros.impl.IsCaseClassImpl
/**
* This class contains the core macro implementations. This is in a separate module to allow it to be in
* a separate compilation unit, which makes it easier to provide helper methods interfacing with macros.
*/
object MacroImpl {
def caseClassTupleSetterNoProof[T]: TupleSetter[T] = macro caseClassTupleSetterNoProofImpl[T]

def caseClassTupleSetterImpl[T](c: Context)(proof: c.Expr[IsCaseClass[T]])(implicit T: c.WeakTypeTag[T]): c.Expr[TupleSetter[T]] =
caseClassTupleSetterNoProofImpl(c)(T)

def caseClassTupleSetterNoProofImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[TupleSetter[T]] = {
import c.universe._
//TODO get rid of the mutability
val cachedTupleSetters: MMap[Type, Int] = MMap.empty
var cacheIdx = 0

def expandMethod(outerTpe: Type, pTree: Tree): Iterable[Int => Tree] = {
outerTpe
.declarations
.collect { case m: MethodSymbol if m.isCaseAccessor => m }
.flatMap { accessorMethod =>
accessorMethod.returnType match {
case tpe if tpe =:= typeOf[String] => List((idx: Int) => q"""tup.setString(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Boolean] => List((idx: Int) => q"""tup.setBoolean(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Short] => List((idx: Int) => q"""tup.setShort(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Int] => List((idx: Int) => q"""tup.setInteger(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Long] => List((idx: Int) => q"""tup.setLong(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Float] => List((idx: Int) => q"""tup.setFloat(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Double] => List((idx: Int) => q"""tup.setDouble(${idx}, $pTree.$accessorMethod)""")
case tpe if IsCaseClassImpl.isCaseClassType(c)(tpe) =>
expandMethod(tpe,
q"""$pTree.$accessorMethod""")
case _ => c.abort(c.enclosingPosition, s"Case class ${T} is not pure primitives or nested case classes")
}
}
}

val set =
expandMethod(T.tpe, q"t")
.zipWithIndex
.map {
case (treeGenerator, idx) =>
treeGenerator(idx)
}

val res = q"""
new _root_.com.twitter.scalding.TupleSetter[$T] with _root_.com.twitter.bijection.macros.MacroGenerated {
override def apply(t: $T): _root_.cascading.tuple.Tuple = {
val tup = _root_.cascading.tuple.Tuple.size(${set.size})
..$set
tup
}
override val arity: scala.Int = ${set.size}
}
"""
c.Expr[TupleSetter[T]](res)
}

def caseClassTupleConverterNoProof[T]: TupleConverter[T] = macro caseClassTupleConverterNoProofImpl[T]

def caseClassTupleConverterImpl[T](c: Context)(proof: c.Expr[IsCaseClass[T]])(implicit T: c.WeakTypeTag[T]): c.Expr[TupleConverter[T]] =
caseClassTupleConverterNoProofImpl(c)(T)

def caseClassTupleConverterNoProofImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[TupleConverter[T]] = {
import c.universe._
//TODO get rid of the mutability
val cachedTupleConverters: MMap[Type, Int] = MMap.empty
var cacheIdx = 0
case class AccessorBuilder(builder: Tree, size: Int)

def getPrimitive(strAccessor: Tree): Int => AccessorBuilder =
{ (idx: Int) =>
AccessorBuilder(q"""${strAccessor}(${idx})""", 1)
}

def flattenAccessorBuilders(tpe: Type, idx: Int, childGetters: List[(Int => AccessorBuilder)]): AccessorBuilder = {
val (_, accessors) = childGetters.foldLeft((idx, List[AccessorBuilder]())) {
case ((curIdx, eles), t) =>
val nextEle = t(curIdx)
val idxIncr = nextEle.size
(curIdx + idxIncr, eles :+ nextEle)
}

val builder = q"""
${tpe.typeSymbol.companionSymbol}(..${accessors.map(_.builder)})
"""
val size = accessors.map(_.size).reduce(_ + _)
AccessorBuilder(builder, size)
}

def expandMethod(outerTpe: Type): List[(Int => AccessorBuilder)] = {
outerTpe.declarations
.collect { case m: MethodSymbol if m.isCaseAccessor => m.returnType }
.toList
.map { accessorMethod =>
accessorMethod match {
case tpe if tpe =:= typeOf[String] => getPrimitive(q"t.getString")
case tpe if tpe =:= typeOf[Boolean] => getPrimitive(q"t.Boolean")
case tpe if tpe =:= typeOf[Short] => getPrimitive(q"t.getShort")
case tpe if tpe =:= typeOf[Int] => getPrimitive(q"t.getInteger")
case tpe if tpe =:= typeOf[Long] => getPrimitive(q"t.getLong")
case tpe if tpe =:= typeOf[Float] => getPrimitive(q"t.getFloat")
case tpe if tpe =:= typeOf[Double] => getPrimitive(q"t.getDouble")
case tpe if IsCaseClassImpl.isCaseClassType(c)(tpe) =>
{ (idx: Int) =>
{
val childGetters = expandMethod(tpe)
flattenAccessorBuilders(tpe, idx, childGetters)
}
}
case _ => c.abort(c.enclosingPosition, s"Case class ${T} is not pure primitives or nested case classes")
}
}
}

val accessorBuilders = flattenAccessorBuilders(T.tpe, 0, expandMethod(T.tpe))

val res = q"""
new _root_.com.twitter.scalding.TupleConverter[$T] with _root_.com.twitter.bijection.macros.MacroGenerated {
override def apply(t: _root_.cascading.tuple.TupleEntry): $T = {
${accessorBuilders.builder}
}
override val arity: scala.Int = ${accessorBuilders.size}
}
"""
c.Expr[TupleConverter[T]](res)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.twitter.scalding.macros

import org.scalatest.WordSpec
import com.twitter.scalding.macros.{ _ => _ }

/**
* This test is intended to ensure that the macros do not require any imported code in scope. This is why all
* references are via absolute paths.
*/
class MacroDepHygiene extends WordSpec {
import com.twitter.bijection.macros.impl.IsCaseClassImpl

case class A(x: Int, y: String)
case class B(x: A, y: String, z: A)
class C

def isMg(a: Any) = a.isInstanceOf[com.twitter.bijection.macros.MacroGenerated]

"TupleSetter macro" should {
def isTupleSetterAvailable[T](implicit proof: com.twitter.scalding.TupleSetter[T]) = isMg(proof)

"work fine without any imports" in {
com.twitter.scalding.macros.Macros.caseClassTupleSetter[A]
com.twitter.scalding.macros.Macros.caseClassTupleSetter[B]
}

"implicitly work fine without any imports" in {
import com.twitter.scalding.macros.MacroImplicits.materializeCaseClassTupleSetter
assert(isTupleSetterAvailable[A])
assert(isTupleSetterAvailable[B])
}

"fail if not a case class" in {
assert(!isTupleSetterAvailable[C])
}
}

"TupleConverter macro" should {
def isTupleConverterAvailable[T](implicit proof: com.twitter.scalding.TupleConverter[T]) = isMg(proof)

"work fine without any imports" in {
com.twitter.scalding.macros.Macros.caseClassTupleConverter[A]
com.twitter.scalding.macros.Macros.caseClassTupleConverter[B]
}

"implicitly work fine without any imports" in {
import com.twitter.scalding.macros.MacroImplicits.materializeCaseClassTupleConverter
assert(isTupleConverterAvailable[A])
assert(isTupleConverterAvailable[B])
}

"fail if not a case class" in {
assert(!isTupleConverterAvailable[C])
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package com.twitter.scalding.macros

import cascading.tuple.{ Tuple => CTuple, TupleEntry }

import org.scalatest.{ Matchers, WordSpec }

import com.twitter.scalding._
import com.twitter.scalding.macros._
import com.twitter.scalding.macros.impl._
import com.twitter.scalding.serialization.Externalizer

import com.twitter.bijection.macros.{ IsCaseClass, MacroGenerated }

// We avoid nesting these just to avoid any complications in the serialization test
case class SampleClassA(x: Int, y: String)
case class SampleClassB(a1: SampleClassA, a2: SampleClassA, y: String)
case class SampleClassC(a: SampleClassA, b: SampleClassB, c: SampleClassA, d: SampleClassB, e: SampleClassB)

class MacrosUnitTests extends WordSpec with Matchers {
import MacroImplicits._
def isMg[T](t: T): T = {
t shouldBe a[MacroGenerated]
t
}

def mgConv[T](te: TupleEntry)(implicit conv: TupleConverter[T]): T = isMg(conv)(te)
def mgSet[T](t: T)(implicit set: TupleSetter[T]): TupleEntry = new TupleEntry(isMg(set)(t))

def shouldRoundTrip[T: IsCaseClass: TupleSetter: TupleConverter](t: T) {
t shouldBe mgConv(mgSet(t))
}

def shouldRoundTripOther[T: IsCaseClass: TupleSetter: TupleConverter](te: TupleEntry, t: T) {
val inter = mgConv(te)
inter shouldBe t
mgSet(inter) shouldBe te
}

def canExternalize(t: AnyRef) { Externalizer(t).javaWorks shouldBe true }

"MacroGenerated TupleSetter" should {
def doesJavaWork[T](implicit set: TupleSetter[T]) { canExternalize(isMg(set)) }
"be serializable for case class A" in { doesJavaWork[SampleClassA] }
"be serializable for case class B" in { doesJavaWork[SampleClassB] }
"be serializable for case class C" in { doesJavaWork[SampleClassC] }
}

"MacroGenerated TupleConverter" should {
def doesJavaWork[T](implicit conv: TupleConverter[T]) { canExternalize(isMg(conv)) }
"be serializable for case class A" in { doesJavaWork[SampleClassA] }
"be serializable for case class B" in { doesJavaWork[SampleClassB] }
"be serializable for case class C" in { doesJavaWork[SampleClassC] }
}

"MacroGenerated TupleSetter and TupleConverter" should {
"round trip class -> tupleentry -> class" in {
shouldRoundTrip(SampleClassA(100, "onehundred"))
shouldRoundTrip(SampleClassB(SampleClassA(100, "onehundred"), SampleClassA(-1, "zero"), "what"))
val a = SampleClassA(73, "hrm")
val b = SampleClassB(a, a, "hrm")
shouldRoundTrip(b)
shouldRoundTrip(SampleClassC(a, b, SampleClassA(123980, "hey"), SampleClassB(a, SampleClassA(-1, "zero"), "zoo"), b))
}

"Case Class should form expected tuple" in {
val input = SampleClassC(SampleClassA(1, "asdf"),
SampleClassB(SampleClassA(2, "bcdf"), SampleClassA(5, "jkfs"), "wetew"),
SampleClassA(9, "xcmv"),
SampleClassB(SampleClassA(23, "ck"), SampleClassA(13, "dafk"), "xcv"),
SampleClassB(SampleClassA(34, "were"), SampleClassA(654, "power"), "adsfmx"))
val setter = implicitly[TupleSetter[SampleClassC]]
val tup = setter(input)
assert(tup.size == 19)
assert(tup.get(0) === 1)
assert(tup.get(18) === "adsfmx")
}

"round trip tupleentry -> class -> tupleEntry" in {
val a_tup = CTuple.size(2)
a_tup.setInteger(0, 100)
a_tup.setString(1, "onehundred")
val a_te = new TupleEntry(a_tup)
val a = SampleClassA(100, "onehundred")
shouldRoundTripOther(a_te, a)

val b_tup = CTuple.size(5)
b_tup.setInteger(0, 100)
b_tup.setString(1, "onehundred")
b_tup.setInteger(2, 100)
b_tup.setString(3, "onehundred")
b_tup.setString(4, "what")
val b_te = new TupleEntry(b_tup)
val b = SampleClassB(a, a, "what")
shouldRoundTripOther(b_te, b)

val c_tup = CTuple.size(19)
c_tup.setInteger(0, 100)
c_tup.setString(1, "onehundred")

c_tup.setInteger(2, 100)
c_tup.setString(3, "onehundred")
c_tup.setInteger(4, 100)
c_tup.setString(5, "onehundred")
c_tup.setString(6, "what")

c_tup.setInteger(7, 100)
c_tup.setString(8, "onehundred")

c_tup.setInteger(9, 100)
c_tup.setString(10, "onehundred")
c_tup.setInteger(11, 100)
c_tup.setString(12, "onehundred")
c_tup.setString(13, "what")

c_tup.setInteger(14, 100)
c_tup.setString(15, "onehundred")
c_tup.setInteger(16, 100)
c_tup.setString(17, "onehundred")
c_tup.setString(18, "what")

val c_te = new TupleEntry(c_tup)
val c = SampleClassC(a, b, a, b, b)
shouldRoundTripOther(c_te, c)
}
}
}