-
Notifications
You must be signed in to change notification settings - Fork 397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure correct metrics despite model failures on some CV folds #404
Changes from 3 commits
ca2afde
cf373e8
e5b76c0
d9f07eb
5f92294
7f40ad2
8fe125e
7f8eb4b
85eeab2
917f13c
872fe35
2714cd5
842d016
39189a2
71649f1
17cb9fa
6687a34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,16 +57,23 @@ private[op] class OpCrossValidation[M <: Model[_], E <: Estimator[_]] | |
private def findBestModel( | ||
folds: Seq[ValidatedModel[E]] | ||
): ValidatedModel[E] = { | ||
val metrics = folds.map(_.metrics).reduce(_ + _) | ||
blas.dscal(metrics.length, 1.0 / numFolds, metrics, 1) | ||
val ValidatedModel(est, _, _, grid) = folds.head | ||
log.info(s"Average cross-validation for $est metrics: {}", metrics.toSeq.mkString(",")) | ||
val (bestMetric, bestIndex) = | ||
if (evaluator.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) | ||
else metrics.zipWithIndex.minBy(_._1) | ||
log.info(s"Best set of parameters:\n${grid(bestIndex)}") | ||
require(folds.map(_.model.uid).toSet.size == 1) // Should be called only on instances of the same model | ||
val gridCounts = folds.map(_.grids.map(_ -> 1).toMap).reduce(_ + _) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it the same as folds.flatMap(_.grids.map(_ -> 1)).sumByKey |
||
val maxFolds = gridCounts.maxBy(_._2)._2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better readability please replace |
||
val gridsIn = gridCounts.filter(_._2 == maxFolds).keySet | ||
val gridMetrics = folds.map(f => f.grids.zip(f.metrics).toMap).reduce(_ + _) | ||
.filterKeys(gridsIn.contains) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let us filter first, so we have less to reduce, maybe val gridMetrics = folds.flatMap(f => f.grids.zip(f.metrics))
.collect { case (pm, met) if gridsIn.contains(pm) => (pm, met / maxFolds) }
.sumByKey |
||
.map{ case (key, met) => key -> met / maxFolds} | ||
.toSeq | ||
val ((bestGrid, bestMetric), bestIndex) = | ||
if (evaluator.isLargerBetter) gridMetrics.zipWithIndex.maxBy(_._1._2) | ||
else gridMetrics.zipWithIndex.minBy(_._1._2) | ||
val ValidatedModel(est, _, _, _) = folds.head | ||
log.info(s"Average cross-validation for $est metrics: {}", gridMetrics.mkString(",")) | ||
log.info(s"Best set of parameters:\n$bestGrid") | ||
log.info(s"Best cross-validation metric: $bestMetric.") | ||
ValidatedModel(est, bestIndex, metrics, grid) | ||
val (grid, metrics) = gridMetrics.unzip | ||
ValidatedModel(est, bestIndex, metrics.toArray, grid.toArray) | ||
} | ||
|
||
private[op] override def validate[T]( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to be defensive, my concern we call it in iteration in a private method completely in the scope here, and by construction we already know that folds are for the same model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I can remove and just put a description on the method