From 2dd1a3fed239af619e0c38d26f3b70b3498bbafd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Aug 2024 06:03:08 -0700 Subject: [PATCH] fix: Fallback to Spark for unsupported partitioning (#759) * fix: Fallback to Spark for unsupported partitioning * fix * Move test * For review (cherry picked from commit 2d95fea7890cf1ceca6b16004ac233d7188e8421) --- .../comet/CometSparkSessionExtensions.scala | 14 +- .../apache/comet/serde/QueryPlanSerde.scala | 65 +++++++--- .../comet/exec/CometShuffle4_0Suite.scala | 122 ++++++++++++++++++ 3 files changed, 177 insertions(+), 24 deletions(-) create mode 100644 spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 3b831c1407..30c0efbc93 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -212,7 +212,7 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode( conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 && + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) @@ -769,7 +769,7 @@ class CometSparkSessionExtensions // convert it to CometColumnarShuffle, case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 && + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") @@ -789,20 +789,22 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec => val isShuffleEnabled = isCometShuffleEnabled(conf) + val outputPartitioning = s.outputPartitioning val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available") val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason") val columnarShuffleEnabled = isCometJVMShuffleMode(conf) val msg2 = createMessage( isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioning(s.child.output, s.outputPartitioning) + .supportPartitioning(s.child.output, outputPartitioning) ._1, "Native shuffle: " + - s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}") + s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}") val msg3 = createMessage( isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioningTypes(s.child.output) + .supportPartitioningTypes(s.child.output, outputPartitioning) ._1, - s"JVM shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}") + "JVM shuffle: " + + s"${QueryPlanSerde.supportPartitioningTypes(s.child.output, outputPartitioning)._2}") withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(",")) s diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 917916c126..948c534d09 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec @@ -2895,7 +2895,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle * which supports struct/array. */ - def supportPartitioningTypes(inputs: Seq[Attribute]): (Boolean, String) = { + def supportPartitioningTypes( + inputs: Seq[Attribute], + partitioning: Partitioning): (Boolean, String) = { def supportedDataType(dt: DataType): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | @@ -2919,14 +2921,37 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim false } - // Check if the datatypes of shuffle input are supported. var msg = "" - val supported = inputs.forall(attr => supportedDataType(attr.dataType)) + val supported = partitioning match { + case HashPartitioning(expressions, _) => + val supported = + expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + expressions.forall(e => supportedDataType(e.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $expressions" + } + supported + case SinglePartition => true + case RoundRobinPartitioning(_) => true + case RangePartitioning(orderings, _) => + val supported = + orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + orderings.forall(e => supportedDataType(e.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $orderings" + } + supported + case _ => + msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" + false + } + if (!supported) { - msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}" emitWarning(msg) + (false, msg) + } else { + (true, null) } - (supported, msg) } /** @@ -2945,23 +2970,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim false } - // Check if the datatypes of shuffle input are supported. - val supported = inputs.forall(attr => supportedDataType(attr.dataType)) + var msg = "" + val supported = partitioning match { + case HashPartitioning(expressions, _) => + val supported = + expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + expressions.forall(e => supportedDataType(e.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $expressions" + } + supported + case SinglePartition => true + case _ => + msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" + false + } if (!supported) { - val msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}" emitWarning(msg) (false, msg) } else { - partitioning match { - case HashPartitioning(expressions, _) => - (expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined), null) - case SinglePartition => (true, null) - case other => - val msg = s"unsupported Spark partitioning: ${other.getClass.getName}" - emitWarning(msg) - (false, msg) - } + (true, null) } } diff --git a/spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala b/spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala new file mode 100644 index 0000000000..517f3f3d21 --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/comet/exec/CometShuffle4_0Suite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import java.util.Collections + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog} +import org.apache.spark.sql.connector.expressions.Expressions.identity +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{FloatType, LongType, StringType, TimestampType} + +class CometShuffle4_0Suite extends CometColumnarShuffleSuite { + override protected val asyncShuffleEnable: Boolean = false + + protected val adaptiveExecutionEnabled: Boolean = true + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) + } + + override def afterAll(): Unit = { + spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") + super.afterAll() + } + + private val emptyProps: java.util.Map[String, String] = { + Collections.emptyMap[String, String] + } + private val items: String = "items" + private val itemsColumns: Array[Column] = Array( + Column.create("id", LongType), + Column.create("name", StringType), + Column.create("price", FloatType), + Column.create("arrive_time", TimestampType)) + + private val purchases: String = "purchases" + private val purchasesColumns: Array[Column] = Array( + Column.create("item_id", LongType), + Column.create("price", FloatType), + Column.create("time", TimestampType)) + + protected def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("testcat") + catalog.asInstanceOf[InMemoryCatalog] + } + + private def createTable( + table: String, + columns: Array[Column], + partitions: Array[Transform], + catalog: InMemoryTableCatalog = catalog): Unit = { + catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps) + } + + private def selectWithMergeJoinHint(t1: String, t2: String): String = { + s"SELECT /*+ MERGE($t1, $t2) */ " + } + + private def createJoinTestDF( + keys: Seq[(String, String)], + extraColumns: Seq[String] = Nil, + joinType: String = ""): DataFrame = { + val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "") + sql(s""" + |${selectWithMergeJoinHint("i", "p")} + |id, name, i.price as purchase_price, p.price as sale_price $extraColList + |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p + |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")} + |ORDER BY id, purchase_price, sale_price $extraColList + |""".stripMargin) + } + + test("Fallback to Spark for unsupported partitioning") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql( + s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql( + s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf( + SQLConf.V2_BUCKETING_ENABLED.key -> "true", + "spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkSparkAnswer(df) + } + } + } +}