Skip to content

Commit

Permalink
Automatic type widening in INSERT (#2785)
Browse files Browse the repository at this point in the history
#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [X] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
This change is part of the type widening table feature.
Type widening feature request:
#2622
Type Widening protocol RFC: #2624

It adds automatic type widening as part of schema evolution in INSERT.
During resolution, when schema evolution and type widening are enabled,
type differences between the input query and the target table are
handled as follows:
- If the type difference qualifies for automatic type evolution: the
input type is left as is, the data will be inserted with the new type
and the table schema will be updated in `ImplicitMetadataOperation`
(already implemented as part of MERGE support)
- If the type difference doesn't qualify for automatic type evolution:
the current behavior is preserved: a cast is added from the input type
to the existing target type.

## How was this patch tested?
- Tests are added to `DeltaTypeWideningAutomaticSuite` to cover type
evolution in INSERT

## This PR introduces the following *user-facing* changes
The table feature is available in testing only, there's no user-facing
changes as of now.

When automatic schema evolution is enabled in INSERT and the source
schema contains a type that is wider than the target schema:

With type widening disabled: the type in the target schema is not
changed. A cast is added to the input to insert to match the expected
target type.

With type widening enabled: the type in the target schema is updated to
the wider source type.
```
-- target: key int, value short
-- source: key int, value int
INSERT INTO target SELECT * FROM source
```
After the INSERT operation, the target schema is `key int, value int`.
  • Loading branch information
johanl-db authored Mar 25, 2024
1 parent 36f95dd commit a172276
Show file tree
Hide file tree
Showing 3 changed files with 577 additions and 27 deletions.
80 changes: 59 additions & 21 deletions spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand All @@ -81,8 +81,8 @@ class DeltaAnalysis(session: SparkSession)
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {
// INSERT INTO by ordinal and df.insertInto()
case a @ AppendDelta(r, d) if !a.isByName &&
needsSchemaAdjustmentByOrdinal(d.name(), a.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d.name())
needsSchemaAdjustmentByOrdinal(d, a.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d)
if (projection != a.query) {
a.copy(query = projection)
} else {
Expand Down Expand Up @@ -208,8 +208,8 @@ class DeltaAnalysis(session: SparkSession)

// INSERT OVERWRITE by ordinal and df.insertInto()
case o @ OverwriteDelta(r, d) if !o.isByName &&
needsSchemaAdjustmentByOrdinal(d.name(), o.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(o.query, r.output, d.name())
needsSchemaAdjustmentByOrdinal(d, o.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(o.query, r.output, d)
if (projection != o.query) {
val aliases = AttributeMap(o.query.output.zip(projection.output).collect {
case (l: AttributeReference, r: AttributeReference) if !l.sameRef(r) => (l, r)
Expand Down Expand Up @@ -245,9 +245,9 @@ class DeltaAnalysis(session: SparkSession)
case o @ DynamicPartitionOverwriteDelta(r, d) if o.resolved
=>
val adjustedQuery = if (!o.isByName &&
needsSchemaAdjustmentByOrdinal(d.name(), o.query, r.schema)) {
needsSchemaAdjustmentByOrdinal(d, o.query, r.schema)) {
// INSERT OVERWRITE by ordinal and df.insertInto()
resolveQueryColumnsByOrdinal(o.query, r.output, d.name())
resolveQueryColumnsByOrdinal(o.query, r.output, d)
} else if (o.isByName && o.origin.sqlText.nonEmpty &&
needsSchemaAdjustmentByName(o.query, r.output, d)) {
// INSERT OVERWRITE by name
Expand Down Expand Up @@ -850,12 +850,14 @@ class DeltaAnalysis(session: SparkSession)
* type column/field.
*/
private def resolveQueryColumnsByOrdinal(
query: LogicalPlan, targetAttrs: Seq[Attribute], tblName: String): LogicalPlan = {
query: LogicalPlan, targetAttrs: Seq[Attribute], deltaTable: DeltaTableV2): LogicalPlan = {
// always add a Cast. it will be removed in the optimizer if it is unnecessary.
val project = query.output.zipWithIndex.map { case (attr, i) =>
if (i < targetAttrs.length) {
val targetAttr = targetAttrs(i)
addCastToColumn(attr, targetAttr, tblName)
addCastToColumn(attr, targetAttr, deltaTable.name(),
allowTypeWidening = allowTypeWidening(deltaTable)
)
} else {
attr
}
Expand Down Expand Up @@ -890,47 +892,69 @@ class DeltaAnalysis(session: SparkSession)
.getOrElse {
throw DeltaErrors.missingColumn(attr, targetAttrs)
}
addCastToColumn(attr, targetAttr, deltaTable.name())
addCastToColumn(attr, targetAttr, deltaTable.name(),
allowTypeWidening = allowTypeWidening(deltaTable)
)
}
Project(project, query)
}

private def addCastToColumn(
attr: Attribute,
targetAttr: Attribute,
tblName: String): NamedExpression = {
tblName: String,
allowTypeWidening: Boolean): NamedExpression = {
val expr = (attr.dataType, targetAttr.dataType) match {
case (s, t) if s == t =>
attr
case (s: StructType, t: StructType) if s != t =>
addCastsToStructs(tblName, attr, s, t)
addCastsToStructs(tblName, attr, s, t, allowTypeWidening)
case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, tNull: Boolean))
if s != t && sNull == tNull =>
addCastsToArrayStructs(tblName, attr, s, t, sNull)
addCastsToArrayStructs(tblName, attr, s, t, sNull, allowTypeWidening)
case (s: AtomicType, t: AtomicType)
if allowTypeWidening && TypeWidening.isTypeChangeSupportedForSchemaEvolution(t, s) =>
// Keep the type from the query, the target schema will be updated to widen the existing
// type to match it.
attr
case _ =>
getCastFunction(attr, targetAttr.dataType, targetAttr.name)
}
Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
}

/**
* Whether inserting values that have a wider type than the table has is allowed. In that case,
* values are not downcasted to the current table type and the table schema is updated instead to
* use the wider type.
*/
private def allowTypeWidening(deltaTable: DeltaTableV2): Boolean = {
val options = new DeltaOptions(Map.empty[String, String], conf)
options.canMergeSchema && TypeWidening.isEnabled(
deltaTable.initialSnapshot.protocol,
deltaTable.initialSnapshot.metadata
)
}

/**
* With Delta, we ACCEPT_ANY_SCHEMA, meaning that Spark doesn't automatically adjust the schema
* of INSERT INTO. This allows us to perform better schema enforcement/evolution. Since Spark
* skips this step, we see if we need to perform any schema adjustment here.
*/
private def needsSchemaAdjustmentByOrdinal(
tableName: String,
deltaTable: DeltaTableV2,
query: LogicalPlan,
schema: StructType): Boolean = {
val output = query.output
if (output.length < schema.length) {
throw DeltaErrors.notEnoughColumnsInInsert(tableName, output.length, schema.length)
throw DeltaErrors.notEnoughColumnsInInsert(deltaTable.name(), output.length, schema.length)
}
// Now we should try our best to match everything that already exists, and leave the rest
// for schema evolution to WriteIntoDelta
val existingSchemaOutput = output.take(schema.length)
existingSchemaOutput.map(_.name) != schema.map(_.name) ||
!SchemaUtils.isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType)
!SchemaUtils.isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType,
allowTypeWidening = allowTypeWidening(deltaTable))
}

/**
Expand Down Expand Up @@ -984,7 +1008,10 @@ class DeltaAnalysis(session: SparkSession)
}
val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name))
!SchemaUtils.isReadCompatible(
specifiedTargetAttrs.toStructType.asNullable, query.output.toStructType)
specifiedTargetAttrs.toStructType.asNullable,
query.output.toStructType,
allowTypeWidening = allowTypeWidening(deltaTable)
)
}

// Get cast operation for the level of strictness in the schema a user asked for
Expand Down Expand Up @@ -1014,7 +1041,8 @@ class DeltaAnalysis(session: SparkSession)
tableName: String,
parent: NamedExpression,
source: StructType,
target: StructType): NamedExpression = {
target: StructType,
allowTypeWidening: Boolean): NamedExpression = {
if (source.length < target.length) {
throw DeltaErrors.notEnoughColumnsInInsert(
tableName, source.length, target.length, Some(parent.qualifiedName))
Expand All @@ -1025,12 +1053,20 @@ class DeltaAnalysis(session: SparkSession)
case t: StructType =>
val subField = Alias(GetStructField(parent, i, Option(name)), target(i).name)(
explicitMetadata = Option(metadata))
addCastsToStructs(tableName, subField, nested, t)
addCastsToStructs(tableName, subField, nested, t, allowTypeWidening)
case o =>
val field = parent.qualifiedName + "." + name
val targetName = parent.qualifiedName + "." + target(i).name
throw DeltaErrors.cannotInsertIntoColumn(tableName, field, targetName, o.simpleString)
}

case (StructField(name, dt: AtomicType, _, _), i) if i < target.length && allowTypeWidening &&
TypeWidening.isTypeChangeSupportedForSchemaEvolution(
target(i).dataType.asInstanceOf[AtomicType], dt) =>
val targetAttr = target(i)
Alias(
GetStructField(parent, i, Option(name)),
targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
case (other, i) if i < target.length =>
val targetAttr = target(i)
Alias(
Expand All @@ -1054,9 +1090,11 @@ class DeltaAnalysis(session: SparkSession)
parent: NamedExpression,
source: StructType,
target: StructType,
sourceNullable: Boolean): Expression = {
sourceNullable: Boolean,
allowTypeWidening: Boolean): Expression = {
val structConverter: (Expression, Expression) => Expression = (_, i) =>
addCastsToStructs(tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target)
addCastsToStructs(
tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target, allowTypeWidening)
val transformLambdaFunc = {
val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable)
val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def normalizeColumnNamesInDataType(
* new schema of a Delta table can be used with a previously analyzed LogicalPlan. Our
* rules are to return false if:
* - Dropping any column that was present in the existing schema, if not allowMissingColumns
* - Any change of datatype
* - Any change of datatype, if not allowTypeWidening. Any non-widening change of datatype
* otherwise.
* - Change of partition columns. Although analyzed LogicalPlan is not changed,
* physical structure of data is changed and thus is considered not read compatible.
* - If `forbidTightenNullability` = true:
Expand All @@ -373,6 +374,7 @@ def normalizeColumnNamesInDataType(
readSchema: StructType,
forbidTightenNullability: Boolean = false,
allowMissingColumns: Boolean = false,
allowTypeWidening: Boolean = false,
newPartitionColumns: Seq[String] = Seq.empty,
oldPartitionColumns: Seq[String] = Seq.empty): Boolean = {

Expand All @@ -387,7 +389,7 @@ def normalizeColumnNamesInDataType(
def isDatatypeReadCompatible(existing: DataType, newtype: DataType): Boolean = {
(existing, newtype) match {
case (e: StructType, n: StructType) =>
isReadCompatible(e, n, forbidTightenNullability)
isReadCompatible(e, n, forbidTightenNullability, allowTypeWidening = allowTypeWidening)
case (e: ArrayType, n: ArrayType) =>
// if existing elements are non-nullable, so should be the new element
isNullabilityCompatible(e.containsNull, n.containsNull) &&
Expand All @@ -397,6 +399,8 @@ def normalizeColumnNamesInDataType(
isNullabilityCompatible(e.valueContainsNull, n.valueContainsNull) &&
isDatatypeReadCompatible(e.keyType, n.keyType) &&
isDatatypeReadCompatible(e.valueType, n.valueType)
case (e: AtomicType, n: AtomicType) if allowTypeWidening =>
TypeWidening.isTypeChangeSupportedForSchemaEvolution(e, n)
case (a, b) => a == b
}
}
Expand Down
Loading

0 comments on commit a172276

Please sign in to comment.