Skip to content

Commit

Permalink
[SPARK-49053][PYTHON][ML] Make model save/load helper functions accep…
Browse files Browse the repository at this point in the history
…t spark session

### What changes were proposed in this pull request?
Make model save/load helper functions accept spark session

### Why are the changes needed?
1, avoid unnecessary spark session creations;
2, to be consistent with scala side changes: apache#47467 and apache#47477

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47527 from zhengruifeng/py_ml_save_metadata_spark.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and attilapiros committed Oct 4, 2024
1 parent 8299ebd commit cd701ae
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 27 deletions.
29 changes: 19 additions & 10 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import Matrix, Vector, Vectors, VectorUDT
from pyspark.sql import DataFrame, Row
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -3678,7 +3678,7 @@ class _OneVsRestSharedReadWrite:
@staticmethod
def saveImpl(
instance: Union[OneVsRest, "OneVsRestModel"],
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
path: str,
extraMetadata: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -3691,7 +3691,10 @@ def saveImpl(
cast(MLWritable, instance.getClassifier()).save(classifierPath)

@staticmethod
def loadClassifier(path: str, sc: "SparkContext") -> Union[OneVsRest, "OneVsRestModel"]:
def loadClassifier(
path: str,
sc: Union["SparkContext", SparkSession],
) -> Union[OneVsRest, "OneVsRestModel"]:
classifierPath = os.path.join(path, "classifier")
return DefaultParamsReader.loadParamsInstance(classifierPath, sc)

Expand All @@ -3716,11 +3719,13 @@ def __init__(self, cls: Type[OneVsRest]) -> None:
self.cls = cls

def load(self, path: str) -> OneVsRest:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
classifier = cast(Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sc))
classifier = cast(
Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sparkSession)
)
ova: OneVsRest = OneVsRest(classifier=classifier)._resetUid(metadata["uid"])
DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=["classifier"])
return ova
Expand All @@ -3734,7 +3739,7 @@ def __init__(self, instance: OneVsRest):

def saveImpl(self, path: str) -> None:
_OneVsRestSharedReadWrite.validateParams(self.instance)
_OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
_OneVsRestSharedReadWrite.saveImpl(self.instance, self.sparkSession, path)


class OneVsRestModel(
Expand Down Expand Up @@ -3963,16 +3968,18 @@ def __init__(self, cls: Type[OneVsRestModel]):
self.cls = cls

def load(self, path: str) -> OneVsRestModel:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sparkSession)
numClasses = metadata["numClasses"]
subModels = [None] * numClasses
for idx in range(numClasses):
subModelPath = os.path.join(path, f"model_{idx}")
subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
subModels[idx] = DefaultParamsReader.loadParamsInstance(
subModelPath, self.sparkSession
)
ovaModel = OneVsRestModel(cast(List[ClassificationModel], subModels))._resetUid(
metadata["uid"]
)
Expand All @@ -3992,7 +3999,9 @@ def saveImpl(self, path: str) -> None:
instance = self.instance
numClasses = len(instance.models)
extraMetadata = {"numClasses": numClasses}
_OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
_OneVsRestSharedReadWrite.saveImpl(
instance, self.sparkSession, path, extraMetadata=extraMetadata
)
for idx in range(numClasses):
subModelPath = os.path.join(path, f"model_{idx}")
cast(MLWritable, instance.models[idx]).save(subModelPath)
Expand Down
19 changes: 11 additions & 8 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

if TYPE_CHECKING:
Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(self, instance: Pipeline):
def saveImpl(self, path: str) -> None:
stages = self.instance.getStages()
PipelineSharedReadWrite.validateStages(stages)
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sparkSession, path)


@inherit_doc
Expand All @@ -244,11 +245,11 @@ def __init__(self, cls: Type[Pipeline]):
self.cls = cls

def load(self, path: str) -> Pipeline:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python":
return JavaMLReader(cast(Type["JavaMLReadable[Pipeline]"], self.cls)).load(path)
else:
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
uid, stages = PipelineSharedReadWrite.load(metadata, self.sparkSession, path)
return Pipeline(stages=stages)._resetUid(uid)


Expand All @@ -266,7 +267,7 @@ def saveImpl(self, path: str) -> None:
stages = self.instance.stages
PipelineSharedReadWrite.validateStages(cast(List["PipelineStage"], stages))
PipelineSharedReadWrite.saveImpl(
self.instance, cast(List["PipelineStage"], stages), self.sc, path
self.instance, cast(List["PipelineStage"], stages), self.sparkSession, path
)


Expand All @@ -281,11 +282,11 @@ def __init__(self, cls: Type["PipelineModel"]):
self.cls = cls

def load(self, path: str) -> "PipelineModel":
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python":
return JavaMLReader(cast(Type["JavaMLReadable[PipelineModel]"], self.cls)).load(path)
else:
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
uid, stages = PipelineSharedReadWrite.load(metadata, self.sparkSession, path)
return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid)


Expand Down Expand Up @@ -403,7 +404,7 @@ def validateStages(stages: List["PipelineStage"]) -> None:
def saveImpl(
instance: Union[Pipeline, PipelineModel],
stages: List["PipelineStage"],
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
path: str,
) -> None:
"""
Expand All @@ -422,7 +423,9 @@ def saveImpl(

@staticmethod
def load(
metadata: Dict[str, Any], sc: "SparkContext", path: str
metadata: Dict[str, Any],
sc: Union["SparkContext", SparkSession],
path: str,
) -> Tuple[str, List["PipelineStage"]]:
"""
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
Expand Down
23 changes: 14 additions & 9 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TypeVar,
cast,
TYPE_CHECKING,
Union,
)

from pyspark import since
Expand Down Expand Up @@ -424,7 +425,7 @@ def __init__(self, instance: "Params"):
self.instance = instance

def saveImpl(self, path: str) -> None:
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
DefaultParamsWriter.saveMetadata(self.instance, path, self.sparkSession)

@staticmethod
def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str, Any]:
Expand All @@ -438,7 +439,7 @@ def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str
def saveMetadata(
instance: "Params",
path: str,
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
extraMetadata: Optional[Dict[str, Any]] = None,
paramMap: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -464,15 +465,15 @@ def saveMetadata(
metadataJson = DefaultParamsWriter._get_metadata_to_save(
instance, sc, extraMetadata, paramMap
)
spark = SparkSession._getActiveSessionOrCreate()
spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate()
spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text(
metadataPath
)

@staticmethod
def _get_metadata_to_save(
instance: "Params",
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
extraMetadata: Optional[Dict[str, Any]] = None,
paramMap: Optional[Dict[str, Any]] = None,
) -> str:
Expand Down Expand Up @@ -560,27 +561,31 @@ def __get_class(clazz: str) -> Type[RL]:
return getattr(m, parts[-1])

def load(self, path: str) -> RL:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"])
instance = py_type()
cast("Params", instance)._resetUid(metadata["uid"])
DefaultParamsReader.getAndSetParams(instance, metadata)
return instance

@staticmethod
def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> Dict[str, Any]:
def loadMetadata(
path: str,
sc: Union["SparkContext", SparkSession],
expectedClassName: str = "",
) -> Dict[str, Any]:
"""
Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
Parameters
----------
path : str
sc : :py:class:`pyspark.SparkContext`
sc : :py:class:`pyspark.SparkContext` or :py:class:`pyspark.sql.SparkSession`
expectedClassName : str, optional
If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
spark = SparkSession._getActiveSessionOrCreate()
spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate()
metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index]
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
return loadedVals
Expand Down Expand Up @@ -641,7 +646,7 @@ def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool:
return metadata["class"].startswith("pyspark.ml.")

@staticmethod
def loadParamsInstance(path: str, sc: "SparkContext") -> RL:
def loadParamsInstance(path: str, sc: Union["SparkContext", SparkSession]) -> RL:
"""
Load a :py:class:`Params` instance from the given path, and return it.
This assumes the instance inherits from :py:class:`MLReadable`.
Expand Down

0 comments on commit cd701ae

Please sign in to comment.