Skip to content

Commit

Permalink
[SPARK-48988][ML] Make DefaultParamsReader/Writer handle metadata w…
Browse files Browse the repository at this point in the history
…ith spark session

### What changes were proposed in this pull request?
`DefaultParamsReader/Writer` handle metadata with spark session

### Why are the changes needed?
In existing ml implementations, when loading/saving a model, it loads/saves the metadata with `SparkContext` then loads/saves the coefficients with `SparkSession`.

This PR aims to also load/save the metadata with `SparkSession`, by introducing new helper functions.

- Note I: 3-rd libraries (e.g. [xgboost](https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala#L38-L53) ) likely depends on existing implementation of saveMetadata/loadMetadata, so we cannot simply remove them even though they are `private[ml]`.

- Note II: this PR only handles `loadMetadata` and `saveMetadata`, there are similar cases for meta algorithms and param read/write, but I want to ignore the remaining part first, to avoid touching too many files in single PR.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47467 from zhengruifeng/ml_load_with_spark.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and attilapiros committed Oct 4, 2024
1 parent 34f9fb1 commit 293f6a7
Show file tree
Hide file tree
Showing 39 changed files with 137 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
Expand All @@ -309,7 +309,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica

override def load(path: String): DecisionTreeClassificationModel = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
factors: Matrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -351,7 +351,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
private val className = classOf[FMClassificationModel].getName

override def load(path: String): FMClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -459,7 +459,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
private val className = classOf[LinearSVCModel].getName

override def load(path: String): LinearSVCModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
val Row(coefficients: Vector, intercept: Double) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
Expand All @@ -1325,7 +1325,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
private val className = classOf[LogisticRegressionModel].getName

override def load(path: String): LogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ object MultilayerPerceptronClassificationModel

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights
val data = Data(instance.weights)
val dataPath = new Path(path, "data").toString
Expand All @@ -380,7 +380,7 @@ object MultilayerPerceptronClassificationModel
private val className = classOf[MultilayerPerceptronClassificationModel].getName

override def load(path: String): MultilayerPerceptronClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString

instance.getModelType match {
Expand All @@ -602,7 +602,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {

override def load(path: String): NaiveBayesModel = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.parentModel.save(sc, dataPath)
}
Expand All @@ -198,7 +198,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
private val className = classOf[BisectingKMeansModel].getName

override def load(path: String): BisectingKMeansModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
val model = new BisectingKMeansModel(metadata.uid, mllibModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights and gaussians
val weights = instance.weights
val gaussians = instance.gaussians
Expand All @@ -253,7 +253,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
private val className = classOf[GaussianMixtureModel].getName

override def load(path: String): GaussianMixtureModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi
override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[KMeansModel]
val sc = sparkSession.sparkContext
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: cluster centers
val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
case (center, idx) =>
Expand Down Expand Up @@ -272,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val sparkSession = super.sparkSession
import sparkSession.implicits._

val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString

val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
Expand Down
10 changes: 5 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
gammaShape: Double)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val oldModel = instance.oldLocalModel
val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
oldModel.topicConcentration, oldModel.gammaShape)
Expand All @@ -668,7 +668,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
private val className = classOf[LocalLDAModel].getName

override def load(path: String): LocalLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
Expand Down Expand Up @@ -809,7 +809,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val modelPath = new Path(path, "oldModel").toString
instance.oldDistributedModel.save(sc, modelPath)
}
Expand All @@ -820,7 +820,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
private val className = classOf[DistributedLDAModel].getName

override def load(path: String): DistributedLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
Expand Down Expand Up @@ -1008,7 +1008,7 @@ object LDA extends MLReadable[LDA] {
private val className = classOf[LDA].getName

override def load(path: String): LDA = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val model = new LDA(metadata.uid)
LDAParams.getAndSetParams(model, metadata)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
private case class Data(randUnitVectors: Matrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randMatrix)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -241,7 +241,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
private val className = classOf[BucketedRandomProjectionLSHModel].getName

override def load(path: String): BucketedRandomProjectionLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
private case class Data(selectedFeatures: Seq[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -185,7 +185,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
private val className = classOf[ChiSqSelectorModel].getName

override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private case class Data(vocabulary: Seq[String])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.vocabulary.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -384,7 +384,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private val className = classOf[CountVectorizerModel].getName

override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("vocabulary")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ object HashingTF extends DefaultParamsReadable[HashingTF] {
private val className = classOf[HashingTF].getName

override def load(path: String): HashingTF = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

// We support loading old `HashingTF` saved by previous Spark versions.
// Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses
Expand Down
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object IDFModel extends MLReadable[IDFModel] {
private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.idf, instance.docFreq, instance.numDocs)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -210,7 +210,7 @@ object IDFModel extends MLReadable[IDFModel] {
private val className = classOf[IDFModel].getName

override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.surrogateDF.repartition(1).write.parquet(dataPath)
}
Expand All @@ -319,7 +319,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
private val className = classOf[ImputerModel].getName

override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
private case class Data(maxAbs: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = new Data(instance.maxAbs)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -174,7 +174,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
private val className = classOf[MaxAbsScalerModel].getName

override def load(path: String): MaxAbsScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
.select("maxAbs")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private case class Data(randCoefficients: Array[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randCoefficients.flatMap(tuple => Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -233,7 +233,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private val className = classOf[MinHashLSHModel].getName

override def load(path: String): MinHashLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("randCoefficients").head()
Expand Down
Loading

0 comments on commit 293f6a7

Please sign in to comment.