Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48988][ML] Make DefaultParamsReader/Writer handle metadata with spark session #47467

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
Expand All @@ -309,7 +309,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica

override def load(path: String): DecisionTreeClassificationModel = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
factors: Matrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -351,7 +351,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
private val className = classOf[FMClassificationModel].getName

override def load(path: String): FMClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -459,7 +459,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
private val className = classOf[LinearSVCModel].getName

override def load(path: String): LinearSVCModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
val Row(coefficients: Vector, intercept: Double) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
Expand All @@ -1325,7 +1325,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
private val className = classOf[LogisticRegressionModel].getName

override def load(path: String): LogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ object MultilayerPerceptronClassificationModel

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights
val data = Data(instance.weights)
val dataPath = new Path(path, "data").toString
Expand All @@ -380,7 +380,7 @@ object MultilayerPerceptronClassificationModel
private val className = classOf[MultilayerPerceptronClassificationModel].getName

override def load(path: String): MultilayerPerceptronClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString

instance.getModelType match {
Expand All @@ -602,7 +602,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {

override def load(path: String): NaiveBayesModel = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)

val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.parentModel.save(sc, dataPath)
}
Expand All @@ -198,7 +198,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
private val className = classOf[BisectingKMeansModel].getName

override def load(path: String): BisectingKMeansModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
val model = new BisectingKMeansModel(metadata.uid, mllibModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights and gaussians
val weights = instance.weights
val gaussians = instance.gaussians
Expand All @@ -253,7 +253,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
private val className = classOf[GaussianMixtureModel].getName

override def load(path: String): GaussianMixtureModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi
override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[KMeansModel]
val sc = sparkSession.sparkContext
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: cluster centers
val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
case (center, idx) =>
Expand Down Expand Up @@ -272,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val sparkSession = super.sparkSession
import sparkSession.implicits._

val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString

val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
Expand Down
10 changes: 5 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
gammaShape: Double)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val oldModel = instance.oldLocalModel
val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
oldModel.topicConcentration, oldModel.gammaShape)
Expand All @@ -668,7 +668,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
private val className = classOf[LocalLDAModel].getName

override def load(path: String): LocalLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
Expand Down Expand Up @@ -809,7 +809,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val modelPath = new Path(path, "oldModel").toString
instance.oldDistributedModel.save(sc, modelPath)
}
Expand All @@ -820,7 +820,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
private val className = classOf[DistributedLDAModel].getName

override def load(path: String): DistributedLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
Expand Down Expand Up @@ -1008,7 +1008,7 @@ object LDA extends MLReadable[LDA] {
private val className = classOf[LDA].getName

override def load(path: String): LDA = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val model = new LDA(metadata.uid)
LDAParams.getAndSetParams(model, metadata)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
private case class Data(randUnitVectors: Matrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randMatrix)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -241,7 +241,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
private val className = classOf[BucketedRandomProjectionLSHModel].getName

override def load(path: String): BucketedRandomProjectionLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
private case class Data(selectedFeatures: Seq[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -185,7 +185,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
private val className = classOf[ChiSqSelectorModel].getName

override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private case class Data(vocabulary: Seq[String])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.vocabulary.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -384,7 +384,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private val className = classOf[CountVectorizerModel].getName

override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("vocabulary")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ object HashingTF extends DefaultParamsReadable[HashingTF] {
private val className = classOf[HashingTF].getName

override def load(path: String): HashingTF = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

// We support loading old `HashingTF` saved by previous Spark versions.
// Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses
Expand Down
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object IDFModel extends MLReadable[IDFModel] {
private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.idf, instance.docFreq, instance.numDocs)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -210,7 +210,7 @@ object IDFModel extends MLReadable[IDFModel] {
private val className = classOf[IDFModel].getName

override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.surrogateDF.repartition(1).write.parquet(dataPath)
}
Expand All @@ -319,7 +319,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
private val className = classOf[ImputerModel].getName

override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
private case class Data(maxAbs: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = new Data(instance.maxAbs)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -174,7 +174,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
private val className = classOf[MaxAbsScalerModel].getName

override def load(path: String): MaxAbsScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
.select("maxAbs")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private case class Data(randCoefficients: Array[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randCoefficients.flatMap(tuple => Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
Expand All @@ -233,7 +233,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private val className = classOf[MinHashLSHModel].getName

override def load(path: String): MinHashLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("randCoefficients").head()
Expand Down
Loading