-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings #45611
[SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings #45611
Changes from 7 commits
f69b093
ca7664e
f78557d
824b92a
78e8310
588cca1
5f59648
514d2cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} | ||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} | ||
import org.apache.spark.sql.types.{StringType, StructField, StructType} | ||
import org.apache.spark.unsafe.types.UTF8String | ||
|
||
class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { | ||
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName | ||
|
@@ -640,6 +641,201 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { | |
"reason" -> "generation expression cannot contain non-default collated string type")) | ||
} | ||
|
||
trait ArrayCheck { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't you put the new code to the end of the test suite? |
||
def dataType: String | ||
def dataTypeCollated: String | ||
} | ||
|
||
trait ArrayCheckSimple extends ArrayCheck { | ||
override def dataType: String = "array<string>" | ||
override def dataTypeCollated: String = "array<string collate utf8_binary_lcase>" | ||
} | ||
|
||
trait ArrayCheckNested extends ArrayCheck { | ||
override def dataType: String = "array<array<string>>" | ||
override def dataTypeCollated: String = "array<array<string collate utf8_binary_lcase>>" | ||
} | ||
|
||
test("Aggregation of arrays built on collated strings") { | ||
abstract class AggCheck(val rows: Seq[String], val result: Seq[(Any, Int)]) extends ArrayCheck | ||
|
||
case class AggCheckSimple( | ||
override val rows: Seq[String], | ||
override val result: Seq[(Seq[String], Int)]) | ||
extends AggCheck(rows, result) with ArrayCheckSimple | ||
|
||
case class AggCheckNested( | ||
override val rows: Seq[String], | ||
override val result: Seq[(Seq[Seq[String]], Int)]) | ||
extends AggCheck(rows, result) with ArrayCheckNested | ||
|
||
val tableName = "test_agg_arr_collated" | ||
val tableNameLowercase = "test_agg_arr_collated_lowercase" | ||
|
||
Seq( | ||
// simple | ||
AggCheckSimple( | ||
rows = Seq("array('aaa')", "array('AAA')"), | ||
result = Seq((Seq("aaa"), 2)) | ||
), | ||
AggCheckSimple( | ||
rows = Seq("array('aaa', 'bbb')", "array('AAA', 'BBB')"), | ||
result = Seq((Seq("aaa", "bbb"), 2)) | ||
), | ||
AggCheckSimple( | ||
rows = Seq("array('aaa')", "array('bbb')", "array('AAA')", "array('BBB')"), | ||
result = Seq((Seq("aaa"), 2), (Seq("bbb"), 2)) | ||
), | ||
// nested | ||
AggCheckNested( | ||
rows = Seq("array(array('aaa'))", "array(array('AAA'))"), | ||
result = Seq((Seq(Seq("aaa")), 2)) | ||
), | ||
AggCheckNested( | ||
rows = Seq("array(array('aaa'), array('bbb'))", "array(array('AAA'), array('bbb'))"), | ||
result = Seq((Seq(Seq("aaa"), Seq("bbb")), 2)) | ||
), | ||
AggCheckNested( | ||
rows = Seq( | ||
"array(array('aaa', 'aaa'), array('bbb', 'ccc'))", | ||
"array(array('aaa', 'aaa'), array('bbb', 'ccc'), array('ddd'))", | ||
"array(array('AAA', 'AAA'), array('BBB', 'CCC'))" | ||
), | ||
result = Seq( | ||
(Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc")), 2), | ||
(Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc"), Seq("ddd")), 1) | ||
) | ||
) | ||
).map((check: AggCheck) => | ||
withTable(tableName, tableNameLowercase) { | ||
def checkResults(table: String): Unit = { | ||
checkAnswer(sql(s"select a, count(*) from $table group by a"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bad practice that we try to test all cases in end-to-end tests. We should improve the test coverage in the unit test. I suggest we add tests in |
||
check.result.map{ case (agg, cnt) => Row(agg, cnt) }) | ||
checkAnswer(sql(s"select distinct a from $table"), | ||
check.result.map{ case (agg, _) => Row(agg) }) | ||
} | ||
|
||
// check against non-binary collation | ||
sql(s"create table $tableName(a ${check.dataTypeCollated}) using parquet") | ||
check.rows.map(row => sql(s"insert into $tableName(a) values($row)")) | ||
checkResults(tableName) | ||
|
||
// binary collation with values converted to lowercase should match the results as well | ||
sql(s"create table $tableNameLowercase(a ${check.dataType}) using parquet") | ||
check.rows.map(row => | ||
// scalastyle:off caselocale | ||
sql( | ||
s""" | ||
|insert into $tableNameLowercase(a) | ||
|values(${UTF8String.fromString(row).toLowerCase}) | ||
|""".stripMargin) | ||
// scalastyle:on caselocale | ||
) | ||
checkResults(tableNameLowercase) | ||
} | ||
) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test pattern that can be useful here is what we did for Window Aggs. You can find test example here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated tests to support this and added case classes for better readability. Please check again. |
||
|
||
test("Join on arrays of collated strings") { | ||
abstract class JoinCheck( | ||
val leftRows: Seq[String], | ||
val rightRows: Seq[String], | ||
val resultRows: Seq[Any]) | ||
extends ArrayCheck | ||
|
||
case class JoinSimpleCheck( | ||
override val leftRows: Seq[String], | ||
override val rightRows: Seq[String], | ||
override val resultRows: Seq[Seq[String]]) | ||
extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckSimple | ||
|
||
case class JoinNestedCheck( | ||
override val leftRows: Seq[String], | ||
override val rightRows: Seq[String], | ||
override val resultRows: Seq[Seq[Seq[String]]]) | ||
extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckNested | ||
|
||
val tablePrefix = "test_join_arr_collated" | ||
val tableLeft = s"${tablePrefix}_left" | ||
val tableLeftLowercase = s"${tableLeft}_lowercase" | ||
val tableRight = s"${tablePrefix}_right" | ||
val tableRightLowercase = s"${tableRight}_lowercase" | ||
|
||
Seq( | ||
// simple | ||
JoinSimpleCheck( | ||
leftRows = Seq("array('aaa')"), | ||
rightRows = Seq("array('AAA')"), | ||
resultRows = Seq(Seq("aaa")) | ||
), | ||
JoinSimpleCheck( | ||
leftRows = Seq("array('aaa', 'bbb')"), | ||
rightRows = Seq("array('AAA', 'BBB')"), | ||
resultRows = Seq(Seq("aaa", "bbb")) | ||
), | ||
JoinSimpleCheck( | ||
leftRows = Seq("array('aaa')", "array('bbb')"), | ||
rightRows = Seq("array('AAA')", "array('BBB')"), | ||
resultRows = Seq(Seq("aaa"), Seq("bbb")) | ||
), | ||
JoinSimpleCheck( | ||
leftRows = Seq("array('aaa')", "array('bbb')"), | ||
rightRows = Seq("array('AAAA')", "array('BBBB')"), | ||
resultRows = Seq() | ||
), | ||
// nested | ||
JoinNestedCheck( | ||
leftRows = Seq("array(array('aaa'))"), | ||
rightRows = Seq("array(array('AAA'))"), | ||
resultRows = Seq(Seq(Seq("aaa"))) | ||
), | ||
JoinNestedCheck( | ||
leftRows = Seq("array(array('aaa', 'bbb'))"), | ||
rightRows = Seq("array(array('AAA', 'BBB'))"), | ||
resultRows = Seq(Seq(Seq("aaa", "bbb"))) | ||
), | ||
JoinNestedCheck( | ||
Seq("array(array('aaa'), array('bbb'))"), | ||
Seq("array(array('AAA'), array('BBB'))"), | ||
Seq(Seq(Seq("aaa"), Seq("bbb"))) | ||
), | ||
JoinNestedCheck( | ||
Seq("array(array('aaa'), array('bbb'))"), | ||
Seq("array(array('AAA'), array('CCC'))"), | ||
Seq() | ||
) | ||
).map((check: JoinCheck) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to use |
||
withTable(tableLeft, tableLeftLowercase, tableRight, tableRightLowercase) { | ||
def checkResults(left: String, right: String): Unit = { | ||
checkAnswer( | ||
sql(s"select $left.a from $left join $right where $left.a = $right.a"), | ||
check.resultRows.map(Row(_)) | ||
) | ||
} | ||
|
||
// check against non-binary collation | ||
Seq(tableLeft, tableRight).map(tab => | ||
sql(s"create table $tab(a ${check.dataTypeCollated}) using parquet")) | ||
Seq((tableLeft, check.leftRows), (tableRight, check.rightRows)).foreach { | ||
case (tab, data) => data.map(row => sql(s"insert into $tab(a) values($row)")) | ||
} | ||
checkResults(tableLeft, tableRight) | ||
|
||
// binary collation with values converted to lowercase should match the results as well | ||
Seq(tableLeftLowercase, tableRightLowercase).map(tab => | ||
sql(s"create table $tab(a ${check.dataType}) using parquet")) | ||
Seq((tableLeftLowercase, check.leftRows), (tableRightLowercase, check.rightRows)).foreach { | ||
case (tab, data) => | ||
// scalastyle:off caselocale | ||
data.map(row => | ||
sql(s"insert into $tab(a) values(${UTF8String.fromString(row).toLowerCase})")) | ||
// scalastyle:on caselocale | ||
} | ||
checkResults(tableLeft, tableRight) | ||
} | ||
) | ||
} | ||
|
||
test("window aggregates should respect collation") { | ||
val t1 = "T_NON_BINARY" | ||
val t2 = "T_BINARY" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
isBinaryStable
is inUnsafeRowUtils
. Is the implementation bound somehow to unsafe row?Why it is not in
DataTypeUtils
orCollation...
, for example?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what are the consequences of moving this function, do you know if we can do that @dbatomic?