From 9f8dc91017eae7a2b01007d9cef274907dc2ac5c Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Fri, 22 Mar 2024 12:15:05 +0500 Subject: [PATCH] [SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings ### What changes were proposed in this pull request? Example of aggregation sequence: ``` create table t(a array) using parquet; insert into t(a) values(array('a' collate utf8_binary_lcase)); insert into t(a) values(array('A' collate utf8_binary_lcase)); select distinct a from t; ``` Example of join sequence: ``` create table l(a array) using parquet; create table r(a array) using parquet; insert into l(a) values(array('a' collate utf8_binary_lcase)); insert into r(a) values(array('A' collate utf8_binary_lcase)); select * from l join r where l.a = r.a; ``` Both runs should yield one row since the arrays are considered equal. Problem is in `isBinaryStable` function which should return false if **any** of its subtypes is non-binary collated string. ### Why are the changes needed? To support aggregates and joins in arrays of collated strings properly. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the described scenarios. ### How was this patch tested? Added new checks to collation suite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45611 from nikolamand-db/SPARK-47483. Authored-by: Nikola Mandic Signed-off-by: Max Gekk --- .../sql/catalyst/util/UnsafeRowUtils.scala | 6 +- .../catalyst/util/UnsafeRowUtilsSuite.scala | 68 ++++++++++++++- .../org/apache/spark/sql/CollationSuite.scala | 87 ++++++++++++++++++- 3 files changed, 156 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 0718cf110f75e..0c1ce5ffa8b0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -204,8 +204,8 @@ object UnsafeRowUtils { * e.g. this is not true for non-binary collations (any case/accent insensitive collation * can lead to rows being semantically equal even though their binary representations differ). */ - def isBinaryStable(dataType: DataType): Boolean = dataType.existsRecursively { - case st: StringType => CollationFactory.fetchCollation(st.collationId).isBinaryCollation - case _ => true + def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { + case st: StringType => !CollationFactory.fetchCollation(st.collationId).isBinaryCollation + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala index c7a8bc74f4ddc..b6e87c456de0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, IntegerType, MapType, StringType, StructField, StructType} class UnsafeRowUtilsSuite extends SparkFunSuite { @@ -91,4 +91,70 @@ class UnsafeRowUtilsSuite extends SparkFunSuite { "fieldStatus:\n" + "[UnsafeRowFieldStatus] index: 0, expectedFieldType: IntegerType,")) } + + test("isBinaryStable on complex types containing collated strings") { + val nonBinaryStringType = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")) + + // simple checks + assert(UnsafeRowUtils.isBinaryStable(IntegerType)) + assert(UnsafeRowUtils.isBinaryStable(StringType)) + assert(!UnsafeRowUtils.isBinaryStable(nonBinaryStringType)) + + assert(UnsafeRowUtils.isBinaryStable(ArrayType(IntegerType))) + assert(UnsafeRowUtils.isBinaryStable(ArrayType(StringType))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(nonBinaryStringType))) + + assert(UnsafeRowUtils.isBinaryStable(MapType(StringType, StringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, StringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(StringType, nonBinaryStringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, nonBinaryStringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, IntegerType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(IntegerType, nonBinaryStringType))) + + assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", IntegerType) :: Nil))) + assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", StringType) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", nonBinaryStringType) :: Nil))) + + // nested complex types + assert(UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(StringType)))) + assert(UnsafeRowUtils.isBinaryStable(ArrayType(MapType(StringType, IntegerType)))) + assert(UnsafeRowUtils.isBinaryStable( + ArrayType(StructType(StructField("field", StringType) :: Nil)))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(nonBinaryStringType)))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(MapType(IntegerType, nonBinaryStringType)))) + assert(!UnsafeRowUtils.isBinaryStable( + ArrayType(MapType(IntegerType, ArrayType(nonBinaryStringType))))) + assert(!UnsafeRowUtils.isBinaryStable( + ArrayType(StructType(StructField("field", nonBinaryStringType) :: Nil)))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(StructType( + Seq(StructField("second", IntegerType), StructField("second", nonBinaryStringType)))))) + + assert(UnsafeRowUtils.isBinaryStable(MapType(ArrayType(StringType), ArrayType(IntegerType)))) + assert(UnsafeRowUtils.isBinaryStable(MapType(MapType(StringType, StringType), IntegerType))) + assert(UnsafeRowUtils.isBinaryStable( + MapType(StructType(StructField("field", StringType) :: Nil), IntegerType))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(ArrayType(nonBinaryStringType), ArrayType(IntegerType)))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(IntegerType, ArrayType(nonBinaryStringType)))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(MapType(IntegerType, nonBinaryStringType), IntegerType))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(StructType(StructField("field", nonBinaryStringType) :: Nil), IntegerType))) + + assert(UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", ArrayType(IntegerType)) :: Nil))) + assert(UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", MapType(StringType, IntegerType)) :: Nil))) + assert(UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", StructType(StructField("sub", IntegerType) :: Nil)) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", ArrayType(nonBinaryStringType)) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", MapType(nonBinaryStringType, IntegerType)) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", + StructType(StructField("sub", nonBinaryStringType) :: Nil)) :: Nil))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index efb3c2f8ba8e4..146ba63cf402a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -27,10 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCu import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName @@ -640,6 +641,90 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "reason" -> "generation expression cannot contain non-default collated string type")) } + test("Aggregation on complex containing collated strings") { + val table = "table_agg" + // array + withTable(table) { + sql(s"create table $table (a array) using parquet") + sql(s"insert into $table values (array('aaa')), (array('AAA'))") + checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) + } + // map doesn't support aggregation + withTable(table) { + sql(s"create table $table (m map) using parquet") + val query = s"select distinct m from $table" + checkError( + exception = intercept[ExtendedAnalysisException](sql(query)), + errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + parameters = Map( + "colName" -> "`m`", + "dataType" -> toSQLType(MapType( + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), + StringType))), + context = ExpectedContext(query, 0, query.length - 1) + ) + } + // struct + withTable(table) { + sql(s"create table $table (s struct) using parquet") + sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))") + checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa"))) + } + } + + test("Joins on complex types containing collated strings") { + val tableLeft = "table_join_le" + val tableRight = "table_join_ri" + // array + withTable(tableLeft, tableRight) { + Seq(tableLeft, tableRight).map(tab => + sql(s"create table $tab (a array) using parquet")) + Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{ + case (tab, data) => sql(s"insert into $tab values ($data)") + } + checkAnswer(sql( + s""" + |select $tableLeft.a from $tableLeft + |join $tableRight on $tableLeft.a = $tableRight.a + |""".stripMargin), Seq(Row(Seq("aaa")))) + } + // map doesn't support joins + withTable(tableLeft, tableRight) { + Seq(tableLeft, tableRight).map(tab => + sql(s"create table $tab (m map) using parquet")) + val query = + s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m = $tableRight.m" + val ctx = s"$tableLeft.m = $tableRight.m" + checkError( + exception = intercept[AnalysisException](sql(query)), + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + parameters = Map( + "functionName" -> "`=`", + "dataType" -> toSQLType(MapType( + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), + StringType + )), + "sqlExpr" -> "\"(m = m)\""), + context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1)) + } + // struct + withTable(tableLeft, tableRight) { + Seq(tableLeft, tableRight).map(tab => + sql(s"create table $tab (s struct) using parquet")) + Seq( + (tableLeft, "named_struct('fld', 'aaa')"), + (tableRight, "named_struct('fld', 'AAA')") + ).map { + case (tab, data) => sql(s"insert into $tab values ($data)") + } + checkAnswer(sql( + s""" + |select $tableLeft.s.fld from $tableLeft + |join $tableRight on $tableLeft.s = $tableRight.s + |""".stripMargin), Seq(Row("aaa"))) + } + } + test("window aggregates should respect collation") { val t1 = "T_NON_BINARY" val t2 = "T_BINARY"