Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-4875][VL]Support spark sql conf sortBeforeRepartition to avoid stage partial retry casuing result mismatch #4872

Merged
merged 10 commits into from
Mar 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec}
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
Expand Down Expand Up @@ -232,7 +232,24 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
TransformHints.tagNotTransformable(shuffle, validationResult)
shuffle.withNewChildren(newChild :: Nil)
}

case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 =>
val hashExpr = new Murmur3Hash(newChild.output)
val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output
val projectTransformer = ProjectExecTransformer(projectList, newChild)
val sortOrder = SortOrder(projectTransformer.output.head, Ascending)
val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = false, projectTransformer)
val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode)
if (dropSortColumnTransformer.doValidate().isValid) {
ColumnarShuffleExchangeExec(
shuffle,
dropSortColumnTransformer,
dropSortColumnTransformer.output)
} else {
TransformHints.tagNotTransformable(
shuffle,
dropSortColumnTransformer.doValidate().reason.get)
zjuwangg marked this conversation as resolved.
Show resolved Hide resolved
shuffle.withNewChildren(newChild :: Nil)
}
case _ =>
ColumnarShuffleExchangeExec(shuffle, newChild, null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1233,4 +1233,29 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
checkOperatorMatch[HashAggregateExecTransformer]
}
}

test("test roundrobine with sort") {
// scalastyle:off
runQueryAndCompare("SELECT /*+ REPARTITION(3) */ l_orderkey, l_partkey FROM lineitem") {
/*
ColumnarExchange RoundRobinPartitioning(3), REPARTITION_BY_NUM, [l_orderkey#16L, l_partkey#17L)
+- ^(2) ProjectExecTransformer [l_orderkey#16L, l_partkey#17L]
+- ^(2) SortExecTransformer [hash_partition_key#302 ASC NULLS FIRST], false, 0
+- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, l_partkey#17L) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L]
+- ^(2) BatchScanExecTransformer[l_orderkey#16L, l_partkey#17L] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<l_orderkey:bigint,l_partkey:bigint>, PushedFilters: [] RuntimeFilters: []
*/
checkOperatorMatch[SortExecTransformer]
}
// scalastyle:on

withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") {
runQueryAndCompare("""SELECT /*+ REPARTITION(3) */
| l_orderkey, l_partkey FROM lineitem""".stripMargin) {
df =>
{
assert(getExecutedPlan(df).count(_.isInstanceOf[SortExecTransformer]) == 0)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ class GlutenImplicitsTest extends GlutenSQLTestsBaseTrait {
testGluten("fallbackSummary with cached data and shuffle") {
withAQEEnabledAndDisabled {
val df = spark.sql("select * from t1").filter(_.getLong(0) > 0).cache.repartition()
assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary())
assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary())
assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary())
df.collect()
assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary())
assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary())
assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class GlutenReplaceHashWithSortAggSuite
| SORT BY key
|)
|GROUP BY key
|ORDER BY key
zjuwangg marked this conversation as resolved.
Show resolved Hide resolved
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class GlutenReplaceHashWithSortAggSuite
| SORT BY key
|)
|GROUP BY key
|ORDER BY key
zjuwangg marked this conversation as resolved.
Show resolved Hide resolved
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
Expand Down
Loading