Skip to content

Commit

Permalink
Merge branch 'SPARK-21774' into 'spark_2.1'
Browse files Browse the repository at this point in the history
[SPARK-21774] 数字类型和字符串比较的时候都统一转成double类型进行比较

现在字符串和数值的比较都是把字符串转成跟数值一样的数据格式之后再去比较  
测试case:
`select "1.1" = 1;`  
`"1.1" = 1`这样的判断,如果是把1.1转成int类型之后就是1了,它就和1相等了...  
resolve apache#110

See merge request !67
  • Loading branch information
cenyuhai committed Sep 15, 2017
2 parents f0eb740 + bd0ac93 commit 0961cac
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,42 @@ object TypeCoercion {
}
}

/**
* This function determines the target type of a comparison operator when one operand
* is a String and the other is not. It also handles when one op is a Date and the
* other is a Timestamp by making the target type to be String.
*/
val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case (StringType, DateType) => Some(StringType)
case (DateType, StringType) => Some(StringType)
case (StringType, TimestampType) => Some(StringType)
case (TimestampType, StringType) => Some(StringType)
case (TimestampType, DateType) => Some(StringType)
case (DateType, TimestampType) => Some(StringType)
case (StringType, NullType) => Some(StringType)
case (NullType, StringType) => Some(StringType)
case (StringType, r: NumericType) => Some(DoubleType)
case (l: NumericType, StringType) => Some(DoubleType)
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
case (l: AtomicType, r: StringType) if l != StringType => Some(l)
case (l, r) => None
}

/**
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
case (l, dt) if (l != dt) => Cast(expr, targetType)
case _ => expr
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
Expand All @@ -314,34 +346,10 @@ object TypeCoercion {
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))

// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))
case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))

// Comparisons between dates and timestamps.
case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))

// Checking NullType
case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
p.makeCopy(Array(left, Literal.create(null, StringType)))
case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
p.makeCopy(Array(Literal.create(null, StringType), right))

case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType =>
p.makeCopy(Array(Cast(left, right.dataType), right))
case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType =>
p.makeCopy(Array(left, Cast(right, left.dataType)))
case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))

case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
Expand All @@ -356,6 +364,8 @@ object TypeCoercion {
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType))
case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType))
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,13 @@ class TypeCoercionSuite extends PlanTest {
test("binary comparison with string promotion") {
ruleTest(PromoteStrings,
GreaterThan(Literal("123"), Literal(1)),
GreaterThan(Cast(Literal("123"), IntegerType), Literal(1)))
GreaterThan(Cast(Literal("123"), DoubleType), Cast(Literal(1), DoubleType)))
ruleTest(PromoteStrings,
GreaterThan(Literal("123"), Literal(1L)),
GreaterThan(Cast(Literal("123"), DoubleType), Cast(Literal(1L), DoubleType)))
ruleTest(PromoteStrings,
GreaterThan(Literal("123"), Literal(0.1)),
GreaterThan(Cast(Literal("123"), DoubleType), Literal(0.1)))
ruleTest(PromoteStrings,
LessThan(Literal(true), Literal("123")),
LessThan(Literal(true), Cast(Literal("123"), BooleanType)))
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2503,4 +2503,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage)
}
}

test("SPARK-21774: should cast a string to double type when compare with a int") {
withTempView("src") {
Seq(("0", 1), ("-0.4", 2)).toDF("a", "b").createOrReplaceTempView("src")
checkAnswer(sql("SELECT a FROM src WHERE a=0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM src WHERE a=0L"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM src WHERE a=0.0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM src WHERE a=-0.4"), Seq(Row("-0.4")))
}
}
}

0 comments on commit 0961cac

Please sign in to comment.