diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala index d403f1998c6b..e212b5153e37 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala @@ -473,7 +473,8 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie val targetTableDb = targetTableIdentify.database.getOrElse("default") val targetTableName = targetTableIdentify.identifier val path = hoodieCatalogTable.tableLocation - val catalogProperties = hoodieCatalogTable.catalogProperties + // force to use ExpressionPayload as WRITE_PAYLOAD_CLASS_NAME in MergeIntoHoodieTableCommand + val catalogProperties = hoodieCatalogTable.catalogProperties + (PAYLOAD_CLASS_NAME.key -> classOf[ExpressionPayload].getCanonicalName) val tableConfig = hoodieCatalogTable.tableConfig val tableSchema = hoodieCatalogTable.tableSchema val partitionColumns = tableConfig.getPartitionFieldProp.split(",").map(_.toLowerCase) @@ -487,14 +488,13 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie val hoodieProps = getHoodieProps(catalogProperties, tableConfig, sparkSession.sqlContext.conf) val hiveSyncConfig = buildHiveSyncConfig(hoodieProps, hoodieCatalogTable) - withSparkConf(sparkSession, hoodieCatalogTable.catalogProperties) { + withSparkConf(sparkSession, catalogProperties) { Map( "path" -> path, RECORDKEY_FIELD.key -> tableConfig.getRecordKeyFieldProp, PRECOMBINE_FIELD.key -> preCombineField, TBL_NAME.key -> hoodieCatalogTable.tableName, PARTITIONPATH_FIELD.key -> tableConfig.getPartitionFieldProp, - PAYLOAD_CLASS_NAME.key -> classOf[ExpressionPayload].getCanonicalName, HIVE_STYLE_PARTITIONING.key -> tableConfig.getHiveStylePartitioningEnable, URL_ENCODE_PARTITIONING.key -> tableConfig.getUrlEncodePartitioning, KEYGENERATOR_CLASS_NAME.key -> classOf[SqlKeyGenerator].getCanonicalName, diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala index 8e6acd1be58c..8a6aa9691d93 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala @@ -674,7 +674,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { } } - test ("Test Merge into with String cast to Double") { + test("Test Merge into with String cast to Double") { withTempDir { tmp => val tableName = generateTableName // Create a cow partitioned table. @@ -713,4 +713,42 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { ) } } + + test("Test Merge into where manually set DefaultHoodieRecordPayload") { + withTempDir { tmp => + val tableName = generateTableName + // Create a cow table with default payload class, check whether it will be overwritten by ExpressionPayload. + // if not, this ut cannot pass since DefaultHoodieRecordPayload can not promotion int to long when insert a ts with Integer value + spark.sql( + s""" + | create table $tableName ( + | id int, + | name string, + | ts long + | ) using hudi + | tblproperties ( + | type = 'cow', + | primaryKey = 'id', + | preCombineField = 'ts', + | hoodie.datasource.write.payload.class = 'org.apache.hudi.common.model.DefaultHoodieRecordPayload' + | ) location '${tmp.getCanonicalPath}' + """.stripMargin) + // Insert data + spark.sql(s"insert into $tableName select 1, 'a1', 999") + spark.sql( + s""" + | merge into $tableName as t0 + | using ( + | select 'a2' as name, 1 as id, 1000 as ts + | ) as s0 + | on t0.id = s0.id + | when matched then update set t0.name = s0.name, t0.ts = s0.ts + | when not matched then insert * + """.stripMargin + ) + checkAnswer(s"select id,name,ts from $tableName")( + Seq(1, "a2", 1000) + ) + } + } }