Skip to content

Commit

Permalink
[SPARK-28495][SQL] Introduce ANSI store assignment policy for table i…
Browse files Browse the repository at this point in the history
…nsertion

### What changes were proposed in this pull request?
 Introduce ANSI store assignment policy for table insertion.
With ANSI policy, Spark performs the type coercion of table insertion as per ANSI SQL.

### Why are the changes needed?
In Spark version 2.4 and earlier, when inserting into a table, Spark will cast the data type of input query to the data type of target table by coercion. This can be super confusing, e.g. users make a mistake and write string values to an int column.

In data source V2, by default, only upcasting is allowed when inserting data into a table. E.g. int -> long and int -> string are allowed, while decimal -> double or long -> int are not allowed. The rules of UpCast was originally created for Dataset type coercion. They are quite strict and different from the behavior of all existing popular DBMS. This is breaking change. It is possible that existing queries are broken after 3.0 releases.

Following ANSI SQL standard makes Spark consistent with the table insertion behaviors of popular DBMS like PostgreSQL/Oracle/Mysql.

### Does this PR introduce any user-facing change?
A new optional mode for table insertion.

### How was this patch tested?
Unit test

Closes #25581 from gengliangwang/ANSImode.

Authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
gengliangwang authored and cloud-fan committed Aug 27, 2019
1 parent 70f4bbc commit 2b24a71
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ object TableOutputResolver {
case StoreAssignmentPolicy.LEGACY =>
outputField

case StoreAssignmentPolicy.STRICT =>
case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI =>
// run the type check first to ensure type errors are present
val canWrite = DataType.canWrite(
queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, addError)
queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name,
storeAssignmentPolicy, addError)
if (queryExpr.nullable && !tableAttr.nullable) {
addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'")
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,36 @@ object Cast {
case _ => false
}

def canANSIStoreAssign(from: DataType, to: DataType): Boolean = (from, to) match {
case _ if from == to => true
case (_: NumericType, _: NumericType) => true
case (_: AtomicType, StringType) => true
case (_: CalendarIntervalType, StringType) => true
case (DateType, TimestampType) => true
case (TimestampType, DateType) => true
// Spark supports casting between long and timestamp, please see `longToTimestamp` and
// `timestampToLong` for details.
case (TimestampType, LongType) => true
case (LongType, TimestampType) => true

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
resolvableNullability(fn, tn) && canANSIStoreAssign(fromType, toType)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
resolvableNullability(fn, tn) && canANSIStoreAssign(fromKey, toKey) &&
canANSIStoreAssign(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (f1, f2) =>
resolvableNullability(f1.nullable, f2.nullable) &&
canANSIStoreAssign(f1.dataType, f2.dataType)
}

case _ => false
}

private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1637,14 +1637,15 @@ object SQLConf {
.createWithDefault(PartitionOverwriteMode.STATIC.toString)

object StoreAssignmentPolicy extends Enumeration {
val LEGACY, STRICT = Value
val ANSI, LEGACY, STRICT = Value
}

val STORE_ASSIGNMENT_POLICY =
buildConf("spark.sql.storeAssignmentPolicy")
.doc("When inserting a value into a column with different data type, Spark will perform " +
"type coercion. Currently we support 2 policies for the type coercion rules: legacy and " +
"strict. With legacy policy, Spark allows casting any value to any data type. " +
"type coercion. Currently we support 3 policies for the type coercion rules: ansi, " +
"legacy and strict. With ansi policy, Spark performs the type coercion as per ANSI SQL. " +
"With legacy policy, Spark allows casting any value to any data type. " +
"The legacy policy is the only behavior in Spark 2.x and it is compatible with Hive. " +
"With strict policy, Spark doesn't allow any possible precision loss or data truncation " +
"in type coercion, e.g. `int` to `long` and `float` to `double` are not allowed."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -371,12 +373,14 @@ object DataType {
byName: Boolean,
resolver: Resolver,
context: String,
storeAssignmentPolicy: StoreAssignmentPolicy.Value,
addError: String => Unit): Boolean = {
(write, read) match {
case (wArr: ArrayType, rArr: ArrayType) =>
// run compatibility check first to produce all error messages
val typesCompatible = canWrite(
wArr.elementType, rArr.elementType, byName, resolver, context + ".element", addError)
wArr.elementType, rArr.elementType, byName, resolver, context + ".element",
storeAssignmentPolicy, addError)

if (wArr.containsNull && !rArr.containsNull) {
addError(s"Cannot write nullable elements to array of non-nulls: '$context'")
Expand All @@ -391,9 +395,11 @@ object DataType {

// run compatibility check first to produce all error messages
val keyCompatible = canWrite(
wMap.keyType, rMap.keyType, byName, resolver, context + ".key", addError)
wMap.keyType, rMap.keyType, byName, resolver, context + ".key",
storeAssignmentPolicy, addError)
val valueCompatible = canWrite(
wMap.valueType, rMap.valueType, byName, resolver, context + ".value", addError)
wMap.valueType, rMap.valueType, byName, resolver, context + ".value",
storeAssignmentPolicy, addError)

if (wMap.valueContainsNull && !rMap.valueContainsNull) {
addError(s"Cannot write nullable values to map of non-nulls: '$context'")
Expand All @@ -409,7 +415,8 @@ object DataType {
val nameMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name)
val fieldContext = s"$context.${rField.name}"
val typesCompatible = canWrite(
wField.dataType, rField.dataType, byName, resolver, fieldContext, addError)
wField.dataType, rField.dataType, byName, resolver, fieldContext,
storeAssignmentPolicy, addError)

if (byName && !nameMatch) {
addError(s"Struct '$context' $i-th field name does not match " +
Expand Down Expand Up @@ -441,14 +448,22 @@ object DataType {

fieldCompatible

case (w: AtomicType, r: AtomicType) =>
case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT =>
if (!Cast.canUpCast(w, r)) {
addError(s"Cannot safely cast '$context': $w to $r")
false
} else {
true
}

case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI =>
if (!Cast.canANSIStoreAssign(w, r)) {
addError(s"Cannot safely cast '$context': $w to $r")
false
} else {
true
}

case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] =>
true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.types._

class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite {
class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
AppendData.byName(table, query)
}
Expand All @@ -37,7 +37,17 @@ class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite {
}
}

class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite {
class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
AppendData.byName(table, query)
}

override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
AppendData.byPosition(table, query)
}
}

class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwritePartitionsDynamic.byName(table, query)
}
Expand All @@ -47,7 +57,17 @@ class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuit
}
}

class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite {
class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwritePartitionsDynamic.byName(table, query)
}

override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwritePartitionsDynamic.byPosition(table, query)
}
}

class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwriteByExpression.byName(table, query, Literal(true))
}
Expand Down Expand Up @@ -104,6 +124,12 @@ class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite {
}
}

class V2OverwriteByExpressionStrictAnalysisSuite extends V2OverwriteByExpressionANSIAnalysisSuite {
override def getSQLConf(caseSensitive: Boolean): SQLConf =
super.getSQLConf(caseSensitive)
.copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT)
}

case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation {
override def name: String = "table-name"
}
Expand All @@ -114,12 +140,85 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference])
override def skipSchemaResolution: Boolean = true
}

abstract class DataSourceV2AnalysisSuite extends AnalysisTest {
abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite {
override def getSQLConf(caseSensitive: Boolean): SQLConf =
super.getSQLConf(caseSensitive)
.copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.ANSI)
}

override def getAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SQLConf()
.copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite {
override def getSQLConf(caseSensitive: Boolean): SQLConf =
super.getSQLConf(caseSensitive)
.copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT)

test("byName: fail canWrite check") {
val parsedPlan = byName(table, widerTable)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write", "'table-name'",
"Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType"))
}

test("byName: multiple field errors are reported") {
val xRequiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", DoubleType))).toAttributes)

val query = TestRelation(StructType(Seq(
StructField("x", DoubleType),
StructField("b", FloatType))).toAttributes)

val parsedPlan = byName(xRequiredTable, query)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot safely cast", "'x'", "DoubleType to FloatType",
"Cannot write nullable values to non-null column", "'x'",
"Cannot find data for output column", "'y'"))
}


test("byPosition: fail canWrite check") {
val widerTable = TestRelation(StructType(Seq(
StructField("a", DoubleType),
StructField("b", DoubleType))).toAttributes)

val parsedPlan = byPosition(table, widerTable)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write", "'table-name'",
"Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType"))
}

test("byPosition: multiple field errors are reported") {
val xRequiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", DoubleType))).toAttributes)

val query = TestRelation(StructType(Seq(
StructField("x", DoubleType),
StructField("b", FloatType))).toAttributes)

val parsedPlan = byPosition(xRequiredTable, query)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot write nullable values to non-null column", "'x'",
"Cannot safely cast", "'x'", "DoubleType to FloatType"))
}
}

abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest {

protected def getSQLConf(caseSensitive: Boolean): SQLConf =
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)

override def getAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = getSQLConf(caseSensitive)
val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
Expand Down Expand Up @@ -254,15 +353,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot find data for output column", "'x'"))
}

test("byName: fail canWrite check") {
val parsedPlan = byName(table, widerTable)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write", "'table-name'",
"Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType"))
}

test("byName: insert safe cast") {
val x = table.output.head
val y = table.output.last
Expand Down Expand Up @@ -294,25 +384,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'x', 'y', 'z'"))
}

test("byName: multiple field errors are reported") {
val xRequiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", DoubleType))).toAttributes)

val query = TestRelation(StructType(Seq(
StructField("x", DoubleType),
StructField("b", FloatType))).toAttributes)

val parsedPlan = byName(xRequiredTable, query)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot safely cast", "'x'", "DoubleType to FloatType",
"Cannot write nullable values to non-null column", "'x'",
"Cannot find data for output column", "'y'"))
}

test("byPosition: basic behavior") {
val query = TestRelation(StructType(Seq(
StructField("a", FloatType),
Expand Down Expand Up @@ -396,19 +467,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'y'"))
}

test("byPosition: fail canWrite check") {
val widerTable = TestRelation(StructType(Seq(
StructField("a", DoubleType),
StructField("b", DoubleType))).toAttributes)

val parsedPlan = byPosition(table, widerTable)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write", "'table-name'",
"Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType"))
}

test("byPosition: insert safe cast") {
val widerTable = TestRelation(StructType(Seq(
StructField("a", DoubleType),
Expand Down Expand Up @@ -444,24 +502,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'a', 'b', 'c'"))
}

test("byPosition: multiple field errors are reported") {
val xRequiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", DoubleType))).toAttributes)

val query = TestRelation(StructType(Seq(
StructField("x", DoubleType),
StructField("b", FloatType))).toAttributes)

val parsedPlan = byPosition(xRequiredTable, query)

assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot write nullable values to non-null column", "'x'",
"Cannot safely cast", "'x'", "DoubleType to FloatType"))
}

test("bypass output column resolution") {
val table = TestRelationAcceptAnySchema(StructType(Seq(
StructField("a", FloatType, nullable = false),
Expand Down
Loading

0 comments on commit 2b24a71

Please sign in to comment.