Skip to content

Commit

Permalink
STYLE: Format with scalamft (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
NickEdwards7502 committed Sep 19, 2024
1 parent 30732ba commit 3381e68
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/main/scala/au/csiro/variantspark/api/GetRFModel.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package au.csiro.variantspark.api

import au.csiro.variantspark.algo.{
RandomForest,
RandomForestModel,
RandomForestParams
}
import au.csiro.variantspark.algo.{RandomForest, RandomForestModel, RandomForestParams}
import au.csiro.variantspark.input.{FeatureSource, LabelSource}

/** Passes a trained random forest model back to the python wrapper
*/
object RFModelTrainer {
def trainModel(
featureSource: FeatureSource,
labelSource: LabelSource,
params: RandomForestParams,
nTrees: Int,
rfBatchSize: Int
): RandomForestModel = {

/** Trains a random forest model with provided data and parameters
*
* @param featureSource: FeatureSource object containing training X
* @param labelSource: LabelSource object containing training y
* @param params: Random forest hyperparameters (passed to model on initialisation)
* @param nTrees: Number of trees to compute (passed to model during training)
* @param rfBatchSize: Number of trees per batch (passed to model during training)
*
* @return Trained random forest model
*/
def trainModel(featureSource: FeatureSource, labelSource: LabelSource,
params: RandomForestParams, nTrees: Int, rfBatchSize: Int): RandomForestModel = {
val labels = labelSource.getLabels(featureSource.sampleNames)
lazy val inputData = featureSource.features.zipWithIndex.cache()
val rf = new RandomForest(params)
Expand Down

0 comments on commit 3381e68

Please sign in to comment.