From bdb72fd1857d40f6feb852588e18e2d15d137be1 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 8 Jul 2023 15:35:50 +0800 Subject: [PATCH 1/3] [SPARK-44340][SQL] Define the computing logic through PartitionEvaluator API and use it in WindowGroupLimitExec --- .../window/WindowGroupLimitExec.scala | 39 +- .../sql/DataFrameWindowFunctionsSuite.scala | 340 +++++++++--------- 2 files changed, 190 insertions(+), 189 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala index a35c33577d0fb..b1f375f415102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.window import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, DenseRank, Expression, Rank, RowNumber, SortOrder, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -73,26 +73,23 @@ case class WindowGroupLimitExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - rankLikeFunction match { - case _: RowNumber if partitionSpec.isEmpty => - child.execute().mapPartitionsInternal(SimpleLimitIterator(_, limit, numOutputRows)) - case _: RowNumber => - child.execute().mapPartitionsInternal(new GroupedLimitIterator(_, output, partitionSpec, - (input: Iterator[InternalRow]) => SimpleLimitIterator(input, limit, numOutputRows))) - case _: Rank if partitionSpec.isEmpty => - child.execute().mapPartitionsInternal( - RankLimitIterator(output, _, orderSpec, limit, numOutputRows)) - case _: Rank => - child.execute().mapPartitionsInternal(new GroupedLimitIterator(_, output, partitionSpec, - (input: Iterator[InternalRow]) => - RankLimitIterator(output, input, orderSpec, limit, numOutputRows))) - case _: DenseRank if partitionSpec.isEmpty => - child.execute().mapPartitionsInternal( - DenseRankLimitIterator(output, _, orderSpec, limit, numOutputRows)) - case _: DenseRank => - child.execute().mapPartitionsInternal(new GroupedLimitIterator(_, output, partitionSpec, - (input: Iterator[InternalRow]) => - DenseRankLimitIterator(output, input, orderSpec, limit, numOutputRows))) + + val evaluatorFactory = + new WindowGroupLimitEvaluatorFactory( + partitionSpec, + orderSpec, + rankLikeFunction, + limit, + child.output, + numOutputRows) + + if (conf.usePartitionEvaluator) { + child.execute().mapPartitionsWithEvaluator(evaluatorFactory) + } else { + child.execute().mapPartitionsInternal { iter => + val evaluator = evaluatorFactory.createEvaluator() + evaluator.eval(0, iter) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index fe4a5cebc5d69..f2f645b126cbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -1289,188 +1289,192 @@ class DataFrameWindowFunctionsSuite extends QueryTest val window2 = Window.partitionBy($"key").orderBy($"order".desc_nulls_first) val window3 = Window.orderBy($"order".asc_nulls_first) - Seq(-1, 100).foreach { threshold => - withSQLConf(SQLConf.WINDOW_GROUP_LIMIT_THRESHOLD.key -> threshold.toString) { - Seq($"rn" === 0, $"rn" < 1, $"rn" <= 0).foreach { condition => - checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), - Seq.empty[Row] - ) - } - - Seq($"rn" === 1, $"rn" < 2, $"rn" <= 1).foreach { condition => - checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), - Seq( - Row("a", 4, "", 2.0, 1), - Row("b", 1, "h", Double.NaN, 1), - Row("c", 2, null, 5.0, 1) - ) - ) - - checkAnswer(df.withColumn("rn", rank().over(window)).where(condition), - Seq( - Row("a", 4, "", 2.0, 1), - Row("a", 4, "", 2.0, 1), - Row("b", 1, "h", Double.NaN, 1), - Row("c", 2, null, 5.0, 1) - ) - ) - - checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), - Seq( - Row("a", 4, "", 2.0, 1), - Row("a", 4, "", 2.0, 1), - Row("b", 1, "h", Double.NaN, 1), - Row("c", 2, null, 5.0, 1) - ) - ) - - checkAnswer(df.withColumn("rn", row_number().over(window3)).where(condition), - Seq( - Row("c", 2, null, 5.0, 1) + Seq(true, false).foreach { enableEvaluator => + withSQLConf(SQLConf.USE_PARTITION_EVALUATOR.key -> enableEvaluator.toString) { + Seq(-1, 100).foreach { threshold => + withSQLConf(SQLConf.WINDOW_GROUP_LIMIT_THRESHOLD.key -> threshold.toString) { + Seq($"rn" === 0, $"rn" < 1, $"rn" <= 0).foreach { condition => + checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), + Seq.empty[Row] + ) + } + + Seq($"rn" === 1, $"rn" < 2, $"rn" <= 1).foreach { condition => + checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), + Seq( + Row("a", 4, "", 2.0, 1), + Row("b", 1, "h", Double.NaN, 1), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", rank().over(window)).where(condition), + Seq( + Row("a", 4, "", 2.0, 1), + Row("a", 4, "", 2.0, 1), + Row("b", 1, "h", Double.NaN, 1), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), + Seq( + Row("a", 4, "", 2.0, 1), + Row("a", 4, "", 2.0, 1), + Row("b", 1, "h", Double.NaN, 1), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", row_number().over(window3)).where(condition), + Seq( + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", rank().over(window3)).where(condition), + Seq( + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", dense_rank().over(window3)).where(condition), + Seq( + Row("c", 2, null, 5.0, 1) + ) + ) + } + + Seq($"rn" < 3, $"rn" <= 2).foreach { condition => + checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), + Seq( + Row("a", 4, "", 2.0, 1), + Row("a", 4, "", 2.0, 2), + Row("b", 1, "h", Double.NaN, 1), + Row("b", 1, "n", Double.PositiveInfinity, 2), + Row("c", 1, "a", -4.0, 2), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", rank().over(window)).where(condition), + Seq( + Row("a", 4, "", 2.0, 1), + Row("a", 4, "", 2.0, 1), + Row("b", 1, "h", Double.NaN, 1), + Row("b", 1, "n", Double.PositiveInfinity, 2), + Row("c", 1, "a", -4.0, 2), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), + Seq( + Row("a", 0, "c", 1.0, 2), + Row("a", 4, "", 2.0, 1), + Row("a", 4, "", 2.0, 1), + Row("b", 1, "h", Double.NaN, 1), + Row("b", 1, "n", Double.PositiveInfinity, 2), + Row("c", 1, "a", -4.0, 2), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", row_number().over(window3)).where(condition), + Seq( + Row("a", 4, "", 2.0, 2), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", rank().over(window3)).where(condition), + Seq( + Row("a", 4, "", 2.0, 2), + Row("a", 4, "", 2.0, 2), + Row("c", 2, null, 5.0, 1) + ) + ) + + checkAnswer(df.withColumn("rn", dense_rank().over(window3)).where(condition), + Seq( + Row("a", 4, "", 2.0, 2), + Row("a", 4, "", 2.0, 2), + Row("c", 2, null, 5.0, 1) + ) + ) + } + + val condition = $"rn" === 2 && $"value2" > 0.5 + checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), + Seq( + Row("a", 4, "", 2.0, 2), + Row("b", 1, "n", Double.PositiveInfinity, 2) + ) ) - ) - checkAnswer(df.withColumn("rn", rank().over(window3)).where(condition), - Seq( - Row("c", 2, null, 5.0, 1) + checkAnswer(df.withColumn("rn", rank().over(window)).where(condition), + Seq( + Row("b", 1, "n", Double.PositiveInfinity, 2) + ) ) - ) - checkAnswer(df.withColumn("rn", dense_rank().over(window3)).where(condition), - Seq( - Row("c", 2, null, 5.0, 1) + checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), + Seq( + Row("a", 0, "c", 1.0, 2), + Row("b", 1, "n", Double.PositiveInfinity, 2) + ) ) - ) - } - Seq($"rn" < 3, $"rn" <= 2).foreach { condition => - checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), - Seq( - Row("a", 4, "", 2.0, 1), - Row("a", 4, "", 2.0, 2), - Row("b", 1, "h", Double.NaN, 1), - Row("b", 1, "n", Double.PositiveInfinity, 2), - Row("c", 1, "a", -4.0, 2), - Row("c", 2, null, 5.0, 1) - ) - ) - - checkAnswer(df.withColumn("rn", rank().over(window)).where(condition), - Seq( - Row("a", 4, "", 2.0, 1), - Row("a", 4, "", 2.0, 1), - Row("b", 1, "h", Double.NaN, 1), - Row("b", 1, "n", Double.PositiveInfinity, 2), - Row("c", 1, "a", -4.0, 2), - Row("c", 2, null, 5.0, 1) + val multipleRowNumbers = df + .withColumn("rn", row_number().over(window)) + .withColumn("rn2", row_number().over(window)) + .where('rn < 2 && 'rn2 < 3) + checkAnswer(multipleRowNumbers, + Seq( + Row("a", 4, "", 2.0, 1, 1), + Row("b", 1, "h", Double.NaN, 1, 1), + Row("c", 2, null, 5.0, 1, 1) + ) ) - ) - - checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), - Seq( - Row("a", 0, "c", 1.0, 2), - Row("a", 4, "", 2.0, 1), - Row("a", 4, "", 2.0, 1), - Row("b", 1, "h", Double.NaN, 1), - Row("b", 1, "n", Double.PositiveInfinity, 2), - Row("c", 1, "a", -4.0, 2), - Row("c", 2, null, 5.0, 1) - ) - ) - checkAnswer(df.withColumn("rn", row_number().over(window3)).where(condition), - Seq( - Row("a", 4, "", 2.0, 2), - Row("c", 2, null, 5.0, 1) + val multipleRanks = df + .withColumn("rn", rank().over(window)) + .withColumn("rn2", rank().over(window)) + .where('rn < 2 && 'rn2 < 3) + checkAnswer(multipleRanks, + Seq( + Row("a", 4, "", 2.0, 1, 1), + Row("a", 4, "", 2.0, 1, 1), + Row("b", 1, "h", Double.NaN, 1, 1), + Row("c", 2, null, 5.0, 1, 1) + ) ) - ) - checkAnswer(df.withColumn("rn", rank().over(window3)).where(condition), - Seq( - Row("a", 4, "", 2.0, 2), - Row("a", 4, "", 2.0, 2), - Row("c", 2, null, 5.0, 1) + val multipleDenseRanks = df + .withColumn("rn", dense_rank().over(window)) + .withColumn("rn2", dense_rank().over(window)) + .where('rn < 2 && 'rn2 < 3) + checkAnswer(multipleDenseRanks, + Seq( + Row("a", 4, "", 2.0, 1, 1), + Row("a", 4, "", 2.0, 1, 1), + Row("b", 1, "h", Double.NaN, 1, 1), + Row("c", 2, null, 5.0, 1, 1) + ) ) - ) - checkAnswer(df.withColumn("rn", dense_rank().over(window3)).where(condition), - Seq( - Row("a", 4, "", 2.0, 2), - Row("a", 4, "", 2.0, 2), - Row("c", 2, null, 5.0, 1) + val multipleWindows = df + .withColumn("rn2", row_number().over(window2)) + .withColumn("rn", row_number().over(window)) + .where('rn < 2 && 'rn2 < 3) + checkAnswer(multipleWindows, + Seq( + Row("b", 1, "h", Double.NaN, 2, 1), + Row("c", 2, null, 5.0, 1, 1) + ) ) - ) + } } - - val condition = $"rn" === 2 && $"value2" > 0.5 - checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), - Seq( - Row("a", 4, "", 2.0, 2), - Row("b", 1, "n", Double.PositiveInfinity, 2) - ) - ) - - checkAnswer(df.withColumn("rn", rank().over(window)).where(condition), - Seq( - Row("b", 1, "n", Double.PositiveInfinity, 2) - ) - ) - - checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), - Seq( - Row("a", 0, "c", 1.0, 2), - Row("b", 1, "n", Double.PositiveInfinity, 2) - ) - ) - - val multipleRowNumbers = df - .withColumn("rn", row_number().over(window)) - .withColumn("rn2", row_number().over(window)) - .where('rn < 2 && 'rn2 < 3) - checkAnswer(multipleRowNumbers, - Seq( - Row("a", 4, "", 2.0, 1, 1), - Row("b", 1, "h", Double.NaN, 1, 1), - Row("c", 2, null, 5.0, 1, 1) - ) - ) - - val multipleRanks = df - .withColumn("rn", rank().over(window)) - .withColumn("rn2", rank().over(window)) - .where('rn < 2 && 'rn2 < 3) - checkAnswer(multipleRanks, - Seq( - Row("a", 4, "", 2.0, 1, 1), - Row("a", 4, "", 2.0, 1, 1), - Row("b", 1, "h", Double.NaN, 1, 1), - Row("c", 2, null, 5.0, 1, 1) - ) - ) - - val multipleDenseRanks = df - .withColumn("rn", dense_rank().over(window)) - .withColumn("rn2", dense_rank().over(window)) - .where('rn < 2 && 'rn2 < 3) - checkAnswer(multipleDenseRanks, - Seq( - Row("a", 4, "", 2.0, 1, 1), - Row("a", 4, "", 2.0, 1, 1), - Row("b", 1, "h", Double.NaN, 1, 1), - Row("c", 2, null, 5.0, 1, 1) - ) - ) - - val multipleWindows = df - .withColumn("rn2", row_number().over(window2)) - .withColumn("rn", row_number().over(window)) - .where('rn < 2 && 'rn2 < 3) - checkAnswer(multipleWindows, - Seq( - Row("b", 1, "h", Double.NaN, 2, 1), - Row("c", 2, null, 5.0, 1, 1) - ) - ) } } } From a8598329be2d25bf32191ab6dbed36fc45d018b9 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 8 Jul 2023 15:41:31 +0800 Subject: [PATCH 2/3] Update code --- .../WindowGroupLimitEvaluatorFactory.scala | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala new file mode 100644 index 0000000000000..6bea52e733787 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, DenseRank, Expression, Rank, RowNumber, SortOrder} +import org.apache.spark.sql.execution.metric.SQLMetric + +class WindowGroupLimitEvaluatorFactory( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + childOutput: Seq[Attribute], + numOutputRows: SQLMetric) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] { + + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { + rankLikeFunction match { + case _: RowNumber if partitionSpec.isEmpty => + new WindowGroupLimitPartitionEvaluator( + input => SimpleLimitIterator(input, limit, numOutputRows)) + case _: RowNumber => + new WindowGroupLimitPartitionEvaluator( + input => new GroupedLimitIterator(input, childOutput, partitionSpec, + (input: Iterator[InternalRow]) => SimpleLimitIterator(input, limit, numOutputRows))) + case _: Rank if partitionSpec.isEmpty => + new WindowGroupLimitPartitionEvaluator( + input => RankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows)) + case _: Rank => + new WindowGroupLimitPartitionEvaluator( + input => new GroupedLimitIterator(input, childOutput, partitionSpec, + (input: Iterator[InternalRow]) => + RankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows))) + case _: DenseRank if partitionSpec.isEmpty => + new WindowGroupLimitPartitionEvaluator( + input => DenseRankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows)) + case _: DenseRank => + new WindowGroupLimitPartitionEvaluator( + input => new GroupedLimitIterator(input, childOutput, partitionSpec, + (input: Iterator[InternalRow]) => + DenseRankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows))) + } + + } + + class WindowGroupLimitPartitionEvaluator(f: Iterator[InternalRow] => Iterator[InternalRow]) + extends PartitionEvaluator[InternalRow, InternalRow] { + + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + f(inputs.head) + } + } +} From ebc2bdd2324a90a8b19664e440b19e9e70b65c91 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 11 Jul 2023 11:33:34 +0800 Subject: [PATCH 3/3] Update code --- .../WindowGroupLimitEvaluatorFactory.scala | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala index 6bea52e733787..6777f6ae7ac66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitEvaluatorFactory.scala @@ -32,32 +32,23 @@ class WindowGroupLimitEvaluatorFactory( extends PartitionEvaluatorFactory[InternalRow, InternalRow] { override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { - rankLikeFunction match { - case _: RowNumber if partitionSpec.isEmpty => - new WindowGroupLimitPartitionEvaluator( - input => SimpleLimitIterator(input, limit, numOutputRows)) + val limitFunc = rankLikeFunction match { case _: RowNumber => - new WindowGroupLimitPartitionEvaluator( - input => new GroupedLimitIterator(input, childOutput, partitionSpec, - (input: Iterator[InternalRow]) => SimpleLimitIterator(input, limit, numOutputRows))) - case _: Rank if partitionSpec.isEmpty => - new WindowGroupLimitPartitionEvaluator( - input => RankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows)) + (iter: Iterator[InternalRow]) => SimpleLimitIterator(iter, limit, numOutputRows) case _: Rank => - new WindowGroupLimitPartitionEvaluator( - input => new GroupedLimitIterator(input, childOutput, partitionSpec, - (input: Iterator[InternalRow]) => - RankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows))) - case _: DenseRank if partitionSpec.isEmpty => - new WindowGroupLimitPartitionEvaluator( - input => DenseRankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows)) + (iter: Iterator[InternalRow]) => + RankLimitIterator(childOutput, iter, orderSpec, limit, numOutputRows) case _: DenseRank => - new WindowGroupLimitPartitionEvaluator( - input => new GroupedLimitIterator(input, childOutput, partitionSpec, - (input: Iterator[InternalRow]) => - DenseRankLimitIterator(childOutput, input, orderSpec, limit, numOutputRows))) + (iter: Iterator[InternalRow]) => + DenseRankLimitIterator(childOutput, iter, orderSpec, limit, numOutputRows) } + if (partitionSpec.isEmpty) { + new WindowGroupLimitPartitionEvaluator(limitFunc) + } else { + new WindowGroupLimitPartitionEvaluator( + input => new GroupedLimitIterator(input, childOutput, partitionSpec, limitFunc)) + } } class WindowGroupLimitPartitionEvaluator(f: Iterator[InternalRow] => Iterator[InternalRow])