Skip to content

Commit

Permalink
DEV: Updated ImportanceAnalysis scala class (#237)
Browse files Browse the repository at this point in the history
REFACTOR: Removed model training from object instantation and
updated class to accept a model as a parameter

REFACTOR: Added normalisation as an optional parameter for
variable importance methods

FEAT: Updated variableImportance method to include splitCount in return as it is required for local FDR analysis
  • Loading branch information
NickEdwards7502 committed Sep 11, 2024
1 parent b8b39fd commit 4bfaac9
Showing 1 changed file with 47 additions and 49 deletions.
96 changes: 47 additions & 49 deletions src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,33 @@ import au.csiro.pbdava.ssparkle.spark.SparkUtils
import au.csiro.variantspark.algo.{RandomForest, RandomForestModel, RandomForestParams}
import au.csiro.variantspark.data.BoundedOrdinalVariable
import au.csiro.variantspark.input.{FeatureSource, LabelSource}
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap
import it.unimi.dsi.fastutil.longs.{Long2DoubleOpenHashMap, Long2LongOpenHashMap}
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{DoubleType, StringType, LongType, StructField, StructType}

import scala.collection.JavaConverters._

/**
* A class to represent an instance of the Importance Analysis
*
* @constructor Create a new `Importance Analysis` by specifying the parameters listed below
*
* @param sqlContext The SQL context.
* @param featureSource The feature source.
* @param labelSource The label source.
* @param rfParams The Random Forest parameters.
* @param nTrees The number of Decision Trees.
* @param rfBatchSize The batch size of the Random Forest
* @param varOrdinalLevels The level of ordinal
*
* @param rfModel The trained random forest model
* @example class ImportanceAnalysis(featureSource, labelSource, nTrees = 1000)
*/
class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureSource,
val labelSource: LabelSource, val rfParams: RandomForestParams, val nTrees: Int,
val rfBatchSize: Int, varOrdinalLevels: Int) {

val rfModel: RandomForestModel) {
private def sc = featureSource.features.sparkContext
private lazy val inputData = featureSource.features.zipWithIndex().cache()

val variableImportanceSchema: StructType =
StructType(Seq(StructField("variable", StringType, true),
StructField("importance", DoubleType, true)))

lazy val rfModel: RandomForestModel = {
val labels = labelSource.getLabels(featureSource.sampleNames)
val rf = new RandomForest(rfParams)
rf.batchTrain(inputData, labels, nTrees, rfBatchSize)
}

val oobError: Double = rfModel.oobError
val variableImportanceWithSplitCountSchema: StructType =
StructType(Seq(StructField("variant_id", StringType, true),
StructField("importance", DoubleType, true), StructField("splitCount", LongType, true)))

private lazy val br_normalizedVariableImportance = {
val indexImportance = rfModel.normalizedVariableImportance()
Expand All @@ -53,18 +41,44 @@ class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureS
indexImportance.asInstanceOf[Map[java.lang.Long, java.lang.Double]].asJava))
}

def variableImportance: DataFrame = {
val local_br_normalizedVariableImportance = br_normalizedVariableImportance
private lazy val br_variableImportance = {
val indexImportance = rfModel.variableImportance
sc.broadcast(
new Long2DoubleOpenHashMap(
indexImportance.asInstanceOf[Map[java.lang.Long, java.lang.Double]].asJava))
}

private lazy val br_splitCounts = {
val splitCounts = rfModel.variableSplitCount
sc.broadcast(
new Long2LongOpenHashMap(
splitCounts.asInstanceOf[Map[java.lang.Long, java.lang.Long]].asJava))
}

private lazy val inputData = featureSource.features.zipWithIndex().cache()

def variableImportance(normalized: Boolean = false): DataFrame = {
val local_br_variableImportance = if (!normalized) {
br_variableImportance
} else {
br_normalizedVariableImportance
}
val local_br_splitCounts = br_splitCounts

val importanceRDD = inputData.map({
case (f, i) => Row(f.label, local_br_normalizedVariableImportance.value.get(i))
case (f, i) =>
Row(f.label, local_br_variableImportance.value.get(i), local_br_splitCounts.value.get(i))
})
sqlContext.createDataFrame(importanceRDD, variableImportanceSchema)
sqlContext.createDataFrame(importanceRDD, variableImportanceWithSplitCountSchema)
}

def importantVariables(nTopLimit: Int = 100): Seq[(String, Double)] = {
// build index for names
val topImportantVariables =
def importantVariables(nTopLimit: Int = 100,
normalized: Boolean = false): Seq[(String, Double)] = {
val topImportantVariables = if (!normalized) {
rfModel.variableImportance.toSeq.sortBy(-_._2).take(nTopLimit)
} else {
rfModel.normalizedVariableImportance().toSeq.sortBy(-_._2).take(nTopLimit)
}
val topImportantVariableIndexes = topImportantVariables.map(_._1).toSet

val index =
Expand All @@ -79,35 +93,19 @@ class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureS
topImportantVariables.map({ case (i, importance) => (index(i), importance) })
}

def importantVariablesJavaMap(nTopLimit: Int = 100): util.Map[String, Double] = {
val impVarMap = collection.mutable.Map(importantVariables(nTopLimit).toMap.toSeq: _*)
def importantVariablesJavaMap(nTopLimit: Int = 100,
normalized: Boolean = false): util.Map[String, Double] = {
val impVarMap =
collection.mutable.Map(importantVariables(nTopLimit, normalized).toMap.toSeq: _*)
impVarMap.map { case (k, v) => k -> double2Double(v) }
impVarMap.asJava
}
}

object ImportanceAnalysis {

val defaultRFParams: RandomForestParams = RandomForestParams()

def apply(featureSource: FeatureSource, labelSource: LabelSource, nTrees: Int = 1000,
mtryFraction: Option[Double] = None, oob: Boolean = true, seed: Option[Long] = None,
batchSize: Int = 100, varOrdinalLevels: Int = 3)(
def apply(featureSource: FeatureSource, rfModel: RandomForestModel)(
implicit vsContext: SqlContextHolder): ImportanceAnalysis = {

new ImportanceAnalysis(vsContext.sqlContext, featureSource, labelSource,
rfParams = RandomForestParams(
nTryFraction = mtryFraction.getOrElse(defaultRFParams.nTryFraction),
seed = seed.getOrElse(defaultRFParams.seed), oob = oob),
nTrees = nTrees, rfBatchSize = batchSize, varOrdinalLevels = varOrdinalLevels)
new ImportanceAnalysis(vsContext.sqlContext, featureSource, rfModel)
}

def fromParams(featureSource: FeatureSource, labelSource: LabelSource,
rfParams: RandomForestParams, nTrees: Int = 1000, batchSize: Int = 100,
varOrdinalLevels: Int = 3)(implicit vsContext: SqlContextHolder): ImportanceAnalysis = {

new ImportanceAnalysis(vsContext.sqlContext, featureSource, labelSource, rfParams = rfParams,
nTrees = nTrees, rfBatchSize = batchSize, varOrdinalLevels = varOrdinalLevels)
}

}

0 comments on commit 4bfaac9

Please sign in to comment.