From cd701aeed8cfa3f4e1fb925af566e8199351964a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 30 Jul 2024 09:06:34 -0400 Subject: [PATCH] [SPARK-49053][PYTHON][ML] Make model save/load helper functions accept spark session ### What changes were proposed in this pull request? Make model save/load helper functions accept spark session ### Why are the changes needed? 1, avoid unnecessary spark session creations; 2, to be consistent with scala side changes: https://github.com/apache/spark/pull/47467 and https://github.com/apache/spark/pull/47477 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #47527 from zhengruifeng/py_ml_save_metadata_spark. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/ml/classification.py | 29 +++++++++++++++++++---------- python/pyspark/ml/pipeline.py | 19 +++++++++++-------- python/pyspark/ml/util.py | 23 ++++++++++++++--------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1eb42f8029b6c..937753b50bb13 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -89,7 +89,7 @@ from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper from pyspark.ml.common import inherit_doc from pyspark.ml.linalg import Matrix, Vector, Vectors, VectorUDT -from pyspark.sql import DataFrame, Row +from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import udf, when from pyspark.sql.types import ArrayType, DoubleType from pyspark.storagelevel import StorageLevel @@ -3678,7 +3678,7 @@ class _OneVsRestSharedReadWrite: @staticmethod def saveImpl( instance: Union[OneVsRest, "OneVsRestModel"], - sc: "SparkContext", + sc: Union["SparkContext", SparkSession], path: str, extraMetadata: Optional[Dict[str, Any]] = None, ) -> None: @@ -3691,7 +3691,10 @@ def saveImpl( cast(MLWritable, instance.getClassifier()).save(classifierPath) @staticmethod - def loadClassifier(path: str, sc: "SparkContext") -> Union[OneVsRest, "OneVsRestModel"]: + def loadClassifier( + path: str, + sc: Union["SparkContext", SparkSession], + ) -> Union[OneVsRest, "OneVsRestModel"]: classifierPath = os.path.join(path, "classifier") return DefaultParamsReader.loadParamsInstance(classifierPath, sc) @@ -3716,11 +3719,13 @@ def __init__(self, cls: Type[OneVsRest]) -> None: self.cls = cls def load(self, path: str) -> OneVsRest: - metadata = DefaultParamsReader.loadMetadata(path, self.sc) + metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession) if not DefaultParamsReader.isPythonParamsInstance(metadata): return JavaMLReader(self.cls).load(path) # type: ignore[arg-type] else: - classifier = cast(Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)) + classifier = cast( + Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sparkSession) + ) ova: OneVsRest = OneVsRest(classifier=classifier)._resetUid(metadata["uid"]) DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=["classifier"]) return ova @@ -3734,7 +3739,7 @@ def __init__(self, instance: OneVsRest): def saveImpl(self, path: str) -> None: _OneVsRestSharedReadWrite.validateParams(self.instance) - _OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path) + _OneVsRestSharedReadWrite.saveImpl(self.instance, self.sparkSession, path) class OneVsRestModel( @@ -3963,16 +3968,18 @@ def __init__(self, cls: Type[OneVsRestModel]): self.cls = cls def load(self, path: str) -> OneVsRestModel: - metadata = DefaultParamsReader.loadMetadata(path, self.sc) + metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession) if not DefaultParamsReader.isPythonParamsInstance(metadata): return JavaMLReader(self.cls).load(path) # type: ignore[arg-type] else: - classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc) + classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sparkSession) numClasses = metadata["numClasses"] subModels = [None] * numClasses for idx in range(numClasses): subModelPath = os.path.join(path, f"model_{idx}") - subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc) + subModels[idx] = DefaultParamsReader.loadParamsInstance( + subModelPath, self.sparkSession + ) ovaModel = OneVsRestModel(cast(List[ClassificationModel], subModels))._resetUid( metadata["uid"] ) @@ -3992,7 +3999,9 @@ def saveImpl(self, path: str) -> None: instance = self.instance numClasses = len(instance.models) extraMetadata = {"numClasses": numClasses} - _OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata) + _OneVsRestSharedReadWrite.saveImpl( + instance, self.sparkSession, path, extraMetadata=extraMetadata + ) for idx in range(numClasses): subModelPath = os.path.join(path, f"model_{idx}") cast(MLWritable, instance.models[idx]).save(subModelPath) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index c8415f89670b7..01339283839e1 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -35,6 +35,7 @@ ) from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc +from pyspark.sql import SparkSession from pyspark.sql.dataframe import DataFrame if TYPE_CHECKING: @@ -230,7 +231,7 @@ def __init__(self, instance: Pipeline): def saveImpl(self, path: str) -> None: stages = self.instance.getStages() PipelineSharedReadWrite.validateStages(stages) - PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sparkSession, path) @inherit_doc @@ -244,11 +245,11 @@ def __init__(self, cls: Type[Pipeline]): self.cls = cls def load(self, path: str) -> Pipeline: - metadata = DefaultParamsReader.loadMetadata(path, self.sc) + metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession) if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python": return JavaMLReader(cast(Type["JavaMLReadable[Pipeline]"], self.cls)).load(path) else: - uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + uid, stages = PipelineSharedReadWrite.load(metadata, self.sparkSession, path) return Pipeline(stages=stages)._resetUid(uid) @@ -266,7 +267,7 @@ def saveImpl(self, path: str) -> None: stages = self.instance.stages PipelineSharedReadWrite.validateStages(cast(List["PipelineStage"], stages)) PipelineSharedReadWrite.saveImpl( - self.instance, cast(List["PipelineStage"], stages), self.sc, path + self.instance, cast(List["PipelineStage"], stages), self.sparkSession, path ) @@ -281,11 +282,11 @@ def __init__(self, cls: Type["PipelineModel"]): self.cls = cls def load(self, path: str) -> "PipelineModel": - metadata = DefaultParamsReader.loadMetadata(path, self.sc) + metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession) if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python": return JavaMLReader(cast(Type["JavaMLReadable[PipelineModel]"], self.cls)).load(path) else: - uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + uid, stages = PipelineSharedReadWrite.load(metadata, self.sparkSession, path) return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid) @@ -403,7 +404,7 @@ def validateStages(stages: List["PipelineStage"]) -> None: def saveImpl( instance: Union[Pipeline, PipelineModel], stages: List["PipelineStage"], - sc: "SparkContext", + sc: Union["SparkContext", SparkSession], path: str, ) -> None: """ @@ -422,7 +423,9 @@ def saveImpl( @staticmethod def load( - metadata: Dict[str, Any], sc: "SparkContext", path: str + metadata: Dict[str, Any], + sc: Union["SparkContext", SparkSession], + path: str, ) -> Tuple[str, List["PipelineStage"]]: """ Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 89e2f9631564b..9bbd64d2aef5a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -32,6 +32,7 @@ TypeVar, cast, TYPE_CHECKING, + Union, ) from pyspark import since @@ -424,7 +425,7 @@ def __init__(self, instance: "Params"): self.instance = instance def saveImpl(self, path: str) -> None: - DefaultParamsWriter.saveMetadata(self.instance, path, self.sc) + DefaultParamsWriter.saveMetadata(self.instance, path, self.sparkSession) @staticmethod def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str, Any]: @@ -438,7 +439,7 @@ def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str def saveMetadata( instance: "Params", path: str, - sc: "SparkContext", + sc: Union["SparkContext", SparkSession], extraMetadata: Optional[Dict[str, Any]] = None, paramMap: Optional[Dict[str, Any]] = None, ) -> None: @@ -464,7 +465,7 @@ def saveMetadata( metadataJson = DefaultParamsWriter._get_metadata_to_save( instance, sc, extraMetadata, paramMap ) - spark = SparkSession._getActiveSessionOrCreate() + spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate() spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text( metadataPath ) @@ -472,7 +473,7 @@ def saveMetadata( @staticmethod def _get_metadata_to_save( instance: "Params", - sc: "SparkContext", + sc: Union["SparkContext", SparkSession], extraMetadata: Optional[Dict[str, Any]] = None, paramMap: Optional[Dict[str, Any]] = None, ) -> str: @@ -560,7 +561,7 @@ def __get_class(clazz: str) -> Type[RL]: return getattr(m, parts[-1]) def load(self, path: str) -> RL: - metadata = DefaultParamsReader.loadMetadata(path, self.sc) + metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession) py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"]) instance = py_type() cast("Params", instance)._resetUid(metadata["uid"]) @@ -568,19 +569,23 @@ def load(self, path: str) -> RL: return instance @staticmethod - def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> Dict[str, Any]: + def loadMetadata( + path: str, + sc: Union["SparkContext", SparkSession], + expectedClassName: str = "", + ) -> Dict[str, Any]: """ Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata` Parameters ---------- path : str - sc : :py:class:`pyspark.SparkContext` + sc : :py:class:`pyspark.SparkContext` or :py:class:`pyspark.sql.SparkSession` expectedClassName : str, optional If non empty, this is checked against the loaded metadata. """ metadataPath = os.path.join(path, "metadata") - spark = SparkSession._getActiveSessionOrCreate() + spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate() metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index] loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName) return loadedVals @@ -641,7 +646,7 @@ def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool: return metadata["class"].startswith("pyspark.ml.") @staticmethod - def loadParamsInstance(path: str, sc: "SparkContext") -> RL: + def loadParamsInstance(path: str, sc: Union["SparkContext", SparkSession]) -> RL: """ Load a :py:class:`Params` instance from the given path, and return it. This assumes the instance inherits from :py:class:`MLReadable`.