From 3ba74bf9b509e1cddbda6bb4849782e26fa840ed Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 21 Jan 2025 16:03:13 +0800 Subject: [PATCH] [SPARK-50898][ML][PYTHON][CONNECT] Support `FPGrowth` on connect ### What changes were proposed in this pull request? Support `FPGrowth` on connect ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? Yes, new algorithms supported on connect ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #49579 from zhengruifeng/ml_connect_fpm. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- dev/sparktestsupport/modules.py | 2 + .../services/org.apache.spark.ml.Estimator | 4 + .../services/org.apache.spark.ml.Transformer | 3 + .../org/apache/spark/ml/fpm/FPGrowth.scala | 2 + python/pyspark/ml/fpm.py | 4 +- .../ml/tests/connect/test_parity_fpm.py | 37 ++++++++ python/pyspark/ml/tests/test_fpm.py | 94 +++++++++++++++++++ .../apache/spark/sql/connect/ml/MLUtils.scala | 4 +- 8 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 python/pyspark/ml/tests/connect/test_parity_fpm.py create mode 100644 python/pyspark/ml/tests/test_fpm.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index cacd4a83bbe4f..5fd3f73772767 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -664,6 +664,7 @@ def __hash__(self): # unittests "pyspark.ml.tests.test_algorithms", "pyspark.ml.tests.test_als", + "pyspark.ml.tests.test_fpm", "pyspark.ml.tests.test_base", "pyspark.ml.tests.test_evaluation", "pyspark.ml.tests.test_feature", @@ -1119,6 +1120,7 @@ def __hash__(self): "pyspark.ml.tests.connect.test_connect_pipeline", "pyspark.ml.tests.connect.test_connect_tuning", "pyspark.ml.tests.connect.test_parity_als", + "pyspark.ml.tests.connect.test_parity_fpm", "pyspark.ml.tests.connect.test_parity_classification", "pyspark.ml.tests.connect.test_parity_regression", "pyspark.ml.tests.connect.test_parity_clustering", diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator index a7d7d3da9df3b..4046cca07dc0f 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator @@ -39,3 +39,7 @@ org.apache.spark.ml.clustering.BisectingKMeans # recommendation org.apache.spark.ml.recommendation.ALS + + +# fpm +org.apache.spark.ml.fpm.FPGrowth diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index 392115be98ba5..7c10796f9a877 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -38,3 +38,6 @@ org.apache.spark.ml.clustering.BisectingKMeansModel # recommendation org.apache.spark.ml.recommendation.ALSModel + +# fpm +org.apache.spark.ml.fpm.FPGrowthModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d054ea8ebdb47..d90124c62d54e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -223,6 +223,8 @@ class FPGrowthModel private[ml] ( private val numTrainingRecords: Long) extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { + private[ml] def this() = this(Identifiable.randomUID("fpgrowth"), null, Map.empty, 0L) + /** @group setParam */ @Since("2.2.0") def setMinConfidence(value: Double): this.type = set(minConfidence, value) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 72fcfccf19e4c..c068b5f26ba84 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -20,7 +20,7 @@ from pyspark import keyword_only, since from pyspark.sql import DataFrame -from pyspark.ml.util import JavaMLWritable, JavaMLReadable +from pyspark.ml.util import JavaMLWritable, JavaMLReadable, try_remote_attribute_relation from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, Params @@ -126,6 +126,7 @@ def setPredictionCol(self, value: str) -> "FPGrowthModel": @property @since("2.2.0") + @try_remote_attribute_relation def freqItemsets(self) -> DataFrame: """ DataFrame with two columns: @@ -136,6 +137,7 @@ def freqItemsets(self) -> DataFrame: @property @since("2.2.0") + @try_remote_attribute_relation def associationRules(self) -> DataFrame: """ DataFrame with four columns: diff --git a/python/pyspark/ml/tests/connect/test_parity_fpm.py b/python/pyspark/ml/tests/connect/test_parity_fpm.py new file mode 100644 index 0000000000000..85ceba87a2f57 --- /dev/null +++ b/python/pyspark/ml/tests/connect/test_parity_fpm.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.ml.tests.test_fpm import FPMTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class FPMParityTests(FPMTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + from pyspark.ml.tests.connect.test_parity_fpm import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_fpm.py b/python/pyspark/ml/tests/test_fpm.py new file mode 100644 index 0000000000000..8db35158978df --- /dev/null +++ b/python/pyspark/ml/tests/test_fpm.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest + +from pyspark.sql import SparkSession +import pyspark.sql.functions as sf +from pyspark.ml.fpm import ( + FPGrowth, + FPGrowthModel, +) + + +class FPMTestsMixin: + def test_fp_growth(self): + df = self.spark.createDataFrame( + [ + ["r z h k p"], + ["z y x w v u t s"], + ["s x o n r"], + ["x z y m t s q e"], + ["z"], + ["x z y r q t p"], + ], + ["items"], + ).select(sf.split("items", " ").alias("items")) + + fp = FPGrowth(minSupport=0.2, minConfidence=0.7) + fp.setNumPartitions(1) + self.assertEqual(fp.getMinSupport(), 0.2) + self.assertEqual(fp.getMinConfidence(), 0.7) + self.assertEqual(fp.getNumPartitions(), 1) + + # Estimator save & load + with tempfile.TemporaryDirectory(prefix="fp_growth") as d: + fp.write().overwrite().save(d) + fp2 = FPGrowth.load(d) + self.assertEqual(str(fp), str(fp2)) + + model = fp.fit(df) + + self.assertEqual(model.freqItemsets.columns, ["items", "freq"]) + self.assertEqual(model.freqItemsets.count(), 54) + + self.assertEqual( + model.associationRules.columns, + ["antecedent", "consequent", "confidence", "lift", "support"], + ) + self.assertEqual(model.associationRules.count(), 89) + + output = model.transform(df) + self.assertEqual(output.columns, ["items", "prediction"]) + self.assertEqual(output.count(), 6) + + # Model save & load + with tempfile.TemporaryDirectory(prefix="fp_growth_model") as d: + model.write().overwrite().save(d) + model2 = FPGrowthModel.load(d) + self.assertEqual(str(model), str(model2)) + + +class FPMTests(FPMTestsMixin, unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession.builder.master("local[4]").getOrCreate() + + def tearDown(self) -> None: + self.spark.stop() + + +if __name__ == "__main__": + from pyspark.ml.tests.test_fpm import * # noqa: F401,F403 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index 4e93aec47ef03..b85bc6771f8ec 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -500,7 +500,9 @@ private[ml] object MLUtils { "recommendForAllUsers", // ALSModel "recommendForAllItems", // ALSModel "recommendForUserSubset", // ALSModel - "recommendForItemSubset" // ALSModel + "recommendForItemSubset", // ALSModel + "associationRules", // FPGrowthModel + "freqItemsets" // FPGrowthModel ) def invokeMethodAllowed(obj: Object, methodName: String): Object = {