From 2b24a71fec021755f43db99628a56bd4a01518eb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 27 Aug 2019 22:13:23 +0800 Subject: [PATCH] [SPARK-28495][SQL] Introduce ANSI store assignment policy for table insertion ### What changes were proposed in this pull request? Introduce ANSI store assignment policy for table insertion. With ANSI policy, Spark performs the type coercion of table insertion as per ANSI SQL. ### Why are the changes needed? In Spark version 2.4 and earlier, when inserting into a table, Spark will cast the data type of input query to the data type of target table by coercion. This can be super confusing, e.g. users make a mistake and write string values to an int column. In data source V2, by default, only upcasting is allowed when inserting data into a table. E.g. int -> long and int -> string are allowed, while decimal -> double or long -> int are not allowed. The rules of UpCast was originally created for Dataset type coercion. They are quite strict and different from the behavior of all existing popular DBMS. This is breaking change. It is possible that existing queries are broken after 3.0 releases. Following ANSI SQL standard makes Spark consistent with the table insertion behaviors of popular DBMS like PostgreSQL/Oracle/Mysql. ### Does this PR introduce any user-facing change? A new optional mode for table insertion. ### How was this patch tested? Unit test Closes #25581 from gengliangwang/ANSImode. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../analysis/TableOutputResolver.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 30 +++ .../apache/spark/sql/internal/SQLConf.scala | 7 +- .../org/apache/spark/sql/types/DataType.scala | 25 ++- .../analysis/DataSourceV2AnalysisSuite.scala | 172 ++++++++++------ .../DataTypeWriteCompatibilitySuite.scala | 189 ++++++++++++------ .../spark/sql/sources/InsertSuite.scala | 52 +++++ .../sql/test/DataFrameReaderWriterSuite.scala | 22 ++ 8 files changed, 367 insertions(+), 135 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index f0991f1927985..6769773cfec45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -108,10 +108,11 @@ object TableOutputResolver { case StoreAssignmentPolicy.LEGACY => outputField - case StoreAssignmentPolicy.STRICT => + case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => // run the type check first to ensure type errors are present val canWrite = DataType.canWrite( - queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, addError) + queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, + storeAssignmentPolicy, addError) if (queryExpr.nullable && !tableAttr.nullable) { addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index baabf193abfb8..452f084e19bc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -158,6 +158,36 @@ object Cast { case _ => false } + def canANSIStoreAssign(from: DataType, to: DataType): Boolean = (from, to) match { + case _ if from == to => true + case (_: NumericType, _: NumericType) => true + case (_: AtomicType, StringType) => true + case (_: CalendarIntervalType, StringType) => true + case (DateType, TimestampType) => true + case (TimestampType, DateType) => true + // Spark supports casting between long and timestamp, please see `longToTimestamp` and + // `timestampToLong` for details. + case (TimestampType, LongType) => true + case (LongType, TimestampType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + resolvableNullability(fn, tn) && canANSIStoreAssign(fromType, toType) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolvableNullability(fn, tn) && canANSIStoreAssign(fromKey, toKey) && + canANSIStoreAssign(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (f1, f2) => + resolvableNullability(f1.nullable, f2.nullable) && + canANSIStoreAssign(f1.dataType, f2.dataType) + } + + case _ => false + } + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index abe7353efd0b1..bde9a87b81f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1637,14 +1637,15 @@ object SQLConf { .createWithDefault(PartitionOverwriteMode.STATIC.toString) object StoreAssignmentPolicy extends Enumeration { - val LEGACY, STRICT = Value + val ANSI, LEGACY, STRICT = Value } val STORE_ASSIGNMENT_POLICY = buildConf("spark.sql.storeAssignmentPolicy") .doc("When inserting a value into a column with different data type, Spark will perform " + - "type coercion. Currently we support 2 policies for the type coercion rules: legacy and " + - "strict. With legacy policy, Spark allows casting any value to any data type. " + + "type coercion. Currently we support 3 policies for the type coercion rules: ansi, " + + "legacy and strict. With ansi policy, Spark performs the type coercion as per ANSI SQL. " + + "With legacy policy, Spark allows casting any value to any data type. " + "The legacy policy is the only behavior in Spark 2.x and it is compatible with Hive. " + "With strict policy, Spark doesn't allow any possible precision loss or data truncation " + "in type coercion, e.g. `int` to `long` and `float` to `double` are not allowed." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a35e971d08823..3a10a56f6937f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT} import org.apache.spark.util.Utils /** @@ -371,12 +373,14 @@ object DataType { byName: Boolean, resolver: Resolver, context: String, + storeAssignmentPolicy: StoreAssignmentPolicy.Value, addError: String => Unit): Boolean = { (write, read) match { case (wArr: ArrayType, rArr: ArrayType) => // run compatibility check first to produce all error messages val typesCompatible = canWrite( - wArr.elementType, rArr.elementType, byName, resolver, context + ".element", addError) + wArr.elementType, rArr.elementType, byName, resolver, context + ".element", + storeAssignmentPolicy, addError) if (wArr.containsNull && !rArr.containsNull) { addError(s"Cannot write nullable elements to array of non-nulls: '$context'") @@ -391,9 +395,11 @@ object DataType { // run compatibility check first to produce all error messages val keyCompatible = canWrite( - wMap.keyType, rMap.keyType, byName, resolver, context + ".key", addError) + wMap.keyType, rMap.keyType, byName, resolver, context + ".key", + storeAssignmentPolicy, addError) val valueCompatible = canWrite( - wMap.valueType, rMap.valueType, byName, resolver, context + ".value", addError) + wMap.valueType, rMap.valueType, byName, resolver, context + ".value", + storeAssignmentPolicy, addError) if (wMap.valueContainsNull && !rMap.valueContainsNull) { addError(s"Cannot write nullable values to map of non-nulls: '$context'") @@ -409,7 +415,8 @@ object DataType { val nameMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) val fieldContext = s"$context.${rField.name}" val typesCompatible = canWrite( - wField.dataType, rField.dataType, byName, resolver, fieldContext, addError) + wField.dataType, rField.dataType, byName, resolver, fieldContext, + storeAssignmentPolicy, addError) if (byName && !nameMatch) { addError(s"Struct '$context' $i-th field name does not match " + @@ -441,7 +448,7 @@ object DataType { fieldCompatible - case (w: AtomicType, r: AtomicType) => + case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT => if (!Cast.canUpCast(w, r)) { addError(s"Cannot safely cast '$context': $w to $r") false @@ -449,6 +456,14 @@ object DataType { true } + case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI => + if (!Cast.canANSIStoreAssign(w, r)) { + addError(s"Cannot safely cast '$context': $w to $r") + false + } else { + true + } + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index c757015c754b7..eade9b6112fe4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types._ -class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -37,7 +37,17 @@ class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { } } -class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -47,7 +57,17 @@ class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuit } } -class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -104,6 +124,12 @@ class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { } } +class V2OverwriteByExpressionStrictAnalysisSuite extends V2OverwriteByExpressionANSIAnalysisSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) + .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT) +} + case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } @@ -114,12 +140,85 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) override def skipSchemaResolution: Boolean = true } -abstract class DataSourceV2AnalysisSuite extends AnalysisTest { +abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) + .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.ANSI) +} - override def getAnalyzer(caseSensitive: Boolean): Analyzer = { - val conf = new SQLConf() - .copy(SQLConf.CASE_SENSITIVE -> caseSensitive) +abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT) + + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("byName: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = byName(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot write nullable values to non-null column", "'x'", + "Cannot find data for output column", "'y'")) + } + + + test("byPosition: fail canWrite check") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val parsedPlan = byPosition(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("byPosition: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = byPosition(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", + "Cannot safely cast", "'x'", "DoubleType to FloatType")) + } +} + +abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { + + protected def getSQLConf(caseSensitive: Boolean): SQLConf = + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + + override def getAnalyzer(caseSensitive: Boolean): Analyzer = { + val conf = getSQLConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf) catalog.createDatabase( CatalogDatabase("default", "", new URI("loc"), Map.empty), @@ -254,15 +353,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("byName: fail canWrite check") { - val parsedPlan = byName(table, widerTable) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) - } - test("byName: insert safe cast") { val x = table.output.head val y = table.output.last @@ -294,25 +384,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("byName: multiple field errors are reported") { - val xRequiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) - - val query = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) - - val parsedPlan = byName(xRequiredTable, query) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write incompatible data to table", "'table-name'", - "Cannot safely cast", "'x'", "DoubleType to FloatType", - "Cannot write nullable values to non-null column", "'x'", - "Cannot find data for output column", "'y'")) - } - test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), @@ -396,19 +467,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("byPosition: fail canWrite check") { - val widerTable = TestRelation(StructType(Seq( - StructField("a", DoubleType), - StructField("b", DoubleType))).toAttributes) - - val parsedPlan = byPosition(table, widerTable) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) - } - test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), @@ -444,24 +502,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("byPosition: multiple field errors are reported") { - val xRequiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) - - val query = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) - - val parsedPlan = byPosition(xRequiredTable, query) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write incompatible data to table", "'table-name'", - "Cannot write nullable values to non-null column", "'x'", - "Cannot safely cast", "'x'", "DoubleType to FloatType")) - } - test("bypass output column resolution") { val table = TestRelationAcceptAnySchema(StructType(Seq( StructField("a", FloatType, nullable = false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 6b5fc5f0d4434..784cc7a70489f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -22,20 +22,136 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -class DataTypeWriteCompatibilitySuite extends SparkFunSuite { - private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, - DoubleType, DateType, TimestampType, StringType, BinaryType) +class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { + override def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.STRICT - private val point2 = StructType(Seq( + override def canCast: (DataType, DataType) => Boolean = Cast.canUpCast + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: casting Long to Integer is not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } +} + +class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { + override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.ANSI + + override def canCast: (DataType, DataType) => Boolean = Cast.canANSIStoreAssign + + test("Check map value types: unsafe casts are not allowed") { + val mapOfString = MapType(StringType, StringType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfString, mapOfInt, "m", + "Should not allow map of strings to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + private val stringPoint2 = StructType(Seq( + StructField("x", StringType, nullable = false), + StructField("y", StringType, nullable = false))) + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(stringPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfString = ArrayType(StringType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfString, arrayOfInt, "arr", + "Should not allow array of strings to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyString = MapType(StringType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyString, mapKeyInt, "m", + "Should not allow map of string keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } +} + +abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { + protected def storeAssignmentPolicy: StoreAssignmentPolicy.Value + + protected def canCast: (DataType, DataType) => Boolean + + protected val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType, StringType, BinaryType) + + protected val point2 = StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", FloatType, nullable = false))) - private val widerPoint2 = StructType(Seq( + protected val widerPoint2 = StructType(Seq( StructField("x", DoubleType, nullable = false), StructField("y", DoubleType, nullable = false))) - private val point3 = StructType(Seq( + protected val point3 = StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", FloatType, nullable = false), StructField("z", FloatType))) @@ -67,7 +183,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { test("Check atomic types: write allowed only when casting is safe") { atomicTypes.foreach { w => atomicTypes.foreach { r => - if (Cast.canUpCast(w, r)) { + if (canCast(w, r)) { assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") } else { @@ -172,18 +288,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { } } - test("Check struct types: unsafe casts are not allowed") { - assertNumErrors(widerPoint2, point2, "t", - "Should fail because types require unsafe casts", 2) { errs => - - assert(errs(0).contains("'t.x'"), "Should include the nested field name context") - assert(errs(0).contains("Cannot safely cast")) - - assert(errs(1).contains("'t.y'"), "Should include the nested field name context") - assert(errs(1).contains("Cannot safely cast")) - } - } - test("Check struct types: type promotion is allowed") { assertAllowed(point2, widerPoint2, "t", "Should allow widening float fields x and y to double") @@ -203,18 +307,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow writing point (x,y) to point(x,y,z=null)") } - test("Check array types: unsafe casts are not allowed") { - val arrayOfLong = ArrayType(LongType) - val arrayOfInt = ArrayType(IntegerType) - - assertSingleError(arrayOfLong, arrayOfInt, "arr", - "Should not allow array of longs to array of ints") { err => - assert(err.contains("'arr.element'"), - "Should identify problem with named array's element type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check array types: type promotion is allowed") { val arrayOfLong = ArrayType(LongType) val arrayOfInt = ArrayType(IntegerType) @@ -241,17 +333,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow array of required elements to array of optional elements") } - test("Check map value types: unsafe casts are not allowed") { - val mapOfLong = MapType(StringType, LongType) - val mapOfInt = MapType(StringType, IntegerType) - - assertSingleError(mapOfLong, mapOfInt, "m", - "Should not allow map of longs to map of ints") { err => - assert(err.contains("'m.value'"), "Should identify problem with named map's value type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check map value types: type promotion is allowed") { val mapOfLong = MapType(StringType, LongType) val mapOfInt = MapType(StringType, IntegerType) @@ -278,17 +359,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow map of required elements to map of optional elements") } - test("Check map key types: unsafe casts are not allowed") { - val mapKeyLong = MapType(LongType, StringType) - val mapKeyInt = MapType(IntegerType, StringType) - - assertSingleError(mapKeyLong, mapKeyInt, "m", - "Should not allow map of long keys to map of int keys") { err => - assert(err.contains("'m.key'"), "Should identify problem with named map's key type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check map key types: type promotion is allowed") { val mapKeyLong = MapType(LongType, StringType) val mapKeyInt = MapType(IntegerType, StringType) @@ -317,9 +387,9 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { StructField("a", ArrayType(StringType)), StructField("arr_of_structs", ArrayType(point3)), StructField("bad_nested_type", point3), - StructField("m", MapType(DoubleType, DoubleType)), + StructField("m", MapType(StringType, BooleanType)), StructField("map_of_structs", MapType(StringType, missingMiddleField)), - StructField("y", LongType) + StructField("y", StringType) )) assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => @@ -342,11 +412,11 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { assert(errs(5).contains("'top.m.key'"), "Should identify bad type") assert(errs(5).contains("Cannot safely cast")) - assert(errs(5).contains("DoubleType to LongType")) + assert(errs(5).contains("StringType to LongType")) assert(errs(6).contains("'top.m.value'"), "Should identify bad type") assert(errs(6).contains("Cannot safely cast")) - assert(errs(6).contains("DoubleType to FloatType")) + assert(errs(6).contains("BooleanType to FloatType")) assert(errs(7).contains("'top.m'"), "Should identify bad type") assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) @@ -364,7 +434,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { assert(errs(11).contains("'top.x'"), "Should identify bad type") assert(errs(11).contains("Cannot safely cast")) - assert(errs(11).contains("LongType to IntegerType")) + assert(errs(11).contains("StringType to IntegerType")) assert(errs(12).contains("'top'"), "Should identify bad type") assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") @@ -386,6 +456,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { byName: Boolean = true): Unit = { assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, + storeAssignmentPolicy, errMsg => fail(s"Should not produce errors but was called with: $errMsg")), desc) } @@ -411,7 +482,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { val errs = new mutable.ArrayBuffer[String]() assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, - errMsg => errs += errMsg) === false, desc) + storeAssignmentPolicy, errMsg => errs += errMsg) === false, desc) assert(errs.size === numErrs, s"Should produce $numErrs error messages") checkErrors(errs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fda97f4e33cee..8f6c47cae03cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Date import org.apache.spark.SparkException import org.apache.spark.sql._ @@ -582,6 +583,57 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("Throw exception on unsafe cast with ANSI casting policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d double) using parquet") + var msg = intercept[AnalysisException] { + sql("insert into t values('a', 'b')") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && + msg.contains("Cannot safely cast 'd': StringType to DoubleType")) + msg = intercept[AnalysisException] { + sql("insert into t values(now(), now())") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': TimestampType to IntegerType") && + msg.contains("Cannot safely cast 'd': TimestampType to DoubleType")) + msg = intercept[AnalysisException] { + sql("insert into t values(true, false)") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) + } + } + } + + test("Allow on writing any numeric value to numeric type with ANSI policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d float) using parquet") + sql("insert into t values(1L, 2.0)") + sql("insert into t values(3.0, 4)") + sql("insert into t values(5.0, 6L)") + checkAnswer(sql("select * from t"), Seq(Row(1, 2.0F), Row(3, 4.0F), Row(5, 6.0F))) + } + } + } + + test("Allow on writing timestamp value to date type with ANSI policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i date) using parquet") + sql("insert into t values(TIMESTAMP('2010-09-02 14:10:10'))") + checkAnswer(sql("select * from t"), Seq(Row(Date.valueOf("2010-09-02")))) + } + } + } + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { withTempPath { path => Seq((1, 1), (2, 2)).toDF("i", "part") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 441750e5a9bc4..369feb504153e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -327,6 +327,28 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } + test("Throw exception on unsafe cast with ANSI casting policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d double) using parquet") + // Calling `saveAsTable` to an existing table with append mode results in table insertion. + var msg = intercept[AnalysisException] { + Seq(("a", "b")).toDF("i", "d").write.mode("append").saveAsTable("t") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && + msg.contains("Cannot safely cast 'd': StringType to DoubleType")) + + msg = intercept[AnalysisException] { + Seq((true, false)).toDF("i", "d").write.mode("append").saveAsTable("t") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) + } + } + } + test("test path option in load") { spark.read .format("org.apache.spark.sql.test")