Skip to content

Commit

Permalink
[HUDI-4861] Relaxing MERGE INTO constraints to permit limited casti…
Browse files Browse the repository at this point in the history
…ng operations w/in matched-on conditions (apache#6820)
  • Loading branch information
Alexey Kudinkin authored and fengjian committed Apr 5, 2023
1 parent 7f1a296 commit dc64c1a
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql

import org.apache.hudi.SparkAdapterSupport.sparkAdapter
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Like, Literal, SubqueryExpression, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}

trait HoodieCatalystExpressionUtils {

Expand All @@ -39,10 +40,34 @@ trait HoodieCatalystExpressionUtils {
* will keep the same ordering b1, b2, b3, ... with b1 = T(a1), b2 = T(a2), ...
*/
def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference]

/**
* Verifies whether [[fromType]] can be up-casted to [[toType]]
*/
def canUpCast(fromType: DataType, toType: DataType): Boolean

/**
* Un-applies [[Cast]] expression into
* <ol>
* <li>Casted [[Expression]]</li>
* <li>Target [[DataType]]</li>
* <li>(Optional) Timezone spec</li>
* <li>Flag whether it's an ANSI cast or not</li>
* </ol>
*/
def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)]
}

object HoodieCatalystExpressionUtils {

/**
* Convenience extractor allowing to untuple [[Cast]] across Spark versions
*/
object MatchCast {
def unapply(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
sparkAdapter.getCatalystExpressionUtils.unapplyCastExpression(expr)
}

/**
* Generates instance of [[UnsafeProjection]] projecting row of one [[StructType]] into another [[StructType]]
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hudi.config.HoodieWriteConfig.TBL_NAME
import org.apache.hudi.hive.HiveSyncConfigHolder
import org.apache.hudi.sync.common.HoodieSyncConfig
import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, SparkAdapterSupport}
import org.apache.spark.sql.HoodieCatalystExpressionUtils.MatchCast
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.Resolver
Expand All @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeRef
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._
import org.apache.spark.sql.hudi.HoodieSqlUtils.getMergeIntoTargetTableId
import org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.CoercedAttributeReference
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._
import org.apache.spark.sql.hudi.{ProvidesHoodieConfig, SerDeUtils}
Expand Down Expand Up @@ -101,19 +103,53 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie
}
val targetAttrs = mergeInto.targetTable.output

val target2Source = conditions.map(_.asInstanceOf[EqualTo])
.map {
case EqualTo(left: AttributeReference, right)
if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field
targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right
case EqualTo(left, right: AttributeReference)
if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field
targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left
case eq =>
throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
"The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
" t.id = s.id and t.dt = from_unixtime(s.ts)")
}.toMap
val cleanedConditions = conditions.map(_.asInstanceOf[EqualTo]).map {
// Here we're unraveling superfluous casting of expressions on both sides of the matched-on condition,
// in case both of them are casted to the same type (which might be result of either explicit casting
// from the user, or auto-casting performed by Spark for type coercion), which has potential
// potential of rendering the whole operation as invalid (check out HUDI-4861 for more details)
case EqualTo(MatchCast(leftExpr, leftCastTargetType, _, _), MatchCast(rightExpr, rightCastTargetType, _, _))
if leftCastTargetType.sameType(rightCastTargetType) => EqualTo(leftExpr, rightExpr)

case c => c
}

val exprUtils = sparkAdapter.getCatalystExpressionUtils
// Expressions of the following forms are supported:
// `target.id = <expr>` (or `<expr> = target.id`)
// `cast(target.id, ...) = <expr>` (or `<expr> = cast(target.id, ...)`)
//
// In the latter case, there are further restrictions: since cast will be dropped on the
// target table side (since we're gonna be matching against primary-key column as is) expression
// on the opposite side of the comparison should be cast-able to the primary-key column's data-type
// t/h "up-cast" (ie w/o any loss in precision)
val target2Source = cleanedConditions.map {
case EqualTo(CoercedAttributeReference(attr), expr)
if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) =>
if (exprUtils.canUpCast(expr.dataType, attr.dataType)) {
targetAttrs.find(f => resolver(f.name, attr.name)).get.name ->
castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf)
} else {
throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: "
+ s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}")
}

case EqualTo(expr, CoercedAttributeReference(attr))
if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) =>
if (exprUtils.canUpCast(expr.dataType, attr.dataType)) {
targetAttrs.find(f => resolver(f.name, attr.name)).get.name ->
castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf)
} else {
throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: "
+ s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}")
}

case expr =>
throw new AnalysisException(s"Invalid MERGE INTO matching condition: `${expr.sql}`: "
+ "expected condition should be 'target.id = <source-column-expr>', e.g. "
+ "`t.id = s.id` or `t.id = cast(s.id, ...)`")
}.toMap

target2Source
}

Expand Down Expand Up @@ -516,3 +552,18 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie
}
}
}

object MergeIntoHoodieTableCommand {

object CoercedAttributeReference {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
case attr: AttributeReference => Some(attr)
case MatchCast(attr: AttributeReference, _, _, _) => Some(attr)

case _ => None
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hudi

import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers}
import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers, HoodieSparkUtils}
import org.apache.hudi.common.fs.FSUtils

class TestMergeIntoTable extends HoodieSparkSqlTestBase {
Expand Down Expand Up @@ -949,4 +949,84 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase {
)
}
}

test("Test Merge Into with target matched columns cast-ed") {
withTempDir { tmp =>
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| value int,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey ='id',
| preCombineField = 'ts'
| )
""".stripMargin)

spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000)")

// Can't down-cast incoming dataset's primary-key w/o loss of precision (should fail)
val errorMsg = if (HoodieSparkUtils.gteqSpark3_2) {
"Invalid MERGE INTO matching condition: s0.id: can't cast s0.id (of LongType) to IntegerType"
} else {
"Invalid MERGE INTO matching condition: s0.`id`: can't cast s0.`id` (of LongType) to IntegerType"
}

checkExceptionContain(
s"""
|merge into $tableName h0
|using (
| select cast(1 as long) as id, 1001 as ts
| ) s0
| on cast(h0.id as long) = s0.id
| when matched then update set h0.ts = s0.ts
|""".stripMargin)(errorMsg)

// Can't down-cast incoming dataset's primary-key w/o loss of precision (should fail)
checkExceptionContain(
s"""
|merge into $tableName h0
|using (
| select cast(1 as long) as id, 1002 as ts
| ) s0
| on h0.id = s0.id
| when matched then update set h0.ts = s0.ts
|""".stripMargin)(errorMsg)

// Can up-cast incoming dataset's primary-key w/o loss of precision (should succeed)
spark.sql(
s"""
|merge into $tableName h0
|using (
| select cast(1 as short) as id, 1003 as ts
| ) s0
| on h0.id = s0.id
| when matched then update set h0.ts = s0.ts
|""".stripMargin)

checkAnswer(s"select id, name, value, ts from $tableName")(
Seq(1, "a1", 10, 1003)
)

// Can remove redundant symmetrical casting on both sides (should succeed)
spark.sql(
s"""
|merge into $tableName h0
|using (
| select cast(1 as int) as id, 1004 as ts
| ) s0
| on cast(h0.id as string) = cast(s0.id as string)
| when matched then update set h0.ts = s0.ts
|""".stripMargin)

checkAnswer(s"select id, name, value, ts from $tableName")(
Seq(1, "a1", 10, 1004)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import HoodieSparkTypeUtils.isCastPreservingOrdering
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Like, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.types.DataType

object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils {

Expand All @@ -29,6 +30,17 @@ object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils
}
}

def canUpCast(fromType: DataType, toType: DataType): Boolean =
// Spark 2.x does not support up-casting, hence we simply check whether types are
// actually the same
fromType.sameType(toType)

override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
expr match {
case Cast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, false))
case _ => None
}

private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
package org.apache.spark.sql

import HoodieSparkTypeUtils.isCastPreservingOrdering
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.types.DataType

object HoodieSpark31CatalystExpressionUtils extends HoodieCatalystExpressionUtils {

Expand All @@ -30,6 +31,16 @@ object HoodieSpark31CatalystExpressionUtils extends HoodieCatalystExpressionUtil
}
}

def canUpCast(fromType: DataType, toType: DataType): Boolean =
Cast.canUpCast(fromType, toType)

override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
expr match {
case Cast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, false))
case AnsiCast(castedExpr, dataType, timeZoneId) => Some((castedExpr, dataType, timeZoneId, true))
case _ => None
}

private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql

import HoodieSparkTypeUtils.isCastPreservingOrdering
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.types.DataType

object HoodieSpark32CatalystExpressionUtils extends HoodieCatalystExpressionUtils {

Expand All @@ -29,6 +30,18 @@ object HoodieSpark32CatalystExpressionUtils extends HoodieCatalystExpressionUtil
}
}

def canUpCast(fromType: DataType, toType: DataType): Boolean =
Cast.canUpCast(fromType, toType)

override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
expr match {
case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) =>
Some((castedExpr, dataType, timeZoneId, ansiEnabled))
case AnsiCast(castedExpr, dataType, timeZoneId) =>
Some((castedExpr, dataType, timeZoneId, true))
case _ => None
}

private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql

import HoodieSparkTypeUtils.isCastPreservingOrdering
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
import org.apache.spark.sql.types.DataType

object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtils {

Expand All @@ -29,6 +30,18 @@ object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtil
}
}

def canUpCast(fromType: DataType, toType: DataType): Boolean =
Cast.canUpCast(fromType, toType)

override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] =
expr match {
case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) =>
Some((castedExpr, dataType, timeZoneId, ansiEnabled))
case AnsiCast(castedExpr, dataType, timeZoneId) =>
Some((castedExpr, dataType, timeZoneId, true))
case _ => None
}

private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
Expand Down

0 comments on commit dc64c1a

Please sign in to comment.