Skip to content

Commit

Permalink
[SPARK-34094][SQL] Extends StringTranslate to support unicode charact…
Browse files Browse the repository at this point in the history
…ers whose code point >= U+10000

### What changes were proposed in this pull request?

This PR extends `StringTranslate` to support unicode characters whose code point >= `U+10000`.

### Why are the changes needed?

To make it work with wide variety of characters.

### Does this PR introduce _any_ user-facing change?

Yes. Users can use `StringTranslate` with unicode characters whose code point >= `U+10000`.

### How was this patch tested?

New assertion added to the existing test.

Closes #31164 from sarutak/extends-translate.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
sarutak authored and srowen committed Jan 21, 2021
1 parent 28131a7 commit 116f4ca
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1075,16 +1075,20 @@ public UTF8String replace(UTF8String search, UTF8String replace) {
return buf.build();
}

// TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes
public UTF8String translate(Map<Character, Character> dict) {
public UTF8String translate(Map<String, String> dict) {
String srcStr = this.toString();

StringBuilder sb = new StringBuilder();
for(int k = 0; k< srcStr.length(); k++) {
if (null == dict.get(srcStr.charAt(k))) {
sb.append(srcStr.charAt(k));
} else if ('\0' != dict.get(srcStr.charAt(k))){
sb.append(dict.get(srcStr.charAt(k)));
int charCount = 0;
for (int k = 0; k < srcStr.length(); k += charCount) {
int codePoint = srcStr.codePointAt(k);
charCount = Character.charCount(codePoint);
String subStr = srcStr.substring(k, k + charCount);
String translated = dict.get(subStr);
if (null == translated) {
sb.append(subStr);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return fromString(sb.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,27 +465,27 @@ public void translate() {
assertEquals(
fromString("1a2s3ae"),
fromString("translate").translate(ImmutableMap.of(
'r', '1',
'n', '2',
'l', '3',
't', '\0'
"r", "1",
"n", "2",
"l", "3",
"t", "\0"
)));
assertEquals(
fromString("translate"),
fromString("translate").translate(new HashMap<>()));
assertEquals(
fromString("asae"),
fromString("translate").translate(ImmutableMap.of(
'r', '\0',
'n', '\0',
'l', '\0',
't', '\0'
"r", "\0",
"n", "\0",
"l", "\0",
"t", "\0"
)));
assertEquals(
fromString("aa世b"),
fromString("花花世界").translate(ImmutableMap.of(
'花', 'a',
'界', 'b'
"花", "a",
"界", "b"
)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,17 +633,29 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
object StringTranslate {

def buildDict(matchingString: UTF8String, replaceString: UTF8String)
: JMap[Character, Character] = {
: JMap[String, String] = {
val matching = matchingString.toString()
val replace = replaceString.toString()
val dict = new HashMap[Character, Character]()
val dict = new HashMap[String, String]()
var i = 0
while (i < matching.length()) {
val rep = if (i < replace.length()) replace.charAt(i) else '\u0000'
if (null == dict.get(matching.charAt(i))) {
dict.put(matching.charAt(i), rep)
var j = 0

while (i < matching.length) {
val rep = if (j < replace.length) {
val repCharCount = Character.charCount(replace.codePointAt(j))
val repStr = replace.substring(j, j + repCharCount)
j += repCharCount
repStr
} else {
"\u0000"
}

val matchCharCount = Character.charCount(matching.codePointAt(i))
val matchStr = matching.substring(i, i + matchCharCount)
if (null == dict.get(matchStr)) {
dict.put(matchStr, rep)
}
i += 1
i += matchCharCount
}
dict
}
Expand Down Expand Up @@ -671,7 +683,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac

@transient private var lastMatching: UTF8String = _
@transient private var lastReplace: UTF8String = _
@transient private var dict: JMap[Character, Character] = _
@transient private var dict: JMap[String, String] = _

override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
if (matchingEval != lastMatching || replaceEval != lastReplace) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b")
// test for unicode characters whose code point >= 0x10000
checkEvaluation(
StringTranslate(
Literal("\uD840\uDC0Bxyza\uD867\uDE49c123b\uD842\uDFB7\uD867\uDE3D"),
Literal("\uD867\uDE3Da\uD842\uDFB7b\uD840\uDC0Bc\uD867\uDE49c"),
Literal("1\uD83C\uDF3B2\uD83C\uDF37\uD83D\uDC15\uD83D\uDC08\uD83C\uDF38")),
"\uD83D\uDC15xyz\uD83C\uDF3B\uD83C\uDF38\uD83D\uDC08123\uD83C\uDF3721")
// scalastyle:on
}

Expand Down

0 comments on commit 116f4ca

Please sign in to comment.