Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8154][SQL] Remove Term/Code type aliases in code generation. #6694

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def exprId: ExprId = throw new UnsupportedOperationException

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -161,7 +161,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null }
})
case BooleanType =>
buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0)))
buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0))
case LongType =>
buildCast[Long](_, l => new Timestamp(l))
case IntegerType =>
Expand Down Expand Up @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -76,16 +76,16 @@ abstract class Expression extends TreeNode[Expression] {
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = ${objectTerm} == null;
Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm;
}
"""
}
Expand Down Expand Up @@ -166,7 +166,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (Term, Term) => Code): String = {
f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
Expand All @@ -182,7 +182,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.isNull}) {
if (!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.isNull} = true;
Expand Down Expand Up @@ -217,7 +217,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
f: String => String): String = {
val eval = child.gen(ctx)
// reuse the previous isNull
ev.isNull = eval.isNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -50,7 +50,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {

private lazy val numeric = TypeUtils.getNumeric(dataType)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
}
Expand All @@ -74,7 +74,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
else math.sqrt(value)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
Expand Down Expand Up @@ -138,7 +138,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
Expand Down Expand Up @@ -236,7 +236,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
/**
* Special case handling due to division by 0 => null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -296,7 +296,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
/**
* Special case handling for x % 0 ==> null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
Expand All @@ -322,102 +322,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}

/**
* A function that calculates bitwise and(&) of two numbers.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "&"

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val and: (Any, Any) => Any = dataType match {
case ByteType =>
((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any]
case ShortType =>
((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any]
case IntegerType =>
((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
case LongType =>
((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
}

protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2)
}

/**
* A function that calculates bitwise or(|) of two numbers.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "|"

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val or: (Any, Any) => Any = dataType match {
case ByteType =>
((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any]
case ShortType =>
((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any]
case IntegerType =>
((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
case LongType =>
((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
}

protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2)
}

/**
* A function that calculates bitwise xor of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val xor: (Any, Any) => Any = dataType match {
case ByteType =>
((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any]
case ShortType =>
((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any]
case IntegerType =>
((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
case LongType =>
((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
}

protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2)
}

/**
* A function that calculates bitwise not(~) of a number.
*/
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
override def toString: String = s"~$child"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")

private lazy val not: (Any) => Any = dataType match {
case ByteType =>
((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any]
case ShortType =>
((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any]
case IntegerType =>
((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any]
case LongType =>
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
}

protected override def evalInternal(evalE: Any) = not(evalE)
}

case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

Expand All @@ -442,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand Down Expand Up @@ -496,7 +400,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isNativeType(left.dataType)) {

val eval1 = left.gen(ctx)
Expand Down
Loading