Skip to content

Commit

Permalink
[GLUTEN-4875][VL]Support spark sql conf sortBeforeRepartition to avoi…
Browse files Browse the repository at this point in the history
…d stage partial retry casuing result mismatch (apache#4872)
  • Loading branch information
zjuwangg authored and taiyang-li committed Oct 8, 2024
1 parent 768bb13 commit 06d99c2
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec}
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
Expand Down Expand Up @@ -232,7 +232,23 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
TransformHints.tagNotTransformable(shuffle, validationResult)
shuffle.withNewChildren(newChild :: Nil)
}

case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 =>
val hashExpr = new Murmur3Hash(newChild.output)
val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output
val projectTransformer = ProjectExecTransformer(projectList, newChild)
val sortOrder = SortOrder(projectTransformer.output.head, Ascending)
val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = false, projectTransformer)
val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode)
val validationResult = dropSortColumnTransformer.doValidate()
if (validationResult.isValid) {
ColumnarShuffleExchangeExec(
shuffle,
dropSortColumnTransformer,
dropSortColumnTransformer.output)
} else {
TransformHints.tagNotTransformable(shuffle, validationResult)
shuffle.withNewChildren(newChild :: Nil)
}
case _ =>
ColumnarShuffleExchangeExec(shuffle, newChild, null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1233,4 +1233,29 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
checkOperatorMatch[HashAggregateExecTransformer]
}
}

test("test roundrobine with sort") {
// scalastyle:off
runQueryAndCompare("SELECT /*+ REPARTITION(3) */ l_orderkey, l_partkey FROM lineitem") {
/*
ColumnarExchange RoundRobinPartitioning(3), REPARTITION_BY_NUM, [l_orderkey#16L, l_partkey#17L)
+- ^(2) ProjectExecTransformer [l_orderkey#16L, l_partkey#17L]
+- ^(2) SortExecTransformer [hash_partition_key#302 ASC NULLS FIRST], false, 0
+- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, l_partkey#17L) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L]
+- ^(2) BatchScanExecTransformer[l_orderkey#16L, l_partkey#17L] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<l_orderkey:bigint,l_partkey:bigint>, PushedFilters: [] RuntimeFilters: []
*/
checkOperatorMatch[SortExecTransformer]
}
// scalastyle:on

withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") {
runQueryAndCompare("""SELECT /*+ REPARTITION(3) */
| l_orderkey, l_partkey FROM lineitem""".stripMargin) {
df =>
{
assert(getExecutedPlan(df).count(_.isInstanceOf[SortExecTransformer]) == 0)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ class GlutenImplicitsTest extends GlutenSQLTestsBaseTrait {
testGluten("fallbackSummary with cached data and shuffle") {
withAQEEnabledAndDisabled {
val df = spark.sql("select * from t1").filter(_.getLong(0) > 0).cache.repartition()
assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary())
assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary())
assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary())
df.collect()
assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary())
assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary())
assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class GlutenReplaceHashWithSortAggSuite

Seq("FIRST", "COLLECT_LIST").foreach {
aggExpr =>
// Because repartition modification causing the result sort order not same and the
// result not same, so we add order by key before comparing the result.
val query =
s"""
|SELECT key, $aggExpr(key)
Expand All @@ -72,6 +74,7 @@ class GlutenReplaceHashWithSortAggSuite
| SORT BY key
|)
|GROUP BY key
|ORDER BY key
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class GlutenReplaceHashWithSortAggSuite

Seq("FIRST", "COLLECT_LIST").foreach {
aggExpr =>
// Because repartition modification causing the result sort order not same and the
// result not same, so we add order by key before comparing the result.
val query =
s"""
|SELECT key, $aggExpr(key)
Expand All @@ -71,6 +73,7 @@ class GlutenReplaceHashWithSortAggSuite
| SORT BY key
|)
|GROUP BY key
|ORDER BY key
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
Expand Down

0 comments on commit 06d99c2

Please sign in to comment.