Skip to content

Commit

Permalink
rely on implict cast to handle string input
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 14, 2015
1 parent b0bff79 commit c3b9839
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ case class Logarithm(left: Expression, right: Expression)
}
}

case class Round(child: Expression, scale: Expression) extends Expression {
case class Round(child: Expression, scale: Expression) extends Expression with ExpectsInputTypes {

def this(child: Expression) = {
this(child, Literal(0))
Expand All @@ -537,30 +537,29 @@ case class Round(child: Expression, scale: Expression) extends Expression {
override def foldable: Boolean = child.foldable

override lazy val dataType: DataType = child.dataType match {
case StringType | BinaryType => DoubleType
case DecimalType.Fixed(p, s) => DecimalType(p, _scale)
case t => t
}

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegralType)

override def checkInputDataTypes(): TypeCheckResult = {
child.dataType match {
case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement
case _: NumericType => // satisfy requirement
case dt =>
return TypeCheckFailure(s"Only numeric, string or binary data types" +
s" are allowed for ROUND function, got $dt")
return TypeCheckFailure(s"Only numeric type is allowed for ROUND function, got $dt")
}
scale match {
case Literal(value, LongType) =>
if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) {
return TypeCheckFailure("ROUND scale argument out of allowed range")
}
case _ =>
if ((scale.dataType.isInstanceOf[IntegralType] || scale.dataType == NullType) &&
scale.foldable) {
// TODO: foldable LongType is not checked for out of range
if (scale.dataType.isInstanceOf[IntegralType] && scale.foldable) {
// TODO: How to check out of range for foldable LongType Expression
// satisfy requirement
} else {
return TypeCheckFailure("Only Integral or Null foldable Expression " +
return TypeCheckFailure("Only foldable Integral Expression " +
s"is allowed for ROUND scale arguments, got ${child.dataType}")
}
}
Expand Down Expand Up @@ -596,10 +595,6 @@ case class Round(child: Expression, scale: Expression) extends Expression {
numericRound(x.asInstanceOf[Float], _scale)
case DoubleType =>
numericRound(x.asInstanceOf[Double], _scale)
case StringType =>
stringLikeRound(x.asInstanceOf[UTF8String].toString, _scale)
case BinaryType =>
stringLikeRound(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale)
}
}

Expand All @@ -612,12 +607,6 @@ case class Round(child: Expression, scale: Expression) extends Expression {
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
}

private def stringLikeRound(input: String, scale: Int): Any = {
try numericRound(input.toDouble, scale) catch {
case _: NumberFormatException => null
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val ce = child.gen(ctx)

Expand All @@ -637,19 +626,6 @@ case class Round(child: Expression, scale: Expression) extends Expression {
}"""
}

def stringLikeConvert(primitive: String): String = {
val dName = ctx.freshName("converter")
s"""
Double $dName = 0.0;
try {
$dName = Double.valueOf(${primitive}.toString());
} catch (NumberFormatException e) {
${ev.isNull} = true;
}
${fractionalCheck(dName, "doubleValue()")}
"""
}

def decimalRound(): String = {
s"""
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
Expand All @@ -676,10 +652,6 @@ case class Round(child: Expression, scale: Expression) extends Expression {
fractionalCheck(ce.primitive, "floatValue()")
case DoubleType =>
fractionalCheck(ce.primitive, "doubleValue()")
case StringType =>
stringLikeConvert(ce.primitive)
case BinaryType =>
stringLikeConvert(s"${ctx.stringType}.fromBytes(${ce.primitive})")
}

ce.code + s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).")
}

def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
assertSuccess(expr)
}
assert(e.getMessage.contains(errorMessage))
}

test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
assertError(Abs('stringField), "function abs accepts numeric type")
Expand Down Expand Up @@ -173,14 +180,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}

test("check types for ROUND") {
assertError(Round(Literal(null), 'booleanField),
"Only Integral or Null foldable Expression is allowed for ROUND scale argument")
assertError(Round(Literal(null), 'complexField),
"Only Integral or Null foldable Expression is allowed for ROUND scale argument")
assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
"Only foldable Integral Expression is allowed for ROUND scale arguments")
assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
"Only foldable Integral Expression is allowed for ROUND scale arguments")
assertSuccess(Round(Literal(null), Literal(null)))
assertError(Round('booleanField, 'intField),
"Only numeric, string or binary data types are allowed for ROUND function")
assertError(Round(Literal(null), Literal(1L + Int.MaxValue)),
"Only numeric type is allowed for ROUND function")
assertErrorWithImplicitCast(Round(Literal(null), Literal(1L + Int.MaxValue)),
"ROUND scale argument out of allowed range")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("round") {
val domain = -16 to 16
val doublePi = math.Pi
val stringPi = "3.141592653589793"
val arrayPi: Array[Byte] = stringPi.toCharArray.map(_.toByte)
val shortPi: Short = 31415
val intPi = 314159265
val longPi = 31415926535897932L
Expand All @@ -352,18 +350,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
domain.foreach { scale =>
checkEvaluation(Round(doublePi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(stringPi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(arrayPi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(shortPi, scale),
BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow)
checkEvaluation(Round(intPi, scale),
BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow)
checkEvaluation(Round(longPi, scale),
BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow)
}
checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow)

// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
Expand Down

0 comments on commit c3b9839

Please sign in to comment.