-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings #45611
[SPARK-47483][SQL] Add support for aggregation and join operations on arrays of collated strings #45611
Changes from all commits
f69b093
ca7664e
f78557d
824b92a
78e8310
588cca1
5f59648
514d2cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,10 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCu | |
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} | ||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper | ||
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership | ||
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType | ||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | ||
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} | ||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} | ||
import org.apache.spark.sql.types.{StringType, StructField, StructType} | ||
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} | ||
|
||
class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { | ||
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName | ||
|
@@ -640,6 +641,90 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { | |
"reason" -> "generation expression cannot contain non-default collated string type")) | ||
} | ||
|
||
test("Aggregation on complex containing collated strings") { | ||
val table = "table_agg" | ||
// array | ||
withTable(table) { | ||
sql(s"create table $table (a array<string collate utf8_binary_lcase>) using parquet") | ||
sql(s"insert into $table values (array('aaa')), (array('AAA'))") | ||
checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) | ||
} | ||
// map doesn't support aggregation | ||
withTable(table) { | ||
sql(s"create table $table (m map<string collate utf8_binary_lcase, string>) using parquet") | ||
val query = s"select distinct m from $table" | ||
checkError( | ||
exception = intercept[ExtendedAnalysisException](sql(query)), | ||
errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", | ||
parameters = Map( | ||
"colName" -> "`m`", | ||
"dataType" -> toSQLType(MapType( | ||
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), | ||
StringType))), | ||
context = ExpectedContext(query, 0, query.length - 1) | ||
) | ||
} | ||
// struct | ||
withTable(table) { | ||
sql(s"create table $table (s struct<fld:string collate utf8_binary_lcase>) using parquet") | ||
sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))") | ||
checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa"))) | ||
} | ||
} | ||
|
||
test("Joins on complex types containing collated strings") { | ||
val tableLeft = "table_join_le" | ||
val tableRight = "table_join_ri" | ||
// array | ||
withTable(tableLeft, tableRight) { | ||
Seq(tableLeft, tableRight).map(tab => | ||
sql(s"create table $tab (a array<string collate utf8_binary_lcase>) using parquet")) | ||
Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{ | ||
case (tab, data) => sql(s"insert into $tab values ($data)") | ||
} | ||
checkAnswer(sql( | ||
s""" | ||
|select $tableLeft.a from $tableLeft | ||
|join $tableRight on $tableLeft.a = $tableRight.a | ||
|""".stripMargin), Seq(Row(Seq("aaa")))) | ||
} | ||
// map doesn't support joins | ||
withTable(tableLeft, tableRight) { | ||
Seq(tableLeft, tableRight).map(tab => | ||
sql(s"create table $tab (m map<string collate utf8_binary_lcase, string>) using parquet")) | ||
val query = | ||
s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m = $tableRight.m" | ||
val ctx = s"$tableLeft.m = $tableRight.m" | ||
checkError( | ||
exception = intercept[AnalysisException](sql(query)), | ||
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", | ||
parameters = Map( | ||
"functionName" -> "`=`", | ||
"dataType" -> toSQLType(MapType( | ||
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), | ||
StringType | ||
)), | ||
"sqlExpr" -> "\"(m = m)\""), | ||
context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1)) | ||
} | ||
// struct | ||
withTable(tableLeft, tableRight) { | ||
Seq(tableLeft, tableRight).map(tab => | ||
sql(s"create table $tab (s struct<fld:string collate utf8_binary_lcase>) using parquet")) | ||
Seq( | ||
(tableLeft, "named_struct('fld', 'aaa')"), | ||
(tableRight, "named_struct('fld', 'AAA')") | ||
).map { | ||
case (tab, data) => sql(s"insert into $tab values ($data)") | ||
} | ||
checkAnswer(sql( | ||
s""" | ||
|select $tableLeft.s.fld from $tableLeft | ||
|join $tableRight on $tableLeft.s = $tableRight.s | ||
|""".stripMargin), Seq(Row("aaa"))) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test pattern that can be useful here is what we did for Window Aggs. You can find test example here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated tests to support this and added case classes for better readability. Please check again. |
||
|
||
test("window aggregates should respect collation") { | ||
val t1 = "T_NON_BINARY" | ||
val t2 = "T_BINARY" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
isBinaryStable
is inUnsafeRowUtils
. Is the implementation bound somehow to unsafe row?Why it is not in
DataTypeUtils
orCollation...
, for example?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what are the consequences of moving this function, do you know if we can do that @dbatomic?