Skip to content

Commit

Permalink
[SPARK-8154][SQL] Remove Term/Code type aliases in code generation.
Browse files Browse the repository at this point in the history
From my perspective as a code reviewer, I find them more confusing than using String directly.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#6694 from rxin/SPARK-8154 and squashes the following commits:

4e5056c [Reynold Xin] [SPARK-8154][SQL] Remove Term/Code type aliases in code generation.
  • Loading branch information
rxin authored and nemccarthy committed Jun 19, 2015
1 parent 2b1f7e8 commit c0c8654
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def exprId: ExprId = throw new UnsupportedOperationException

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -76,7 +76,7 @@ abstract class Expression extends TreeNode[Expression] {
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
Expand Down Expand Up @@ -166,7 +166,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (Term, Term) => Code): String = {
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
Expand All @@ -182,7 +182,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.isNull}) {
if (!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.isNull} = true;
Expand Down Expand Up @@ -217,7 +217,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
f: String => String): String = {
val eval = child.gen(ctx)
// reuse the previous isNull
ev.isNull = eval.isNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -50,7 +50,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {

private lazy val numeric = TypeUtils.getNumeric(dataType)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
}
Expand All @@ -74,7 +74,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
else math.sqrt(value)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
Expand Down Expand Up @@ -138,7 +138,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
Expand Down Expand Up @@ -236,7 +236,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
/**
* Special case handling due to division by 0 => null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -296,7 +296,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
/**
* Special case handling for x % 0 ==> null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand Down Expand Up @@ -400,7 +400,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isNativeType(left.dataType)) {

val eval1 = left.gen(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.spark.sql.types._

/**
* A function that calculates bitwise and(&) of two numbers.
*
* Code generation inherited from BinaryArithmetic.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "&"
Expand All @@ -48,6 +50,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme

/**
* A function that calculates bitwise or(|) of two numbers.
*
* Code generation inherited from BinaryArithmetic.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "|"
Expand All @@ -71,6 +75,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet

/**
* A function that calculates bitwise xor of two numbers.
*
* Code generation inherited from BinaryArithmetic.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"
Expand Down Expand Up @@ -112,8 +118,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
}

protected override def evalInternal(evalE: Any) = not(evalE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)
case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String)

/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
Expand All @@ -65,14 +65,14 @@ class CodeGenContext {
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
def freshName(prefix: String): Term = {
def freshName(prefix: String): String = {
s"$prefix${curId.getAndIncrement}"
}

/**
* Return the code to access a column for given DataType
*/
def getColumn(dataType: DataType, ordinal: Int): Code = {
def getColumn(dataType: DataType, ordinal: Int): String = {
if (isNativeType(dataType)) {
s"i.${accessorForType(dataType)}($ordinal)"
} else {
Expand All @@ -83,7 +83,7 @@ class CodeGenContext {
/**
* Return the code to update a column in Row for given DataType
*/
def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
def setColumn(dataType: DataType, ordinal: Int, value: String): String = {
if (isNativeType(dataType)) {
s"${mutatorForType(dataType)}($ordinal, $value)"
} else {
Expand All @@ -94,23 +94,23 @@ class CodeGenContext {
/**
* Return the name of accessor in Row for a DataType
*/
def accessorForType(dt: DataType): Term = dt match {
def accessorForType(dt: DataType): String = dt match {
case IntegerType => "getInt"
case other => s"get${boxedType(dt)}"
}

/**
* Return the name of mutator in Row for a DataType
*/
def mutatorForType(dt: DataType): Term = dt match {
def mutatorForType(dt: DataType): String = dt match {
case IntegerType => "setInt"
case other => s"set${boxedType(dt)}"
}

/**
* Return the Java type for a DataType
*/
def javaType(dt: DataType): Term = dt match {
def javaType(dt: DataType): String = dt match {
case IntegerType => "int"
case LongType => "long"
case ShortType => "short"
Expand All @@ -131,7 +131,7 @@ class CodeGenContext {
/**
* Return the boxed type in Java
*/
def boxedType(dt: DataType): Term = dt match {
def boxedType(dt: DataType): String = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
Expand All @@ -146,7 +146,7 @@ class CodeGenContext {
/**
* Return the representation of default value for given DataType
*/
def defaultValue(dt: DataType): Term = dt match {
def defaultValue(dt: DataType): String = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "(short)-1"
Expand All @@ -161,7 +161,7 @@ class CodeGenContext {
/**
* Returns a function to generate equal expression in Java
*/
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
case BinaryType => { case (eval1, eval2) =>
s"java.util.Arrays.equals($eval1, $eval2)" }
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ import org.apache.spark.util.Utils
*/
package object codegen {

type Term = String
type Code = String

/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
val batches =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val condEval = predicate.gen(ctx)
val trueEval = trueValue.gen(ctx)
val falseEval = falseValue.gen(ctx)
Expand Down Expand Up @@ -155,7 +155,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
return res
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val len = branchesArr.length
val got = ctx.freshName("got")

Expand Down Expand Up @@ -248,7 +248,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
return res
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val keyEval = key.gen(ctx)
val len = branchesArr.length
val got = ctx.freshName("got")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._

/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
Expand All @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
}
}
Expand All @@ -59,7 +59,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -88,7 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres

override def eval(input: Row): Any = value

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
// change the isNull and primitive to consts, to inline them
if (value == null) {
ev.isNull = "true"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
// name of function in java.lang.Math
def funcName: String = name.toLowerCase

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
Expand Down Expand Up @@ -93,7 +93,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
}
}
Expand Down Expand Up @@ -180,7 +180,7 @@ case class Atan2(left: Expression, right: Expression)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
Expand All @@ -194,7 +194,7 @@ case class Hypot(left: Expression, right: Expression)

case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
Expand Down
Loading

0 comments on commit c0c8654

Please sign in to comment.