diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7deefda2eeaff..c5f1d7f39b6b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -293,7 +293,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numClasses" -> instance.numClasses) - DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString val numDataParts = NodeData.inferNumPartitions(instance.numNodes) @@ -309,7 +309,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica override def load(path: String): DecisionTreeClassificationModel = { implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val root = loadTreeNodes(path, metadata, sparkSession) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index aec740a932acf..33e7c1fdd5e05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -339,7 +339,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] { factors: Matrix) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.intercept, instance.linear, instance.factors) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -351,7 +351,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] { private val className = classOf[FMClassificationModel].getName override def load(path: String): FMClassificationModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 3e27f781d561b..161e8f4cbd2c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -446,7 +446,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.coefficients, instance.intercept) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -459,7 +459,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { private val className = classOf[LinearSVCModel].getName override def load(path: String): LinearSVCModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) val Row(coefficients: Vector, intercept: Double) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ac0682f1df5bf..745cb61bb7aa1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1310,7 +1310,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: numClasses, numFeatures, intercept, coefficients val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector, instance.coefficientMatrix, instance.isMultinomial) @@ -1325,7 +1325,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 16984bf9aed8a..106282b9dc3a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -365,7 +365,7 @@ object MultilayerPerceptronClassificationModel override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: weights val data = Data(instance.weights) val dataPath = new Path(path, "data").toString @@ -380,7 +380,7 @@ object MultilayerPerceptronClassificationModel private val className = classOf[MultilayerPerceptronClassificationModel].getName override def load(path: String): MultilayerPerceptronClassificationModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 52486cb8aa245..4a511581d31a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -580,7 +580,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val dataPath = new Path(path, "data").toString instance.getModelType match { @@ -602,7 +602,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { override def load(path: String): NaiveBayesModel = { implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 2d809151384b8..b4f1565362b02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -186,7 +186,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val dataPath = new Path(path, "data").toString instance.parentModel.save(sc, dataPath) } @@ -198,7 +198,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { private val className = classOf[BisectingKMeansModel].getName override def load(path: String): BisectingKMeansModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 0f6648bb4cda7..d0db5dcba87b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -235,7 +235,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: weights and gaussians val weights = instance.weights val gaussians = instance.gaussians @@ -253,7 +253,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { private val className = classOf[GaussianMixtureModel].getName override def load(path: String): GaussianMixtureModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 04f76660aee6a..50fb18bb620a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -219,9 +219,8 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi override def write(path: String, sparkSession: SparkSession, optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { val instance = stage.asInstanceOf[KMeansModel] - val sc = sparkSession.sparkContext // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: cluster centers val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => @@ -272,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val sparkSession = super.sparkSession import sparkSession.implicits._ - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 7cbfc732a19ca..b3d3c84db0511 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -654,7 +654,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { gammaShape: Double) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val oldModel = instance.oldLocalModel val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, oldModel.topicConcentration, oldModel.gammaShape) @@ -668,7 +668,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { private val className = classOf[LocalLDAModel].getName override def load(path: String): LocalLDAModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration") @@ -809,7 +809,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val modelPath = new Path(path, "oldModel").toString instance.oldDistributedModel.save(sc, modelPath) } @@ -820,7 +820,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { private val className = classOf[DistributedLDAModel].getName override def load(path: String): DistributedLDAModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize, @@ -1008,7 +1008,7 @@ object LDA extends MLReadable[LDA] { private val className = classOf[LDA].getName override def load(path: String): LDA = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val model = new LDA(metadata.uid) LDAParams.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index d30962088cb81..537cb5020c88d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -227,7 +227,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject private case class Data(randUnitVectors: Matrix) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.randMatrix) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -241,7 +241,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject private val className = classOf[BucketedRandomProjectionLSHModel].getName override def load(path: String): BucketedRandomProjectionLSHModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 3062f643e950d..eb2122b09b2fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -173,7 +173,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { private case class Data(selectedFeatures: Seq[Int]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.selectedFeatures.toImmutableArraySeq) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -185,7 +185,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { private val className = classOf[ChiSqSelectorModel].getName override def load(path: String): ChiSqSelectorModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head() val selectedFeatures = data.getAs[Seq[Int]](0).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index b81914f86fbb7..611b5c710add1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -372,7 +372,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private case class Data(vocabulary: Seq[String]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.vocabulary.toImmutableArraySeq) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -384,7 +384,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) .select("vocabulary") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index f4223bc85943d..3b42105958c72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -154,7 +154,7 @@ object HashingTF extends DefaultParamsReadable[HashingTF] { private val className = classOf[HashingTF].getName override def load(path: String): HashingTF = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) // We support loading old `HashingTF` saved by previous Spark versions. // Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 696e1516582d0..3025a7b04af53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -198,7 +198,7 @@ object IDFModel extends MLReadable[IDFModel] { private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.idf, instance.docFreq, instance.numDocs) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -210,7 +210,7 @@ object IDFModel extends MLReadable[IDFModel] { private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index ae65b17d7a810..38fb25903dcaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -308,7 +308,7 @@ object ImputerModel extends MLReadable[ImputerModel] { private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val dataPath = new Path(path, "data").toString instance.surrogateDF.repartition(1).write.parquet(dataPath) } @@ -319,7 +319,7 @@ object ImputerModel extends MLReadable[ImputerModel] { private val className = classOf[ImputerModel].getName override def load(path: String): ImputerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 05ee59d1627da..1a378cd85f3e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -162,7 +162,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { private case class Data(maxAbs: Vector) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = new Data(instance.maxAbs) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -174,7 +174,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { private val className = classOf[MaxAbsScalerModel].getName override def load(path: String): MaxAbsScalerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath) .select("maxAbs") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index d94aadd1ce1f9..3f2a3327128a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -220,7 +220,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { private case class Data(randCoefficients: Array[Int]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.randCoefficients.flatMap(tuple => Array(tuple._1, tuple._2))) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -233,7 +233,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { private val className = classOf[MinHashLSHModel].getName override def load(path: String): MinHashLSHModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("randCoefficients").head() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 4111e559a5c20..c311f4260424d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -247,7 +247,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private case class Data(originalMin: Vector, originalMax: Vector) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = new Data(instance.originalMin, instance.originalMax) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -259,7 +259,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) val Row(originalMin: Vector, originalMax: Vector) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e7cf0105754a9..823f767eebbe0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -403,7 +403,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { private case class Data(categorySizes: Array[Int]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.categorySizes) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -415,7 +415,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { private val className = classOf[OneHotEncoderModel].getName override def load(path: String): OneHotEncoderModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) .select("categorySizes") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index f7ec18b38a0e3..0bd9a3c38a1e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -184,7 +184,7 @@ object PCAModel extends MLReadable[PCAModel] { private case class Data(pc: DenseMatrix, explainedVariance: DenseVector) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.pc, instance.explainedVariance) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -205,7 +205,7 @@ object PCAModel extends MLReadable[PCAModel] { * @return a [[PCAModel]] */ override def load(path: String): PCAModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val model = if (majorVersion(metadata.sparkVersion) >= 2) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7a47e73e5ef48..77bd18423ef1b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -427,7 +427,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: resolvedFormula val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(instance.resolvedFormula)) @@ -444,7 +444,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { private val className = classOf[RFormulaModel].getName override def load(path: String): RFormulaModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() @@ -502,7 +502,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: columnsToPrune val data = Data(instance.columnsToPrune.toSeq) val dataPath = new Path(path, "data").toString @@ -516,7 +516,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { private val className = classOf[ColumnPruner].getName override def load(path: String): ColumnPruner = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head() @@ -594,7 +594,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: vectorCol, prefixesToRewrite val data = Data(instance.vectorCol, instance.prefixesToRewrite) val dataPath = new Path(path, "data").toString @@ -608,7 +608,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite private val className = classOf[VectorAttributeRewriter].getName override def load(path: String): VectorAttributeRewriter = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala index 0950dc55dccba..f3e068f049205 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala @@ -284,7 +284,7 @@ object RobustScalerModel extends MLReadable[RobustScalerModel] { private case class Data(range: Vector, median: Vector) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.range, instance.median) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -296,7 +296,7 @@ object RobustScalerModel extends MLReadable[RobustScalerModel] { private val className = classOf[RobustScalerModel].getName override def load(path: String): RobustScalerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) val Row(range: Vector, median: Vector) = MLUtils diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c0a6392c29c3e..f1e48b053d883 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -205,7 +205,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private case class Data(std: Vector, mean: Vector) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -217,7 +217,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) val Row(std: Vector, mean: Vector) = MLUtils.convertVectorColumnsToML(data, "std", "mean") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 2ca640445b553..94d4fa6fe6f20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -509,7 +509,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private case class Data(labelsArray: Array[Array[String]]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.labelsArray) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -521,7 +521,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString // We support loading old `StringIndexerModel` saved by previous Spark versions. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index 29a0910124953..9c2033c28430e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -349,7 +349,7 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect private case class Data(selectedFeatures: Seq[Int]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.selectedFeatures.toImmutableArraySeq) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -363,7 +363,7 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect private val className = classOf[UnivariateFeatureSelectorModel].getName override def load(path: String): UnivariateFeatureSelectorModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) .select("selectedFeatures").head() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala index df57e19f1a723..d767e113144c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala @@ -187,7 +187,7 @@ object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelect private case class Data(selectedFeatures: Seq[Int]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.selectedFeatures.toImmutableArraySeq) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -201,7 +201,7 @@ object VarianceThresholdSelectorModel extends MLReadable[VarianceThresholdSelect private val className = classOf[VarianceThresholdSelectorModel].getName override def load(path: String): VarianceThresholdSelectorModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) .select("selectedFeatures").head() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 4fed325e19e91..ff89dee68ea38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -519,7 +519,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.numFeatures, instance.categoryMaps) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -531,7 +531,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { private val className = classOf[VectorIndexerModel].getName override def load(path: String): VectorIndexerModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) .select("numFeatures", "categoryMaps") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 66b56f8b88ef1..0329190a239ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -352,7 +352,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val wordVectors = instance.wordVectors.getVectors val dataPath = new Path(path, "data").toString @@ -407,7 +407,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val spark = sparkSession import spark.implicits._ - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 081a40bfbe801..d054ea8ebdb47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -336,7 +336,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { override protected def saveImpl(path: String): Unit = { val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) - DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = Some(extraMetadata)) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession, + extraMetadata = Some(extraMetadata)) val dataPath = new Path(path, "data").toString instance.freqItemsets.write.parquet(dataPath) } @@ -349,7 +350,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { override def load(path: String): FPGrowthModel = { implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) val numTrainingRecords = if (major < 2 || (major == 2 && minor < 4)) { // 2.3 and before don't store the count diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 50f94a5799444..1a004f71749e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -556,7 +556,7 @@ object ALSModel extends MLReadable[ALSModel] { override protected def saveImpl(path: String): Unit = { val extraMetadata = "rank" -> instance.rank - DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) val itemPath = new Path(path, "itemFactors").toString @@ -570,7 +570,7 @@ object ALSModel extends MLReadable[ALSModel] { private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) implicit val format = DefaultFormats val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index d77d79dae4b8c..6451cbf0329d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -494,7 +494,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: coefficients, intercept, scale val data = Data(instance.coefficients, instance.intercept, instance.scale) val dataPath = new Path(path, "data").toString @@ -508,7 +508,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] private val className = classOf[AFTSurvivalRegressionModel].getName override def load(path: String): AFTSurvivalRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 481e8c8357f16..dace99f214b16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -302,7 +302,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode override protected def saveImpl(path: String): Unit = { val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures) - DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString val numDataParts = NodeData.inferNumPartitions(instance.numNodes) @@ -318,7 +318,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode override def load(path: String): DecisionTreeRegressionModel = { implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val root = loadTreeNodes(path, metadata, sparkSession) val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index 8c797295e6715..182107a443c1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -504,7 +504,7 @@ object FMRegressionModel extends MLReadable[FMRegressionModel] { factors: Matrix) override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) val data = Data(instance.intercept, instance.linear, instance.factors) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) @@ -516,7 +516,7 @@ object FMRegressionModel extends MLReadable[FMRegressionModel] { private val className = classOf[FMRegressionModel].getName override def load(path: String): FMRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 181a1a03e6f38..dc0b553e2c91d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1141,7 +1141,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString @@ -1156,7 +1156,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr private val className = classOf[GeneralizedLinearRegressionModel].getName override def load(path: String): GeneralizedLinearRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 29d8a00a43844..d624270af89d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -301,7 +301,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: boundaries, predictions, isotonic val data = Data( instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) @@ -316,7 +316,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { private val className = classOf[IsotonicRegressionModel].getName override def load(path: String): IsotonicRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d5dce782770b4..abac9db8df024 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -780,7 +780,7 @@ private class InternalLinearRegressionModelWriter val instance = stage.asInstanceOf[LinearRegressionModel] val sc = sparkSession.sparkContext // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) // Save model data: intercept, coefficients, scale val data = Data(instance.intercept, instance.coefficients, instance.scale) val dataPath = new Path(path, "data").toString @@ -824,7 +824,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 6a7615fb149b8..cdd40ae355037 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -471,7 +471,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sparkSession: SparkSession, extraMetadata: JObject): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sparkSession.sparkContext, Some(extraMetadata)) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata)) val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree, treeID) => (treeID, DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession.sparkContext), @@ -510,7 +510,7 @@ private[ml] object EnsembleModelReadWrite { treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sparkSession.implicits._ implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sparkSession.sparkContext, className) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) // Get impurity to construct ImpurityCalculator for each node val impurityType: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index c127575e14707..d338c267d823c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -382,7 +382,7 @@ trait DefaultParamsReadable[T] extends MLReadable[T] { private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) } } @@ -403,20 +403,58 @@ private[ml] object DefaultParamsWriter { * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ + @deprecated("use saveMetadata with SparkSession", "4.0.0") def saveMetadata( instance: Params, path: String, sc: SparkContext, extraMetadata: Option[JObject] = None, - paramMap: Option[JValue] = None): Unit = { + paramMap: Option[JValue] = None): Unit = + saveMetadata( + instance, + path, + SparkSession.builder().sparkContext(sc).getOrCreate(), + extraMetadata, + paramMap) + + /** + * Saves metadata + Params to: path + "/metadata" + * - class + * - timestamp + * - sparkVersion + * - uid + * - defaultParamMap + * - paramMap + * - (optionally, extra metadata) + * + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. + */ + def saveMetadata( + instance: Params, + path: String, + spark: SparkSession, + extraMetadata: Option[JObject], + paramMap: Option[JValue]): Unit = { val metadataPath = new Path(path, "metadata").toString - val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) - val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadataJson = getMetadataToSave(instance, spark.sparkContext, extraMetadata, paramMap) // Note that we should write single file. If there are more than one row // it produces more partitions. spark.createDataFrame(Seq(Tuple1(metadataJson))).write.text(metadataPath) } + def saveMetadata( + instance: Params, + path: String, + spark: SparkSession, + extraMetadata: Option[JObject]): Unit = + saveMetadata(instance, path, spark, extraMetadata, None) + + def saveMetadata(instance: Params, path: String, spark: SparkSession): Unit = + saveMetadata(instance, path, spark, None, None) + /** * Helper for [[saveMetadata()]] which extracts the JSON to save. * This is useful for ensemble models which need to save metadata for many sub-models. @@ -466,7 +504,7 @@ private[ml] object DefaultParamsWriter { private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { - val metadata = DefaultParamsReader.loadMetadata(path, sc) + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession) val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] @@ -586,13 +624,22 @@ private[ml] object DefaultParamsReader { * @param expectedClassName If non empty, this is checked against the loaded metadata. * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata */ - def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { + @deprecated("use loadMetadata with SparkSession", "4.0.0") + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = + loadMetadata( + path, + SparkSession.builder().sparkContext(sc).getOrCreate(), + expectedClassName) + + def loadMetadata(path: String, spark: SparkSession, expectedClassName: String): Metadata = { val metadataPath = new Path(path, "metadata").toString - val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadataStr = spark.read.text(metadataPath).first().getString(0) parseMetadata(metadataStr, expectedClassName) } + def loadMetadata(path: String, spark: SparkSession): Metadata = + loadMetadata(path, spark, "") + /** * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]]. * This is a helper function for [[loadMetadata()]].