From ea069d6989a0ff573fd16b79181d455e96f910df Mon Sep 17 00:00:00 2001 From: NickEdwards7502 Date: Wed, 11 Sep 2024 15:38:37 +1000 Subject: [PATCH] REFACTOR: Removed model definition and training from importAnalysis method of AnalyticsFunctions (#237) --- .../variantspark/api/AnalyticsFunctions.scala | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/main/scala/au/csiro/variantspark/api/AnalyticsFunctions.scala b/src/main/scala/au/csiro/variantspark/api/AnalyticsFunctions.scala index 54bd5177..571cc7cb 100644 --- a/src/main/scala/au/csiro/variantspark/api/AnalyticsFunctions.scala +++ b/src/main/scala/au/csiro/variantspark/api/AnalyticsFunctions.scala @@ -4,6 +4,8 @@ import au.csiro.variantspark.input.FeatureSource import au.csiro.variantspark.input.LabelSource import au.csiro.variantspark.algo.PairwiseOperation import au.csiro.variantspark.algo.LowerTriangMatrix +import au.csiro.variantspark.algo.RandomForestModel +import org.apache.spark.SparkContext /** * Extends a feature with analytical functions. @@ -14,22 +16,13 @@ class AnalyticsFunctions(val featureSource: FeatureSource) extends AnyVal { * Builds random forest classifier for the provided labels and estimates variable * importance using gini importance. * - * @param labelSource: labels to use for importance analysis. - * @param nTrees: the number of trees to build in the forest. - * @param mtryFraction: the fraction of variables to try at each split. - * @param oob: should OOB (Out of Bag) error estimate be calculated. - * @param seed: random seed to use. - * @param batchSize: the number of trees to build in one batch. - * @param varOrdinalLevels: the number levels in the ordinal features. + * @param rfModel: The trained random forest model. * * @return [[au.csiro.variantspark.api.ImportanceAnalysis.apply]] importance analysis model */ - def importanceAnalysis(labelSource: LabelSource, nTrees: Int = 1000, - mtryFraction: Option[Double] = None, oob: Boolean = true, seed: Option[Long] = None, - batchSize: Int = 100, varOrdinalLevels: Int = 3)( + def importanceAnalysis(rfModel: RandomForestModel)( implicit vsContext: SqlContextHolder): ImportanceAnalysis = { - ImportanceAnalysis(featureSource, labelSource, nTrees, mtryFraction, oob, seed, batchSize, - varOrdinalLevels) + ImportanceAnalysis(featureSource, rfModel) } /**