diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index 0448117468198..03a05fdd56a2a 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -19,6 +19,10 @@ # So register the supported transformer here if you're trying to add a new one. ########### Transformers org.apache.spark.ml.feature.VectorAssembler +org.apache.spark.ml.feature.Tokenizer +org.apache.spark.ml.feature.RegexTokenizer +org.apache.spark.ml.feature.SQLTransformer +org.apache.spark.ml.feature.StopWordsRemover ########### Model for loading # classification diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ff8555fadbd12..04e989481e7df 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -4970,7 +4970,8 @@ class StopWordsRemover( Notes ----- - null values from input array are preserved unless adding null to stopWords explicitly. + - null values from input array are preserved unless adding null to stopWords explicitly. + - In Spark Connect Mode, the default value of parameter `locale` is not set. Examples -------- @@ -5069,11 +5070,19 @@ def __init__( self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.StopWordsRemover", self.uid ) - self._setDefault( - stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False, - locale=self._java_obj.getLocale(), - ) + if isinstance(self._java_obj, str): + # Skip setting the default value of 'locale', which needs to invoke a JVM method. + # So if users don't explicitly set 'locale', then getLocale fails. + self._setDefault( + stopWords=StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive=False, + ) + else: + self._setDefault( + stopWords=StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive=False, + locale=self._java_obj.getLocale(), + ) kwargs = self._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 9766ab1b02438..126e8245ae780 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -30,6 +30,9 @@ IDF, NGram, RFormula, + Tokenizer, + SQLTransformer, + RegexTokenizer, StandardScaler, StandardScalerModel, MaxAbsScaler, @@ -401,6 +404,90 @@ def test_word2vec(self): model2 = Word2VecModel.load(d) self.assertEqual(str(model), str(model2)) + def test_tokenizer(self): + df = self.spark.createDataFrame([("a b c",)], ["text"]) + + tokenizer = Tokenizer(outputCol="words") + tokenizer.setInputCol("text") + self.assertEqual(tokenizer.getInputCol(), "text") + self.assertEqual(tokenizer.getOutputCol(), "words") + + output = tokenizer.transform(df) + self.assertEqual(output.columns, ["text", "words"]) + self.assertEqual(output.count(), 1) + self.assertEqual(output.head().words, ["a", "b", "c"]) + + # save & load + with tempfile.TemporaryDirectory(prefix="tokenizer") as d: + tokenizer.write().overwrite().save(d) + tokenizer2 = Tokenizer.load(d) + self.assertEqual(str(tokenizer), str(tokenizer2)) + + def test_regex_tokenizer(self): + df = self.spark.createDataFrame([("A B c",)], ["text"]) + + tokenizer = RegexTokenizer(outputCol="words") + tokenizer.setInputCol("text") + self.assertEqual(tokenizer.getInputCol(), "text") + self.assertEqual(tokenizer.getOutputCol(), "words") + + output = tokenizer.transform(df) + self.assertEqual(output.columns, ["text", "words"]) + self.assertEqual(output.count(), 1) + self.assertEqual(output.head().words, ["a", "b", "c"]) + + # save & load + with tempfile.TemporaryDirectory(prefix="regex_tokenizer") as d: + tokenizer.write().overwrite().save(d) + tokenizer2 = RegexTokenizer.load(d) + self.assertEqual(str(tokenizer), str(tokenizer2)) + + def test_sql_transformer(self): + df = self.spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) + + statement = "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__" + sql = SQLTransformer(statement=statement) + self.assertEqual(sql.getStatement(), statement) + + output = sql.transform(df) + self.assertEqual(output.columns, ["id", "v1", "v2", "v3", "v4"]) + self.assertEqual(output.count(), 2) + self.assertEqual( + output.collect(), + [ + Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0), + Row(id=2, v1=2.0, v2=5.0, v3=7.0, v4=10.0), + ], + ) + + # save & load + with tempfile.TemporaryDirectory(prefix="sql_transformer") as d: + sql.write().overwrite().save(d) + sql2 = SQLTransformer.load(d) + self.assertEqual(str(sql), str(sql2)) + + def test_stop_words_remover(self): + df = self.spark.createDataFrame([(["a", "b", "c"],)], ["text"]) + + remover = StopWordsRemover(stopWords=["b"]) + remover.setInputCol("text") + remover.setOutputCol("words") + + self.assertEqual(remover.getStopWords(), ["b"]) + self.assertEqual(remover.getInputCol(), "text") + self.assertEqual(remover.getOutputCol(), "words") + + output = remover.transform(df) + self.assertEqual(output.columns, ["text", "words"]) + self.assertEqual(output.count(), 1) + self.assertEqual(output.head().words, ["a", "c"]) + + # save & load + with tempfile.TemporaryDirectory(prefix="stop_words_remover") as d: + remover.write().overwrite().save(d) + remover2 = StopWordsRemover.load(d) + self.assertEqual(str(remover), str(remover2)) + def test_binarizer(self): b0 = Binarizer() self.assertListEqual(