Skip to content

Commit

Permalink
[SPARK-50879][ML][PYTHON][CONNECT][FOLLOW-UP] Support RobustScaler on…
Browse files Browse the repository at this point in the history
… Connect

### What changes were proposed in this pull request?
Support RobustScaler on Connect

### Why are the changes needed?
feature parity

### Does this PR introduce _any_ user-facing change?
yes, new feature

### How was this patch tested?
added test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#49597 from zhengruifeng/ml_connect_robust_scaler.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Jan 22, 2025
1 parent 8611d0f commit 454463b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ org.apache.spark.ml.fpm.FPGrowth
org.apache.spark.ml.feature.StandardScaler
org.apache.spark.ml.feature.MaxAbsScaler
org.apache.spark.ml.feature.MinMaxScaler
org.apache.spark.ml.feature.RobustScaler
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ org.apache.spark.ml.fpm.FPGrowthModel
org.apache.spark.ml.feature.StandardScalerModel
org.apache.spark.ml.feature.MaxAbsScalerModel
org.apache.spark.ml.feature.MinMaxScalerModel
org.apache.spark.ml.feature.RobustScalerModel
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class RobustScalerModel private[ml] (

import RobustScalerModel._

private[ml] def this() = this(Identifiable.randomUID("robustScal"), Vectors.empty, Vectors.empty)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand Down
47 changes: 44 additions & 3 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
MaxAbsScalerModel,
MinMaxScaler,
MinMaxScalerModel,
RobustScaler,
RobustScalerModel,
StopWordsRemover,
StringIndexer,
StringIndexerModel,
Expand Down Expand Up @@ -103,7 +105,7 @@ def test_standard_scaler(self):
["index", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
.sortWithinPartitions("index")
.select("features")
)
scaler = StandardScaler(inputCol="features", outputCol="scaled")
Expand Down Expand Up @@ -141,7 +143,7 @@ def test_maxabs_scaler(self):
["index", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
.sortWithinPartitions("index")
.select("features")
)

Expand Down Expand Up @@ -179,7 +181,7 @@ def test_minmax_scaler(self):
["index", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
.sortWithinPartitions("index")
.select("features")
)

Expand Down Expand Up @@ -207,6 +209,45 @@ def test_minmax_scaler(self):
model2 = MinMaxScalerModel.load(d)
self.assertEqual(str(model), str(model2))

def test_robust_scaler(self):
df = (
self.spark.createDataFrame(
[
(1, 1.0, Vectors.dense([0.0])),
(2, 2.0, Vectors.dense([2.0])),
(3, 3.0, Vectors.sparse(1, [(0, 3.0)])),
],
["index", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("index")
.select("features")
)

scaler = RobustScaler(inputCol="features", outputCol="scaled")
self.assertEqual(scaler.getInputCol(), "features")
self.assertEqual(scaler.getOutputCol(), "scaled")

# Estimator save & load
with tempfile.TemporaryDirectory(prefix="robust_scaler") as d:
scaler.write().overwrite().save(d)
scaler2 = RobustScaler.load(d)
self.assertEqual(str(scaler), str(scaler2))

model = scaler.fit(df)
self.assertTrue(np.allclose(model.range.toArray(), [3.0], atol=1e-4))
self.assertTrue(np.allclose(model.median.toArray(), [2.0], atol=1e-4))

output = model.transform(df)
self.assertEqual(output.columns, ["features", "scaled"])
self.assertEqual(output.count(), 3)

# Model save & load
with tempfile.TemporaryDirectory(prefix="robust_scaler_model") as d:
model.write().overwrite().save(d)
model2 = RobustScalerModel.load(d)
self.assertEqual(str(model), str(model2))

def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ private[ml] object MLUtils {
"maxAbs", // MaxAbsScalerModel
"originalMax", // MinMaxScalerModel
"originalMin", // MinMaxScalerModel
"range", // RobustScalerModel
"median", // RobustScalerModel
"toString",
"toDebugString",
"numFeatures",
Expand Down

0 comments on commit 454463b

Please sign in to comment.