diff --git a/core/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala b/core/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala index c7e0fdf0a84..e3237b0fa69 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -593,6 +593,9 @@ class DeltaAnalysis(session: SparkSession) attr case (s: StructType, t: StructType) if s != t => addCastsToStructs(tblName, attr, s, t) + case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, tNull: Boolean)) + if s != t && sNull == tNull => + addCastsToArrayStructs(tblName, attr, s, t, sNull) case _ => getCastFunction(attr, targetAttr.dataType, targetAttr.name) } @@ -680,6 +683,22 @@ class DeltaAnalysis(session: SparkSession) parent.exprId, parent.qualifier, Option(parent.metadata)) } + private def addCastsToArrayStructs( + tableName: String, + parent: NamedExpression, + source: StructType, + target: StructType, + sourceNullable: Boolean): Expression = { + val structConverter: (Expression, Expression) => Expression = (_, i) => + addCastsToStructs(tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target) + val transformLambdaFunc = { + val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable) + val indexVar = NamedLambdaVariable("indexVar", IntegerType, false) + LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, indexVar)) + } + ArrayTransform(parent, transformLambdaFunc) + } + private def stripTempViewWrapper(plan: LogicalPlan): LogicalPlan = { DeltaViewHelper.stripTempView(plan, conf) } diff --git a/core/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala index e4c6b06dc7e..e633c0e3340 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala @@ -75,6 +75,90 @@ class DeltaInsertIntoSQLSuite } } + + Seq(("ordinal", ""), ("name", "(id, col2, col)")).foreach { case (testName, values) => + test(s"INSERT OVERWRITE schema evolution works for array struct types - $testName") { + val sourceSchema = "id INT, col2 STRING, col ARRAY>" + val sourceRecord = "1, '2022-11-01', array(struct('s1', 's2', DATE'2022-11-01'))" + val targetSchema = "id INT, col2 DATE, col ARRAY>" + val targetRecord = "1, DATE'2022-11-02', array(struct('t1', 't2'))" + + runInsertOverwrite(sourceSchema, sourceRecord, targetSchema, targetRecord) { + (sourceTable, targetTable) => + sql(s"INSERT OVERWRITE $targetTable $values SELECT * FROM $sourceTable") + + // make sure table is still writeable + sql(s"""INSERT INTO $targetTable VALUES (2, DATE'2022-11-02', + | array(struct('s3', 's4', DATE'2022-11-02')))""".stripMargin) + sql(s"""INSERT INTO $targetTable VALUES (3, DATE'2022-11-03', + |array(struct('s5', 's6', NULL)))""".stripMargin) + val df = spark.sql( + """SELECT 1 as id, DATE'2022-11-01' as col2, + | array(struct('s1', 's2', DATE'2022-11-01')) as col UNION + | SELECT 2 as id, DATE'2022-11-02' as col2, + | array(struct('s3', 's4', DATE'2022-11-02')) as col UNION + | SELECT 3 as id, DATE'2022-11-03' as col2, + | array(struct('s5', 's6', NULL)) as col""".stripMargin) + verifyTable(targetTable, df) + } + } + } + + Seq(("ordinal", ""), ("name", "(id, col2, col)")).foreach { case (testName, values) => + test(s"INSERT OVERWRITE schema evolution works for array nested types - $testName") { + val sourceSchema = "id INT, col2 STRING, " + + "col ARRAY, f3: STRUCT>>" + val sourceRecord = "1, '2022-11-01', " + + "array(struct(1, struct('s1', DATE'2022-11-01'), struct('s1')))" + val targetSchema = "id INT, col2 DATE, col ARRAY>>" + val targetRecord = "2, DATE'2022-11-02', array(struct(2, struct('s2')))" + + runInsertOverwrite(sourceSchema, sourceRecord, targetSchema, targetRecord) { + (sourceTable, targetTable) => + sql(s"INSERT OVERWRITE $targetTable $values SELECT * FROM $sourceTable") + + // make sure table is still writeable + sql(s"""INSERT INTO $targetTable VALUES (2, DATE'2022-11-02', + | array(struct(2, struct('s2', DATE'2022-11-02'), struct('s2'))))""".stripMargin) + sql(s"""INSERT INTO $targetTable VALUES (3, DATE'2022-11-03', + | array(struct(3, struct('s3', NULL), struct(NULL))))""".stripMargin) + val df = spark.sql( + """SELECT 1 as id, DATE'2022-11-01' as col2, + | array(struct(1, struct('s1', DATE'2022-11-01'), struct('s1'))) as col UNION + | SELECT 2 as id, DATE'2022-11-02' as col2, + | array(struct(2, struct('s2', DATE'2022-11-02'), struct('s2'))) as col UNION + | SELECT 3 as id, DATE'2022-11-03' as col2, + | array(struct(3, struct('s3', NULL), struct(NULL))) as col + |""".stripMargin) + verifyTable(targetTable, df) + } + } + } + + def runInsertOverwrite( + sourceSchema: String, + sourceRecord: String, + targetSchema: String, + targetRecord: String)( + runAndVerify: (String, String) => Unit): Unit = { + val sourceTable = "source" + val targetTable = "target" + withTable(sourceTable) { + withTable(targetTable) { + withSQLConf("spark.databricks.delta.schema.autoMerge.enabled" -> "true") { + // prepare source table + sql(s"""CREATE TABLE $sourceTable ($sourceSchema) + | USING DELTA""".stripMargin) + sql(s"INSERT INTO $sourceTable VALUES ($sourceRecord)") + // prepare target table + sql(s"""CREATE TABLE $targetTable ($targetSchema) + | USING DELTA""".stripMargin) + sql(s"INSERT INTO $targetTable VALUES ($targetRecord)") + runAndVerify(sourceTable, targetTable) + } + } + } + } } class DeltaInsertIntoSQLByPathSuite