diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..a967ab9270215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -221,6 +221,22 @@ trait HiveTypeCoercion { b.makeCopy(Array(newLeft, newRight)) }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. } + + // Also widen types for InExpressions. + case q: LogicalPlan => q transformExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => + b.map(_.dataType).foldLeft(None: Option[DataType])((r, c) => r match { + case None => Some(c) + case Some(dt) => findTightestCommonType(dt, c) + }) match { + // If there is no applicable conversion, leave expression unchanged. + case None => i.makeCopy(Array(a, b)) + case Some(dt) => i.makeCopy(Array(Cast(a, dt), b.map(Cast(_, dt)))) + } + } } } @@ -270,6 +286,14 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + case i @ In(a, b) if a.dataType == StringType + && b.exists(_.dataType.isInstanceOf[NumericType]) => + i.makeCopy(Array(Cast(a, DoubleType), b)) + case i @ In(a, b) if b.exists(_.dataType == StringType) + && a.dataType.isInstanceOf[NumericType] => + i.makeCopy(Array(a, b.map(_.dataType match{ + case StringType => Cast(a, DoubleType) + }))) case Sum(e) if e.dataType == StringType => Sum(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2d03fbfb0d311..9dc6f3c398ae1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -308,8 +308,8 @@ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => - val hSet = list.map(e => e.eval(null)) - InSet(v, HashSet() ++ hSet) + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e02e69fda3f2..b7fd2fb37e703 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -88,6 +88,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(1, 1) :: Nil) } + test("SPARK-6201 IN string promote") { + jsonRDD(sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + .registerTempTable("d") + + checkAnswer( + sql("select * from d where d.a in (1,2)"), + Seq(Row("1"), Row("2"))) + } + test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"),