Skip to content

Commit

Permalink
PUB-1103: Add option to automatically call SplitFrame from model buil…
Browse files Browse the repository at this point in the history
…der page if holdout_fraction is > 0.
  • Loading branch information
arnocandel committed Dec 30, 2014
1 parent 22e89ae commit e054707
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/main/java/water/Job.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package water;

import hex.FrameSplitter;
import static water.util.Utils.difference;
import static water.util.Utils.isEmpty;

Expand Down Expand Up @@ -808,6 +809,9 @@ public static abstract class ValidatedJob extends ModelJob {
@API(help = "Number of folds for cross-validation (if no validation data is specified)", filter = Default.class, json = true)
public int n_folds = 0;

@API(help = "Fraction of training data (from end) to hold out for validation (if no validation data is specified)", filter = Default.class, json = true)
public float holdout_fraction = 0;

@API(help = "Keep cross-validation dataset splits", filter = Default.class, json = true)
public boolean keep_cross_validation_splits = false;

Expand Down Expand Up @@ -918,6 +922,23 @@ final protected void genericCrossValidation(Frame[] splits, long[] offsets, int
}
_responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response";

if (holdout_fraction > 0) {
if (validation != null)
throw new IllegalArgumentException("Cannot specify both a holdout fraction and a validation frame.");
if (n_folds != 0)
throw new IllegalArgumentException("Cannot specify both a holdout fraction and a n-fold cross-validation.");

Log.info("Holding out last " + Utils.formatPct(holdout_fraction) + " of training data.");
FrameSplitter fs = new FrameSplitter(source, new float[]{1 - holdout_fraction});
H2O.submitTask(fs).join();
Frame[] splits = fs.getResult();
source = splits[0];
response = source.vecs()[rIndex];
validation = splits[1];
Log.warn("Allocating data split frames: " + source._key.toString() + " and " + validation._key.toString());
Log.warn("Both will be kept after the the model is trained. It's the user's responsibility to manage their lifetime.");
}

_train = selectVecs(source);
_names = new String[cols.length];
for( int i = 0; i < cols.length; i++ )
Expand Down

0 comments on commit e054707

Please sign in to comment.