Skip to content

Commit

Permalink
[SPARK-50924][SPARK-50926][ML][PYTHON][CONNECT] Support AFTSurvivalRe…
Browse files Browse the repository at this point in the history
…gression and IsotonicRegression on Connect

### What changes were proposed in this pull request?
 Support AFTSurvivalRegression and IsotonicRegression on Connect

### Why are the changes needed?
feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#49687 from zhengruifeng/ml_connect_aft.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Jan 27, 2025
1 parent 6baddd0 commit 3ba76bf
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ org.apache.spark.ml.classification.RandomForestClassifier
org.apache.spark.ml.classification.GBTClassifier

# regression
org.apache.spark.ml.regression.AFTSurvivalRegression
org.apache.spark.ml.regression.IsotonicRegression
org.apache.spark.ml.regression.LinearRegression
org.apache.spark.ml.regression.GeneralizedLinearRegression
org.apache.spark.ml.regression.DecisionTreeRegressor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ org.apache.spark.ml.classification.RandomForestClassificationModel
org.apache.spark.ml.classification.GBTClassificationModel

# regression
org.apache.spark.ml.regression.AFTSurvivalRegressionModel
org.apache.spark.ml.regression.IsotonicRegressionModel
org.apache.spark.ml.regression.LinearRegressionModel
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
org.apache.spark.ml.regression.DecisionTreeRegressionModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ class AFTSurvivalRegressionModel private[ml] (
extends RegressionModel[Vector, AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("aftSurvReg"),
Vectors.empty, Double.NaN, Double.NaN)

@Since("3.0.0")
override def numFeatures: Int = coefficients.size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class IsotonicRegressionModel private[ml] (
private val oldModel: MLlibIsotonicRegressionModel)
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("isoReg"), null)

/** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down
102 changes: 102 additions & 0 deletions python/pyspark/ml/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession
from pyspark.ml.regression import (
AFTSurvivalRegression,
AFTSurvivalRegressionModel,
IsotonicRegression,
IsotonicRegressionModel,
LinearRegression,
LinearRegressionModel,
GeneralizedLinearRegression,
Expand Down Expand Up @@ -57,6 +61,104 @@ def df(self):
.sortWithinPartitions("weight")
)

def test_aft_survival(self):
spark = self.spark
df = spark.createDataFrame(
[(1.0, Vectors.dense(1.0), 1.0), (1e-40, Vectors.sparse(1, [], []), 0.0)],
["label", "features", "censor"],
)

aft = AFTSurvivalRegression()
aft.setMaxIter(1)
self.assertEqual(aft.getMaxIter(), 1)

model = aft.fit(df)
self.assertEqual(aft.uid, model.uid)
self.assertEqual(model.numFeatures, 1)
self.assertTrue(np.allclose(model.intercept, 0.0, atol=1e-4), model.intercept)
self.assertTrue(
np.allclose(model.coefficients.toArray(), [0.0], atol=1e-4), model.coefficients
)
self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4), model.scale)

vec = Vectors.dense(6.3)
pred = model.predict(vec)
self.assertEqual(pred, 1.0)
pred = model.predictQuantiles(vec)
self.assertTrue(
np.allclose(
pred,
[
0.010050335853501444,
0.051293294387550536,
0.1053605156578263,
0.2876820724517809,
0.6931471805599453,
1.3862943611198906,
2.302585092994046,
2.9957322735539895,
4.60517018598809,
],
atol=1e-4,
),
pred,
)

output = model.transform(df)
expected_cols = ["label", "features", "censor", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 2)

# Model save & load
with tempfile.TemporaryDirectory(prefix="aft_survival") as d:
aft.write().overwrite().save(d)
aft2 = AFTSurvivalRegression.load(d)
self.assertEqual(str(aft), str(aft2))

model.write().overwrite().save(d)
model2 = AFTSurvivalRegressionModel.load(d)
self.assertEqual(str(model), str(model2))

def test_isotonic_regression(self):
spark = self.spark
df = spark.createDataFrame(
[(1.0, Vectors.dense(1.0)), (0.0, Vectors.sparse(1, [], []))], ["label", "features"]
)

ir = IsotonicRegression(
isotonic=True,
featureIndex=0,
)
self.assertTrue(ir.getIsotonic())
self.assertEqual(ir.getFeatureIndex(), 0)

model = ir.fit(df)
self.assertEqual(model.numFeatures, 1)
self.assertTrue(
np.allclose(model.boundaries.toArray(), [0.0, 1.0], atol=1e-4), model.boundaries
)
self.assertTrue(
np.allclose(model.predictions.toArray(), [0.0, 1.0], atol=1e-4), model.predictions
)

pred = model.predict(1.0)
self.assertTrue(np.allclose(pred, 1.0, atol=1e-4), pred)

output = model.transform(df)
expected_cols = ["label", "features", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 2)

# Model save & load
with tempfile.TemporaryDirectory(prefix="isotonic_regression") as d:
ir.write().overwrite().save(d)
ir2 = IsotonicRegression.load(d)
self.assertEqual(str(ir), str(ir2))

model.write().overwrite().save(d)
model2 = IsotonicRegressionModel.load(d)
self.assertEqual(str(model), str(model2))

def test_linear_regression(self):
df = self.df
lr = LinearRegression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,12 @@ private[ml] object MLUtils {
(classOf[MultilayerPerceptronClassificationModel], Set("weights", "evaluate")),

// Regression Models
(
classOf[AFTSurvivalRegressionModel],
Set("intercept", "coefficients", "scale", "predictQuantiles")),
(
classOf[IsotonicRegressionModel],
Set("boundaries", "predictions", "numFeatures", "predict")),
(
classOf[GeneralizedLinearRegressionModel],
Set("intercept", "coefficients", "numFeatures", "evaluate")),
Expand Down

0 comments on commit 3ba76bf

Please sign in to comment.