Skip to content

Commit

Permalink
[SPARK-50941][ML][PYTHON][CONNECT] add supports for TrainValidationSplit
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds support for TrainValidationSplit and TrainValidationSplitModel on Connect

### Why are the changes needed?
new feature parity

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
The CI passes

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#49688 from wbo4958/train_validation_split.

Lead-authored-by: Bobby Wang <wbo4958@gmail.com>
Co-authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
wbo4958 and zhengruifeng committed Jan 27, 2025
1 parent 3ba76bf commit 9d0e888
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 13 deletions.
52 changes: 49 additions & 3 deletions python/pyspark/ml/connect/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param
from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
from pyspark.ml.tuning import (
CrossValidatorModelWriter,
CrossValidatorModel,
TrainValidationSplitModel,
TrainValidationSplitModelWriter,
)
from pyspark.ml.util import MLWriter, MLReader, RL
from pyspark.ml.wrapper import JavaWrapper

Expand All @@ -42,6 +47,19 @@ def __init__(
self.session(session) # type: ignore[arg-type]


class RemoteTrainValidationSplitModelWriter(TrainValidationSplitModelWriter):
def __init__(
self,
instance: "TrainValidationSplitModel",
optionMap: Dict[str, Any] = {},
session: Optional["SparkSession"] = None,
):
super(RemoteTrainValidationSplitModelWriter, self).__init__(instance)
self.instance = instance
self.optionMap = optionMap
self.session(session) # type: ignore[arg-type]


class RemoteMLWriter(MLWriter):
def __init__(self, instance: "JavaMLWritable") -> None:
super().__init__()
Expand Down Expand Up @@ -76,7 +94,7 @@ def saveInstance(
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit

# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
Expand Down Expand Up @@ -155,6 +173,20 @@ def saveInstance(
warnings.warn("Overwrite doesn't take effect for CrossValidatorModel")
cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, session)
cvm_writer.save(path)
elif isinstance(instance, TrainValidationSplit):
from pyspark.ml.tuning import TrainValidationSplitWriter

if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
warnings.warn("Overwrite doesn't take effect for TrainValidationSplit")
tvs_writer = TrainValidationSplitWriter(instance)
tvs_writer.save(path)
elif isinstance(instance, TrainValidationSplitModel):
if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
warnings.warn("Overwrite doesn't take effect for TrainValidationSplitModel")
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, optionMap, session)
tvsm_writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for {instance.__class__}")

Expand Down Expand Up @@ -182,7 +214,7 @@ def loadInstance(
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit

if (
issubclass(clazz, JavaModel)
Expand Down Expand Up @@ -261,5 +293,19 @@ def _get_class() -> Type[RL]:
cvm_reader.session(session)
return cvm_reader.load(path)

elif issubclass(clazz, TrainValidationSplit):
from pyspark.ml.tuning import TrainValidationSplitReader

tvs_reader = TrainValidationSplitReader(TrainValidationSplit)
tvs_reader.session(session)
return tvs_reader.load(path)

elif issubclass(clazz, TrainValidationSplitModel):
from pyspark.ml.tuning import TrainValidationSplitModelReader

tvs_reader = TrainValidationSplitModelReader(TrainValidationSplitModel)
tvs_reader.session(session)
return tvs_reader.load(path)

else:
raise RuntimeError(f"Unsupported read for {clazz}")
60 changes: 59 additions & 1 deletion python/pyspark/ml/tests/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,69 @@
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
from pyspark.ml.tuning import (
ParamGridBuilder,
CrossValidator,
CrossValidatorModel,
TrainValidationSplit,
TrainValidationSplitModel,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase


class TuningTestsMixin:
def test_train_validation_split(self):
dataset = self.spark.createDataFrame(
[
(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0),
]
* 10, # Repeat the data 10 times
["features", "label"],
).repartition(1)

lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()

tvs = TrainValidationSplit(
estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, parallelism=1, seed=42
)
self.assertEqual(tvs.getEstimator(), lr)
self.assertEqual(tvs.getEvaluator(), evaluator)
self.assertEqual(tvs.getParallelism(), 1)
self.assertEqual(tvs.getEstimatorParamMaps(), grid)

tvs_model = tvs.fit(dataset)

# Access the train ratio
self.assertEqual(tvs_model.getTrainRatio(), 0.75)
print("----------- ", tvs_model.validationMetrics)
self.assertTrue(np.isclose(tvs_model.validationMetrics[0], 0.5, atol=1e-4))
self.assertTrue(np.isclose(tvs_model.validationMetrics[1], 0.8857142857142857, atol=1e-4))

evaluation_score = evaluator.evaluate(tvs_model.transform(dataset))
self.assertTrue(np.isclose(evaluation_score, 0.8333333333333333, atol=1e-4))

# save & load
with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
path1 = os.path.join(d, "cv")
tvs.write().save(path1)
tvs2 = TrainValidationSplit.load(path1)
self.assertEqual(str(tvs), str(tvs2))
self.assertEqual(str(tvs.getEstimator()), str(tvs2.getEstimator()))
self.assertEqual(str(tvs.getEvaluator()), str(tvs2.getEvaluator()))

path2 = os.path.join(d, "cv_model")
tvs_model.write().save(path2)
model2 = TrainValidationSplitModel.load(path2)
self.assertEqual(str(tvs_model), str(model2))
self.assertEqual(str(tvs_model.getEstimator()), str(model2.getEstimator()))
self.assertEqual(str(tvs_model.getEvaluator()), str(model2.getEvaluator()))

def test_cross_validator(self):
dataset = self.spark.createDataFrame(
[
Expand Down
26 changes: 17 additions & 9 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,12 +1186,12 @@ def __init__(self, cls: Type["TrainValidationSplit"]):
self.cls = cls

def load(self, path: str) -> "TrainValidationSplit":
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
path, self.sc, metadata
path, self.sparkSession, metadata
)
tvs = TrainValidationSplit(
estimator=estimator, estimatorParamMaps=estimatorParamMaps, evaluator=evaluator
Expand All @@ -1209,7 +1209,7 @@ def __init__(self, instance: "TrainValidationSplit"):

def saveImpl(self, path: str) -> None:
_ValidatorSharedReadWrite.validateParams(self.instance)
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sparkSession)


@inherit_doc
Expand All @@ -1219,15 +1219,17 @@ def __init__(self, cls: Type["TrainValidationSplitModel"]):
self.cls = cls

def load(self, path: str) -> "TrainValidationSplitModel":
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
path, self.sc, metadata
path, self.sparkSession, metadata
)
bestModelPath = os.path.join(path, "bestModel")
bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
bestModel: Model = DefaultParamsReader.loadParamsInstance(
bestModelPath, self.sparkSession
)
validationMetrics = metadata["validationMetrics"]
persistSubModels = ("persistSubModels" in metadata) and metadata["persistSubModels"]

Expand All @@ -1236,7 +1238,7 @@ def load(self, path: str) -> "TrainValidationSplitModel":
for paramIndex in range(len(estimatorParamMaps)):
modelPath = os.path.join(path, "subModels", f"{paramIndex}")
subModels[paramIndex] = DefaultParamsReader.loadParamsInstance(
modelPath, self.sc
modelPath, self.sparkSession
)
else:
subModels = None
Expand Down Expand Up @@ -1273,7 +1275,9 @@ def saveImpl(self, path: str) -> None:
"validationMetrics": instance.validationMetrics,
"persistSubModels": persistSubModels,
}
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
_ValidatorSharedReadWrite.saveImpl(
path, instance, self.sparkSession, extraMetadata=extraMetadata
)
bestModelPath = os.path.join(path, "bestModel")
cast(MLWritable, instance.bestModel).save(bestModelPath)
if persistSubModels:
Expand Down Expand Up @@ -1473,7 +1477,7 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
subModels = [None for i in range(numModels)]

tasks = map(
inheritable_thread_target,
inheritable_thread_target(dataset.sparkSession),
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
Expand Down Expand Up @@ -1529,6 +1533,7 @@ def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplit":
return newTVS

@since("2.3.0")
@try_remote_write
def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
Expand All @@ -1537,6 +1542,7 @@ def write(self) -> MLWriter:

@classmethod
@since("2.3.0")
@try_remote_read
def read(cls) -> TrainValidationSplitReader:
"""Returns an MLReader instance for this class."""
return TrainValidationSplitReader(cls)
Expand Down Expand Up @@ -1649,6 +1655,7 @@ def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplitModel
)

@since("2.3.0")
@try_remote_write
def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
Expand All @@ -1657,6 +1664,7 @@ def write(self) -> MLWriter:

@classmethod
@since("2.3.0")
@try_remote_read
def read(cls) -> TrainValidationSplitModelReader:
"""Returns an MLReader instance for this class."""
return TrainValidationSplitModelReader(cls)
Expand Down

0 comments on commit 9d0e888

Please sign in to comment.