From 5c23822e9ae72dd106a199b78409366bbe52e4d2 Mon Sep 17 00:00:00 2001 From: Aleksandar Tomic Date: Wed, 20 Mar 2024 21:16:51 +0500 Subject: [PATCH] [SPARK-47443][SQL] Window Aggregate support for collations ### What changes were proposed in this pull request? This PR introduces support for Window Aggregates when partitioning is done against expressions with non-binary collation. The approach is same as for regular aggregates. Instead of doing byte-for-byte comparison against `UnsafeRow` we fall back to interpreted mode if there is a data type in grouping expressions that doesn't satisfy `isBinaryStable` constraint. ### Why are the changes needed? Previous implementation returned invalid results. ### Does this PR introduce _any_ user-facing change? yes - fixes incorrect behavior. ### How was this patch tested? New test is added in `CollationSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45568 from dbatomic/win_agg_support_for_collations. Authored-by: Aleksandar Tomic Signed-off-by: Max Gekk --- .../window/WindowEvaluatorFactory.scala | 13 ++++++++++-- .../org/apache/spark/sql/CollationSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala index fb4ea7f35c0db..9ff056a279466 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.window import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, NamedExpression, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InterpretedOrdering, JoinedRow, NamedExpression, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf @@ -54,6 +55,14 @@ class WindowEvaluatorFactory( // Get all relevant projections. val result = createResultProjection(expressions) val grouping = UnsafeProjection.create(partitionSpec, childOutput) + val groupEqualityCheck = + if (partitionSpec.forall(e => UnsafeRowUtils.isBinaryStable(e.dataType))) { + (key1: UnsafeRow, key2: UnsafeRow) => key1.equals(key2) + } else { + val types = partitionSpec.map(_.dataType) + val ordering = InterpretedOrdering.forSchema(types) + (key1: UnsafeRow, key2: UnsafeRow) => ordering.compare(key1, key2) == 0 + } // Manage the stream and the grouping. var nextRow: UnsafeRow = null @@ -88,7 +97,7 @@ class WindowEvaluatorFactory( // clear last partition buffer.clear() - while (nextRowAvailable && nextGroup == currentGroup) { + while (nextRowAvailable && groupEqualityCheck(nextGroup, currentGroup)) { buffer.add(nextRow) fetchNextRow() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index ee2b34706e0ba..efb3c2f8ba8e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -639,4 +639,24 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "expressionStr" -> "UCASE(struct1.a)", "reason" -> "generation expression cannot contain non-default collated string type")) } + + test("window aggregates should respect collation") { + val t1 = "T_NON_BINARY" + val t2 = "T_BINARY" + + withTable(t1, t2) { + sql(s"CREATE TABLE $t1 (c STRING COLLATE UTF8_BINARY_LCASE, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES ('aA', 2), ('Aa', 1), ('ab', 3), ('aa', 1)") + + sql(s"CREATE TABLE $t2 (c STRING, i int) USING PARQUET") + // Same input but already normalized to lowercase. + sql(s"INSERT INTO $t2 VALUES ('aa', 2), ('aa', 1), ('ab', 3), ('aa', 1)") + + val dfNonBinary = + sql(s"SELECT lower(c), i, nth_value(i, 2) OVER (PARTITION BY c ORDER BY i) FROM $t1") + val dfBinary = + sql(s"SELECT c, i, nth_value(i, 2) OVER (PARTITION BY c ORDER BY i) FROM $t2") + checkAnswer(dfNonBinary, dfBinary) + } + } }