diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
index 5dd249bd24f76..f2d6f0381a471 100644
--- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
+++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql
+import org.apache.hudi.SparkAdapterSupport.sparkAdapter
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Like, Literal, SubqueryExpression, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
trait HoodieCatalystExpressionUtils {
@@ -39,10 +40,34 @@ trait HoodieCatalystExpressionUtils {
* will keep the same ordering b1, b2, b3, ... with b1 = T(a1), b2 = T(a2), ...
*/
def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference]
+
+ /**
+ * Verifies whether [[fromType]] can be up-casted to [[toType]]
+ */
+ def canUpCast(fromType: DataType, toType: DataType): Boolean
+
+ /**
+ * Un-applies [[Cast]] expression into
+ *
+ * - Casted [[Expression]]
+ * - Target [[DataType]]
+ * - (Optional) Timezone spec
+ * - Flag whether it's an ANSI cast or not
+ *
+ */
+ def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)]
}
object HoodieCatalystExpressionUtils {
+ /**
+ * Convenience extractor allowing to untuple [[Cast]] across Spark versions
+ */
+ object MatchCast {
+ def unapply(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
+ sparkAdapter.getCatalystExpressionUtils.unapplyCastExpression(expr)
+ }
+
/**
* Generates instance of [[UnsafeProjection]] projecting row of one [[StructType]] into another [[StructType]]
*
diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
index d403f1998c6b3..17ff34c909be8 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
@@ -25,6 +25,7 @@ import org.apache.hudi.config.HoodieWriteConfig.TBL_NAME
import org.apache.hudi.hive.HiveSyncConfigHolder
import org.apache.hudi.sync.common.HoodieSyncConfig
import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, SparkAdapterSupport}
+import org.apache.spark.sql.HoodieCatalystExpressionUtils.MatchCast
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.Resolver
@@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeRef
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._
import org.apache.spark.sql.hudi.HoodieSqlUtils.getMergeIntoTargetTableId
+import org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.CoercedAttributeReference
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._
import org.apache.spark.sql.hudi.{ProvidesHoodieConfig, SerDeUtils}
@@ -101,19 +103,53 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie
}
val targetAttrs = mergeInto.targetTable.output
- val target2Source = conditions.map(_.asInstanceOf[EqualTo])
- .map {
- case EqualTo(left: AttributeReference, right)
- if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field
- targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right
- case EqualTo(left, right: AttributeReference)
- if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field
- targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left
- case eq =>
- throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
- "The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
- " t.id = s.id and t.dt = from_unixtime(s.ts)")
- }.toMap
+ val cleanedConditions = conditions.map(_.asInstanceOf[EqualTo]).map {
+ // Here we're unraveling superfluous casting of expressions on both sides of the matched-on condition,
+ // in case both of them are casted to the same type (which might be result of either explicit casting
+ // from the user, or auto-casting performed by Spark for type coercion), which has potential
+ // potential of rendering the whole operation as invalid (check out HUDI-4861 for more details)
+ case EqualTo(MatchCast(leftExpr, leftCastTargetType, _, _), MatchCast(rightExpr, rightCastTargetType, _, _))
+ if leftCastTargetType.sameType(rightCastTargetType) => EqualTo(leftExpr, rightExpr)
+
+ case c => c
+ }
+
+ val exprUtils = sparkAdapter.getCatalystExpressionUtils
+ // Expressions of the following forms are supported:
+ // `target.id = ` (or ` = target.id`)
+ // `cast(target.id, ...) = ` (or ` = cast(target.id, ...)`)
+ //
+ // In the latter case, there are further restrictions: since cast will be dropped on the
+ // target table side (since we're gonna be matching against primary-key column as is) expression
+ // on the opposite side of the comparison should be cast-able to the primary-key column's data-type
+ // t/h "up-cast" (ie w/o any loss in precision)
+ val target2Source = cleanedConditions.map {
+ case EqualTo(CoercedAttributeReference(attr), expr)
+ if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) =>
+ if (exprUtils.canUpCast(expr.dataType, attr.dataType)) {
+ targetAttrs.find(f => resolver(f.name, attr.name)).get.name ->
+ castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf)
+ } else {
+ throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: "
+ + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}")
+ }
+
+ case EqualTo(expr, CoercedAttributeReference(attr))
+ if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) =>
+ if (exprUtils.canUpCast(expr.dataType, attr.dataType)) {
+ targetAttrs.find(f => resolver(f.name, attr.name)).get.name ->
+ castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf)
+ } else {
+ throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: "
+ + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}")
+ }
+
+ case expr =>
+ throw new AnalysisException(s"Invalid MERGE INTO matching condition: `${expr.sql}`: "
+ + "expected condition should be 'target.id = ', e.g. "
+ + "`t.id = s.id` or `t.id = cast(s.id, ...)`")
+ }.toMap
+
target2Source
}
@@ -516,3 +552,18 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie
}
}
}
+
+object MergeIntoHoodieTableCommand {
+
+ object CoercedAttributeReference {
+ def unapply(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ case attr: AttributeReference => Some(attr)
+ case MatchCast(attr: AttributeReference, _, _, _) => Some(attr)
+
+ case _ => None
+ }
+ }
+ }
+
+}
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala
index 93079ac554db9..4aa91498b110b 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hudi
-import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers}
+import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers, HoodieSparkUtils}
import org.apache.hudi.common.fs.FSUtils
class TestMergeIntoTable extends HoodieSparkSqlTestBase {
@@ -949,4 +949,84 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase {
)
}
}
+
+ test("Test Merge Into with target matched columns cast-ed") {
+ withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | name string,
+ | value int,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}/$tableName'
+ | tblproperties (
+ | primaryKey ='id',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000)")
+
+ // Can't down-cast incoming dataset's primary-key w/o loss of precision (should fail)
+ val errorMsg = if (HoodieSparkUtils.gteqSpark3_2) {
+ "Invalid MERGE INTO matching condition: s0.id: can't cast s0.id (of LongType) to IntegerType"
+ } else {
+ "Invalid MERGE INTO matching condition: s0.`id`: can't cast s0.`id` (of LongType) to IntegerType"
+ }
+
+ checkExceptionContain(
+ s"""
+ |merge into $tableName h0
+ |using (
+ | select cast(1 as long) as id, 1001 as ts
+ | ) s0
+ | on cast(h0.id as long) = s0.id
+ | when matched then update set h0.ts = s0.ts
+ |""".stripMargin)(errorMsg)
+
+ // Can't down-cast incoming dataset's primary-key w/o loss of precision (should fail)
+ checkExceptionContain(
+ s"""
+ |merge into $tableName h0
+ |using (
+ | select cast(1 as long) as id, 1002 as ts
+ | ) s0
+ | on h0.id = s0.id
+ | when matched then update set h0.ts = s0.ts
+ |""".stripMargin)(errorMsg)
+
+ // Can up-cast incoming dataset's primary-key w/o loss of precision (should succeed)
+ spark.sql(
+ s"""
+ |merge into $tableName h0
+ |using (
+ | select cast(1 as short) as id, 1003 as ts
+ | ) s0
+ | on h0.id = s0.id
+ | when matched then update set h0.ts = s0.ts
+ |""".stripMargin)
+
+ checkAnswer(s"select id, name, value, ts from $tableName")(
+ Seq(1, "a1", 10, 1003)
+ )
+
+ // Can remove redundant symmetrical casting on both sides (should succeed)
+ spark.sql(
+ s"""
+ |merge into $tableName h0
+ |using (
+ | select cast(1 as int) as id, 1004 as ts
+ | ) s0
+ | on cast(h0.id as string) = cast(s0.id as string)
+ | when matched then update set h0.ts = s0.ts
+ |""".stripMargin)
+
+ checkAnswer(s"select id, name, value, ts from $tableName")(
+ Seq(1, "a1", 10, 1004)
+ )
+ }
+ }
}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala
index 745bbdb14295b..854adcff6abb0 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import HoodieSparkTypeUtils.isCastPreservingOrdering
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Like, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.types.DataType
object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
@@ -29,6 +30,17 @@ object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils
}
}
+ def canUpCast(fromType: DataType, toType: DataType): Boolean =
+ // Spark 2.x does not support up-casting, hence we simply check whether types are
+ // actually the same
+ fromType.sameType(toType)
+
+ override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
+ expr match {
+ case Cast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, false))
+ case _ => None
+ }
+
private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala
index 16ad19a33374b..0def10d6e20b0 100644
--- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala
@@ -19,7 +19,8 @@
package org.apache.spark.sql
import HoodieSparkTypeUtils.isCastPreservingOrdering
-import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.types.DataType
object HoodieSpark31CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
@@ -30,6 +31,16 @@ object HoodieSpark31CatalystExpressionUtils extends HoodieCatalystExpressionUtil
}
}
+ def canUpCast(fromType: DataType, toType: DataType): Boolean =
+ Cast.canUpCast(fromType, toType)
+
+ override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
+ expr match {
+ case Cast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, false))
+ case AnsiCast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, true))
+ case _ => None
+ }
+
private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala
index 91832845feaff..82e8cfd9b3131 100644
--- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala
+++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql
import HoodieSparkTypeUtils.isCastPreservingOrdering
-import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.types.DataType
object HoodieSpark32CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
@@ -29,6 +30,18 @@ object HoodieSpark32CatalystExpressionUtils extends HoodieCatalystExpressionUtil
}
}
+ def canUpCast(fromType: DataType, toType: DataType): Boolean =
+ Cast.canUpCast(fromType, toType)
+
+ override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
+ expr match {
+ case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) =>
+ Some((castedExpr, dataType, timeZoneId, ansiEnabled))
+ case AnsiCast(castedExpr, dataType, timeZoneId) =>
+ Some((castedExpr, dataType, timeZoneId, true))
+ case _ => None
+ }
+
private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala
index 87404adb5e2e5..b540c3ec1c568 100644
--- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala
+++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql
import HoodieSparkTypeUtils.isCastPreservingOrdering
-import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+import org.apache.spark.sql.types.DataType
object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
@@ -29,6 +30,18 @@ object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtil
}
}
+ def canUpCast(fromType: DataType, toType: DataType): Boolean =
+ Cast.canUpCast(fromType, toType)
+
+ override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
+ expr match {
+ case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) =>
+ Some((castedExpr, dataType, timeZoneId, ansiEnabled))
+ case AnsiCast(castedExpr, dataType, timeZoneId) =>
+ Some((castedExpr, dataType, timeZoneId, true))
+ case _ => None
+ }
+
private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {