Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings #45611

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ object UnsafeRowUtils {
* e.g. this is not true for non-binary collations (any case/accent insensitive collation
* can lead to rows being semantically equal even though their binary representations differ).
*/
def isBinaryStable(dataType: DataType): Boolean = dataType.existsRecursively {
case st: StringType => CollationFactory.fetchCollation(st.collationId).isBinaryCollation
case _ => true
def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively {
Copy link
Member

@MaxGekk MaxGekk Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isBinaryStable is in UnsafeRowUtils. Is the implementation bound somehow to unsafe row?

Why it is not in DataTypeUtils or Collation..., for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what are the consequences of moving this function, do you know if we can do that @dbatomic?

case st: StringType => !CollationFactory.fetchCollation(st.collationId).isBinaryCollation
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, IntegerType, MapType, StringType, StructField, StructType}

class UnsafeRowUtilsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -91,4 +91,70 @@ class UnsafeRowUtilsSuite extends SparkFunSuite {
"fieldStatus:\n" +
"[UnsafeRowFieldStatus] index: 0, expectedFieldType: IntegerType,"))
}

test("isBinaryStable on complex types containing collated strings") {
val nonBinaryStringType = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))

// simple checks
assert(UnsafeRowUtils.isBinaryStable(IntegerType))
assert(UnsafeRowUtils.isBinaryStable(StringType))
assert(!UnsafeRowUtils.isBinaryStable(nonBinaryStringType))

assert(UnsafeRowUtils.isBinaryStable(ArrayType(IntegerType)))
assert(UnsafeRowUtils.isBinaryStable(ArrayType(StringType)))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(nonBinaryStringType)))

assert(UnsafeRowUtils.isBinaryStable(MapType(StringType, StringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, StringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(StringType, nonBinaryStringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, nonBinaryStringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(IntegerType, nonBinaryStringType)))

assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", IntegerType) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", StringType) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", nonBinaryStringType) :: Nil)))

// nested complex types
assert(UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(StringType))))
assert(UnsafeRowUtils.isBinaryStable(ArrayType(MapType(StringType, IntegerType))))
assert(UnsafeRowUtils.isBinaryStable(
ArrayType(StructType(StructField("field", StringType) :: Nil))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(MapType(IntegerType, nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(
ArrayType(MapType(IntegerType, ArrayType(nonBinaryStringType)))))
assert(!UnsafeRowUtils.isBinaryStable(
ArrayType(StructType(StructField("field", nonBinaryStringType) :: Nil))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(StructType(
Seq(StructField("second", IntegerType), StructField("second", nonBinaryStringType))))))

assert(UnsafeRowUtils.isBinaryStable(MapType(ArrayType(StringType), ArrayType(IntegerType))))
assert(UnsafeRowUtils.isBinaryStable(MapType(MapType(StringType, StringType), IntegerType)))
assert(UnsafeRowUtils.isBinaryStable(
MapType(StructType(StructField("field", StringType) :: Nil), IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(ArrayType(nonBinaryStringType), ArrayType(IntegerType))))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(IntegerType, ArrayType(nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(MapType(IntegerType, nonBinaryStringType), IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(StructType(StructField("field", nonBinaryStringType) :: Nil), IntegerType)))

assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", ArrayType(IntegerType)) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", MapType(StringType, IntegerType)) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", StructType(StructField("sub", IntegerType) :: Nil)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", ArrayType(nonBinaryStringType)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", MapType(nonBinaryStringType, IntegerType)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field",
StructType(StructField("sub", nonBinaryStringType) :: Nil)) :: Nil)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCu
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}

class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
Expand Down Expand Up @@ -640,6 +641,90 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"reason" -> "generation expression cannot contain non-default collated string type"))
}

test("Aggregation on complex containing collated strings") {
val table = "table_agg"
// array
withTable(table) {
sql(s"create table $table (a array<string collate utf8_binary_lcase>) using parquet")
sql(s"insert into $table values (array('aaa')), (array('AAA'))")
checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa"))))
}
// map doesn't support aggregation
withTable(table) {
sql(s"create table $table (m map<string collate utf8_binary_lcase, string>) using parquet")
val query = s"select distinct m from $table"
checkError(
exception = intercept[ExtendedAnalysisException](sql(query)),
errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
parameters = Map(
"colName" -> "`m`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType))),
context = ExpectedContext(query, 0, query.length - 1)
)
}
// struct
withTable(table) {
sql(s"create table $table (s struct<fld:string collate utf8_binary_lcase>) using parquet")
sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))")
checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa")))
}
}

test("Joins on complex types containing collated strings") {
val tableLeft = "table_join_le"
val tableRight = "table_join_ri"
// array
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (a array<string collate utf8_binary_lcase>) using parquet"))
Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{
case (tab, data) => sql(s"insert into $tab values ($data)")
}
checkAnswer(sql(
s"""
|select $tableLeft.a from $tableLeft
|join $tableRight on $tableLeft.a = $tableRight.a
|""".stripMargin), Seq(Row(Seq("aaa"))))
}
// map doesn't support joins
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (m map<string collate utf8_binary_lcase, string>) using parquet"))
val query =
s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m = $tableRight.m"
val ctx = s"$tableLeft.m = $tableRight.m"
checkError(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
parameters = Map(
"functionName" -> "`=`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType
)),
"sqlExpr" -> "\"(m = m)\""),
context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1))
}
// struct
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (s struct<fld:string collate utf8_binary_lcase>) using parquet"))
Seq(
(tableLeft, "named_struct('fld', 'aaa')"),
(tableRight, "named_struct('fld', 'AAA')")
).map {
case (tab, data) => sql(s"insert into $tab values ($data)")
}
checkAnswer(sql(
s"""
|select $tableLeft.s.fld from $tableLeft
|join $tableRight on $tableLeft.s = $tableRight.s
|""".stripMargin), Seq(Row("aaa")))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test pattern that can be useful here is what we did for Window Aggs.
In short, just create a query that targets mixed-case data with LCASE collation (e.g. "aA", "aa", "AA") and query that targets normalized data with UTF8_BINARY ("aa", "aa", "aa"). Aggs and Joins should return the same result.

You can find test example here:
#45568

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated tests to support this and added case classes for better readability. Please check again.


test("window aggregates should respect collation") {
val t1 = "T_NON_BINARY"
val t2 = "T_BINARY"
Expand Down