Skip to content

Commit

Permalink
[SPARK-47988][SQL] When the collationId is invalid, throw `COLLATION_…
Browse files Browse the repository at this point in the history
…INVALID_ID`
  • Loading branch information
panbingkun committed Apr 25, 2024
1 parent b4624bf commit 4da4f42
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.ibm.icu.util.ULocale;
import com.ibm.icu.text.Collator;

import org.apache.spark.SparkArrayIndexOutOfBoundsException;
import org.apache.spark.SparkException;
import org.apache.spark.unsafe.types.UTF8String;

Expand Down Expand Up @@ -217,17 +218,17 @@ public static StringSearch getStringSearch(
* Returns if the given collationName is valid one.
*/
public static boolean isValidCollation(String collationName) {
return collationNameToIdMap.containsKey(collationName.toUpperCase());
return collationNameToIdMap.containsKey(collationName.toUpperCase(Locale.ROOT));
}

/**
* Returns closest valid name to collationName
*/
public static String getClosestCollation(String collationName) {
Collation suggestion = Collections.min(List.of(collationTable), Comparator.comparingInt(
c -> UTF8String.fromString(c.collationName).levenshteinDistance(
UTF8String.fromString(collationName.toUpperCase()))));
return suggestion.collationName;
public static Collation getClosestCollation(String collationName) {
String normalizedName = collationName.toUpperCase(Locale.ROOT);
return Collections.min(List.of(collationTable), Comparator.comparingInt(
c -> UTF8String.fromString(c.collationName).levenshteinDistance(
UTF8String.fromString(normalizedName))));
}

/**
Expand All @@ -245,13 +246,11 @@ public static StringSearch getStringSearch(
* Returns the collation id for the given collation name.
*/
public static int collationNameToId(String collationName) throws SparkException {
String normalizedName = collationName.toUpperCase();
String normalizedName = collationName.toUpperCase(Locale.ROOT);
if (collationNameToIdMap.containsKey(normalizedName)) {
return collationNameToIdMap.get(normalizedName);
} else {
Collation suggestion = Collections.min(List.of(collationTable), Comparator.comparingInt(
c -> UTF8String.fromString(c.collationName).levenshteinDistance(
UTF8String.fromString(normalizedName))));
Collation suggestion = getClosestCollation(collationName);

Map<String, String> params = new HashMap<>();
params.put("collationName", collationName);
Expand All @@ -263,6 +262,13 @@ public static int collationNameToId(String collationName) throws SparkException
}

public static Collation fetchCollation(int collationId) {
if (collationId < 0 || collationId >= collationTable.length) {
Map<String, String> params = new HashMap<>();
params.put("collationId", String.valueOf(collationId));
params.put("valueRange", "[0, " + (collationTable.length - 1) + "]");
throw new SparkArrayIndexOutOfBoundsException(
"COLLATION_INVALID_ID", SparkException.constructMessageParams(params), null);
}
return collationTable[collationId];
}

Expand Down
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,12 @@
],
"sqlState" : "42704"
},
"COLLATION_INVALID_ID" : {
"message" : [
"The collationId value <collationId> must be between <valueRange>."
],
"sqlState" : "42704"
},
"COLLATION_INVALID_NAME" : {
"message" : [
"The value <collationName> does not represent a correct collation name. Suggested valid collation name: [<proposal>]."
Expand Down
45 changes: 24 additions & 21 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -361,27 +361,30 @@ abstract class SparkFunSuite
} else {
assert(expectedParameters === parameters)
}
val actualQueryContext = exception.getQueryContext()
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
val actualQueryContext = exception.getQueryContext
if (actualQueryContext != null) {
assert(actualQueryContext.length === queryContext.length,
"Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ object SQLConf {
"DEFAULT_COLLATION",
name =>
Map(
"proposal" -> CollationFactory.getClosestCollation(name)
"proposal" -> CollationFactory.getClosestCollation(name).collationName
))
.createWithDefault("UTF8_BINARY")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import scala.jdk.CollectionConverters.MapHasAsJava

import org.apache.spark.SparkException
import org.apache.spark.{SparkArrayIndexOutOfBoundsException, SparkException}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.CollationFactory
Expand Down Expand Up @@ -144,6 +144,19 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
parameters = Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS"))
}

test("invalid collationId throws exception") {
val schema = StructType(StructField("s", StringType(8)) :: Nil)
val data = Seq(Row("Alice"), Row("Bob"), Row("bob"))
val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
checkError(
exception = intercept[SparkArrayIndexOutOfBoundsException] {
df.schema.printTreeString()
},
errorClass = "COLLATION_INVALID_ID",
parameters = Map("collationId" -> "8", "valueRange" -> "[0, 3]")
)
}

test("disable bucketing on collated string column") {
def createTable(bucketColumns: String*): Unit = {
val tableName = "test_partition_tbl"
Expand Down

0 comments on commit 4da4f42

Please sign in to comment.