diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala index 5dd249bd24f76..f2d6f0381a471 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.hudi.SparkAdapterSupport.sparkAdapter import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Like, Literal, SubqueryExpression, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} trait HoodieCatalystExpressionUtils { @@ -39,10 +40,34 @@ trait HoodieCatalystExpressionUtils { * will keep the same ordering b1, b2, b3, ... with b1 = T(a1), b2 = T(a2), ... */ def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] + + /** + * Verifies whether [[fromType]] can be up-casted to [[toType]] + */ + def canUpCast(fromType: DataType, toType: DataType): Boolean + + /** + * Un-applies [[Cast]] expression into + *
    + *
  1. Casted [[Expression]]
  2. + *
  3. Target [[DataType]]
  4. + *
  5. (Optional) Timezone spec
  6. + *
  7. Flag whether it's an ANSI cast or not
  8. + *
+ */ + def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] } object HoodieCatalystExpressionUtils { + /** + * Convenience extractor allowing to untuple [[Cast]] across Spark versions + */ + object MatchCast { + def unapply(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + sparkAdapter.getCatalystExpressionUtils.unapplyCastExpression(expr) + } + /** * Generates instance of [[UnsafeProjection]] projecting row of one [[StructType]] into another [[StructType]] * diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala index d403f1998c6b3..17ff34c909be8 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala @@ -25,6 +25,7 @@ import org.apache.hudi.config.HoodieWriteConfig.TBL_NAME import org.apache.hudi.hive.HiveSyncConfigHolder import org.apache.hudi.sync.common.HoodieSyncConfig import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, SparkAdapterSupport} +import org.apache.spark.sql.HoodieCatalystExpressionUtils.MatchCast import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.Resolver @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeRef import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ import org.apache.spark.sql.hudi.HoodieSqlUtils.getMergeIntoTargetTableId +import org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.CoercedAttributeReference import org.apache.spark.sql.hudi.command.payload.ExpressionPayload import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._ import org.apache.spark.sql.hudi.{ProvidesHoodieConfig, SerDeUtils} @@ -101,19 +103,53 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie } val targetAttrs = mergeInto.targetTable.output - val target2Source = conditions.map(_.asInstanceOf[EqualTo]) - .map { - case EqualTo(left: AttributeReference, right) - if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field - targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right - case EqualTo(left, right: AttributeReference) - if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field - targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left - case eq => - throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." + - "The validate condition should be 'targetColumn = sourceColumnExpression', e.g." + - " t.id = s.id and t.dt = from_unixtime(s.ts)") - }.toMap + val cleanedConditions = conditions.map(_.asInstanceOf[EqualTo]).map { + // Here we're unraveling superfluous casting of expressions on both sides of the matched-on condition, + // in case both of them are casted to the same type (which might be result of either explicit casting + // from the user, or auto-casting performed by Spark for type coercion), which has potential + // potential of rendering the whole operation as invalid (check out HUDI-4861 for more details) + case EqualTo(MatchCast(leftExpr, leftCastTargetType, _, _), MatchCast(rightExpr, rightCastTargetType, _, _)) + if leftCastTargetType.sameType(rightCastTargetType) => EqualTo(leftExpr, rightExpr) + + case c => c + } + + val exprUtils = sparkAdapter.getCatalystExpressionUtils + // Expressions of the following forms are supported: + // `target.id = ` (or ` = target.id`) + // `cast(target.id, ...) = ` (or ` = cast(target.id, ...)`) + // + // In the latter case, there are further restrictions: since cast will be dropped on the + // target table side (since we're gonna be matching against primary-key column as is) expression + // on the opposite side of the comparison should be cast-able to the primary-key column's data-type + // t/h "up-cast" (ie w/o any loss in precision) + val target2Source = cleanedConditions.map { + case EqualTo(CoercedAttributeReference(attr), expr) + if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) => + if (exprUtils.canUpCast(expr.dataType, attr.dataType)) { + targetAttrs.find(f => resolver(f.name, attr.name)).get.name -> + castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf) + } else { + throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: " + + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}") + } + + case EqualTo(expr, CoercedAttributeReference(attr)) + if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) => + if (exprUtils.canUpCast(expr.dataType, attr.dataType)) { + targetAttrs.find(f => resolver(f.name, attr.name)).get.name -> + castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf) + } else { + throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: " + + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}") + } + + case expr => + throw new AnalysisException(s"Invalid MERGE INTO matching condition: `${expr.sql}`: " + + "expected condition should be 'target.id = ', e.g. " + + "`t.id = s.id` or `t.id = cast(s.id, ...)`") + }.toMap + target2Source } @@ -516,3 +552,18 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie } } } + +object MergeIntoHoodieTableCommand { + + object CoercedAttributeReference { + def unapply(expr: Expression): Option[AttributeReference] = { + expr match { + case attr: AttributeReference => Some(attr) + case MatchCast(attr: AttributeReference, _, _, _) => Some(attr) + + case _ => None + } + } + } + +} diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala index 93079ac554db9..4aa91498b110b 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hudi -import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers} +import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers, HoodieSparkUtils} import org.apache.hudi.common.fs.FSUtils class TestMergeIntoTable extends HoodieSparkSqlTestBase { @@ -949,4 +949,84 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase { ) } } + + test("Test Merge Into with target matched columns cast-ed") { + withTempDir { tmp => + val tableName = generateTableName + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | value int, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey ='id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000)") + + // Can't down-cast incoming dataset's primary-key w/o loss of precision (should fail) + val errorMsg = if (HoodieSparkUtils.gteqSpark3_2) { + "Invalid MERGE INTO matching condition: s0.id: can't cast s0.id (of LongType) to IntegerType" + } else { + "Invalid MERGE INTO matching condition: s0.`id`: can't cast s0.`id` (of LongType) to IntegerType" + } + + checkExceptionContain( + s""" + |merge into $tableName h0 + |using ( + | select cast(1 as long) as id, 1001 as ts + | ) s0 + | on cast(h0.id as long) = s0.id + | when matched then update set h0.ts = s0.ts + |""".stripMargin)(errorMsg) + + // Can't down-cast incoming dataset's primary-key w/o loss of precision (should fail) + checkExceptionContain( + s""" + |merge into $tableName h0 + |using ( + | select cast(1 as long) as id, 1002 as ts + | ) s0 + | on h0.id = s0.id + | when matched then update set h0.ts = s0.ts + |""".stripMargin)(errorMsg) + + // Can up-cast incoming dataset's primary-key w/o loss of precision (should succeed) + spark.sql( + s""" + |merge into $tableName h0 + |using ( + | select cast(1 as short) as id, 1003 as ts + | ) s0 + | on h0.id = s0.id + | when matched then update set h0.ts = s0.ts + |""".stripMargin) + + checkAnswer(s"select id, name, value, ts from $tableName")( + Seq(1, "a1", 10, 1003) + ) + + // Can remove redundant symmetrical casting on both sides (should succeed) + spark.sql( + s""" + |merge into $tableName h0 + |using ( + | select cast(1 as int) as id, 1004 as ts + | ) s0 + | on cast(h0.id as string) = cast(s0.id as string) + | when matched then update set h0.ts = s0.ts + |""".stripMargin) + + checkAnswer(s"select id, name, value, ts from $tableName")( + Seq(1, "a1", 10, 1004) + ) + } + } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala index 745bbdb14295b..854adcff6abb0 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import HoodieSparkTypeUtils.isCastPreservingOrdering import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Like, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.types.DataType object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils { @@ -29,6 +30,17 @@ object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils } } + def canUpCast(fromType: DataType, toType: DataType): Boolean = + // Spark 2.x does not support up-casting, hence we simply check whether types are + // actually the same + fromType.sameType(toType) + + override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + expr match { + case Cast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, false)) + case _ => None + } + private object OrderPreservingTransformation { def unapply(expr: Expression): Option[AttributeReference] = { expr match { diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala index 16ad19a33374b..0def10d6e20b0 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import HoodieSparkTypeUtils.isCastPreservingOrdering -import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.types.DataType object HoodieSpark31CatalystExpressionUtils extends HoodieCatalystExpressionUtils { @@ -30,6 +31,16 @@ object HoodieSpark31CatalystExpressionUtils extends HoodieCatalystExpressionUtil } } + def canUpCast(fromType: DataType, toType: DataType): Boolean = + Cast.canUpCast(fromType, toType) + + override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + expr match { + case Cast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, false)) + case AnsiCast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, true)) + case _ => None + } + private object OrderPreservingTransformation { def unapply(expr: Expression): Option[AttributeReference] = { expr match { diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala index 91832845feaff..82e8cfd9b3131 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql import HoodieSparkTypeUtils.isCastPreservingOrdering -import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.types.DataType object HoodieSpark32CatalystExpressionUtils extends HoodieCatalystExpressionUtils { @@ -29,6 +30,18 @@ object HoodieSpark32CatalystExpressionUtils extends HoodieCatalystExpressionUtil } } + def canUpCast(fromType: DataType, toType: DataType): Boolean = + Cast.canUpCast(fromType, toType) + + override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + expr match { + case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) => + Some((castedExpr, dataType, timeZoneId, ansiEnabled)) + case AnsiCast(castedExpr, dataType, timeZoneId) => + Some((castedExpr, dataType, timeZoneId, true)) + case _ => None + } + private object OrderPreservingTransformation { def unapply(expr: Expression): Option[AttributeReference] = { expr match { diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala index 87404adb5e2e5..b540c3ec1c568 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql import HoodieSparkTypeUtils.isCastPreservingOrdering -import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.types.DataType object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtils { @@ -29,6 +30,18 @@ object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtil } } + def canUpCast(fromType: DataType, toType: DataType): Boolean = + Cast.canUpCast(fromType, toType) + + override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + expr match { + case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) => + Some((castedExpr, dataType, timeZoneId, ansiEnabled)) + case AnsiCast(castedExpr, dataType, timeZoneId) => + Some((castedExpr, dataType, timeZoneId, true)) + case _ => None + } + private object OrderPreservingTransformation { def unapply(expr: Expression): Option[AttributeReference] = { expr match {