From 61760eeb92476fe327fd352c3879f71e31289e64 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 11 Jul 2015 22:02:44 +0800 Subject: [PATCH] address reviews --- .../spark/sql/catalyst/expressions/math.scala | 145 +++++++++++++----- .../expressions/MathFunctionsSuite.scala | 51 +++--- 2 files changed, 141 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index cdb2db8c4b046..7e7a0e280a62d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -523,6 +523,20 @@ case class Logarithm(left: Expression, right: Expression) } } +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ case class Round(child: Expression, scale: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -559,10 +573,27 @@ case class Round(child: Expression, scale: Expression) } } - private lazy val scaleV = scale.eval(EmptyRow) - private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] - protected override def nullSafeEval(input1: Any, input2: Any): Any = { + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) + } + } + } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] @@ -604,45 +635,89 @@ case class Round(child: Expression, scale: Expression) ${ev.isNull} = true; }""" case ByteType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } case ShortType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } case IntegerType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } case LongType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" - case FloatType => - s""" - if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" } else { - ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); - }""" - case DoubleType => - s""" - if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" } else { - ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); - }""" + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } } - ce.code + s""" - boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${evaluationCode} - } + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } """ + } } + + override def prettyName: String = "round" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7aa924c6d4584..52a874a9d89ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -340,32 +340,43 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round") { - val domain = -16 to 16 - val doublePi = math.Pi + val domain = -6 to 6 + val doublePi: Double = math.Pi val shortPi: Short = 31415 - val intPi = 314159265 - val longPi = 31415926535897932L - val bdPi = BigDecimal(31415926535897932L, 10) - - domain.foreach { scale => - checkEvaluation(Round(doublePi, scale), - BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(shortPi, scale), - BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow) - checkEvaluation(Round(intPi, scale), - BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow) - checkEvaluation(Round(longPi, scale), - BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow) + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null - val (validScales, nullScales) = domain.splitAt(27) - validScales.foreach { scale => - checkEvaluation(Round(bdPi, scale), - Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow) + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) } - nullScales.foreach { scale => + (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) } }