From e32decb23657ec223ce5f6f9b12795bf1531b0fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Aug 2024 16:02:00 -0700 Subject: [PATCH 1/4] fix: Fallback to Spark for unsupported partitioning --- .../comet/CometSparkSessionExtensions.scala | 12 +- .../apache/comet/serde/QueryPlanSerde.scala | 65 ++++++++--- .../exec/CometColumnarShuffleSuite.scala | 105 ++++++++++++++++++ 3 files changed, 159 insertions(+), 23 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 028ea063a..948230527 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -212,7 +212,7 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode( conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 && + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) @@ -769,7 +769,7 @@ class CometSparkSessionExtensions // convert it to CometColumnarShuffle, case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 && + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") @@ -789,6 +789,7 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec => val isShuffleEnabled = isCometShuffleEnabled(conf) + val outputPartitioning = s.outputPartitioning val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available") val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason") val columnarShuffleEnabled = isCometJVMShuffleMode(conf) @@ -797,12 +798,13 @@ class CometSparkSessionExtensions .supportPartitioning(s.child.output, s.outputPartitioning) ._1, "Native shuffle: " + - s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}") + s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}") val msg3 = createMessage( isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioningTypes(s.child.output) + .supportPartitioningTypes(s.child.output, s.outputPartitioning) ._1, - s"JVM shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}") + "JVM shuffle: " + + s"${QueryPlanSerde.supportPartitioningTypes(s.child.output, outputPartitioning)._2}") withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(",")) s diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d91ae5e4d..c0da27488 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec @@ -2880,7 +2880,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle * which supports struct/array. */ - def supportPartitioningTypes(inputs: Seq[Attribute]): (Boolean, String) = { + def supportPartitioningTypes( + inputs: Seq[Attribute], + partitioning: Partitioning): (Boolean, String) = { def supportedDataType(dt: DataType): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | @@ -2904,14 +2906,37 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim false } - // Check if the datatypes of shuffle input are supported. var msg = "" - val supported = inputs.forall(attr => supportedDataType(attr.dataType)) + val supported = partitioning match { + case HashPartitioning(expressions, _) => + val supported = + expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + expressions.forall(e => supportedDataType(e.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $expressions" + } + supported + case SinglePartition => true + case RoundRobinPartitioning(_) => true + case RangePartitioning(orderings, _) => + val supported = + orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + orderings.forall(e => supportedDataType(e.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $orderings" + } + supported + case _ => + msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" + false + } + if (!supported) { - msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}" emitWarning(msg) + (false, msg) + } else { + (true, null) } - (supported, msg) } /** @@ -2930,23 +2955,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim false } - // Check if the datatypes of shuffle input are supported. - val supported = inputs.forall(attr => supportedDataType(attr.dataType)) + var msg = "" + val supported = partitioning match { + case HashPartitioning(expressions, _) => + val supported = + expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + expressions.forall(e => supportedDataType(e.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $expressions" + } + supported + case SinglePartition => true + case _ => + msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" + false + } if (!supported) { - val msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}" emitWarning(msg) (false, msg) } else { - partitioning match { - case HashPartitioning(expressions, _) => - (expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined), null) - case SinglePartition => (true, null) - case other => - val msg = s"unsupported Spark partitioning: ${other.getClass.getName}" - emitWarning(msg) - (false, msg) - } + (true, null) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index c38be7c4a..f89f1b9c5 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -19,6 +19,8 @@ package org.apache.comet.exec +import java.util.Collections + import org.scalactic.source.Position import org.scalatest.Tag @@ -26,6 +28,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} +import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog} +import org.apache.spark.sql.connector.distributions.Distributions +import org.apache.spark.sql.connector.expressions.Expressions._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -34,6 +40,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { protected val adaptiveExecutionEnabled: Boolean @@ -47,6 +54,16 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar protected val asyncShuffleEnable: Boolean + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) + } + + override def afterAll(): Unit = { + spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") + super.afterAll() + } + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { super.test(testName, testTags: _*) { @@ -85,6 +102,94 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } } + private val emptyProps: java.util.Map[String, String] = { + Collections.emptyMap[String, String] + } + private val items: String = "items" + private val itemsColumns: Array[Column] = Array( + Column.create("id", LongType), + Column.create("name", StringType), + Column.create("price", FloatType), + Column.create("arrive_time", TimestampType)) + + private val purchases: String = "purchases" + private val purchasesColumns: Array[Column] = Array( + Column.create("item_id", LongType), + Column.create("price", FloatType), + Column.create("time", TimestampType)) + + protected def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("testcat") + catalog.asInstanceOf[InMemoryCatalog] + } + + private def createTable( + table: String, + columns: Array[Column], + partitions: Array[Transform], + catalog: InMemoryTableCatalog = catalog): Unit = { + catalog.createTable( + Identifier.of(Array("ns"), table), + columns, + partitions, + emptyProps, + Distributions.unspecified(), + Array.empty, + None, + None, + numRowsPerSplit = 1) + } + + private def selectWithMergeJoinHint(t1: String, t2: String): String = { + s"SELECT /*+ MERGE($t1, $t2) */ " + } + + private def createJoinTestDF( + keys: Seq[(String, String)], + extraColumns: Seq[String] = Nil, + joinType: String = ""): DataFrame = { + val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "") + sql(s""" + |${selectWithMergeJoinHint("i", "p")} + |id, name, i.price as purchase_price, p.price as sale_price $extraColList + |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p + |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")} + |ORDER BY id, purchase_price, sale_price $extraColList + |""".stripMargin) + } + + test("Fallback to Spark for unsupported partitioning") { + assume(isSpark40Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql( + s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql( + s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf( + SQLConf.V2_BUCKETING_ENABLED.key -> "true", + "spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkSparkAnswer(df) + } + } + } + test("columnar shuffle on nested struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => From 932e15c3c3e70ffd22d200b16e3dfa0ff3adafa9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Aug 2024 18:38:44 -0700 Subject: [PATCH 2/4] fix --- .../comet/exec/CometColumnarShuffleSuite.scala | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index f89f1b9c5..b4c50b5ef 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.{Partitioner, SparkConf} import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog} -import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions.Expressions._ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} @@ -128,16 +127,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar columns: Array[Column], partitions: Array[Transform], catalog: InMemoryTableCatalog = catalog): Unit = { - catalog.createTable( - Identifier.of(Array("ns"), table), - columns, - partitions, - emptyProps, - Distributions.unspecified(), - Array.empty, - None, - None, - numRowsPerSplit = 1) + catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps) } private def selectWithMergeJoinHint(t1: String, t2: String): String = { From a7ed952ee53f5188c24398da010e1cac090cee2a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Aug 2024 20:42:07 -0700 Subject: [PATCH 3/4] Move test --- .../exec/CometColumnarShuffleSuite.scala | 95 -------------- .../comet/exec/CometShuffle4_0Suite.scala | 122 ++++++++++++++++++ 2 files changed, 122 insertions(+), 95 deletions(-) create mode 100644 spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index b4c50b5ef..c38be7c4a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -19,8 +19,6 @@ package org.apache.comet.exec -import java.util.Collections - import org.scalactic.source.Position import org.scalatest.Tag @@ -28,9 +26,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} -import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog} -import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -39,7 +34,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { protected val adaptiveExecutionEnabled: Boolean @@ -53,16 +47,6 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar protected val asyncShuffleEnable: Boolean - override def beforeAll(): Unit = { - super.beforeAll() - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) - } - - override def afterAll(): Unit = { - spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") - super.afterAll() - } - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { super.test(testName, testTags: _*) { @@ -101,85 +85,6 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } } - private val emptyProps: java.util.Map[String, String] = { - Collections.emptyMap[String, String] - } - private val items: String = "items" - private val itemsColumns: Array[Column] = Array( - Column.create("id", LongType), - Column.create("name", StringType), - Column.create("price", FloatType), - Column.create("arrive_time", TimestampType)) - - private val purchases: String = "purchases" - private val purchasesColumns: Array[Column] = Array( - Column.create("item_id", LongType), - Column.create("price", FloatType), - Column.create("time", TimestampType)) - - protected def catalog: InMemoryCatalog = { - val catalog = spark.sessionState.catalogManager.catalog("testcat") - catalog.asInstanceOf[InMemoryCatalog] - } - - private def createTable( - table: String, - columns: Array[Column], - partitions: Array[Transform], - catalog: InMemoryTableCatalog = catalog): Unit = { - catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps) - } - - private def selectWithMergeJoinHint(t1: String, t2: String): String = { - s"SELECT /*+ MERGE($t1, $t2) */ " - } - - private def createJoinTestDF( - keys: Seq[(String, String)], - extraColumns: Seq[String] = Nil, - joinType: String = ""): DataFrame = { - val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "") - sql(s""" - |${selectWithMergeJoinHint("i", "p")} - |id, name, i.price as purchase_price, p.price as sale_price $extraColList - |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p - |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")} - |ORDER BY id, purchase_price, sale_price $extraColList - |""".stripMargin) - } - - test("Fallback to Spark for unsupported partitioning") { - assume(isSpark40Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - - val items_partitions = Array(identity("id")) - createTable(items, itemsColumns, items_partitions) - - sql( - s"INSERT INTO testcat.ns.$items VALUES " + - "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + - "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + - "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") - - createTable(purchases, purchasesColumns, Array.empty) - sql( - s"INSERT INTO testcat.ns.$purchases VALUES " + - "(1, 42.0, cast('2020-01-01' as timestamp)), " + - "(3, 19.5, cast('2020-02-01' as timestamp)), " + - "(5, 26.0, cast('2023-01-01' as timestamp)), " + - "(6, 50.0, cast('2023-02-01' as timestamp))") - - Seq(true, false).foreach { shuffle => - withSQLConf( - SQLConf.V2_BUCKETING_ENABLED.key -> "true", - "spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString, - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { - val df = createJoinTestDF(Seq("id" -> "item_id")) - checkSparkAnswer(df) - } - } - } - test("columnar shuffle on nested struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => diff --git a/spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala b/spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala new file mode 100644 index 000000000..517f3f3d2 --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import java.util.Collections + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog} +import org.apache.spark.sql.connector.expressions.Expressions.identity +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{FloatType, LongType, StringType, TimestampType} + +class CometShuffle4_0Suite extends CometColumnarShuffleSuite { + override protected val asyncShuffleEnable: Boolean = false + + protected val adaptiveExecutionEnabled: Boolean = true + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) + } + + override def afterAll(): Unit = { + spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") + super.afterAll() + } + + private val emptyProps: java.util.Map[String, String] = { + Collections.emptyMap[String, String] + } + private val items: String = "items" + private val itemsColumns: Array[Column] = Array( + Column.create("id", LongType), + Column.create("name", StringType), + Column.create("price", FloatType), + Column.create("arrive_time", TimestampType)) + + private val purchases: String = "purchases" + private val purchasesColumns: Array[Column] = Array( + Column.create("item_id", LongType), + Column.create("price", FloatType), + Column.create("time", TimestampType)) + + protected def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("testcat") + catalog.asInstanceOf[InMemoryCatalog] + } + + private def createTable( + table: String, + columns: Array[Column], + partitions: Array[Transform], + catalog: InMemoryTableCatalog = catalog): Unit = { + catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps) + } + + private def selectWithMergeJoinHint(t1: String, t2: String): String = { + s"SELECT /*+ MERGE($t1, $t2) */ " + } + + private def createJoinTestDF( + keys: Seq[(String, String)], + extraColumns: Seq[String] = Nil, + joinType: String = ""): DataFrame = { + val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "") + sql(s""" + |${selectWithMergeJoinHint("i", "p")} + |id, name, i.price as purchase_price, p.price as sale_price $extraColList + |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p + |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")} + |ORDER BY id, purchase_price, sale_price $extraColList + |""".stripMargin) + } + + test("Fallback to Spark for unsupported partitioning") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql( + s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql( + s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf( + SQLConf.V2_BUCKETING_ENABLED.key -> "true", + "spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkSparkAnswer(df) + } + } + } +} From afd70e68556199347f40810126ba9a7dacab3064 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Aug 2024 23:22:14 -0700 Subject: [PATCH 4/4] For review --- .../scala/org/apache/comet/CometSparkSessionExtensions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 948230527..96ec23fae 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -795,13 +795,13 @@ class CometSparkSessionExtensions val columnarShuffleEnabled = isCometJVMShuffleMode(conf) val msg2 = createMessage( isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioning(s.child.output, s.outputPartitioning) + .supportPartitioning(s.child.output, outputPartitioning) ._1, "Native shuffle: " + s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}") val msg3 = createMessage( isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioningTypes(s.child.output, s.outputPartitioning) + .supportPartitioningTypes(s.child.output, outputPartitioning) ._1, "JVM shuffle: " + s"${QueryPlanSerde.supportPartitioningTypes(s.child.output, outputPartitioning)._2}")