Skip to content

Commit

Permalink
[SPARK-49354][SQL] split_part should check whether the collation
Browse files Browse the repository at this point in the history
…values of all parameter types are the same

### What changes were proposed in this pull request?
The same principle as #47825 (review), the parameter `delimiter` in expression `split_part` are treated as (`collation-dependent`) delimiters, rather than (`collation-unaware`) regular expressions.

### Why are the changes needed?
Strengthen the parameter data type check of expression `split_part`  to avoid potential issues.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
- Add some `test case` to `collations.sql`.
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47845 from panbingkun/SPARK-49354.

Lead-authored-by: panbingkun <panbingkun@baidu.com>
Co-authored-by: panbingkun <pbk1982@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
2 people authored and MaxGekk committed Aug 26, 2024
1 parent 7e4d6bd commit ccef4df
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ object CollationTypeCasts extends TypeCoercionRule {
val newValues = collateToSingleType(mapCreate.values)
mapCreate.withNewChildren(newKeys.zip(newValues).flatMap(pair => Seq(pair._1, pair._2)))

case splitPart: SplitPart =>
val Seq(str, delimiter, partNum) = splitPart.children
val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
splitPart.withNewChildren(Seq(newStr, newDelimiter, partNum))

case stringSplitSQL: StringSplitSQL =>
val Seq(str, delimiter) = stringSplitSQL.children
val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
stringSplitSQL.withNewChildren(Seq(newStr, newDelimiter))

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: ArrayContains | _: ArrayExcept | _: ConcatWs | _: Mask | _: StringReplace |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,55 @@ drop table t4
-- !query analysis
DropTable false, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t4


-- !query
create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet
-- !query analysis
CreateDataSourceTableCommand `spark_catalog`.`default`.`t5`, false


-- !query
insert into t5 values('11AB12AB13', 'AB', 2)
-- !query analysis
InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [str, delimiter, partNum]
+- Project [cast(col1#x as string) AS str#x, cast(col2#x as string collate UTF8_LCASE) AS delimiter#x, cast(col3#x as int) AS partNum#x]
+- LocalRelation [col1#x, col2#x, col3#x]


-- !query
select split_part(str, delimiter, partNum) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "COLLATION_MISMATCH.IMPLICIT",
"sqlState" : "42P21"
}


-- !query
select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "COLLATION_MISMATCH.EXPLICIT",
"sqlState" : "42P21",
"messageParameters" : {
"explicitTypes" : "`string`, `string collate UTF8_LCASE`"
}
}


-- !query
select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5
-- !query analysis
Project [split_part(collate(str#x, utf8_binary), collate(delimiter#x, utf8_binary), partNum#x) AS split_part(collate(str, utf8_binary), collate(delimiter, utf8_binary), partNum)#x]
+- SubqueryAlias spark_catalog.default.t5
+- Relation spark_catalog.default.t5[str#x,delimiter#x,partNum#x] parquet


-- !query
drop table t5
-- !query analysis
DropTable false, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5
11 changes: 11 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/collations.sql
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,14 @@ select str_to_map(text collate utf8_binary, pairDelim collate utf8_lcase, keyVal
select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary, keyValueDelim collate utf8_binary) from t4;

drop table t4;

-- create table for split_part
create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet;

insert into t5 values('11AB12AB13', 'AB', 2);

select split_part(str, delimiter, partNum) from t5;
select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5;
select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5;

drop table t5;
59 changes: 59 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/collations.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,62 @@ drop table t4
struct<>
-- !query output



-- !query
create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet
-- !query schema
struct<>
-- !query output



-- !query
insert into t5 values('11AB12AB13', 'AB', 2)
-- !query schema
struct<>
-- !query output



-- !query
select split_part(str, delimiter, partNum) from t5
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
"errorClass" : "COLLATION_MISMATCH.IMPLICIT",
"sqlState" : "42P21"
}


-- !query
select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
"errorClass" : "COLLATION_MISMATCH.EXPLICIT",
"sqlState" : "42P21",
"messageParameters" : {
"explicitTypes" : "`string`, `string collate UTF8_LCASE`"
}
}


-- !query
select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5
-- !query schema
struct<split_part(collate(str, utf8_binary), collate(delimiter, utf8_binary), partNum):string>
-- !query output
12


-- !query
drop table t5
-- !query schema
struct<>
-- !query output

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.analysis.CollationTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -112,6 +113,51 @@ class CollationStringExpressionsSuite
})
}

test("Support `StringSplitSQL` string expression with collation") {
case class StringSplitSQLTestCase[R](
str: String,
delimiter: String,
collation: String,
result: R)
val testCases = Seq(
StringSplitSQLTestCase("1a2", "a", "UTF8_BINARY", Array("1", "2")),
StringSplitSQLTestCase("1a2", "a", "UNICODE", Array("1", "2")),
StringSplitSQLTestCase("1a2", "A", "UTF8_LCASE", Array("1", "2")),
StringSplitSQLTestCase("1a2", "A", "UNICODE_CI", Array("1", "2"))
)
testCases.foreach(t => {
// Unit test.
val str = Literal.create(t.str, StringType(t.collation))
val delimiter = Literal.create(t.delimiter, StringType(t.collation))
checkEvaluation(StringSplitSQL(str, delimiter), t.result)
})

// Because `StringSplitSQL` is an internal expression,
// E2E SQL test cannot be performed in `collations.sql`.
checkError(
exception = intercept[AnalysisException] {
val expr = StringSplitSQL(
Cast(Literal.create("1a2"), StringType("UTF8_BINARY")),
Cast(Literal.create("a"), StringType("UTF8_LCASE")))
CollationTypeCasts.transform(expr)
},
errorClass = "COLLATION_MISMATCH.IMPLICIT",
sqlState = "42P21",
parameters = Map.empty
)
checkError(
exception = intercept[AnalysisException] {
val expr = StringSplitSQL(
Collate(Literal.create("1a2"), "UTF8_BINARY"),
Collate(Literal.create("a"), "UTF8_LCASE"))
CollationTypeCasts.transform(expr)
},
errorClass = "COLLATION_MISMATCH.EXPLICIT",
sqlState = "42P21",
parameters = Map("explicitTypes" -> "`string`, `string collate UTF8_LCASE`")
)
}

test("Support `Contains` string expression with collation") {
case class ContainsTestCase[R](left: String, right: String, collation: String, result: R)
val testCases = Seq(
Expand Down

0 comments on commit ccef4df

Please sign in to comment.