Skip to content

Commit

Permalink
REFACTOR: Removed model definition and training
Browse files Browse the repository at this point in the history
from importAnalysis method of AnalyticsFunctions (#237)
  • Loading branch information
NickEdwards7502 committed Sep 11, 2024
1 parent 0fc736f commit ea069d6
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions src/main/scala/au/csiro/variantspark/api/AnalyticsFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import au.csiro.variantspark.input.FeatureSource
import au.csiro.variantspark.input.LabelSource
import au.csiro.variantspark.algo.PairwiseOperation
import au.csiro.variantspark.algo.LowerTriangMatrix
import au.csiro.variantspark.algo.RandomForestModel
import org.apache.spark.SparkContext

/**
* Extends a feature with analytical functions.
Expand All @@ -14,22 +16,13 @@ class AnalyticsFunctions(val featureSource: FeatureSource) extends AnyVal {
* Builds random forest classifier for the provided labels and estimates variable
* importance using gini importance.
*
* @param labelSource: labels to use for importance analysis.
* @param nTrees: the number of trees to build in the forest.
* @param mtryFraction: the fraction of variables to try at each split.
* @param oob: should OOB (Out of Bag) error estimate be calculated.
* @param seed: random seed to use.
* @param batchSize: the number of trees to build in one batch.
* @param varOrdinalLevels: the number levels in the ordinal features.
* @param rfModel: The trained random forest model.
*
* @return [[au.csiro.variantspark.api.ImportanceAnalysis.apply]] importance analysis model
*/
def importanceAnalysis(labelSource: LabelSource, nTrees: Int = 1000,
mtryFraction: Option[Double] = None, oob: Boolean = true, seed: Option[Long] = None,
batchSize: Int = 100, varOrdinalLevels: Int = 3)(
def importanceAnalysis(rfModel: RandomForestModel)(
implicit vsContext: SqlContextHolder): ImportanceAnalysis = {
ImportanceAnalysis(featureSource, labelSource, nTrees, mtryFraction, oob, seed, batchSize,
varOrdinalLevels)
ImportanceAnalysis(featureSource, rfModel)
}

/**
Expand Down

0 comments on commit ea069d6

Please sign in to comment.