Skip to content

Commit

Permalink
[SPARK-31227][SQL] Non-nullable null type in complex types should not…
Browse files Browse the repository at this point in the history
… coerce to nullable type

### What changes were proposed in this pull request?

This PR targets for non-nullable null type not to coerce to nullable type in complex types.

Non-nullable fields in struct, elements in an array and entries in map can mean empty array, struct and map. They are empty so it does not need to force the nullability when we find common types.

This PR also reverts and supersedes d7b97a1

### Why are the changes needed?

To make type coercion coherent and consistent. Currently, we correctly keep the nullability even between non-nullable fields:

```scala
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
spark.range(1).select(array(lit(1)).cast(ArrayType(IntegerType, false))).printSchema()
spark.range(1).select(array(lit(1)).cast(ArrayType(DoubleType, false))).printSchema()
```
```scala
spark.range(1).selectExpr("concat(array(1), array(1)) as arr").printSchema()
```

### Does this PR introduce any user-facing change?

Yes.

```scala
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
spark.range(1).select(array().cast(ArrayType(IntegerType, false))).printSchema()
```
```scala
spark.range(1).selectExpr("concat(array(), array(1)) as arr").printSchema()
```

**Before:**

```
org.apache.spark.sql.AnalysisException: cannot resolve 'array()' due to data type mismatch: cannot cast array<null> to array<int>;;
'Project [cast(array() as array<int>) AS array()#68]
+- Range (0, 1, step=1, splits=Some(12))

  at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:149)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:140)
  at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:333)
  at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
  at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:333)
  at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
  at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
  at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
```

```
root
 |-- arr: array (nullable = false)
 |    |-- element: integer (containsNull = true)
```

**After:**

```
root
 |-- array(): array (nullable = false)
 |    |-- element: integer (containsNull = false)
```

```
root
 |-- arr: array (nullable = false)
 |    |-- element: integer (containsNull = false)
```

### How was this patch tested?

Unittests were added and manually tested.

Closes #27991 from HyukjinKwon/SPARK-31227.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Mar 26, 2020
1 parent 44bd36a commit 3bd10ce
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object TypeCoercion {
}
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
findTypeFunc(kt1, kt2)
.filter { kt => Cast.canCastMapKeyNullSafe(kt1, kt) && Cast.canCastMapKeyNullSafe(kt2, kt) }
.filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) }
.flatMap { kt =>
findTypeFunc(vt1, vt2).map { vt =>
MapType(kt, vt, valueContainsNull1 || valueContainsNull2 ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ object Cast {
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canCast(fromKey, toKey) && canCastMapKeyNullSafe(fromKey, toKey) &&
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

Expand All @@ -97,11 +98,6 @@ object Cast {
case _ => false
}

def canCastMapKeyNullSafe(fromType: DataType, toType: DataType): Boolean = {
// If the original map key type is NullType, it's OK as the map must be empty.
fromType == NullType || !forceNullable(fromType, toType)
}

/**
* Return true if we need to use the `timeZone` information casting `from` type to `to` type.
* The patterns matched reflect the current implementation in the Cast node.
Expand Down Expand Up @@ -210,8 +206,13 @@ object Cast {
case _ => false // overflow
}

/**
* Returns `true` if casting non-nullable values from `from` type to `to` type
* may return null. Note that the caller side should take care of input nullability
* first and only call this method if the input is not nullable.
*/
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => true
case (NullType, _) => false // empty array or map case
case (_, _) if from == to => false

case (StringType, BinaryType) => false
Expand Down Expand Up @@ -269,7 +270,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}
}

override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable
override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)

protected def ansiEnabled: Boolean

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

class TypeCoercionSuite extends AnalysisTest {
import TypeCoercionSuite._

// scalastyle:off line.size.limit
// The following table shows all implicit data type conversions that are not visible to the user.
Expand Down Expand Up @@ -99,22 +100,6 @@ class TypeCoercionSuite extends AnalysisTest {
case _ => Literal.create(null, dataType)
}

val integralTypes: Seq[DataType] =
Seq(ByteType, ShortType, IntegerType, LongType)
val fractionalTypes: Seq[DataType] =
Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
val atomicTypes: Seq[DataType] =
numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
val complexTypes: Seq[DataType] =
Seq(ArrayType(IntegerType),
ArrayType(StringType),
MapType(StringType, StringType),
new StructType().add("a1", StringType),
new StructType().add("a1", StringType).add("a2", IntegerType))
val allTypes: Seq[DataType] =
atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)

// Check whether the type `checkedType` can be cast to all the types in `castableTypes`,
// but cannot be cast to the other types in `allTypes`.
private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = {
Expand Down Expand Up @@ -497,6 +482,23 @@ class TypeCoercionSuite extends AnalysisTest {
.add("null", IntegerType, nullable = false),
Some(new StructType()
.add("null", IntegerType, nullable = true)))

widenTest(
ArrayType(NullType, containsNull = false),
ArrayType(IntegerType, containsNull = false),
Some(ArrayType(IntegerType, containsNull = false)))

widenTest(MapType(NullType, NullType, false),
MapType(IntegerType, StringType, false),
Some(MapType(IntegerType, StringType, false)))

widenTest(
new StructType()
.add("null", NullType, nullable = false),
new StructType()
.add("null", IntegerType, nullable = false),
Some(new StructType()
.add("null", IntegerType, nullable = false)))
}

test("wider common type for decimal and array") {
Expand Down Expand Up @@ -728,8 +730,6 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._

ruleTest(TypeCoercion.ImplicitTypeCasts,
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))
Expand All @@ -740,8 +740,6 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("cast NullType for binary operators") {
import TypeCoercionSuite._

ruleTest(TypeCoercion.ImplicitTypeCasts,
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
Expand Down Expand Up @@ -1548,6 +1546,22 @@ class TypeCoercionSuite extends AnalysisTest {

object TypeCoercionSuite {

val integralTypes: Seq[DataType] =
Seq(ByteType, ShortType, IntegerType, LongType)
val fractionalTypes: Seq[DataType] =
Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
val atomicTypes: Seq[DataType] =
numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
val complexTypes: Seq[DataType] =
Seq(ArrayType(IntegerType),
ArrayType(StringType),
MapType(StringType, StringType),
new StructType().add("a1", StringType),
new StructType().add("a1", StringType).add("a2", IntegerType))
val allTypes: Seq[DataType] =
atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)

case class AnyTypeUnaryExpression(child: Expression)
extends UnaryExpression with ExpectsInputTypes with Unevaluable {
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -413,6 +414,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved)
checkEvaluation(ret, Seq(null, true, false, null))
}

{
val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = false))
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
assert(ret.resolved)
checkEvaluation(ret, Seq.empty)
}

{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === false)
Expand Down Expand Up @@ -1158,6 +1167,19 @@ class CastSuite extends CastSuiteBase {
StructType(StructField("a", IntegerType, true) :: Nil)))
}

test("SPARK-31227: Non-nullable null type should not coerce to nullable type") {
TypeCoercionSuite.allTypes.foreach { t =>
assert(Cast.canCast(ArrayType(NullType, false), ArrayType(t, false)))

assert(Cast.canCast(
MapType(NullType, NullType, false), MapType(t, t, false)))

assert(Cast.canCast(
StructType(StructField("a", NullType, false) :: Nil),
StructType(StructField("a", t, false) :: Nil)))
}
}

test("Cast should output null for invalid strings when ANSI is not enabled.") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
assert(e.getMessage.contains("string, binary or array"))
}

test("SPARK-31227: Non-nullable null type should not coerce to nullable type in concat") {
val actual = spark.range(1).selectExpr("concat(array(), array(1)) as arr")
val expected = spark.range(1).selectExpr("array(1) as arr")
checkAnswer(actual, expected)
assert(actual.schema === expected.schema)
}

test("flatten function") {
// Test cases with a primitive type
val intDF = Seq(
Expand Down

0 comments on commit 3bd10ce

Please sign in to comment.