Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8117] [SQL] Move codegen implementation into Expression #6660

Closed
wants to merge 16 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def qualifiers: Seq[String] = throw new UnsupportedOperationException

override def exprId: ExprId = throw new UnsupportedOperationException

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
s"""
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
}

object BindReferences extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -433,6 +434,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = this match {

case Cast(child @ BinaryType(), StringType) =>
castOrNull (ctx, ev, c =>
s"new org.apache.spark.sql.types.UTF8String().set($c)")

case Cast(child @ DateType(), StringType) =>
castOrNull(ctx, ev, c =>
s"""new org.apache.spark.sql.types.UTF8String().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")

case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)")

case Cast(child @ DecimalType(), IntegerType) =>
castOrNull(ctx, ev, c => s"($c).toInt()")

case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")

case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")

// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
castOrNull(ctx, ev, c =>
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))")

case other =>
super.genCode(ctx, ev)
}
}

object Cast {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -51,6 +52,37 @@ abstract class Expression extends TreeNode[Expression] {
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): Any

/**
* Returns an [[EvaluatedExpression]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
* @param ctx a [[CodeGenContext]]
*/
def gen(ctx: CodeGenContext): EvaluatedExpression = {
val nullTerm = ctx.freshName("nullTerm")
val primitiveTerm = ctx.freshName("primitiveTerm")
val objectTerm = ctx.freshName("objectTerm")
val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm)
ve.code = genCode(ctx, ve)
ve
}

/**
* Returns Java source code for this expression
*/
def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val e = this.asInstanceOf[Expression]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this cast?

ctx.references += e
s"""
/* expression: ${this} */
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) ${ev.primitiveTerm} =
(${ctx.boxedType(e.dataType)})${ev.objectTerm};
"""
}

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
Expand Down Expand Up @@ -116,6 +148,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable

override def toString: String = s"($left $symbol $right)"


/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f a function from two primitive term names to a tree that evaluates them.
*/
def evaluate(ctx: CodeGenContext,
ev: EvaluatedExpression,
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
}

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)

s"""
${eval1.code}
boolean ${ev.nullTerm} = ${eval1.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${eval2.code}
if(!${eval2.nullTerm}) {
${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to cast here? the result code should return the correct time (basically we should assume in execution time, all the type casts are added explicitly; if there are any that's not, we should fix them)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ will convert byte into int, should we do this kind of cast in Add(), Minus() and Times()?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ic - might be ok to leave it here for now. i'd add a comment explaining why it is here though.

} else {
${ev.nullTerm} = true;
}
}
"""
}
}

abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
Expand All @@ -124,6 +191,18 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
def castOrNull(ctx: CodeGenContext,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protected?

also need to add javadoc for this

ev: EvaluatedExpression,
f: String => String): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
}
"""
}
}

// TODO Semantically we probably not need GroupExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -86,6 +87,8 @@ case class Abs(child: Expression) extends UnaryArithmetic {
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>

def decimalMethod: String = ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u add inline comment explaining what this is for?


override def dataType: DataType = left.dataType

override def checkInputDataTypes(): TypeCheckResult = {
Expand Down Expand Up @@ -114,12 +117,21 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (left.dataType.isInstanceOf[DecimalType]) {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } )
} else {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } )
}
}

protected def evalInternal(evalE1: Any, evalE2: Any): Any =
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
override def decimalMethod: String = "$plus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
Expand All @@ -134,6 +146,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "-"
override def decimalMethod: String = "$minus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
Expand All @@ -148,6 +161,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "*"
override def decimalMethod: String = "$times"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
Expand All @@ -162,6 +176,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "/"
override def decimalMethod: String = "$divide"

override def nullable: Boolean = true

override lazy val resolved =
Expand All @@ -188,10 +204,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}
}

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitiveTerm}.isZero()"
} else {
s"${eval2.primitiveTerm} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
} else {
s"$symbol"
}
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
}
"""
}
}

case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "%"
override def decimalMethod: String = "reminder"

override def nullable: Boolean = true

override lazy val resolved =
Expand All @@ -218,6 +262,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}
}

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitiveTerm}.isZero()"
} else {
s"${eval2.primitiveTerm} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
} else {
s"$symbol"
}
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
} else {
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
}
"""
}
}

/**
Expand Down Expand Up @@ -336,6 +406,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
${ev.primitiveTerm} = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
${ev.nullTerm} = ${eval1.nullTerm};
${ev.primitiveTerm} = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
${ev.primitiveTerm} = ${eval1.primitiveTerm};
} else {
${ev.primitiveTerm} = ${eval2.primitiveTerm};
}
}
"""
} else {
super.genCode(ctx, ev)
}
}
override def toString: String = s"MaxOf($left, $right)"
}

Expand Down Expand Up @@ -363,5 +460,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (ctx.isNativeType(left.dataType)) {

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)

eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
${ev.primitiveTerm} = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
${ev.nullTerm} = ${eval1.nullTerm};
${ev.primitiveTerm} = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
${ev.primitiveTerm} = ${eval1.primitiveTerm};
} else {
${ev.primitiveTerm} = ${eval2.primitiveTerm};
}
}
"""
} else {
super.genCode(ctx, ev)
}
}

override def toString: String = s"MinOf($left, $right)"
}
Loading