Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings #45611

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ object UnsafeRowUtils {
* e.g. this is not true for non-binary collations (any case/accent insensitive collation
* can lead to rows being semantically equal even though their binary representations differ).
*/
def isBinaryStable(dataType: DataType): Boolean = dataType.existsRecursively {
case st: StringType => CollationFactory.fetchCollation(st.collationId).isBinaryCollation
case _ => true
def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively {
Copy link
Member

@MaxGekk MaxGekk Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isBinaryStable is in UnsafeRowUtils. Is the implementation bound somehow to unsafe row?

Why it is not in DataTypeUtils or Collation..., for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what are the consequences of moving this function, do you know if we can do that @dbatomic?

case st: StringType => !CollationFactory.fetchCollation(st.collationId).isBinaryCollation
case _ => false
}
}
191 changes: 191 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
Expand Down Expand Up @@ -640,6 +641,196 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"reason" -> "generation expression cannot contain non-default collated string type"))
}

trait ArrayCheck {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you put the new code to the end of the test suite?

def dataType: String
def dataTypeCollated: String
}

trait ArrayCheckSimple extends ArrayCheck {
override def dataType: String = s"array<string>"
nikolamand-db marked this conversation as resolved.
Show resolved Hide resolved
override def dataTypeCollated: String = s"array<string collate utf8_binary_lcase>"
}

trait ArrayCheckNested extends ArrayCheck {
override def dataType: String = s"array<array<string>>"
override def dataTypeCollated: String = s"array<array<string collate utf8_binary_lcase>>"
}

test("Aggregation of arrays built on collated strings") {
abstract class AggCheck(val rows: Seq[String], val result: Seq[(Any, Int)]) extends ArrayCheck

case class AggCheckSimple(override val rows: Seq[String],
override val result: Seq[(Seq[String], Int)])
extends AggCheck(rows, result) with ArrayCheckSimple

case class AggCheckNested(override val rows: Seq[String],
override val result: Seq[(Seq[Seq[String]], Int)])
extends AggCheck(rows, result) with ArrayCheckNested

val tableName = "test_agg_arr_collated"
val tableNameLowercase = "test_agg_arr_collated_lowercase"

Seq(
// simple
AggCheckSimple(
rows = Seq("array('aaa')", "array('AAA')"),
result = Seq((Seq("aaa"), 2))
),
AggCheckSimple(
rows = Seq("array('aaa', 'bbb')", "array('AAA', 'BBB')"),
result = Seq((Seq("aaa", "bbb"), 2))
),
AggCheckSimple(
rows = Seq("array('aaa')", "array('bbb')", "array('AAA')", "array('BBB')"),
result = Seq((Seq("aaa"), 2), (Seq("bbb"), 2))
),
// nested
AggCheckNested(
rows = Seq("array(array('aaa'))", "array(array('AAA'))"),
result = Seq((Seq(Seq("aaa")), 2))
),
AggCheckNested(
rows = Seq("array(array('aaa'), array('bbb'))", "array(array('AAA'), array('bbb'))"),
result = Seq((Seq(Seq("aaa"), Seq("bbb")), 2))
),
AggCheckNested(
rows = Seq(
"array(array('aaa', 'aaa'), array('bbb', 'ccc'))",
"array(array('aaa', 'aaa'), array('bbb', 'ccc'), array('ddd'))",
"array(array('AAA', 'AAA'), array('BBB', 'CCC'))"
),
result = Seq(
(Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc")), 2),
(Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc"), Seq("ddd")), 1)
)
)
).map((check: AggCheck) =>
withTable(tableName, tableNameLowercase) {
def checkResults(table: String): Unit = {
checkAnswer(sql(s"select a, count(*) from $table group by a"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bad practice that we try to test all cases in end-to-end tests. We should improve the test coverage in the unit test.

I suggest we add tests in UnsafeRowUtilsSuite to test the isBinaryStable function with different cases: array of string, array of array of string, struct of array of string, etc. The end-to-end test here should just run a few queries to show it works.

check.result.map{ case (agg, cnt) => Row(agg, cnt) })
checkAnswer(sql(s"select distinct a from $table"),
check.result.map{ case (agg, _) => Row(agg) })
}

// check against non-binary collation
sql(s"create table $tableName(a ${check.dataTypeCollated}) using parquet")
check.rows.map(row => sql(s"insert into $tableName(a) values($row)"))
checkResults(tableName)

// binary collation with values converted to lowercase should match the results as well
sql(s"create table $tableNameLowercase(a ${check.dataType}) using parquet")
check.rows.map(row =>
// scalastyle:off caselocale
sql(
s"""
|insert into $tableNameLowercase(a)
|values(${UTF8String.fromString(row).toLowerCase})
|""".stripMargin)
// scalastyle:on caselocale
)
checkResults(tableNameLowercase)
}
)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test pattern that can be useful here is what we did for Window Aggs.
In short, just create a query that targets mixed-case data with LCASE collation (e.g. "aA", "aa", "AA") and query that targets normalized data with UTF8_BINARY ("aa", "aa", "aa"). Aggs and Joins should return the same result.

You can find test example here:
#45568

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated tests to support this and added case classes for better readability. Please check again.


test("Join on arrays of collated strings") {
abstract class JoinCheck(val leftRows: Seq[String],
val rightRows: Seq[String],
val resultRows: Seq[Any])
extends ArrayCheck
nikolamand-db marked this conversation as resolved.
Show resolved Hide resolved

case class JoinSimpleCheck(override val leftRows: Seq[String],
override val rightRows: Seq[String],
override val resultRows: Seq[Seq[String]])
extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckSimple

case class JoinNestedCheck(override val leftRows: Seq[String],
override val rightRows: Seq[String],
override val resultRows: Seq[Seq[Seq[String]]])
extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckNested

val tablePrefix = "test_join_arr_collated"
val tableLeft = s"${tablePrefix}_left"
val tableLeftLowercase = s"${tableLeft}_lowercase"
val tableRight = s"${tablePrefix}_right"
val tableRightLowercase = s"${tableRight}_lowercase"

Seq(
// simple
JoinSimpleCheck(
leftRows = Seq("array('aaa')"),
rightRows = Seq("array('AAA')"),
resultRows = Seq(Seq("aaa"))
),
JoinSimpleCheck(
leftRows = Seq("array('aaa', 'bbb')"),
rightRows = Seq("array('AAA', 'BBB')"),
resultRows = Seq(Seq("aaa", "bbb"))
),
JoinSimpleCheck(
leftRows = Seq("array('aaa')", "array('bbb')"),
rightRows = Seq("array('AAA')", "array('BBB')"),
resultRows = Seq(Seq("aaa"), Seq("bbb"))
),
JoinSimpleCheck(
leftRows = Seq("array('aaa')", "array('bbb')"),
rightRows = Seq("array('AAAA')", "array('BBBB')"),
resultRows = Seq()
),
// nested
JoinNestedCheck(
leftRows = Seq("array(array('aaa'))"),
rightRows = Seq("array(array('AAA'))"),
resultRows = Seq(Seq(Seq("aaa")))
),
JoinNestedCheck(
leftRows = Seq("array(array('aaa', 'bbb'))"),
rightRows = Seq("array(array('AAA', 'BBB'))"),
resultRows = Seq(Seq(Seq("aaa", "bbb")))
),
JoinNestedCheck(
Seq("array(array('aaa'), array('bbb'))"),
Seq("array(array('AAA'), array('BBB'))"),
Seq(Seq(Seq("aaa"), Seq("bbb")))
),
JoinNestedCheck(
Seq("array(array('aaa'), array('bbb'))"),
Seq("array(array('AAA'), array('CCC'))"),
Seq()
)
).map((check: JoinCheck) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use .foreach here. map is usually used to build new collection which is not need in this case.

withTable(tableLeft, tableLeftLowercase, tableRight, tableRightLowercase) {
def checkResults(left: String, right: String): Unit = {
checkAnswer(
sql(s"select $left.a from $left join $right where $left.a = $right.a"),
check.resultRows.map(Row(_))
)
}

// check against non-binary collation
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab(a ${check.dataTypeCollated}) using parquet"))
Seq((tableLeft, check.leftRows), (tableRight, check.rightRows)).foreach {
case (tab, data) => data.map(row => sql(s"insert into $tab(a) values($row)"))
}
checkResults(tableLeft, tableRight)

// binary collation with values converted to lowercase should match the results as well
Seq(tableLeftLowercase, tableRightLowercase).map(tab =>
sql(s"create table $tab(a ${check.dataType}) using parquet"))
Seq((tableLeftLowercase, check.leftRows), (tableRightLowercase, check.rightRows)).foreach {
case (tab, data) =>
// scalastyle:off caselocale
data.map(row =>
sql(s"insert into $tab(a) values(${UTF8String.fromString(row).toLowerCase})"))
// scalastyle:on caselocale
}
checkResults(tableLeft, tableRight)
}
)
}

test("window aggregates should respect collation") {
val t1 = "T_NON_BINARY"
val t2 = "T_BINARY"
Expand Down