Skip to content

Commit

Permalink
[SPARK-48549][SQL][PYTHON] Improve SQL function sentences
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The pr aims to  improve SQL function `sentences`, includes:
- update the description of the `sentences` expression to make it more realistic.
- add `def sentences(string: Column, language: Column): Column` to SQL functions
- `codegen` support for `sentences`

### Why are the changes needed?
Fix inconsistency in using the function `sentences` in the following scenarios
  <img width="1051" alt="image" src="https://github.com/apache/spark/assets/15246973/033c731d-5a2f-455f-8517-ed95bd6c1f6e">

- According to the definition of function `sentences`, we should only allow the following two kinds of parameter calls:
  A.sentences(str)
  B.sentences(str, language, country) - the parameters `language` and `country` either coexist or do not exist at the same time

  **In file `sql/core/src/main/scala/org/apache/spark/sql/functions.scala`, only the following two functions are defined**:
https://github.com/apache/spark/blob/f4434c36cc4f7b0147e0e8fe26ac0f177a5199cd/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L4273-L4282

- When we directly call the expression `sentences`, it actually supports the following:
A.`df.select(sentences($"str", $"language", $"country"))`;
B.`df.select(sentences($"str", $"language"))`;
C.`df.select(sentences($"str"))`;

## Let's align it

### Does this PR introduce _any_ user-facing change?
Yes, allow calling SQL function `sentences` as parameters (`str`, `language`).

### How was this patch tested?
- Add new UT & Update existed UT.
- Pass GA.
- Manually check
```scala
scala> val df =  Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US")).toDF("str", "language", "country");
val df: org.apache.spark.sql.DataFrame = [str: string, language: string ... 1 more field]

scala> df.select(sentences($"str", $"language", $"country"));
val res0: org.apache.spark.sql.DataFrame = [sentences(str, language, country): array<array<string>>]

scala> df.select(sentences($"str", $"language"));
val res1: org.apache.spark.sql.DataFrame = [sentences(str, language, ): array<array<string>>]

scala> df.select(sentences($"str"));
val res2: org.apache.spark.sql.DataFrame = [sentences(str, , ): array<array<string>>]

scala> df.selectExpr("sentences(str, language, country)");
val res3: org.apache.spark.sql.DataFrame = [sentences(str, language, country): array<array<string>>]

scala> df.selectExpr("sentences(str, language)");
val res4: org.apache.spark.sql.DataFrame = [sentences(str, language, ): array<array<string>>]

scala> df.selectExpr("sentences(str)");
val res5: org.apache.spark.sql.DataFrame = [sentences(str, , ): array<array<string>>]
```

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46880 from panbingkun/sentences_improve.

Authored-by: panbingkun <panbingkun@baidu.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
panbingkun authored and MaxGekk committed Sep 11, 2024
1 parent a9502d4 commit e037953
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1809,7 +1809,11 @@ class PlanGenerationTestSuite
fn.sentences(fn.col("g"))
}

functionTest("sentences with locale") {
functionTest("sentences with language") {
fn.sentences(fn.col("g"), lit("en"))
}

functionTest("sentences with language and country") {
fn.sentences(fn.col("g"), lit("en"), lit("US"))
}

Expand Down
22 changes: 21 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11241,13 +11241,27 @@ def sentences(
) -> Column:
"""
Splits a string into arrays of sentences, where each sentence is an array of words.
The 'language' and 'country' arguments are optional, and if omitted, the default locale is used.
The `language` and `country` arguments are optional,
When they are omitted:
1.If they are both omitted, the `Locale.ROOT - locale(language='', country='')` is used.
The `Locale.ROOT` is regarded as the base locale of all locales, and is used as the
language/country neutral locale for the locale sensitive operations.
2.If the `country` is omitted, the `locale(language, country='')` is used.
When they are null:
1.If they are both `null`, the `Locale.US - locale(language='en', country='US')` is used.
2.If the `language` is null and the `country` is not null,
the `Locale.US - locale(language='en', country='US')` is used.
3.If the `language` is not null and the `country` is null, the `locale(language)` is used.
4.If neither is `null`, the `locale(language, country)` is used.

.. versionadded:: 3.2.0

.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.0.0
Supports `sentences(string, language)`.

Parameters
----------
string : :class:`~pyspark.sql.Column` or str
Expand All @@ -11271,6 +11285,12 @@ def sentences(
+-----------------------------------+
|[[This, is, an, example, sentence]]|
+-----------------------------------+
>>> df.select(sentences(df.string, lit("en"))).show(truncate=False)
+-----------------------------------+
|sentences(string, en, ) |
+-----------------------------------+
|[[This, is, an, example, sentence]]|
+-----------------------------------+
>>> df = spark.createDataFrame([["Hello world. How are you?"]], ["s"])
>>> df.select(sentences("s")).show(truncate=False)
+---------------------------------+
Expand Down
9 changes: 9 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4349,6 +4349,15 @@ object functions {
def sentences(string: Column, language: Column, country: Column): Column =
Column.fn("sentences", string, language, country)

/**
* Splits a string into arrays of sentences, where each sentence is an array of words. The
* default `country`('') is used.
* @group string_funcs
* @since 4.0.0
*/
def sentences(string: Column, language: Column): Column =
Column.fn("sentences", string, language)

/**
* Splits a string into arrays of sentences, where each sentence is an array of words. The
* default locale is used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@

package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.SparkBuildInfo;
import org.apache.spark.sql.errors.QueryExecutionErrors;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.VersionUtils;

import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.text.BreakIterator;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import org.apache.spark.SparkBuildInfo;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
import org.apache.spark.sql.errors.QueryExecutionErrors;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.VersionUtils;

/**
* A utility class for constructing expressions.
Expand Down Expand Up @@ -272,4 +277,42 @@ private static byte[] aesInternal(
throw QueryExecutionErrors.aesCryptoError(e.getMessage());
}
}

public static ArrayData getSentences(
UTF8String str,
UTF8String language,
UTF8String country) {
if (str == null) return null;
Locale locale;
if (language != null && country != null) {
locale = new Locale(language.toString(), country.toString());
} else if (language != null) {
locale = new Locale(language.toString());
} else {
locale = Locale.US;
}
String sentences = str.toString();
BreakIterator sentenceInstance = BreakIterator.getSentenceInstance(locale);
sentenceInstance.setText(sentences);

int sentenceIndex = 0;
List<GenericArrayData> res = new ArrayList<>();
while (sentenceInstance.next() != BreakIterator.DONE) {
String sentence = sentences.substring(sentenceIndex, sentenceInstance.current());
sentenceIndex = sentenceInstance.current();
BreakIterator wordInstance = BreakIterator.getWordInstance(locale);
wordInstance.setText(sentence);
int wordIndex = 0;
List<UTF8String> words = new ArrayList<>();
while (wordInstance.next() != BreakIterator.DONE) {
String word = sentence.substring(wordIndex, wordInstance.current());
wordIndex = wordInstance.current();
if (Character.isLetterOrDigit(word.charAt(0))) {
words.add(UTF8String.fromString(word));
}
}
res.add(new GenericArrayData(words.toArray(new UTF8String[0])));
}
return new GenericArrayData(res.toArray(new GenericArrayData[0]));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.nio.{ByteBuffer, CharBuffer}
import java.nio.charset.CharacterCodingException
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.text.{DecimalFormat, DecimalFormatSymbols}
import java.util.{Base64 => JBase64, HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -3327,22 +3327,47 @@ case class FormatNumber(x: Expression, d: Expression)

/**
* Splits a string into arrays of sentences, where each sentence is an array of words.
* The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
* The `lang` and `country` arguments are optional, their default values are all '',
* - When they are omitted:
* 1. If they are both omitted, the `Locale.ROOT - locale(language='', country='')` is used.
* The `Locale.ROOT` is regarded as the base locale of all locales, and is used as the
* language/country neutral locale for the locale sensitive operations.
* 2. If the `country` is omitted, the `locale(language, country='')` is used.
* - When they are null:
* 1. If they are both `null`, the `Locale.US - locale(language='en', country='US')` is used.
* 2. If the `language` is null and the `country` is not null,
* the `Locale.US - locale(language='en', country='US')` is used.
* 3. If the `language` is not null and the `country` is null, the `locale(language)` is used.
* 4. If neither is `null`, the `locale(language, country)` is used.
*/
@ExpressionDescription(
usage = "_FUNC_(str[, lang, country]) - Splits `str` into an array of array of words.",
usage = "_FUNC_(str[, lang[, country]]) - Splits `str` into an array of array of words.",
arguments = """
Arguments:
* str - A STRING expression to be parsed.
* lang - An optional STRING expression with a language code from ISO 639 Alpha-2 (e.g. 'DE'),
Alpha-3, or a language subtag of up to 8 characters.
* country - An optional STRING expression with a country code from ISO 3166 alpha-2 country
code or a UN M.49 numeric-3 area code.
""",
examples = """
Examples:
> SELECT _FUNC_('Hi there! Good morning.');
[["Hi","there"],["Good","morning"]]
> SELECT _FUNC_('Hi there! Good morning.', 'en');
[["Hi","there"],["Good","morning"]]
> SELECT _FUNC_('Hi there! Good morning.', 'en', 'US');
[["Hi","there"],["Good","morning"]]
""",
since = "2.0.0",
group = "string_funcs")
case class Sentences(
str: Expression,
language: Expression = Literal(""),
country: Expression = Literal(""))
extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
extends TernaryExpression
with ImplicitCastInputTypes
with RuntimeReplaceable {

def this(str: Expression) = this(str, Literal(""), Literal(""))
def this(str: Expression, language: Expression) = this(str, language, Literal(""))
Expand All @@ -3356,49 +3381,18 @@ case class Sentences(
override def second: Expression = language
override def third: Expression = country

override def eval(input: InternalRow): Any = {
val string = str.eval(input)
if (string == null) {
null
} else {
val languageStr = language.eval(input).asInstanceOf[UTF8String]
val countryStr = country.eval(input).asInstanceOf[UTF8String]
val locale = if (languageStr != null && countryStr != null) {
new Locale(languageStr.toString, countryStr.toString)
} else {
Locale.US
}
getSentences(string.asInstanceOf[UTF8String].toString, locale)
}
}

private def getSentences(sentences: String, locale: Locale) = {
val bi = BreakIterator.getSentenceInstance(locale)
bi.setText(sentences)
var idx = 0
val result = new ArrayBuffer[GenericArrayData]
while (bi.next != BreakIterator.DONE) {
val sentence = sentences.substring(idx, bi.current)
idx = bi.current

val wi = BreakIterator.getWordInstance(locale)
var widx = 0
wi.setText(sentence)
val words = new ArrayBuffer[UTF8String]
while (wi.next != BreakIterator.DONE) {
val word = sentence.substring(widx, wi.current)
widx = wi.current
if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
}
result += new GenericArrayData(words)
}
new GenericArrayData(result)
}
override def replacement: Expression =
StaticInvoke(
classOf[ExpressionImplUtils],
dataType,
"getSentences",
Seq(str, language, country),
inputTypes,
propagateNull = false)

override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Sentences =
copy(str = newFirst, language = newSecond, country = newThird)

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

// Test escaping of arguments
GenerateUnsafeProjection.generate(
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")).replacement :: Nil)
}

test("SPARK-33386: elt ArrayIndexOutOfBoundsException") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [sentences(g#0, , ) AS sentences(g, , )#0]
Project [static_invoke(ExpressionImplUtils.getSentences(g#0, , )) AS sentences(g, , )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [static_invoke(ExpressionImplUtils.getSentences(g#0, en, )) AS sentences(g, en, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [static_invoke(ExpressionImplUtils.getSentences(g#0, en, US)) AS sentences(g, en, US)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "sentences",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}, {
"literal": {
"string": "en"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,34 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
df.select(sentences($"str", $"language", $"country")),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.selectExpr("sentences(str, language)"),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.select(sentences($"str", $"language")),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.selectExpr("sentences(str)"),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.select(sentences($"str")),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.selectExpr("sentences(str, null, null)"),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.selectExpr("sentences(str, '', null)"),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

checkAnswer(
df.selectExpr("sentences(str, null)"),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))

// Type coercion
checkAnswer(
df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"),
Expand Down

0 comments on commit e037953

Please sign in to comment.