Skip to content

Commit

Permalink
[SPARK-47483][SQL] Add support for aggregation and join operations on…
Browse files Browse the repository at this point in the history
… arrays of collated strings

### What changes were proposed in this pull request?

Example of aggregation sequence:
```
create table t(a array<string collate utf8_binary_lcase>) using parquet;

insert into t(a) values(array('a' collate utf8_binary_lcase));
insert into t(a) values(array('A' collate utf8_binary_lcase));

select distinct a from t;
```
Example of join sequence:
```
create table l(a array<string collate utf8_binary_lcase>) using parquet;
create table r(a array<string collate utf8_binary_lcase>) using parquet;

insert into l(a) values(array('a' collate utf8_binary_lcase));
insert into r(a) values(array('A' collate utf8_binary_lcase));

select * from l join r where l.a = r.a;
```
Both runs should yield one row since the arrays are considered equal.

Problem is in `isBinaryStable` function which should return false if **any** of its subtypes is non-binary collated string.

### Why are the changes needed?

To support aggregates and joins in arrays of collated strings properly.

### Does this PR introduce _any_ user-facing change?

Yes, it fixes the described scenarios.

### How was this patch tested?

Added new checks to collation suite.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45611 from nikolamand-db/SPARK-47483.

Authored-by: Nikola Mandic <nikola.mandic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
nikolamand-db authored and sweisdb committed Apr 1, 2024
1 parent 6174ae9 commit 9f8dc91
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ object UnsafeRowUtils {
* e.g. this is not true for non-binary collations (any case/accent insensitive collation
* can lead to rows being semantically equal even though their binary representations differ).
*/
def isBinaryStable(dataType: DataType): Boolean = dataType.existsRecursively {
case st: StringType => CollationFactory.fetchCollation(st.collationId).isBinaryCollation
case _ => true
def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively {
case st: StringType => !CollationFactory.fetchCollation(st.collationId).isBinaryCollation
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, IntegerType, MapType, StringType, StructField, StructType}

class UnsafeRowUtilsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -91,4 +91,70 @@ class UnsafeRowUtilsSuite extends SparkFunSuite {
"fieldStatus:\n" +
"[UnsafeRowFieldStatus] index: 0, expectedFieldType: IntegerType,"))
}

test("isBinaryStable on complex types containing collated strings") {
val nonBinaryStringType = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))

// simple checks
assert(UnsafeRowUtils.isBinaryStable(IntegerType))
assert(UnsafeRowUtils.isBinaryStable(StringType))
assert(!UnsafeRowUtils.isBinaryStable(nonBinaryStringType))

assert(UnsafeRowUtils.isBinaryStable(ArrayType(IntegerType)))
assert(UnsafeRowUtils.isBinaryStable(ArrayType(StringType)))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(nonBinaryStringType)))

assert(UnsafeRowUtils.isBinaryStable(MapType(StringType, StringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, StringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(StringType, nonBinaryStringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, nonBinaryStringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(IntegerType, nonBinaryStringType)))

assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", IntegerType) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", StringType) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", nonBinaryStringType) :: Nil)))

// nested complex types
assert(UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(StringType))))
assert(UnsafeRowUtils.isBinaryStable(ArrayType(MapType(StringType, IntegerType))))
assert(UnsafeRowUtils.isBinaryStable(
ArrayType(StructType(StructField("field", StringType) :: Nil))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(MapType(IntegerType, nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(
ArrayType(MapType(IntegerType, ArrayType(nonBinaryStringType)))))
assert(!UnsafeRowUtils.isBinaryStable(
ArrayType(StructType(StructField("field", nonBinaryStringType) :: Nil))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(StructType(
Seq(StructField("second", IntegerType), StructField("second", nonBinaryStringType))))))

assert(UnsafeRowUtils.isBinaryStable(MapType(ArrayType(StringType), ArrayType(IntegerType))))
assert(UnsafeRowUtils.isBinaryStable(MapType(MapType(StringType, StringType), IntegerType)))
assert(UnsafeRowUtils.isBinaryStable(
MapType(StructType(StructField("field", StringType) :: Nil), IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(ArrayType(nonBinaryStringType), ArrayType(IntegerType))))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(IntegerType, ArrayType(nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(MapType(IntegerType, nonBinaryStringType), IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(StructType(StructField("field", nonBinaryStringType) :: Nil), IntegerType)))

assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", ArrayType(IntegerType)) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", MapType(StringType, IntegerType)) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", StructType(StructField("sub", IntegerType) :: Nil)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", ArrayType(nonBinaryStringType)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", MapType(nonBinaryStringType, IntegerType)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field",
StructType(StructField("sub", nonBinaryStringType) :: Nil)) :: Nil)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCu
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}

class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
Expand Down Expand Up @@ -640,6 +641,90 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"reason" -> "generation expression cannot contain non-default collated string type"))
}

test("Aggregation on complex containing collated strings") {
val table = "table_agg"
// array
withTable(table) {
sql(s"create table $table (a array<string collate utf8_binary_lcase>) using parquet")
sql(s"insert into $table values (array('aaa')), (array('AAA'))")
checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa"))))
}
// map doesn't support aggregation
withTable(table) {
sql(s"create table $table (m map<string collate utf8_binary_lcase, string>) using parquet")
val query = s"select distinct m from $table"
checkError(
exception = intercept[ExtendedAnalysisException](sql(query)),
errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
parameters = Map(
"colName" -> "`m`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType))),
context = ExpectedContext(query, 0, query.length - 1)
)
}
// struct
withTable(table) {
sql(s"create table $table (s struct<fld:string collate utf8_binary_lcase>) using parquet")
sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))")
checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa")))
}
}

test("Joins on complex types containing collated strings") {
val tableLeft = "table_join_le"
val tableRight = "table_join_ri"
// array
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (a array<string collate utf8_binary_lcase>) using parquet"))
Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{
case (tab, data) => sql(s"insert into $tab values ($data)")
}
checkAnswer(sql(
s"""
|select $tableLeft.a from $tableLeft
|join $tableRight on $tableLeft.a = $tableRight.a
|""".stripMargin), Seq(Row(Seq("aaa"))))
}
// map doesn't support joins
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (m map<string collate utf8_binary_lcase, string>) using parquet"))
val query =
s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m = $tableRight.m"
val ctx = s"$tableLeft.m = $tableRight.m"
checkError(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
parameters = Map(
"functionName" -> "`=`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType
)),
"sqlExpr" -> "\"(m = m)\""),
context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1))
}
// struct
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (s struct<fld:string collate utf8_binary_lcase>) using parquet"))
Seq(
(tableLeft, "named_struct('fld', 'aaa')"),
(tableRight, "named_struct('fld', 'AAA')")
).map {
case (tab, data) => sql(s"insert into $tab values ($data)")
}
checkAnswer(sql(
s"""
|select $tableLeft.s.fld from $tableLeft
|join $tableRight on $tableLeft.s = $tableRight.s
|""".stripMargin), Seq(Row("aaa")))
}
}

test("window aggregates should respect collation") {
val t1 = "T_NON_BINARY"
val t2 = "T_BINARY"
Expand Down

0 comments on commit 9f8dc91

Please sign in to comment.