Skip to content

Commit

Permalink
handle log expressions similar to Hive
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 17, 2015
1 parent 188be51 commit 63dee44
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
}
}

abstract class UnaryLogExpression(f: Double => Double, name: String)
extends UnaryMathExpression(f, name) { self: Product =>

// values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
protected val yAsymptote: Double = 0.0

protected override def nullSafeEval(input: Any): Any = {
val d = input.asInstanceOf[Double]
if (d <= yAsymptote) null else f(d)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, c =>
s"""
if ($c <= $yAsymptote) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.${funcName}($c);
}
"""
)
}
}

/**
* A binary expression specifically for math functions that take two `Double`s as input and returns
* a `Double`.
Expand Down Expand Up @@ -390,18 +414,28 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas
}
}

case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG")

case class Log2(child: Expression)
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.log($c) / java.lang.Math.log(2)")
nullSafeCodeGen(ctx, ev, c =>
s"""
if ($c <= $yAsymptote) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2);
}
"""
)
}
}

case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10")

case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") {
protected override val yAsymptote: Double = -1.0
}

case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
Expand Down Expand Up @@ -675,11 +709,32 @@ case class Logarithm(left: Expression, right: Expression)
this(EulerNumber(), child)
}

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val dLeft = input1.asInstanceOf[Double]
val dRight = input2.asInstanceOf[Double]
// Unlike Hive, we support Log base in (0.0, 1.0]
if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (left.isInstanceOf[EulerNumber]) {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
nullSafeCodeGen(ctx, ev, (c1, c2) =>
s"""
if ($c2 <= 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.log($c2);
}
""")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
nullSafeCodeGen(ctx, ev, (c1, c2) =>
s"""
if ($c1 <= 0.0 || $c2 <= 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1);
}
""")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param c expression
* @param f The functions in scala.math or elsewhere used to generate expected results
* @param domain The set of values to run the function with
* @param expectNull Whether the given values should return null or not
* @param expectNaN Whether the given values should eval to NaN or not
* @tparam T Generic type for primitives
* @tparam U Generic type for the output of the given function `f`
Expand All @@ -58,9 +59,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
c: Expression => Expression,
f: T => U,
domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
expectNull: Boolean = false,
expectNaN: Boolean = false,
evalType: DataType = DoubleType): Unit = {
if (expectNaN) {
if (expectNull) {
domain.foreach { case value =>
checkEvaluation(c(Literal(value)), null, EmptyRow)
}
} else if (expectNaN) {
domain.foreach { value =>
checkNaN(c(Literal(value)), EmptyRow)
}
Expand All @@ -78,14 +84,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param c The DataFrame function
* @param f The functions in scala.math
* @param domain The set of values to run the function with
* @param expectNull Whether the given values should return null or not
* @param expectNaN Whether the given values should eval to NaN or not
*/
private def testBinary(
c: (Expression, Expression) => Expression,
f: (Double, Double) => Double,
domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
expectNaN: Boolean = false): Unit = {
if (expectNaN) {
expectNull: Boolean = false, expectNaN: Boolean = false): Unit = {
if (expectNull) {
domain.foreach { case (v1, v2) =>
checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null))
}
} else if (expectNaN) {
domain.foreach { case (v1, v2) =>
checkNaN(c(Literal(v1), Literal(v2)), EmptyRow)
}
Expand Down Expand Up @@ -265,18 +276,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("log") {
testUnary(Log, math.log, (0 to 20).map(_ * 0.1))
testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNaN = true)
testUnary(Log, math.log, (1 to 20).map(_ * 0.1))
testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true)
}

test("log10") {
testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1))
testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNaN = true)
testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1))
testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true)
}

test("log1p") {
testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1))
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNaN = true)
testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1))
testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true)
}

test("bin") {
Expand All @@ -298,8 +309,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("log2") {
def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
testUnary(Log2, f, (0 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNaN = true)
testUnary(Log2, f, (1 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true)
}

test("sqrt") {
Expand Down Expand Up @@ -406,12 +417,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))

// negative input should yield NaN output
checkNaN(
// negative input should yield null output
checkEvaluation(
Logarithm(Literal(-1.0), Literal(1.0)),
null,
create_row(null))
checkNaN(
checkEvaluation(
Logarithm(Literal(1.0), Literal(-1.0)),
null,
create_row(null))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class MathExpressionsSuite extends QueryTest {
if (f(-1) === math.log1p(-1)) {
checkAnswer(
nnDoubleData.select(c('b)),
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity)
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null)
)
}

Expand Down

0 comments on commit 63dee44

Please sign in to comment.