Skip to content

Commit

Permalink
[SPARK-48937][SQL] Add collation support for StringToMap string expre…
Browse files Browse the repository at this point in the history
…ssions

### What changes were proposed in this pull request?
Add collation awareness for `StringToMap` string expression.

### Why are the changes needed?
`StringToMap` should be collation aware when splitting strings on specified delimiters.

### Does this PR introduce _any_ user-facing change?
Yes, `StringToMap` is now collation aware.

### How was this patch tested?
New unit tests and e2e sql tests for `str_to_map`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47621 from uros-db/fix-str-to-map.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed Aug 9, 2024
1 parent fd3069a commit d8aff6e
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@

import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

/**
* Utility class for collation-aware UTF8String operations.
Expand Down Expand Up @@ -1226,6 +1229,60 @@ public static UTF8String trimRight(
return UTF8String.fromString(src.substring(0, charIndex));
}

public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim,
final int limit, final int collationId) {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
return input.split(delim, limit);
} else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
return lowercaseSplitSQL(input, delim, limit);
} else {
return icuSplitSQL(input, delim, limit, collationId);
}
}

public static UTF8String[] lowercaseSplitSQL(final UTF8String string, final UTF8String delimiter,
final int limit) {
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
CollationSupport.lowercaseRegexFlags);
String[] splits = pattern.split(string.toString(), limit);
UTF8String[] res = new UTF8String[splits.length];
for (int i = 0; i < res.length; i++) {
res[i] = UTF8String.fromString(splits[i]);
}
return res;
}

public static UTF8String[] icuSplitSQL(final UTF8String string, final UTF8String delimiter,
final int limit, final int collationId) {
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
List<UTF8String> strings = new ArrayList<>();
String target = string.toString(), pattern = delimiter.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
int start = 0, end;
while ((end = stringSearch.next()) != StringSearch.DONE) {
if (limit > 0 && strings.size() == limit - 1) {
break;
}
strings.add(UTF8String.fromString(target.substring(start, end)));
start = end + stringSearch.getMatchLength();
}
if (start <= target.length()) {
strings.add(UTF8String.fromString(target.substring(start)));
}
if (limit == 0) {
// Remove trailing empty strings
int i = strings.size() - 1;
while (i >= 0 && strings.get(i).numBytes() == 0) {
strings.remove(i);
i--;
}
}
return strings.toArray(new UTF8String[0]);
}

// TODO: Add more collation-aware UTF8String operations here.

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import org.apache.spark.unsafe.types.UTF8String;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -62,33 +60,11 @@ public static UTF8String[] execBinary(final UTF8String string, final UTF8String
return string.splitSQL(delimiter, -1);
}
public static UTF8String[] execLowercase(final UTF8String string, final UTF8String delimiter) {
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
CollationSupport.lowercaseRegexFlags);
String[] splits = pattern.split(string.toString(), -1);
UTF8String[] res = new UTF8String[splits.length];
for (int i = 0; i < res.length; i++) {
res[i] = UTF8String.fromString(splits[i]);
}
return res;
return CollationAwareUTF8String.lowercaseSplitSQL(string, delimiter, -1);
}
public static UTF8String[] execICU(final UTF8String string, final UTF8String delimiter,
final int collationId) {
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
List<UTF8String> strings = new ArrayList<>();
String target = string.toString(), pattern = delimiter.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
int start = 0, end;
while ((end = stringSearch.next()) != StringSearch.DONE) {
strings.add(UTF8String.fromString(target.substring(start, end)));
start = end + stringSearch.getMatchLength();
}
if (start <= target.length()) {
strings.add(UTF8String.fromString(target.substring(start)));
}
return strings.toArray(new UTF8String[0]);
return CollationAwareUTF8String.icuSplitSQL(string, delimiter, -1, collationId);
}
}

Expand Down Expand Up @@ -696,7 +672,7 @@ public static boolean supportsLowercaseRegex(final int collationId) {
return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality;
}

private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE;
static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE;
public static int collationAwareRegexFlags(final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.encoders.HashableWeakReference
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationAwareUTF8String, CollationFactory, CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1529,6 +1529,7 @@ object CodeGenerator extends Logging {
classOf[TaskContext].getName,
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName,
classOf[CollationAwareUTF8String].getName,
classOf[CollationFactory].getName,
classOf[CollationSupport].getName,
QueryExecutionErrors.getClass.getName.stripSuffix("$")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,17 +585,20 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E

private lazy val mapBuilder = new ArrayBasedMapBuilder(first.dataType, first.dataType)

private final lazy val collationId: Int = text.dataType.asInstanceOf[StringType].collationId

override def nullSafeEval(
inputString: Any,
stringDelimiter: Any,
keyValueDelimiter: Any): Any = {
val keyValues =
inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1)
val keyValues = CollationAwareUTF8String.splitSQL(inputString.asInstanceOf[UTF8String],
stringDelimiter.asInstanceOf[UTF8String], -1, collationId)
val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String]

var i = 0
while (i < keyValues.length) {
val keyValueArray = keyValues(i).split(keyValueDelimiterUTF8String, 2)
val keyValueArray = CollationAwareUTF8String.splitSQL(
keyValues(i), keyValueDelimiterUTF8String, 2, collationId)
val key = keyValueArray(0)
val value = if (keyValueArray.length < 2) null else keyValueArray(1)
mapBuilder.put(key, value)
Expand All @@ -610,9 +613,9 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E

nullSafeCodeGen(ctx, ev, (text, pd, kvd) =>
s"""
|UTF8String[] $keyValues = $text.split($pd, -1);
|UTF8String[] $keyValues = CollationAwareUTF8String.splitSQL($text, $pd, -1, $collationId);
|for(UTF8String kvEntry: $keyValues) {
| UTF8String[] kv = kvEntry.split($kvd, 2);
| UTF8String[] kv = CollationAwareUTF8String.splitSQL(kvEntry, $kvd, 2, $collationId);
| $builderTerm.put(kv[0], kv.length == 2 ? kv[1] : null);
|}
|${ev.value} = $builderTerm.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.immutable.Seq

import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.test.SharedSparkSession
Expand All @@ -34,8 +34,9 @@ import org.apache.spark.util.collection.OpenHashMap

// scalastyle:off nonascii
class CollationSQLExpressionsSuite
extends QueryTest
with SharedSparkSession {
extends QueryTest
with SharedSparkSession
with ExpressionEvalHelper {

private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI")

Expand Down Expand Up @@ -964,25 +965,36 @@ class CollationSQLExpressionsSuite
})
}

test("Support StringToMap expression with collation") {
// Supported collations
case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R)
test("Support `StringToMap` expression with collation") {
case class StringToMapTestCase[R](
text: String,
pairDelim: String,
keyValueDelim: String,
collation: String,
result: R)
val testCases = Seq(
StringToMapTestCase("a:1,b:2,c:3", ",", ":", "UTF8_BINARY",
Map("a" -> "1", "b" -> "2", "c" -> "3")),
StringToMapTestCase("A-1;B-2;C-3", ";", "-", "UTF8_LCASE",
StringToMapTestCase("A-1xB-2xC-3", "X", "-", "UTF8_LCASE",
Map("A" -> "1", "B" -> "2", "C" -> "3")),
StringToMapTestCase("1:a,2:b,3:c", ",", ":", "UNICODE",
StringToMapTestCase("1:ax2:bx3:c", "x", ":", "UNICODE",
Map("1" -> "a", "2" -> "b", "3" -> "c")),
StringToMapTestCase("1/A!2/B!3/C", "!", "/", "UNICODE_CI",
StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
Map("1" -> "A", "2" -> "B", "3" -> "C"))
)
testCases.foreach(t => {
val query = s"SELECT str_to_map(collate('${t.t}', '${t.c}'), '${t.p}', '${t.k}');"
// Result & data type
checkAnswer(sql(query), Row(t.result))
val dataType = MapType(StringType(t.c), StringType(t.c), true)
assert(sql(query).schema.fields.head.dataType.sameType(dataType))
// Unit test.
val text = Literal.create(t.text, StringType(t.collation))
val pairDelim = Literal.create(t.pairDelim, StringType(t.collation))
val keyValueDelim = Literal.create(t.keyValueDelim, StringType(t.collation))
checkEvaluation(StringToMap(text, pairDelim, keyValueDelim), t.result)
// E2E SQL test.
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
val query = s"SELECT str_to_map('${t.text}', '${t.pairDelim}', '${t.keyValueDelim}')"
checkAnswer(sql(query), Row(t.result))
val dataType = MapType(StringType(t.collation), StringType(t.collation), true)
assert(sql(query).schema.fields.head.dataType.sameType(dataType))
}
})
}

Expand Down

0 comments on commit d8aff6e

Please sign in to comment.