Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31227][SQL] Non-nullable null type in complex types should not coerce to nullable type #27991

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object TypeCoercion {
}
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
findTypeFunc(kt1, kt2)
.filter { kt => Cast.canCastMapKeyNullSafe(kt1, kt) && Cast.canCastMapKeyNullSafe(kt2, kt) }
.filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) }
.flatMap { kt =>
findTypeFunc(vt1, vt2).map { vt =>
MapType(kt, vt, valueContainsNull1 || valueContainsNull2 ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ object Cast {
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canCast(fromKey, toKey) && canCastMapKeyNullSafe(fromKey, toKey) &&
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

Expand All @@ -97,11 +98,6 @@ object Cast {
case _ => false
}

def canCastMapKeyNullSafe(fromType: DataType, toType: DataType): Boolean = {
// If the original map key type is NullType, it's OK as the map must be empty.
fromType == NullType || !forceNullable(fromType, toType)
}

/**
* Return true if we need to use the `timeZone` information casting `from` type to `to` type.
* The patterns matched reflect the current implementation in the Cast node.
Expand Down Expand Up @@ -210,8 +206,13 @@ object Cast {
case _ => false // overflow
}

/**
* Returns `true` if casting non-nullable values from `from` type to `to` type
* may return null. Note that the caller side should take care of input nullability
* first and only call this method if the input is not nullable.
*/
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => true
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
case (NullType, _) => false // empty array or map case
case (_, _) if from == to => false

case (StringType, BinaryType) => false
Expand Down Expand Up @@ -269,7 +270,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}
}

override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable
override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)

protected def ansiEnabled: Boolean

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

class TypeCoercionSuite extends AnalysisTest {
import TypeCoercionSuite._

// scalastyle:off line.size.limit
// The following table shows all implicit data type conversions that are not visible to the user.
Expand Down Expand Up @@ -99,22 +100,6 @@ class TypeCoercionSuite extends AnalysisTest {
case _ => Literal.create(null, dataType)
}

val integralTypes: Seq[DataType] =
Seq(ByteType, ShortType, IntegerType, LongType)
val fractionalTypes: Seq[DataType] =
Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
val atomicTypes: Seq[DataType] =
numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
val complexTypes: Seq[DataType] =
Seq(ArrayType(IntegerType),
ArrayType(StringType),
MapType(StringType, StringType),
new StructType().add("a1", StringType),
new StructType().add("a1", StringType).add("a2", IntegerType))
val allTypes: Seq[DataType] =
atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)

// Check whether the type `checkedType` can be cast to all the types in `castableTypes`,
// but cannot be cast to the other types in `allTypes`.
private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = {
Expand Down Expand Up @@ -497,6 +482,23 @@ class TypeCoercionSuite extends AnalysisTest {
.add("null", IntegerType, nullable = false),
Some(new StructType()
.add("null", IntegerType, nullable = true)))

widenTest(
ArrayType(NullType, containsNull = false),
ArrayType(IntegerType, containsNull = false),
Some(ArrayType(IntegerType, containsNull = false)))

widenTest(MapType(NullType, NullType, false),
MapType(IntegerType, StringType, false),
Some(MapType(IntegerType, StringType, false)))

widenTest(
new StructType()
.add("null", NullType, nullable = false),
new StructType()
.add("null", IntegerType, nullable = false),
Some(new StructType()
.add("null", IntegerType, nullable = false)))
}

test("wider common type for decimal and array") {
Expand Down Expand Up @@ -728,8 +730,6 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._

ruleTest(TypeCoercion.ImplicitTypeCasts,
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))
Expand All @@ -740,8 +740,6 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("cast NullType for binary operators") {
import TypeCoercionSuite._

ruleTest(TypeCoercion.ImplicitTypeCasts,
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
Expand Down Expand Up @@ -1548,6 +1546,22 @@ class TypeCoercionSuite extends AnalysisTest {

object TypeCoercionSuite {

val integralTypes: Seq[DataType] =
Seq(ByteType, ShortType, IntegerType, LongType)
val fractionalTypes: Seq[DataType] =
Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
val atomicTypes: Seq[DataType] =
numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
val complexTypes: Seq[DataType] =
Seq(ArrayType(IntegerType),
ArrayType(StringType),
MapType(StringType, StringType),
new StructType().add("a1", StringType),
new StructType().add("a1", StringType).add("a2", IntegerType))
val allTypes: Seq[DataType] =
atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)

case class AnyTypeUnaryExpression(child: Expression)
extends UnaryExpression with ExpectsInputTypes with Unevaluable {
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -412,6 +413,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved)
checkEvaluation(ret, Seq(null, true, false, null))
}

{
val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = false))
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
assert(ret.resolved)
checkEvaluation(ret, Seq.empty)
}

{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === false)
Expand Down Expand Up @@ -1157,6 +1166,19 @@ class CastSuite extends CastSuiteBase {
StructType(StructField("a", IntegerType, true) :: Nil)))
}

test("SPARK-31227: Non-nullable null type should not coerce to nullable type") {
TypeCoercionSuite.allTypes.foreach { t =>
assert(Cast.canCast(ArrayType(NullType, false), ArrayType(t, false)))

assert(Cast.canCast(
MapType(NullType, NullType, false), MapType(t, t, false)))

assert(Cast.canCast(
StructType(StructField("a", NullType, false) :: Nil),
StructType(StructField("a", t, false) :: Nil)))
}
}

test("Cast should output null for invalid strings when ANSI is not enabled.") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
assert(e.getMessage.contains("string, binary or array"))
}

test("SPARK-31227: Non-nullable null type should not coerce to nullable type in concat") {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about concat(array(), array(NULL))? That should have the same type as array(NULL).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the output has the same type with array(null);

scala> sql("select concat(array(), array(NULL))").printSchema
root
 |-- concat(array(), array(NULL)): array (nullable = false)
 |    |-- element: null (containsNull = true)

scala> sql("select array()").printSchema
root
 |-- array(): array (nullable = false)
 |    |-- element: null (containsNull = false)

scala> sql("select array(null)").printSchema
root
 |-- array(NULL): array (nullable = false)
 |    |-- element: null (containsNull = true)

Any concern?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only that it should be tested, since it's an interesting corner case!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe those cases are tested in CastSuite.scala and TypeCoercionSuite.scala including all types if I didn't miss anything. I just kept one e2e test here since it was the reported case in the JIRA SPARK-31227.

val actual = spark.range(1).selectExpr("concat(array(), array(1)) as arr")
val expected = spark.range(1).selectExpr("array(1) as arr")
checkAnswer(actual, expected)
assert(actual.schema === expected.schema)
}

test("flatten function") {
// Test cases with a primitive type
val intDF = Seq(
Expand Down