Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47476][SQL] Support REPLACE function to work with collated strings #45704

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
2a5fce7
Update StringReplace class
miland-db Mar 25, 2024
e0ce699
Add UTF8_BINARY_LCASE collation support using custom function
miland-db Mar 25, 2024
0711242
Merge branch 'apache:master' into miland-db/string-replace
miland-db Mar 25, 2024
d2e90f8
Improve testReplace signature
miland-db Mar 26, 2024
1e41ebd
Merge branch 'master' into miland-db/string-replace
miland-db Mar 26, 2024
93c6eb7
Resolve merge problems with master
miland-db Mar 26, 2024
7a1b240
Improve scala style
miland-db Mar 26, 2024
c59d71e
Solve whitespace scala style problem
miland-db Mar 27, 2024
a5c75b3
Add lowercase StringSearch and remove lowercaseReplace
miland-db Apr 1, 2024
76878b9
Remove repeated code
miland-db Apr 1, 2024
572bd54
Improve naming of collation aware methods
miland-db Apr 2, 2024
839b39a
Merge branch 'master' into string-replace
miland-db Apr 2, 2024
e2bea13
Improve java style
miland-db Apr 3, 2024
d719fe2
Merge branch 'master' into string-replace
miland-db Apr 3, 2024
a194292
Remove unnecessary check for mathced length
miland-db Apr 3, 2024
7b6720b
Improve style in CollationFactory
miland-db Apr 3, 2024
3042d7e
Add doc comment
miland-db Apr 3, 2024
84e41a3
Improve comment style
miland-db Apr 3, 2024
cc940cb
Improve naming in getStringSearch
miland-db Apr 3, 2024
ec960b8
Merge branch 'master' into string-replace
miland-db Apr 4, 2024
4e93874
Remove type checks for collation missmatch
miland-db Apr 4, 2024
9cb0944
Remove checkInputDataTypes
miland-db Apr 4, 2024
41c3872
Add empty lines between imports
miland-db Apr 4, 2024
74f69b9
Handle all collationIds in getStringSearch
miland-db Apr 4, 2024
ea3730c
Improve Java style
miland-db Apr 5, 2024
5b2a9d3
Merge branch 'master' into string-replace
miland-db Apr 12, 2024
bc5c256
Refactor StringReplace
miland-db Apr 12, 2024
8a81536
Break lines to 100 characters
miland-db Apr 12, 2024
c456325
Refactor tests
miland-db Apr 15, 2024
68d55f2
Merge branch 'master' into string-replace
miland-db Apr 16, 2024
09f13d8
Sync with the latest master
miland-db Apr 16, 2024
67ecb47
Merge branch 'master' into string-replace
miland-db Apr 17, 2024
a67fc9b
Merge branch 'master' into string-replace
miland-db Apr 17, 2024
08d1462
Merge branch 'master' into string-replace
miland-db Apr 18, 2024
d9f56d6
Added new tests (2 failing)
miland-db Apr 18, 2024
ade12fc
Merge branch 'master' into string-replace
miland-db Apr 23, 2024
f6b4413
Merge branch 'master' into string-replace
miland-db Apr 24, 2024
0c725f9
Fix bug with case-variable lenght characters
miland-db Apr 24, 2024
816a49a
Fix java linter errors
miland-db Apr 24, 2024
feda2b9
Merge branch 'master' into string-replace
miland-db Apr 25, 2024
0ef49d0
Fix import scalastyle
miland-db Apr 25, 2024
a4747f1
Merge branch 'master' into miland-db/string-replace
uros-db Apr 26, 2024
91b32f2
Merge branch 'master' into miland-db/string-replace
uros-db Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,29 @@ public Collation(
*/

public static StringSearch getStringSearch(
final UTF8String left,
final UTF8String right,
final UTF8String targetUTF8String,
final UTF8String patternUTF8String,
final int collationId) {
String pattern = right.toString();
CharacterIterator target = new StringCharacterIterator(left.toString());

if (collationId == UTF8_BINARY_LCASE_COLLATION_ID) {
return getStringSearch(targetUTF8String.toLowerCase(), patternUTF8String.toLowerCase());
}

String pattern = patternUTF8String.toString();
CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString());
Collator collator = CollationFactory.fetchCollation(collationId).collator;
return new StringSearch(pattern, target, (RuleBasedCollator) collator);
}

public static StringSearch getStringSearch(
final UTF8String targetUTF8String,
final UTF8String patternUTF8String) {
String pattern = patternUTF8String.toString();
String target = targetUTF8String.toString();

return new StringSearch(pattern, target);
}

/**
* Returns the collation id for the given collation name.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,64 @@ public UTF8String replace(UTF8String search, UTF8String replace) {
return buf.build();
}

/**
* Replace all occurrences of search in this with replace respecting collation with id = collationId.
* @param search the string to be searched
* @param replace the start position of the current string for searching
* @param collationId the id of applied collation
* @return the string with replace instead of search in all places
*/
public UTF8String replace(UTF8String search, UTF8String replace, int collationId) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
return this.replace(search, replace);
}
return collationAwareReplace(search, replace, collationId);
}

private UTF8String collationAwareReplace(UTF8String search, UTF8String replace, int collationId) {
// This collation aware implementation is based on existing implementation on UTF8String with default collation
if (numBytes == 0 || search.numBytes == 0) {
return this;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, search, collationId);

// Find the first occurrence of the search string.
int end = stringSearch.next();
if (end == StringSearch.DONE) {
// Search string was not found, so string is unchanged.
return this;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < numBytes && c < end) {
byteEnd += numBytesForFirstByte(getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, Math.abs(replace.numBytes - search.numBytes)) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
while (end != StringSearch.DONE) {
buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart);
buf.append(replace);
byteStart = byteEnd + search.numBytes;
// Go to next match
end = stringSearch.next();
// Update byte positions
while (byteEnd < numBytes && c < end) {
byteEnd += numBytesForFirstByte(getByte(byteEnd));
c += 1;
}
}
buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart);
return buf.build();
}

public UTF8String translate(Map<String, String> dict) {
String srcStr = this.toString();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,23 +735,43 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

def this(srcExpr: Expression, searchExpr: Expression) = {
this(srcExpr, searchExpr, Literal(""))
}

override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = {
srcEval.asInstanceOf[UTF8String].replace(
searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String])
searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (src, search, replace) => {
s"""${ev.value} = $src.replace($search, $replace);"""
})
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
nullSafeCodeGen(ctx, ev, (src, search, replace) => {
s"""${ev.value} = $src.replace($search, $replace);"""
})
} else {
nullSafeCodeGen(ctx, ev, (src, search, replace) => {
s"""${ev.value} = $src.replace($search, $replace, $collationId);"""
})
}
}

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
return defaultCheck
}

// Only srcExpr and searchExpr are checked for collation compatibility.
val collationId = first.dataType.asInstanceOf[StringType].collationId
CollationTypeConstraints.checkCollationCompatibility(collationId, Seq(second.dataType))
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

override def dataType: DataType = srcExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
override def first: Expression = srcExpr
override def second: Expression = searchExpr
override def third: Expression = replaceExpr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat}
import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat, StringReplace}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -73,6 +74,49 @@ class CollationStringExpressionsSuite extends QueryTest
})
}

test("REPLACE check result on explicitly collated strings") {
def testReplace(source: String, search: String, replace: String,
collationId: Integer, expected: String): Unit = {
val sourceLiteral = Literal.create(source, StringType(collationId))
val searchLiteral = Literal.create(search, StringType(collationId))
val replaceLiteral = Literal.create(replace, StringType(collationId))

checkEvaluation(StringReplace(sourceLiteral, searchLiteral, replaceLiteral), expected)
}

// scalastyle:off
var collationId = CollationFactory.collationNameToId("UTF8_BINARY")
testReplace("r世eplace", "pl", "123", collationId, "r世e123ace")
testReplace("replace", "pl", "", collationId, "reace")
testReplace("repl世ace", "Pl", "", collationId, "repl世ace")
testReplace("replace", "", "123", collationId, "replace")
testReplace("abcabc", "b", "12", collationId, "a12ca12c")
testReplace("abcdabcd", "bc", "", collationId, "adad")

collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE")
testReplace("r世eplace", "pl", "xx", collationId, "r世exxace")
testReplace("repl世ace", "PL", "AB", collationId, "reAB世ace")
testReplace("Replace", "", "123", collationId, "Replace")
testReplace("re世place", "世", "x", collationId, "rexplace")
testReplace("abcaBc", "B", "12", collationId, "a12ca12c")
testReplace("AbcdabCd", "Bc", "", collationId, "Adad")

collationId = CollationFactory.collationNameToId("UNICODE")
testReplace("re世place", "plx", "123", collationId, "re世place")
testReplace("世Replace", "re", "", collationId, "世Replace")
testReplace("replace世", "", "123", collationId, "replace世")
testReplace("aBc世abc", "b", "12", collationId, "aBc世a12c")
testReplace("abcdabcd", "bc", "", collationId, "adad")

collationId = CollationFactory.collationNameToId("UNICODE_CI")
testReplace("replace", "plx", "123", collationId, "replace")
testReplace("Replace", "re", "", collationId, "place")
testReplace("replace", "", "123", collationId, "replace")
testReplace("aBc世abc", "b", "12", collationId, "a12c世a12c")
testReplace("a世Bcdabcd", "bC", "", collationId, "a世dad")
// scalastyle:on
}

test("REPEAT check output type on explicitly collated string") {
def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = {
val s = Literal.create(input, StringType(collationId))
Expand Down