Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Shuffle #46255

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.immutable.ArraySeq
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.log10
Expand Down Expand Up @@ -149,7 +150,9 @@ private[spark] class KeyGroupedPartitioner(
override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
val keys = key.asInstanceOf[Seq[Any]]
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
val normalizedKeys = ArraySeq.from(keys)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what does this do? why it is normalized?

Copy link
Contributor Author

@szehon-ho szehon-ho May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc, I hit a bug due to trying to compare different Seq types (more info in the pr description)

normalize the valueMap key type in KeyGroupedPartitioner to use specific Seq implementation class. Previously the partitioner's map are initialized with keys as Vector , but then compared with keys as ArraySeq, and these seem to have different hashcodes, so will always create new entries with new partition ids.

valueMap.getOrElseUpdate(normalizedKeys,
Utils.nonNegativeMod(normalizedKeys.hashCode, numPartitions))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType
case class TransformExpression(
function: BoundFunction,
children: Seq[Expression],
numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable {
numBucketsOpt: Option[Int] = None) extends Expression {

override def nullable: Boolean = true

Expand Down Expand Up @@ -113,4 +116,23 @@ case class TransformExpression(

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

private lazy val resolvedFunction: Option[Expression] = this match {
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
Seq(Literal(numBuckets)) ++ arguments))
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
case _ => None
}

override def eval(input: InternalRow): Any = {
resolvedFunction match {
case Some(fn) => fn.eval(input)
case None => throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,30 @@ case class KeyGroupedShuffleSpec(
if (results.forall(p => p.isEmpty)) None else Some(results)
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
override def canCreatePartitioning: Boolean = {
// Allow one side shuffle for SPJ for now only if partially-clustered is not enabled
// and for join keys less than partition keys only if transforms are not enabled.
val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand the reason behind this. Also, it might be better to add some logging here if it is easy.

Copy link
Contributor Author

@szehon-ho szehon-ho May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, I hit a pretty hard bug when trying to enable the feature with v2BucketingAllowJoinKeysSubsetOfPartitionKeys (more in the pr description). As we may need to rethink the logic of v2BucketingAllowJoinKeysSubsetOfPartitionKeys to fix, I was going to disable for now, and try to fix in a subsequent PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at it, maybe i will do logging in another pr. There's no table name so not sure if its valuable to log the decision?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.

e: Expression => e.isInstanceOf[AttributeReference]
} else {
e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
}
SQLConf.get.v2BucketingShuffleEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
partitioning.expressions.forall(checkExprType)
}



override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map {
case (c, e: TransformExpression) => TransformExpression(
e.function, Seq(c), e.numBucketsOpt)
case (c, _) => c
}
KeyGroupedPartitioning(newExpressions,
partitioning.numPartitions,
partitioning.partitionValues)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
val df = createJoinTestDF(Seq("arrive_time" -> "time"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 2, "partitioning with transform not work now")
assert(shuffles.size == 1, "partitioning with transform should trigger SPJ")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
Expand Down Expand Up @@ -1991,22 +1991,19 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { pushdownValues =>
Seq(true, false).foreach { partiallyClustered =>
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key
-> partiallyClustered.toString,
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Row(1, "aa", 40.0, 89.0),
Row(3, "bb", 10.0, 19.5)))
}
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Row(1, "aa", 40.0, 89.0),
Row(3, "bb", 10.0, 19.5)))
}
}
}
Expand Down Expand Up @@ -2052,4 +2049,109 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}
}

test("SPARK-48012: one-side shuffle with partition transforms") {
val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id"))

Seq(items_partitions, items_partitions2).foreach { partition =>
catalog.clearTables()

createTable(items, itemsColumns, partition)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
"(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " +
"(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " +
"(5, 'ff', 32.1, cast('2020-03-01' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 10.7, cast('2020-01-01' as timestamp))," +
"(3, 19.5, cast('2020-02-01' as timestamp))," +
"(4, 56.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "only shuffle side that does not report partitioning")

checkAnswer(df, Seq(
Row(1, "bb", 30.0, 42.0),
Row(1, "aa", 40.0, 42.0),
Row(4, "ee", 15.5, 56.5)))
}
}
}

test("SPARK-48012: one-side shuffle with partition transforms and pushdown values") {
val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
"(1, 'cc', 30.0, cast('2020-01-02' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 10.7, cast('2020-01-01' as timestamp))")

Seq(true, false).foreach { pushDown => {
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
pushDown.toString) {
val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "only shuffle side that does not report partitioning")

checkAnswer(df, Seq(
Row(1, "bb", 30.0, 42.0),
Row(1, "aa", 40.0, 42.0)))
}
}
}
}

test("SPARK-48012: one-side shuffle with partition transforms " +
"with fewer join keys than partition kes") {
val items_partitions = Array(bucket(2, "id"), identity("name"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 89.0, cast('2020-01-03' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" +
"less join keys than partition keys for now.")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Row(1, "aa", 40.0, 89.0),
Row(3, "bb", 10.0, 19.5)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
*/
package org.apache.spark.sql.connector.catalog.functions

import java.sql.Timestamp
import java.time.{Instant, LocalDate, ZoneId}
import java.time.temporal.ChronoUnit

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] {
override def name(): String = "years"
override def canonicalName(): String = name()

def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900
val UTC: ZoneId = ZoneId.of("UTC")
val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate

def invoke(ts: Long): Long = {
val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
}
}

object DaysFunction extends BoundFunction {
Expand Down