Skip to content

Commit

Permalink
[SPARK-50898][ML][PYTHON][CONNECT] Support FPGrowth on connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support `FPGrowth` on connect

### Why are the changes needed?
for feature parity

### Does this PR introduce _any_ user-facing change?
Yes, new algorithms supported on connect

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49579 from zhengruifeng/ml_connect_fpm.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Jan 21, 2025
1 parent ce07396 commit 3ba74bf
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 2 deletions.
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def __hash__(self):
# unittests
"pyspark.ml.tests.test_algorithms",
"pyspark.ml.tests.test_als",
"pyspark.ml.tests.test_fpm",
"pyspark.ml.tests.test_base",
"pyspark.ml.tests.test_evaluation",
"pyspark.ml.tests.test_feature",
Expand Down Expand Up @@ -1119,6 +1120,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_parity_als",
"pyspark.ml.tests.connect.test_parity_fpm",
"pyspark.ml.tests.connect.test_parity_classification",
"pyspark.ml.tests.connect.test_parity_regression",
"pyspark.ml.tests.connect.test_parity_clustering",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ org.apache.spark.ml.clustering.BisectingKMeans

# recommendation
org.apache.spark.ml.recommendation.ALS


# fpm
org.apache.spark.ml.fpm.FPGrowth
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ org.apache.spark.ml.clustering.BisectingKMeansModel

# recommendation
org.apache.spark.ml.recommendation.ALSModel

# fpm
org.apache.spark.ml.fpm.FPGrowthModel
2 changes: 2 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ class FPGrowthModel private[ml] (
private val numTrainingRecords: Long)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("fpgrowth"), null, Map.empty, 0L)

/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pyspark import keyword_only, since
from pyspark.sql import DataFrame
from pyspark.ml.util import JavaMLWritable, JavaMLReadable
from pyspark.ml.util import JavaMLWritable, JavaMLReadable, try_remote_attribute_relation
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, Params

Expand Down Expand Up @@ -126,6 +126,7 @@ def setPredictionCol(self, value: str) -> "FPGrowthModel":

@property
@since("2.2.0")
@try_remote_attribute_relation
def freqItemsets(self) -> DataFrame:
"""
DataFrame with two columns:
Expand All @@ -136,6 +137,7 @@ def freqItemsets(self) -> DataFrame:

@property
@since("2.2.0")
@try_remote_attribute_relation
def associationRules(self) -> DataFrame:
"""
DataFrame with four columns:
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_fpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.ml.tests.test_fpm import FPMTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class FPMParityTests(FPMTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_fpm import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
94 changes: 94 additions & 0 deletions python/pyspark/ml/tests/test_fpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tempfile
import unittest

from pyspark.sql import SparkSession
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
FPGrowthModel,
)


class FPMTestsMixin:
def test_fp_growth(self):
df = self.spark.createDataFrame(
[
["r z h k p"],
["z y x w v u t s"],
["s x o n r"],
["x z y m t s q e"],
["z"],
["x z y r q t p"],
],
["items"],
).select(sf.split("items", " ").alias("items"))

fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
fp.setNumPartitions(1)
self.assertEqual(fp.getMinSupport(), 0.2)
self.assertEqual(fp.getMinConfidence(), 0.7)
self.assertEqual(fp.getNumPartitions(), 1)

# Estimator save & load
with tempfile.TemporaryDirectory(prefix="fp_growth") as d:
fp.write().overwrite().save(d)
fp2 = FPGrowth.load(d)
self.assertEqual(str(fp), str(fp2))

model = fp.fit(df)

self.assertEqual(model.freqItemsets.columns, ["items", "freq"])
self.assertEqual(model.freqItemsets.count(), 54)

self.assertEqual(
model.associationRules.columns,
["antecedent", "consequent", "confidence", "lift", "support"],
)
self.assertEqual(model.associationRules.count(), 89)

output = model.transform(df)
self.assertEqual(output.columns, ["items", "prediction"])
self.assertEqual(output.count(), 6)

# Model save & load
with tempfile.TemporaryDirectory(prefix="fp_growth_model") as d:
model.write().overwrite().save(d)
model2 = FPGrowthModel.load(d)
self.assertEqual(str(model), str(model2))


class FPMTests(FPMTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[4]").getOrCreate()

def tearDown(self) -> None:
self.spark.stop()


if __name__ == "__main__":
from pyspark.ml.tests.test_fpm import * # noqa: F401,F403

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ private[ml] object MLUtils {
"recommendForAllUsers", // ALSModel
"recommendForAllItems", // ALSModel
"recommendForUserSubset", // ALSModel
"recommendForItemSubset" // ALSModel
"recommendForItemSubset", // ALSModel
"associationRules", // FPGrowthModel
"freqItemsets" // FPGrowthModel
)

def invokeMethodAllowed(obj: Object, methodName: String): Object = {
Expand Down

0 comments on commit 3ba74bf

Please sign in to comment.