Skip to content

Commit

Permalink
[SPARK-50963][ML][PYTHON][CONNECT] Support Tokenizers, SQLTransform a…
Browse files Browse the repository at this point in the history
…nd StopWordsRemover on Connect

### What changes were proposed in this pull request?

Support a group of text processing algorithms:

- Tokenizer
- RegexTokenizer
- SQLTransform
- StopWordsRemover

### Why are the changes needed?
for feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49624 from zhengruifeng/ml_connect_tokenizer.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
(cherry picked from commit 42b15c9)
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent b080fd7 commit 5044b6d
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
# So register the supported transformer here if you're trying to add a new one.
########### Transformers
org.apache.spark.ml.feature.VectorAssembler
org.apache.spark.ml.feature.Tokenizer
org.apache.spark.ml.feature.RegexTokenizer
org.apache.spark.ml.feature.SQLTransformer
org.apache.spark.ml.feature.StopWordsRemover

########### Model for loading
# classification
Expand Down
21 changes: 15 additions & 6 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -4970,7 +4970,8 @@ class StopWordsRemover(
Notes
-----
null values from input array are preserved unless adding null to stopWords explicitly.
- null values from input array are preserved unless adding null to stopWords explicitly.
- In Spark Connect Mode, the default value of parameter `locale` is not set.
Examples
--------
Expand Down Expand Up @@ -5069,11 +5070,19 @@ def __init__(
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.StopWordsRemover", self.uid
)
self._setDefault(
stopWords=StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive=False,
locale=self._java_obj.getLocale(),
)
if isinstance(self._java_obj, str):
# Skip setting the default value of 'locale', which needs to invoke a JVM method.
# So if users don't explicitly set 'locale', then getLocale fails.
self._setDefault(
stopWords=StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive=False,
)
else:
self._setDefault(
stopWords=StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive=False,
locale=self._java_obj.getLocale(),
)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand Down
87 changes: 87 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
IDF,
NGram,
RFormula,
Tokenizer,
SQLTransformer,
RegexTokenizer,
StandardScaler,
StandardScalerModel,
MaxAbsScaler,
Expand Down Expand Up @@ -401,6 +404,90 @@ def test_word2vec(self):
model2 = Word2VecModel.load(d)
self.assertEqual(str(model), str(model2))

def test_tokenizer(self):
df = self.spark.createDataFrame([("a b c",)], ["text"])

tokenizer = Tokenizer(outputCol="words")
tokenizer.setInputCol("text")
self.assertEqual(tokenizer.getInputCol(), "text")
self.assertEqual(tokenizer.getOutputCol(), "words")

output = tokenizer.transform(df)
self.assertEqual(output.columns, ["text", "words"])
self.assertEqual(output.count(), 1)
self.assertEqual(output.head().words, ["a", "b", "c"])

# save & load
with tempfile.TemporaryDirectory(prefix="tokenizer") as d:
tokenizer.write().overwrite().save(d)
tokenizer2 = Tokenizer.load(d)
self.assertEqual(str(tokenizer), str(tokenizer2))

def test_regex_tokenizer(self):
df = self.spark.createDataFrame([("A B c",)], ["text"])

tokenizer = RegexTokenizer(outputCol="words")
tokenizer.setInputCol("text")
self.assertEqual(tokenizer.getInputCol(), "text")
self.assertEqual(tokenizer.getOutputCol(), "words")

output = tokenizer.transform(df)
self.assertEqual(output.columns, ["text", "words"])
self.assertEqual(output.count(), 1)
self.assertEqual(output.head().words, ["a", "b", "c"])

# save & load
with tempfile.TemporaryDirectory(prefix="regex_tokenizer") as d:
tokenizer.write().overwrite().save(d)
tokenizer2 = RegexTokenizer.load(d)
self.assertEqual(str(tokenizer), str(tokenizer2))

def test_sql_transformer(self):
df = self.spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"])

statement = "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"
sql = SQLTransformer(statement=statement)
self.assertEqual(sql.getStatement(), statement)

output = sql.transform(df)
self.assertEqual(output.columns, ["id", "v1", "v2", "v3", "v4"])
self.assertEqual(output.count(), 2)
self.assertEqual(
output.collect(),
[
Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0),
Row(id=2, v1=2.0, v2=5.0, v3=7.0, v4=10.0),
],
)

# save & load
with tempfile.TemporaryDirectory(prefix="sql_transformer") as d:
sql.write().overwrite().save(d)
sql2 = SQLTransformer.load(d)
self.assertEqual(str(sql), str(sql2))

def test_stop_words_remover(self):
df = self.spark.createDataFrame([(["a", "b", "c"],)], ["text"])

remover = StopWordsRemover(stopWords=["b"])
remover.setInputCol("text")
remover.setOutputCol("words")

self.assertEqual(remover.getStopWords(), ["b"])
self.assertEqual(remover.getInputCol(), "text")
self.assertEqual(remover.getOutputCol(), "words")

output = remover.transform(df)
self.assertEqual(output.columns, ["text", "words"])
self.assertEqual(output.count(), 1)
self.assertEqual(output.head().words, ["a", "c"])

# save & load
with tempfile.TemporaryDirectory(prefix="stop_words_remover") as d:
remover.write().overwrite().save(d)
remover2 = StopWordsRemover.load(d)
self.assertEqual(str(remover), str(remover2))

def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
Expand Down

0 comments on commit 5044b6d

Please sign in to comment.