-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-8117] [SQL] Move codegen implementation into Expression #6660
Changes from 3 commits
593d617
3ff25f8
e57959d
b145047
8c6d82d
c5fb514
12ff88a
2344bc0
48c454f
b5d3617
02262c9
86fac2c
e03edaa
bad6828
f42c732
9adaeaf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} | ||
import org.apache.spark.sql.catalyst.trees | ||
import org.apache.spark.sql.catalyst.trees.TreeNode | ||
import org.apache.spark.sql.types._ | ||
|
@@ -51,6 +52,37 @@ abstract class Expression extends TreeNode[Expression] { | |
/** Returns the result of evaluating this expression on a given input Row */ | ||
def eval(input: Row = null): Any | ||
|
||
/** | ||
* Returns an [[EvaluatedExpression]], which contains Java source code that | ||
* can be used to generate the result of evaluating the expression on an input row. | ||
* @param ctx a [[CodeGenContext]] | ||
*/ | ||
def gen(ctx: CodeGenContext): EvaluatedExpression = { | ||
val nullTerm = ctx.freshName("nullTerm") | ||
val primitiveTerm = ctx.freshName("primitiveTerm") | ||
val objectTerm = ctx.freshName("objectTerm") | ||
val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm) | ||
ve.code = genCode(ctx, ve) | ||
ve | ||
} | ||
|
||
/** | ||
* Returns Java source code for this expression | ||
*/ | ||
def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { | ||
val e = this.asInstanceOf[Expression] | ||
ctx.references += e | ||
s""" | ||
/* expression: ${this} */ | ||
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); | ||
boolean ${ev.nullTerm} = ${ev.objectTerm} == null; | ||
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = | ||
${ctx.defaultValue(e.dataType)}; | ||
if (!${ev.nullTerm}) ${ev.primitiveTerm} = | ||
(${ctx.boxedType(e.dataType)})${ev.objectTerm}; | ||
""" | ||
} | ||
|
||
/** | ||
* Returns `true` if this expression and all its children have been resolved to a specific schema | ||
* and input data types checking passed, and `false` if it still contains any unresolved | ||
|
@@ -116,6 +148,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express | |
override def nullable: Boolean = left.nullable || right.nullable | ||
|
||
override def toString: String = s"($left $symbol $right)" | ||
|
||
|
||
/** | ||
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of | ||
* the same type. If either of the sub-expressions is null, the result of this computation | ||
* is assumed to be null. | ||
* | ||
* @param f a function from two primitive term names to a tree that evaluates them. | ||
*/ | ||
def evaluate(ctx: CodeGenContext, | ||
ev: EvaluatedExpression, | ||
f: (String, String) => String): String = { | ||
// TODO: Right now some timestamp tests fail if we enforce this... | ||
if (left.dataType != right.dataType) { | ||
// log.warn(s"${left.dataType} != ${right.dataType}") | ||
} | ||
|
||
val eval1 = left.gen(ctx) | ||
val eval2 = right.gen(ctx) | ||
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) | ||
|
||
s""" | ||
${eval1.code} | ||
boolean ${ev.nullTerm} = ${eval1.nullTerm}; | ||
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; | ||
if (!${ev.nullTerm}) { | ||
${eval2.code} | ||
if(!${eval2.nullTerm}) { | ||
${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we want to cast here? the result code should return the correct time (basically we should assume in execution time, all the type casts are added explicitly; if there are any that's not, we should fix them) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ic - might be ok to leave it here for now. i'd add a comment explaining why it is here though. |
||
} else { | ||
${ev.nullTerm} = true; | ||
} | ||
} | ||
""" | ||
} | ||
} | ||
|
||
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { | ||
|
@@ -124,6 +191,18 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] | |
|
||
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { | ||
self: Product => | ||
def castOrNull(ctx: CodeGenContext, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. protected? also need to add javadoc for this |
||
ev: EvaluatedExpression, | ||
f: String => String): String = { | ||
val eval = child.gen(ctx) | ||
eval.code + s""" | ||
boolean ${ev.nullTerm} = ${eval.nullTerm}; | ||
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; | ||
if (!${ev.nullTerm}) { | ||
${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; | ||
} | ||
""" | ||
} | ||
} | ||
|
||
// TODO Semantically we probably not need GroupExpression | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext} | ||
import org.apache.spark.sql.catalyst.util.TypeUtils | ||
import org.apache.spark.sql.types._ | ||
|
||
|
@@ -86,6 +87,8 @@ case class Abs(child: Expression) extends UnaryArithmetic { | |
abstract class BinaryArithmetic extends BinaryExpression { | ||
self: Product => | ||
|
||
def decimalMethod: String = "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can u add inline comment explaining what this is for? |
||
|
||
override def dataType: DataType = left.dataType | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
|
@@ -114,12 +117,21 @@ abstract class BinaryArithmetic extends BinaryExpression { | |
} | ||
} | ||
|
||
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { | ||
if (left.dataType.isInstanceOf[DecimalType]) { | ||
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) | ||
} else { | ||
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } ) | ||
} | ||
} | ||
|
||
protected def evalInternal(evalE1: Any, evalE2: Any): Any = | ||
sys.error(s"BinaryArithmetics must override either eval or evalInternal") | ||
} | ||
|
||
case class Add(left: Expression, right: Expression) extends BinaryArithmetic { | ||
override def symbol: String = "+" | ||
override def decimalMethod: String = "$plus" | ||
|
||
override lazy val resolved = | ||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) | ||
|
@@ -134,6 +146,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { | |
|
||
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { | ||
override def symbol: String = "-" | ||
override def decimalMethod: String = "$minus" | ||
|
||
override lazy val resolved = | ||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) | ||
|
@@ -148,6 +161,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti | |
|
||
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { | ||
override def symbol: String = "*" | ||
override def decimalMethod: String = "$times" | ||
|
||
override lazy val resolved = | ||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) | ||
|
@@ -162,6 +176,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti | |
|
||
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { | ||
override def symbol: String = "/" | ||
override def decimalMethod: String = "$divide" | ||
|
||
override def nullable: Boolean = true | ||
|
||
override lazy val resolved = | ||
|
@@ -188,10 +204,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic | |
} | ||
} | ||
} | ||
|
||
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { | ||
val eval1 = left.gen(ctx) | ||
val eval2 = right.gen(ctx) | ||
val test = if (left.dataType.isInstanceOf[DecimalType]) { | ||
s"${eval2.primitiveTerm}.isZero()" | ||
} else { | ||
s"${eval2.primitiveTerm} == 0" | ||
} | ||
val method = if (left.dataType.isInstanceOf[DecimalType]) { | ||
s".$decimalMethod" | ||
} else { | ||
s"$symbol" | ||
} | ||
eval1.code + eval2.code + | ||
s""" | ||
boolean ${ev.nullTerm} = false; | ||
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = | ||
${ctx.defaultValue(left.dataType)}; | ||
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { | ||
${ev.nullTerm} = true; | ||
} else { | ||
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); | ||
} | ||
""" | ||
} | ||
} | ||
|
||
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { | ||
override def symbol: String = "%" | ||
override def decimalMethod: String = "reminder" | ||
|
||
override def nullable: Boolean = true | ||
|
||
override lazy val resolved = | ||
|
@@ -218,6 +262,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet | |
} | ||
} | ||
} | ||
|
||
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { | ||
val eval1 = left.gen(ctx) | ||
val eval2 = right.gen(ctx) | ||
val test = if (left.dataType.isInstanceOf[DecimalType]) { | ||
s"${eval2.primitiveTerm}.isZero()" | ||
} else { | ||
s"${eval2.primitiveTerm} == 0" | ||
} | ||
val method = if (left.dataType.isInstanceOf[DecimalType]) { | ||
s".$decimalMethod" | ||
} else { | ||
s"$symbol" | ||
} | ||
eval1.code + eval2.code + | ||
s""" | ||
boolean ${ev.nullTerm} = false; | ||
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = | ||
${ctx.defaultValue(left.dataType)}; | ||
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { | ||
${ev.nullTerm} = true; | ||
} else { | ||
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); | ||
} | ||
""" | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -336,6 +406,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { | |
} | ||
} | ||
|
||
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { | ||
if (ctx.isNativeType(left.dataType)) { | ||
val eval1 = left.gen(ctx) | ||
val eval2 = right.gen(ctx) | ||
eval1.code + eval2.code + s""" | ||
boolean ${ev.nullTerm} = false; | ||
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = | ||
${ctx.defaultValue(left.dataType)}; | ||
|
||
if (${eval1.nullTerm}) { | ||
${ev.nullTerm} = ${eval2.nullTerm}; | ||
${ev.primitiveTerm} = ${eval2.primitiveTerm}; | ||
} else if (${eval2.nullTerm}) { | ||
${ev.nullTerm} = ${eval1.nullTerm}; | ||
${ev.primitiveTerm} = ${eval1.primitiveTerm}; | ||
} else { | ||
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { | ||
${ev.primitiveTerm} = ${eval1.primitiveTerm}; | ||
} else { | ||
${ev.primitiveTerm} = ${eval2.primitiveTerm}; | ||
} | ||
} | ||
""" | ||
} else { | ||
super.genCode(ctx, ev) | ||
} | ||
} | ||
override def toString: String = s"MaxOf($left, $right)" | ||
} | ||
|
||
|
@@ -363,5 +460,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { | |
} | ||
} | ||
|
||
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { | ||
if (ctx.isNativeType(left.dataType)) { | ||
|
||
val eval1 = left.gen(ctx) | ||
val eval2 = right.gen(ctx) | ||
|
||
eval1.code + eval2.code + s""" | ||
boolean ${ev.nullTerm} = false; | ||
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = | ||
${ctx.defaultValue(left.dataType)}; | ||
|
||
if (${eval1.nullTerm}) { | ||
${ev.nullTerm} = ${eval2.nullTerm}; | ||
${ev.primitiveTerm} = ${eval2.primitiveTerm}; | ||
} else if (${eval2.nullTerm}) { | ||
${ev.nullTerm} = ${eval1.nullTerm}; | ||
${ev.primitiveTerm} = ${eval1.primitiveTerm}; | ||
} else { | ||
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { | ||
${ev.primitiveTerm} = ${eval1.primitiveTerm}; | ||
} else { | ||
${ev.primitiveTerm} = ${eval2.primitiveTerm}; | ||
} | ||
} | ||
""" | ||
} else { | ||
super.genCode(ctx, ev) | ||
} | ||
} | ||
|
||
override def toString: String = s"MinOf($left, $right)" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this cast?