Skip to content

Commit

Permalink
DEV: Created scala function that trains a forest
Browse files Browse the repository at this point in the history
and passes back to python context (#237)
  • Loading branch information
NickEdwards7502 committed Sep 11, 2024
1 parent 4bfaac9 commit 0fc736f
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/main/scala/au/csiro/variantspark/api/GetRFModel.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package au.csiro.variantspark.api

import au.csiro.variantspark.algo.{
RandomForest,
RandomForestModel,
RandomForestParams
}
import au.csiro.variantspark.input.{FeatureSource, LabelSource}

/** Passes a trained random forest model back to the python wrapper
*/
object RFModelTrainer {
def trainModel(
featureSource: FeatureSource,
labelSource: LabelSource,
params: RandomForestParams,
nTrees: Int,
rfBatchSize: Int
): RandomForestModel = {
val labels = labelSource.getLabels(featureSource.sampleNames)
lazy val inputData = featureSource.features.zipWithIndex.cache()
val rf = new RandomForest(params)
val rfTrained = rf.batchTrain(inputData, labels, nTrees, rfBatchSize)
rfTrained
}
}

0 comments on commit 0fc736f

Please sign in to comment.