Skip to content

Commit

Permalink
Fixes schema evolution for complex types during INSERT OVERWRITE
Browse files Browse the repository at this point in the history
GitOrigin-RevId: f3310fa7be21364b3c36f47115d18d24ddc47c42
  • Loading branch information
Lukas Rupprecht authored and vkorukanti committed Jan 24, 2023
1 parent 0985436 commit 77c1e0a
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand Down Expand Up @@ -593,6 +593,9 @@ class DeltaAnalysis(session: SparkSession)
attr
case (s: StructType, t: StructType) if s != t =>
addCastsToStructs(tblName, attr, s, t)
case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, tNull: Boolean))
if s != t && sNull == tNull =>
addCastsToArrayStructs(tblName, attr, s, t, sNull)
case _ =>
getCastFunction(attr, targetAttr.dataType, targetAttr.name)
}
Expand Down Expand Up @@ -680,6 +683,22 @@ class DeltaAnalysis(session: SparkSession)
parent.exprId, parent.qualifier, Option(parent.metadata))
}

private def addCastsToArrayStructs(
tableName: String,
parent: NamedExpression,
source: StructType,
target: StructType,
sourceNullable: Boolean): Expression = {
val structConverter: (Expression, Expression) => Expression = (_, i) =>
addCastsToStructs(tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target)
val transformLambdaFunc = {
val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable)
val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, indexVar))
}
ArrayTransform(parent, transformLambdaFunc)
}

private def stripTempViewWrapper(plan: LogicalPlan): LogicalPlan = {
DeltaViewHelper.stripTempView(plan, conf)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,90 @@ class DeltaInsertIntoSQLSuite
}
}


Seq(("ordinal", ""), ("name", "(id, col2, col)")).foreach { case (testName, values) =>
test(s"INSERT OVERWRITE schema evolution works for array struct types - $testName") {
val sourceSchema = "id INT, col2 STRING, col ARRAY<STRUCT<f1: STRING, f2: STRING, f3: DATE>>"
val sourceRecord = "1, '2022-11-01', array(struct('s1', 's2', DATE'2022-11-01'))"
val targetSchema = "id INT, col2 DATE, col ARRAY<STRUCT<f1: STRING, f2: STRING>>"
val targetRecord = "1, DATE'2022-11-02', array(struct('t1', 't2'))"

runInsertOverwrite(sourceSchema, sourceRecord, targetSchema, targetRecord) {
(sourceTable, targetTable) =>
sql(s"INSERT OVERWRITE $targetTable $values SELECT * FROM $sourceTable")

// make sure table is still writeable
sql(s"""INSERT INTO $targetTable VALUES (2, DATE'2022-11-02',
| array(struct('s3', 's4', DATE'2022-11-02')))""".stripMargin)
sql(s"""INSERT INTO $targetTable VALUES (3, DATE'2022-11-03',
|array(struct('s5', 's6', NULL)))""".stripMargin)
val df = spark.sql(
"""SELECT 1 as id, DATE'2022-11-01' as col2,
| array(struct('s1', 's2', DATE'2022-11-01')) as col UNION
| SELECT 2 as id, DATE'2022-11-02' as col2,
| array(struct('s3', 's4', DATE'2022-11-02')) as col UNION
| SELECT 3 as id, DATE'2022-11-03' as col2,
| array(struct('s5', 's6', NULL)) as col""".stripMargin)
verifyTable(targetTable, df)
}
}
}

Seq(("ordinal", ""), ("name", "(id, col2, col)")).foreach { case (testName, values) =>
test(s"INSERT OVERWRITE schema evolution works for array nested types - $testName") {
val sourceSchema = "id INT, col2 STRING, " +
"col ARRAY<STRUCT<f1: INT, f2: STRUCT<f21: STRING, f22: DATE>, f3: STRUCT<f31: STRING>>>"
val sourceRecord = "1, '2022-11-01', " +
"array(struct(1, struct('s1', DATE'2022-11-01'), struct('s1')))"
val targetSchema = "id INT, col2 DATE, col ARRAY<STRUCT<f1: INT, f2: STRUCT<f21: STRING>>>"
val targetRecord = "2, DATE'2022-11-02', array(struct(2, struct('s2')))"

runInsertOverwrite(sourceSchema, sourceRecord, targetSchema, targetRecord) {
(sourceTable, targetTable) =>
sql(s"INSERT OVERWRITE $targetTable $values SELECT * FROM $sourceTable")

// make sure table is still writeable
sql(s"""INSERT INTO $targetTable VALUES (2, DATE'2022-11-02',
| array(struct(2, struct('s2', DATE'2022-11-02'), struct('s2'))))""".stripMargin)
sql(s"""INSERT INTO $targetTable VALUES (3, DATE'2022-11-03',
| array(struct(3, struct('s3', NULL), struct(NULL))))""".stripMargin)
val df = spark.sql(
"""SELECT 1 as id, DATE'2022-11-01' as col2,
| array(struct(1, struct('s1', DATE'2022-11-01'), struct('s1'))) as col UNION
| SELECT 2 as id, DATE'2022-11-02' as col2,
| array(struct(2, struct('s2', DATE'2022-11-02'), struct('s2'))) as col UNION
| SELECT 3 as id, DATE'2022-11-03' as col2,
| array(struct(3, struct('s3', NULL), struct(NULL))) as col
|""".stripMargin)
verifyTable(targetTable, df)
}
}
}

def runInsertOverwrite(
sourceSchema: String,
sourceRecord: String,
targetSchema: String,
targetRecord: String)(
runAndVerify: (String, String) => Unit): Unit = {
val sourceTable = "source"
val targetTable = "target"
withTable(sourceTable) {
withTable(targetTable) {
withSQLConf("spark.databricks.delta.schema.autoMerge.enabled" -> "true") {
// prepare source table
sql(s"""CREATE TABLE $sourceTable ($sourceSchema)
| USING DELTA""".stripMargin)
sql(s"INSERT INTO $sourceTable VALUES ($sourceRecord)")
// prepare target table
sql(s"""CREATE TABLE $targetTable ($targetSchema)
| USING DELTA""".stripMargin)
sql(s"INSERT INTO $targetTable VALUES ($targetRecord)")
runAndVerify(sourceTable, targetTable)
}
}
}
}
}

class DeltaInsertIntoSQLByPathSuite
Expand Down

0 comments on commit 77c1e0a

Please sign in to comment.