Skip to content

Commit

Permalink
[ML-185]Select label and features columns and cache data (#196)
Browse files Browse the repository at this point in the history
* update

Signed-off-by: minmingzhu <minming.zhu@intel.com>

* Update NaiveBayes.scala
  • Loading branch information
minmingzhu authored Apr 1, 2022
1 parent d57cc82 commit 401891e
Showing 1 changed file with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,14 @@ class NaiveBayes @Since("1.5.0") (

val sc = spark.sparkContext

val executor_num = Utils.sparkExecutorNum(sc)
val executor_cores = Utils.sparkExecutorCores()
// select label and features columns and cache data.
val naiveBayesData = dataset.select($(labelCol), $(featuresCol)).cache()
naiveBayesData.count()

logInfo(s"NaiveBayesDAL fit using $executor_num Executors")
val executorNum = Utils.sparkExecutorNum(sc)
val executorCores = Utils.sparkExecutorCores()

logInfo(s"NaiveBayesDAL fit using $executorNum Executors")

// DAL only support [0..numClasses) as labels, should map original labels using StringIndexer
// Todo: optimize getting num of classes
Expand All @@ -146,17 +150,17 @@ class NaiveBayes @Since("1.5.0") (
// numClasses should be explicitly included in the parquet metadata
// This can be done by applying StringIndexer to the label column
val numClasses = confClasses match {
case -1 => getNumClasses(dataset)
case -1 => getNumClasses(naiveBayesData)
case _ => confClasses
}

instr.logNumClasses(numClasses)

val labeledPointsDS = dataset
val labeledPointsDS = naiveBayesData
.select(col(getLabelCol), DatasetUtils.columnToVector(dataset, getFeaturesCol))

val dalModel = new NaiveBayesDALImpl(uid, numClasses,
executor_num, executor_cores).train(labeledPointsDS, ${labelCol}, ${featuresCol})
executorNum, executorCores).train(labeledPointsDS, ${labelCol}, ${featuresCol})

val model = copyValues(new NaiveBayesModel(
dalModel.uid, dalModel.pi, dalModel.theta, dalModel.sigma))
Expand Down Expand Up @@ -332,4 +336,4 @@ class NaiveBayes @Since("1.5.0") (
val sigma = new DenseMatrix(numLabels, numFeatures, sigmaArray, true)
new NaiveBayesModel(uid, pi.compressed, theta.compressed, sigma.compressed)
}
}
}

0 comments on commit 401891e

Please sign in to comment.