Skip to content

Commit

Permalink
Fix PUB-1101 by using GBM's variable importance method for DRF:
Browse files Browse the repository at this point in the history
Follow tree split decisions instead of randomizing OOB data columns.
  • Loading branch information
arnocandel committed Dec 29, 2014
1 parent d268fd7 commit 0026200
Showing 1 changed file with 92 additions and 53 deletions.
145 changes: 92 additions & 53 deletions src/main/java/hex/drf/DRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class myClassFilter extends DRFCopyDataBoolean { myClassFilter() { super("source
private transient TreeMeasures _treeMeasuresOnOOB;
// Tree votes/SSE per individual features on permutated OOB rows
private transient TreeMeasures[/*features*/] _treeMeasuresOnSOOB;
// Variable importance beased on tree split decisions
private transient float[/*nfeatures*/] _improvPerVar;

/** DRF model holding serialized tree and implementing logic for scoring a row */
public static class DRFModel extends DTree.TreeModel {
Expand Down Expand Up @@ -275,14 +277,15 @@ public DRFModel atomic(DRFModel m) {
if( Job.isRunning(self()) ) { // do not perform final scoring and finish
model = doScoring(model, fr, ktrees, tid, tstats, true, !hasValidation(), build_tree_one_node);
// Make sure that we did not miss any votes
assert !importance || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0/*variable*/].npredictors() : "Missing some tree votes in variable importance voting?!";
// assert !importance || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0/*variable*/].npredictors() : "Missing some tree votes in variable importance voting?!";
}

return model;
}

private void initTreeMeasurements() {
assert importance : "Tree votes should be initialized only if variable importance is requested!";
_improvPerVar = new float[_ncols];
// Preallocate tree votes
if (classification) {
_treeMeasuresOnOOB = new TreeVotes(ntrees);
Expand All @@ -295,60 +298,95 @@ private void initTreeMeasurements() {
}
}

/** On-the-fly version for varimp. After generation a new tree, its tree votes are collected on shuffled
* OOB rows and variable importance is recomputed.
* <p>
* The <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says:
* <cite>
* "In every tree grown in the forest, put down the oob cases and count the number of votes cast for the correct class.
* Now randomly permute the values of variable m in the oob cases and put these cases down the tree.
* Subtract the number of votes for the correct class in the variable-m-permuted oob data from the number of votes
* for the correct class in the untouched oob data.
* The average of this number over all trees in the forest is the raw importance score for variable m."
* </cite>
* </p>
* */
@Override
protected VarImp doVarImpCalc(final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) {
// Check if we have already serialized 'ktrees'-trees in the model
assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid="+tid;
assert _treeMeasuresOnOOB.npredictors()-1 == tid : "Tree votes over OOB rows for this tree (var ktrees) were not found!";
// Compute tree votes over shuffled data
final CompressedTree[/*nclass*/] theTree = model.ctree(tid); // get the last tree FIXME we should pass only keys
final int nclasses = model.nclasses();
Futures fs = new Futures();
for (int var=0; var<_ncols; var++) {
final int variable = var;
H2OCountedCompleter task4var = classification ? new H2OCountedCompleter() {
@Override public void compute2() {
// Compute this tree votes over all data over given variable
TreeVotes cd = TreeMeasuresCollector.collectVotes(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
assert cd.npredictors() == 1;
asVotes(_treeMeasuresOnSOOB[variable]).append(cd);
tryComplete();
}
} : /* regression */ new H2OCountedCompleter() {
@Override public void compute2() {
// Compute this tree votes over all data over given variable
TreeSSE cd = TreeMeasuresCollector.collectSSE(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
assert cd.npredictors() == 1;
asSSE(_treeMeasuresOnSOOB[variable]).append(cd);
tryComplete();
}
};
fs.add(task4var);
H2O.submitTask(task4var); // Fork computation
// /** On-the-fly version for varimp. After generation a new tree, its tree votes are collected on shuffled
// * OOB rows and variable importance is recomputed.
// * <p>
// * The <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says:
// * <cite>
// * "In every tree grown in the forest, put down the oob cases and count the number of votes cast for the correct class.
// * Now randomly permute the values of variable m in the oob cases and put these cases down the tree.
// * Subtract the number of votes for the correct class in the variable-m-permuted oob data from the number of votes
// * for the correct class in the untouched oob data.
// * The average of this number over all trees in the forest is the raw importance score for variable m."
// * </cite>
// * </p>
// * */
// @Override
// protected VarImp doVarImpCalc(final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) {
// // Check if we have already serialized 'ktrees'-trees in the model
// assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid="+tid;
// assert _treeMeasuresOnOOB.npredictors()-1 == tid : "Tree votes over OOB rows for this tree (var ktrees) were not found!";
// // Compute tree votes over shuffled data
// final CompressedTree[/*nclass*/] theTree = model.ctree(tid); // get the last tree FIXME we should pass only keys
// final int nclasses = model.nclasses();
// Futures fs = new Futures();
// for (int var=0; var<_ncols; var++) {
// final int variable = var;
// H2OCountedCompleter task4var = classification ? new H2OCountedCompleter() {
// @Override public void compute2() {
// // Compute this tree votes over all data over given variable
// TreeVotes cd = TreeMeasuresCollector.collectVotes(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
// assert cd.npredictors() == 1;
// asVotes(_treeMeasuresOnSOOB[variable]).append(cd);
// tryComplete();
// }
// } : /* regression */ new H2OCountedCompleter() {
// @Override public void compute2() {
// // Compute this tree votes over all data over given variable
// TreeSSE cd = TreeMeasuresCollector.collectSSE(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
// assert cd.npredictors() == 1;
// asSSE(_treeMeasuresOnSOOB[variable]).append(cd);
// tryComplete();
// }
// };
// fs.add(task4var);
// H2O.submitTask(task4var); // Fork computation
// }
// fs.blockForPending(); // Wait for results
// // Compute varimp for individual features (_ncols)
// final float[] varimp = new float[_ncols]; // output variable importance
// final float[] varimpSD = new float[_ncols]; // output variable importance sd
// for (int var=0; var<_ncols; var++) {
// double[/*2*/] imp = classification ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB)) : asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB));
// varimp [var] = (float) imp[0];
// varimpSD[var] = (float) imp[1];
// }
// return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees());
// }

/** Compute relative variable importance for RF model.
*
* See (45), (35) formulas in Friedman: Greedy Function Approximation: A Gradient boosting machine.
* Algo used here can be used for computation individual importance of features per output class. */
@Override protected VarImp doVarImpCalc(DRFModel model, DTree[] ktrees, int tid, Frame validationFrame, boolean scale) {
assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "varimp computation expect model with already serialized trees: tid="+tid;
// Iterates over k-tree
for (DTree t : ktrees) { // Iterate over trees
if (t!=null) {
for (int n = 0; n< t.len()-t.leaves; n++)
if (t.node(n) instanceof DecidedNode) { // it is split node
DTree.Split split = t.decided(n)._split;
if (split.col()!=-1) // Skip impossible splits ~ leafs
_improvPerVar[split.col()] += split.improvement(); // least squares improvement
}
}
}
// Compute variable importance for all trees in model
float[] varimp = new float[model.nfeatures()];

int ntreesTotal = model.ntrees() * model.nclasses();
int maxVar = 0;
for (int var=0; var<_improvPerVar.length; var++) {
varimp[var] = _improvPerVar[var] / ntreesTotal;
if (varimp[var] > varimp[maxVar]) maxVar = var;
}
fs.blockForPending(); // Wait for results
// Compute varimp for individual features (_ncols)
final float[] varimp = new float[_ncols]; // output variable importance
final float[] varimpSD = new float[_ncols]; // output variable importance sd
for (int var=0; var<_ncols; var++) {
double[/*2*/] imp = classification ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB)) : asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB));
varimp [var] = (float) imp[0];
varimpSD[var] = (float) imp[1];
// scale varimp to scale 0..100
if (scale) {
float maxVal = varimp[maxVar];
for (int var=0; var<varimp.length; var++) varimp[var] /= maxVal;
}
return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees());

return new VarImp.VarImpRI(varimp);
}

@Override public boolean supportsBagging() { return true; }
Expand Down Expand Up @@ -684,6 +722,7 @@ static class Sample extends MRTask2<Sample> {
@Override public void crossValidate(Frame[] splits, Frame[] cv_preds, long[] offsets, int i) {
// Train a clone with slightly modified parameters (to account for cross-validation)
DRF cv = (DRF) this.clone();
// cv.importance = false; //Don't compute variable importance for N CV-folds
cv.genericCrossValidation(splits, offsets, i);
cv_preds[i] = ((DRFModel) UKV.get(cv.dest())).score(cv.validation); // cv_preds is escaping the context of this function and needs to be DELETED by the caller!!!
}
Expand Down

0 comments on commit 0026200

Please sign in to comment.