Skip to content

Commit

Permalink
fix: Fallback to Spark for unsupported partitioning (apache#759)
Browse files Browse the repository at this point in the history
* fix: Fallback to Spark for unsupported partitioning

* fix

* Move test

* For review
  • Loading branch information
viirya authored Aug 2, 2024
1 parent e33d560 commit 2d95fea
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class CometSparkSessionExtensions
case s: ShuffleExchangeExec
if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(
conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 &&
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
Expand Down Expand Up @@ -769,7 +769,7 @@ class CometSparkSessionExtensions
// convert it to CometColumnarShuffle,
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 &&
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")

Expand All @@ -789,20 +789,22 @@ class CometSparkSessionExtensions

case s: ShuffleExchangeExec =>
val isShuffleEnabled = isCometShuffleEnabled(conf)
val outputPartitioning = s.outputPartitioning
val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available")
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
val msg2 = createMessage(
isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde
.supportPartitioning(s.child.output, s.outputPartitioning)
.supportPartitioning(s.child.output, outputPartitioning)
._1,
"Native shuffle: " +
s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}")
s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}")
val msg3 = createMessage(
isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
.supportPartitioningTypes(s.child.output)
.supportPartitioningTypes(s.child.output, outputPartitioning)
._1,
s"JVM shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}")
"JVM shuffle: " +
s"${QueryPlanSerde.supportPartitioningTypes(s.child.output, outputPartitioning)._2}")
withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
s

Expand Down
65 changes: 47 additions & 18 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
Expand Down Expand Up @@ -2895,7 +2895,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
* which supports struct/array.
*/
def supportPartitioningTypes(inputs: Seq[Attribute]): (Boolean, String) = {
def supportPartitioningTypes(
inputs: Seq[Attribute],
partitioning: Partitioning): (Boolean, String) = {
def supportedDataType(dt: DataType): Boolean = dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType |
Expand All @@ -2919,14 +2921,37 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
false
}

// Check if the datatypes of shuffle input are supported.
var msg = ""
val supported = inputs.forall(attr => supportedDataType(attr.dataType))
val supported = partitioning match {
case HashPartitioning(expressions, _) =>
val supported =
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
expressions.forall(e => supportedDataType(e.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $expressions"
}
supported
case SinglePartition => true
case RoundRobinPartitioning(_) => true
case RangePartitioning(orderings, _) =>
val supported =
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
orderings.forall(e => supportedDataType(e.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $orderings"
}
supported
case _ =>
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
false
}

if (!supported) {
msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}"
emitWarning(msg)
(false, msg)
} else {
(true, null)
}
(supported, msg)
}

/**
Expand All @@ -2945,23 +2970,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
false
}

// Check if the datatypes of shuffle input are supported.
val supported = inputs.forall(attr => supportedDataType(attr.dataType))
var msg = ""
val supported = partitioning match {
case HashPartitioning(expressions, _) =>
val supported =
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
expressions.forall(e => supportedDataType(e.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $expressions"
}
supported
case SinglePartition => true
case _ =>
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
false
}

if (!supported) {
val msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}"
emitWarning(msg)
(false, msg)
} else {
partitioning match {
case HashPartitioning(expressions, _) =>
(expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined), null)
case SinglePartition => (true, null)
case other =>
val msg = s"unsupported Spark partitioning: ${other.getClass.getName}"
emitWarning(msg)
(false, msg)
}
(true, null)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.exec

import java.util.Collections

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.Expressions.identity
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{FloatType, LongType, StringType, TimestampType}

class CometShuffle4_0Suite extends CometColumnarShuffleSuite {
override protected val asyncShuffleEnable: Boolean = false

protected val adaptiveExecutionEnabled: Boolean = true

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
}

override def afterAll(): Unit = {
spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat")
super.afterAll()
}

private val emptyProps: java.util.Map[String, String] = {
Collections.emptyMap[String, String]
}
private val items: String = "items"
private val itemsColumns: Array[Column] = Array(
Column.create("id", LongType),
Column.create("name", StringType),
Column.create("price", FloatType),
Column.create("arrive_time", TimestampType))

private val purchases: String = "purchases"
private val purchasesColumns: Array[Column] = Array(
Column.create("item_id", LongType),
Column.create("price", FloatType),
Column.create("time", TimestampType))

protected def catalog: InMemoryCatalog = {
val catalog = spark.sessionState.catalogManager.catalog("testcat")
catalog.asInstanceOf[InMemoryCatalog]
}

private def createTable(
table: String,
columns: Array[Column],
partitions: Array[Transform],
catalog: InMemoryTableCatalog = catalog): Unit = {
catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps)
}

private def selectWithMergeJoinHint(t1: String, t2: String): String = {
s"SELECT /*+ MERGE($t1, $t2) */ "
}

private def createJoinTestDF(
keys: Seq[(String, String)],
extraColumns: Seq[String] = Nil,
joinType: String = ""): DataFrame = {
val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "")
sql(s"""
|${selectWithMergeJoinHint("i", "p")}
|id, name, i.price as purchase_price, p.price as sale_price $extraColList
|FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p
|ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")}
|ORDER BY id, purchase_price, sale_price $extraColList
|""".stripMargin)
}

test("Fallback to Spark for unsupported partitioning") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(
s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(
s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(
SQLConf.V2_BUCKETING_ENABLED.key -> "true",
"spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
checkSparkAnswer(df)
}
}
}
}

0 comments on commit 2d95fea

Please sign in to comment.