Skip to content

Commit

Permalink
add codegen, test for all types
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Jul 10, 2015
1 parent ec625b0 commit c1f6824
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{NullType, BooleanType, DataType}


case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
Expand Down Expand Up @@ -313,62 +314,102 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
}
}

case class Least(children: Expression*)
extends Expression {
case class Least(children: Expression*) extends Expression {
require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length)

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

private lazy val ordering = TypeUtils.getOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.map(_.dataType).distinct.size > 1) {
if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"differing types in Least (${children.map(_.dataType)}).")
s"The expressions should all have the same type," +
s" got LEAST (${children.map(_.dataType)}).")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
val cmp = GreaterThan
children.foldLeft[Expression](null)((r, c) => {
if (c != null) {
if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r
children.foldLeft[Any](null)((r, c) => {
val evalc = c.eval(input)
if (evalc != null) {
if (r == null || ordering.lt(evalc, r)) evalc else r
} else {
r
}
}).eval(input)
})
}

override def toString: String = s"LEAST(${children.mkString(", ")})"
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evalChildren = children.map(_.gen(ctx))
def updateEval(i: Int): String =
s"""
if (${ev.isNull} || (!${evalChildren(i).isNull} && ${
ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) {
${ev.isNull} = ${evalChildren(i).isNull};
${ev.primitive} = ${evalChildren(i).primitive};
}
"""
s"""
${evalChildren.map(_.code).mkString("\n")}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${(0 to children.length - 1).map(updateEval).mkString("\n")}
"""
}
}

case class Greatest(children: Expression*)
extends Expression {
case class Greatest(children: Expression*) extends Expression {
require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length)

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

private lazy val ordering = TypeUtils.getOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.map(_.dataType).distinct.size > 1) {
if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"differing types in Greatest (${children.map(_.dataType)}).")
s"The expressions should all have the same type," +
s" got GREATEST (${children.map(_.dataType)}).")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
val cmp = LessThan
children.foldLeft[Expression](null)((r, c) => {
if (c != null) {
if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r
children.foldLeft[Any](null)((r, c) => {
val evalc = c.eval(input)
if (evalc != null) {
if (r == null || ordering.gt(evalc, r)) evalc else r
} else {
r
}
}).eval(input)
})
}

override def toString: String = s"LEAST(${children.mkString(", ")})"
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evalChildren = children.map(_.gen(ctx))
def updateEval(i: Int): String =
s"""
if (${ev.isNull} || (!${evalChildren(i).isNull} && ${
ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) {
${ev.isNull} = ${evalChildren(i).isNull};
${ev.primitive} = ${evalChildren(i).primitive};
}
"""
s"""
${evalChildren.map(_.code).mkString("\n")}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${(0 to children.length - 1).map(updateEval).mkString("\n")}
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Timestamp, Date}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -134,21 +137,82 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}

test("greatest/least") {
test("function least") {
val row = create_row(1, 2, "a", "b", "c")
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.string.at(2)
val c4 = 'a.string.at(3)
val c5 = 'a.string.at(4)
checkEvaluation(Greatest(c4, c5, c3), "c", row)
checkEvaluation(Greatest(c2, c1), 2, row)
checkEvaluation(Least(c4, c3, c5), "a", row)
checkEvaluation(Least(c1, c2), 1, row)
checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row)
checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row)
checkEvaluation(Least(c1, c2, Literal(-1)), -1, row)
checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row)

checkEvaluation(Least(Literal(null), Literal(null)), null, InternalRow.empty)
checkEvaluation(Least(Literal(-1.0), Literal(2.5)), -1.0, InternalRow.empty)
checkEvaluation(Least(Literal(-1), Literal(2)), -1, InternalRow.empty)
checkEvaluation(
Least(Literal((-1.0).toFloat), Literal(2.5.toFloat)), (-1.0).toFloat, InternalRow.empty)
checkEvaluation(
Least(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MinValue, InternalRow.empty)
checkEvaluation(Least(Literal(1.toByte), Literal(2.toByte)), 1.toByte, InternalRow.empty)
checkEvaluation(
Least(Literal(1.toShort), Literal(2.toByte.toShort)), 1.toShort, InternalRow.empty)
checkEvaluation(Least(Literal("abc"), Literal("aaaa")), "aaaa", InternalRow.empty)
checkEvaluation(Least(Literal(true), Literal(false)), false, InternalRow.empty)
checkEvaluation(
Least(
Literal(BigDecimal("1234567890987654321123456")),
Literal(BigDecimal("1234567890987654321123458"))),
BigDecimal("1234567890987654321123456"), InternalRow.empty)
checkEvaluation(
Least(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))),
Date.valueOf("2015-01-01"), InternalRow.empty)
checkEvaluation(
Least(
Literal(Timestamp.valueOf("2015-07-01 08:00:00")),
Literal(Timestamp.valueOf("2015-07-01 10:00:00"))),
Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty)
}

test("function greatest") {
val row = create_row(1, 2, "a", "b", "c")
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.string.at(2)
val c4 = 'a.string.at(3)
val c5 = 'a.string.at(4)
checkEvaluation(Greatest(c4, c5, c3), "c", row)
checkEvaluation(Greatest(c2, c1), 2, row)
checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row)
checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row)

checkEvaluation(Greatest(Literal(null), Literal(null)), null, InternalRow.empty)
checkEvaluation(Greatest(Literal(-1.0), Literal(2.5)), 2.5, InternalRow.empty)
checkEvaluation(Greatest(Literal(-1), Literal(2)), 2, InternalRow.empty)
checkEvaluation(
Greatest(Literal((-1.0).toFloat), Literal(2.5.toFloat)), 2.5.toFloat, InternalRow.empty)
checkEvaluation(
Greatest(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MaxValue, InternalRow.empty)
checkEvaluation(Greatest(Literal(1.toByte), Literal(2.toByte)), 2.toByte, InternalRow.empty)
checkEvaluation(
Greatest(Literal(1.toShort), Literal(2.toByte.toShort)), 2.toShort, InternalRow.empty)
checkEvaluation(Greatest(Literal("abc"), Literal("aaaa")), "abc", InternalRow.empty)
checkEvaluation(Greatest(Literal(true), Literal(false)), true, InternalRow.empty)
checkEvaluation(
Greatest(
Literal(BigDecimal("1234567890987654321123456")),
Literal(BigDecimal("1234567890987654321123458"))),
BigDecimal("1234567890987654321123458"), InternalRow.empty)
checkEvaluation(
Greatest(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))),
Date.valueOf("2015-07-01"), InternalRow.empty)
checkEvaluation(
Greatest(
Literal(Timestamp.valueOf("2015-07-01 08:00:00")),
Literal(Timestamp.valueOf("2015-07-01 10:00:00"))),
Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty)
}

}

0 comments on commit c1f6824

Please sign in to comment.