From c3b9839b63affa05e3549d7e8cdb6950e9abb0ba Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 3 Jul 2015 23:48:46 +0800 Subject: [PATCH] rely on implict cast to handle string input --- .../spark/sql/catalyst/expressions/math.scala | 44 ++++--------------- .../ExpressionTypeCheckingSuite.scala | 19 +++++--- .../expressions/MathFunctionsSuite.scala | 7 --- 3 files changed, 21 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 92d8118c67252..f858650df410d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -524,7 +524,7 @@ case class Logarithm(left: Expression, right: Expression) } } -case class Round(child: Expression, scale: Expression) extends Expression { +case class Round(child: Expression, scale: Expression) extends Expression with ExpectsInputTypes { def this(child: Expression) = { this(child, Literal(0)) @@ -537,17 +537,17 @@ case class Round(child: Expression, scale: Expression) extends Expression { override def foldable: Boolean = child.foldable override lazy val dataType: DataType = child.dataType match { - case StringType | BinaryType => DoubleType case DecimalType.Fixed(p, s) => DecimalType(p, _scale) case t => t } + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegralType) + override def checkInputDataTypes(): TypeCheckResult = { child.dataType match { - case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement + case _: NumericType => // satisfy requirement case dt => - return TypeCheckFailure(s"Only numeric, string or binary data types" + - s" are allowed for ROUND function, got $dt") + return TypeCheckFailure(s"Only numeric type is allowed for ROUND function, got $dt") } scale match { case Literal(value, LongType) => @@ -555,12 +555,11 @@ case class Round(child: Expression, scale: Expression) extends Expression { return TypeCheckFailure("ROUND scale argument out of allowed range") } case _ => - if ((scale.dataType.isInstanceOf[IntegralType] || scale.dataType == NullType) && - scale.foldable) { - // TODO: foldable LongType is not checked for out of range + if (scale.dataType.isInstanceOf[IntegralType] && scale.foldable) { + // TODO: How to check out of range for foldable LongType Expression // satisfy requirement } else { - return TypeCheckFailure("Only Integral or Null foldable Expression " + + return TypeCheckFailure("Only foldable Integral Expression " + s"is allowed for ROUND scale arguments, got ${child.dataType}") } } @@ -596,10 +595,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { numericRound(x.asInstanceOf[Float], _scale) case DoubleType => numericRound(x.asInstanceOf[Double], _scale) - case StringType => - stringLikeRound(x.asInstanceOf[UTF8String].toString, _scale) - case BinaryType => - stringLikeRound(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) } } @@ -612,12 +607,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) } - private def stringLikeRound(input: String, scale: Int): Any = { - try numericRound(input.toDouble, scale) catch { - case _: NumberFormatException => null - } - } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) @@ -637,19 +626,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { }""" } - def stringLikeConvert(primitive: String): String = { - val dName = ctx.freshName("converter") - s""" - Double $dName = 0.0; - try { - $dName = Double.valueOf(${primitive}.toString()); - } catch (NumberFormatException e) { - ${ev.isNull} = true; - } - ${fractionalCheck(dName, "doubleValue()")} - """ - } - def decimalRound(): String = { s""" if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { @@ -676,10 +652,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { fractionalCheck(ce.primitive, "floatValue()") case DoubleType => fractionalCheck(ce.primitive, "doubleValue()") - case StringType => - stringLikeConvert(ce.primitive) - case BinaryType => - stringLikeConvert(s"${ctx.stringType}.fromBytes(${ce.primitive})") } ce.code + s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 6bae906ee9f57..8b596ff2526d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") } + def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains(errorMessage)) + } + test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") @@ -173,14 +180,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for ROUND") { - assertError(Round(Literal(null), 'booleanField), - "Only Integral or Null foldable Expression is allowed for ROUND scale argument") - assertError(Round(Literal(null), 'complexField), - "Only Integral or Null foldable Expression is allowed for ROUND scale argument") + assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), + "Only foldable Integral Expression is allowed for ROUND scale arguments") + assertErrorWithImplicitCast(Round(Literal(null), 'complexField), + "Only foldable Integral Expression is allowed for ROUND scale arguments") assertSuccess(Round(Literal(null), Literal(null))) assertError(Round('booleanField, 'intField), - "Only numeric, string or binary data types are allowed for ROUND function") - assertError(Round(Literal(null), Literal(1L + Int.MaxValue)), + "Only numeric type is allowed for ROUND function") + assertErrorWithImplicitCast(Round(Literal(null), Literal(1L + Int.MaxValue)), "ROUND scale argument out of allowed range") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 477ae969240e9..7aa924c6d4584 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -342,8 +342,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("round") { val domain = -16 to 16 val doublePi = math.Pi - val stringPi = "3.141592653589793" - val arrayPi: Array[Byte] = stringPi.toCharArray.map(_.toByte) val shortPi: Short = 31415 val intPi = 314159265 val longPi = 31415926535897932L @@ -352,10 +350,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { domain.foreach { scale => checkEvaluation(Round(doublePi, scale), BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(stringPi, scale), - BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(arrayPi, scale), - BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) checkEvaluation(Round(shortPi, scale), BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow) checkEvaluation(Round(intPi, scale), @@ -363,7 +357,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(longPi, scale), BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow) } - checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null