diff --git a/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala b/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala index 686bf0a2..73272963 100644 --- a/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala +++ b/src/main/scala/au/csiro/variantspark/api/ImportanceAnalysis.scala @@ -6,9 +6,10 @@ import au.csiro.pbdava.ssparkle.spark.SparkUtils import au.csiro.variantspark.algo.{RandomForest, RandomForestModel, RandomForestParams} import au.csiro.variantspark.data.BoundedOrdinalVariable import au.csiro.variantspark.input.{FeatureSource, LabelSource} -import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap +import it.unimi.dsi.fastutil.longs.{Long2DoubleOpenHashMap, Long2LongOpenHashMap} +import org.apache.spark.SparkContext import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StringType, LongType, StructField, StructType} import scala.collection.JavaConverters._ @@ -16,35 +17,22 @@ import scala.collection.JavaConverters._ * A class to represent an instance of the Importance Analysis * * @constructor Create a new `Importance Analysis` by specifying the parameters listed below - * * @param sqlContext The SQL context. * @param featureSource The feature source. - * @param labelSource The label source. - * @param rfParams The Random Forest parameters. - * @param nTrees The number of Decision Trees. - * @param rfBatchSize The batch size of the Random Forest - * @param varOrdinalLevels The level of ordinal - * + * @param rfModel The trained random forest model * @example class ImportanceAnalysis(featureSource, labelSource, nTrees = 1000) */ class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureSource, - val labelSource: LabelSource, val rfParams: RandomForestParams, val nTrees: Int, - val rfBatchSize: Int, varOrdinalLevels: Int) { - + val rfModel: RandomForestModel) { private def sc = featureSource.features.sparkContext - private lazy val inputData = featureSource.features.zipWithIndex().cache() val variableImportanceSchema: StructType = StructType(Seq(StructField("variable", StringType, true), StructField("importance", DoubleType, true))) - lazy val rfModel: RandomForestModel = { - val labels = labelSource.getLabels(featureSource.sampleNames) - val rf = new RandomForest(rfParams) - rf.batchTrain(inputData, labels, nTrees, rfBatchSize) - } - - val oobError: Double = rfModel.oobError + val variableImportanceWithSplitCountSchema: StructType = + StructType(Seq(StructField("variant_id", StringType, true), + StructField("importance", DoubleType, true), StructField("splitCount", LongType, true))) private lazy val br_normalizedVariableImportance = { val indexImportance = rfModel.normalizedVariableImportance() @@ -53,18 +41,44 @@ class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureS indexImportance.asInstanceOf[Map[java.lang.Long, java.lang.Double]].asJava)) } - def variableImportance: DataFrame = { - val local_br_normalizedVariableImportance = br_normalizedVariableImportance + private lazy val br_variableImportance = { + val indexImportance = rfModel.variableImportance + sc.broadcast( + new Long2DoubleOpenHashMap( + indexImportance.asInstanceOf[Map[java.lang.Long, java.lang.Double]].asJava)) + } + + private lazy val br_splitCounts = { + val splitCounts = rfModel.variableSplitCount + sc.broadcast( + new Long2LongOpenHashMap( + splitCounts.asInstanceOf[Map[java.lang.Long, java.lang.Long]].asJava)) + } + + private lazy val inputData = featureSource.features.zipWithIndex().cache() + + def variableImportance(normalized: Boolean = false): DataFrame = { + val local_br_variableImportance = if (!normalized) { + br_variableImportance + } else { + br_normalizedVariableImportance + } + val local_br_splitCounts = br_splitCounts + val importanceRDD = inputData.map({ - case (f, i) => Row(f.label, local_br_normalizedVariableImportance.value.get(i)) + case (f, i) => + Row(f.label, local_br_variableImportance.value.get(i), local_br_splitCounts.value.get(i)) }) - sqlContext.createDataFrame(importanceRDD, variableImportanceSchema) + sqlContext.createDataFrame(importanceRDD, variableImportanceWithSplitCountSchema) } - def importantVariables(nTopLimit: Int = 100): Seq[(String, Double)] = { - // build index for names - val topImportantVariables = + def importantVariables(nTopLimit: Int = 100, + normalized: Boolean = false): Seq[(String, Double)] = { + val topImportantVariables = if (!normalized) { + rfModel.variableImportance.toSeq.sortBy(-_._2).take(nTopLimit) + } else { rfModel.normalizedVariableImportance().toSeq.sortBy(-_._2).take(nTopLimit) + } val topImportantVariableIndexes = topImportantVariables.map(_._1).toSet val index = @@ -79,35 +93,19 @@ class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureS topImportantVariables.map({ case (i, importance) => (index(i), importance) }) } - def importantVariablesJavaMap(nTopLimit: Int = 100): util.Map[String, Double] = { - val impVarMap = collection.mutable.Map(importantVariables(nTopLimit).toMap.toSeq: _*) + def importantVariablesJavaMap(nTopLimit: Int = 100, + normalized: Boolean = false): util.Map[String, Double] = { + val impVarMap = + collection.mutable.Map(importantVariables(nTopLimit, normalized).toMap.toSeq: _*) impVarMap.map { case (k, v) => k -> double2Double(v) } impVarMap.asJava } } object ImportanceAnalysis { - - val defaultRFParams: RandomForestParams = RandomForestParams() - - def apply(featureSource: FeatureSource, labelSource: LabelSource, nTrees: Int = 1000, - mtryFraction: Option[Double] = None, oob: Boolean = true, seed: Option[Long] = None, - batchSize: Int = 100, varOrdinalLevels: Int = 3)( + def apply(featureSource: FeatureSource, rfModel: RandomForestModel)( implicit vsContext: SqlContextHolder): ImportanceAnalysis = { - new ImportanceAnalysis(vsContext.sqlContext, featureSource, labelSource, - rfParams = RandomForestParams( - nTryFraction = mtryFraction.getOrElse(defaultRFParams.nTryFraction), - seed = seed.getOrElse(defaultRFParams.seed), oob = oob), - nTrees = nTrees, rfBatchSize = batchSize, varOrdinalLevels = varOrdinalLevels) + new ImportanceAnalysis(vsContext.sqlContext, featureSource, rfModel) } - - def fromParams(featureSource: FeatureSource, labelSource: LabelSource, - rfParams: RandomForestParams, nTrees: Int = 1000, batchSize: Int = 100, - varOrdinalLevels: Int = 3)(implicit vsContext: SqlContextHolder): ImportanceAnalysis = { - - new ImportanceAnalysis(vsContext.sqlContext, featureSource, labelSource, rfParams = rfParams, - nTrees = nTrees, rfBatchSize = batchSize, varOrdinalLevels = varOrdinalLevels) - } - }