Skip to content

Commit

Permalink
Fix PUB-1118 - AutoEncoder POJO scoring.
Browse files Browse the repository at this point in the history
Also fixed a bug in initializing the activation layer for categoricals.
  • Loading branch information
arnocandel committed Jan 31, 2015
1 parent 4115ee0 commit 6505a1f
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 21 deletions.
2 changes: 1 addition & 1 deletion R/tests/Utils/shared_javapredict_DL.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ iris_train.hex <- h2o.uploadFile(conn, train)

heading("Creating DL model in H2O")
balance_classes <- if (exists("balance_classes")) balance_classes else FALSE
iris.dl.h2o <- h2o.deeplearning(x = x, y = y, data = iris_train.hex, hidden = hidden, balance_classes = balance_classes, classification = classification, activation = activation, epochs = epochs)
iris.dl.h2o <- h2o.deeplearning(x = x, y = y, data = iris_train.hex, seed = 1234, reproducible = T, hidden = hidden, balance_classes = balance_classes, classification = classification, activation = activation, epochs = epochs, autoencoder = autoencoder)
print(iris.dl.h2o)

heading("Downloading Java prediction model code from H2O")
Expand Down
16 changes: 10 additions & 6 deletions R/tests/testdir_javapredict/PredictCSV.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,14 @@ public static void main(String[] args) throws Exception{
}

// Print outputCSV column names.
output.write("predict");
for (int i = 0; i < model.getNumResponseClasses(); i++) {
output.write(",");
output.write(model.getDomainValues(model.getResponseIdx())[i]);
if (model.isAutoEncoder()) {
output.write(model.getHeader());
} else {
output.write("predict");
for (int i = 0; i < model.getNumResponseClasses(); i++) {
output.write(",");
output.write(model.getDomainValues(model.getResponseIdx())[i]);
}
}
output.write("\n");

Expand All @@ -132,7 +136,7 @@ public static void main(String[] args) throws Exception{
// Parse the CSV line. Don't handle quoted commas. This isn't a parser test.
String trimmedLine = line.trim();
String[] inputColumnsArray = trimmedLine.split(",");
int numInputColumns = model.getNames().length-1; // we do not need response !
int numInputColumns = model.isAutoEncoder() ? model.getNames().length : model.getNames().length-1; // we do not need response !
if (inputColumnsArray.length != numInputColumns) {
System.out.println("WARNING: Line " + lineno + " has " + inputColumnsArray.length + " columns (expected " + numInputColumns + ")");
}
Expand Down Expand Up @@ -201,7 +205,7 @@ public static void main(String[] args) throws Exception{
} else {
if (i > 0) output.write(",");
output.write(Double.toHexString(preds[i]));
if (!model.isClassifier()) break;
if (!model.isClassifier() && !model.isAutoEncoder()) break;
}
}
output.write("\n");
Expand Down
44 changes: 42 additions & 2 deletions R/tests/testdir_javapredict/runit_DL_javapredict_iris_large.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ train <- locate("smalldata/iris/iris_train.csv")
test <- locate("smalldata/iris/iris_test.csv")
x = c("sepal_len","sepal_wid","petal_len","petal_wid");
y = "species"
classification = T
activation = "Tanh"
epochs = 2
autoencoder = F

#----------------------------------------------------------------------
# Run the tests
#----------------------------------------------------------------------

# CLASSIFICATION
classification = T

# large network
hidden = c(500,500,500)
source('../Utils/shared_javapredict_DL.R')
Expand Down Expand Up @@ -58,8 +61,11 @@ source('../Utils/shared_javapredict_DL.R')
activation = "MaxoutWithDropout"
source('../Utils/shared_javapredict_DL.R')

# regression


# REGRESSION
classification = F

activation = "Tanh"
x = c("species","sepal_len","sepal_wid","petal_len")
y = c("petal_wid")
Expand All @@ -84,3 +90,37 @@ source('../Utils/shared_javapredict_DL.R')

activation = "MaxoutWithDropout"
source('../Utils/shared_javapredict_DL.R')



# AUTOENCODER
autoencoder = T

activation = "Rectifier"
hidden = c(5,3,2)
epochs = 3

x = c("species","sepal_len","sepal_wid","petal_len","petal_wid");
y = c("petal_wid") #ignored
source('../Utils/shared_javapredict_DL.R')

# only numericals
x = c("sepal_len","sepal_wid","petal_len","petal_wid");
y = c("petal_wid") #ignored
source('../Utils/shared_javapredict_DL.R')

# mixed numericals and categoricals
x = c("species","sepal_len","sepal_wid","petal_len","petal_wid");
y = c("petal_wid") #ignored
source('../Utils/shared_javapredict_DL.R')

activation = "Tanh"
x = c("species","sepal_len","sepal_wid","petal_len","petal_wid");
y = c("petal_wid") #ignored
source('../Utils/shared_javapredict_DL.R')

hidden = c(3)
activation = "Tanh"
x = c("species","sepal_len","sepal_wid","petal_len","petal_wid");
y = c("petal_wid") #ignored
source('../Utils/shared_javapredict_DL.R')
83 changes: 74 additions & 9 deletions src/main/java/hex/deeplearning/DeepLearningModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,21 @@ public String toStringAll() {
return sb.toString();
}

public String getHeader() {
assert get_params().autoencoder;
StringBuilder sb = new StringBuilder();
final int len = model_info().data_info().fullN();
String prefix = "reconstr_";
assert (model_info().data_info()._responses == 0);
String[] coefnames = model_info().data_info().coefNames();
assert (len == coefnames.length);
for (int c = 0; c < len; c++) {
if (c>0) sb.append(",");
sb.append(prefix + coefnames[c]);
}
return sb.toString();
}

/**
* This is an overridden version of Model.score(). Make either a prediction or a reconstruction.
* @param frame Test dataset
Expand Down Expand Up @@ -1248,6 +1263,10 @@ private double score_autoencoder(double[] data, float[] preds, Neurons[] neurons
DeepLearningTask.step(-1, neurons, model_info, false, null); // reconstructs data in expanded space
float[] in = neurons[0]._a.raw(); //input (expanded)
float[] out = neurons[neurons.length - 1]._a.raw(); //output (expanded)
// DEBUGGING
// Log.info(Arrays.toString(data));
// Log.info(Arrays.toString(in));
// Log.info(Arrays.toString(out));
assert(in.length == out.length);

// First normalize categorical reconstructions to be probabilities
Expand All @@ -1265,6 +1284,9 @@ private double score_autoencoder(double[] data, float[] preds, Neurons[] neurons
model_info().data_info().unScaleNumericals(out, out); //only modifies the numericals
System.arraycopy(out, 0, preds, 0, out.length); //copy reconstruction into preds
}
// DEBUGGING
// Log.info(Arrays.toString(preds));
// Log.info("");
return l2;
}

Expand Down Expand Up @@ -1673,6 +1695,17 @@ else if(e.num_folds > 0) {
return true;
}

@Override
protected SB toJavaNCLASSES(SB sb) {
return !get_params().autoencoder ? super.toJavaNCLASSES(sb) :
JCodeGen.toStaticVar(sb, "NCLASSES", model_info.units[model_info.units.length-1], "Number of output features (same as features of training data).");
}

@Override
protected void toJavaFillPreds0(SB bodySb) {
if (!get_params().autoencoder) super.toJavaFillPreds0(bodySb);
}

public void toJavaHtml(StringBuilder sb) {
sb.append("<br /><br /><div class=\"pull-right\"><a href=\"#\" onclick=\'$(\"#javaModel\").toggleClass(\"hide\");\'" +
"class=\'btn btn-inverse btn-mini\'>Java Model</a></div><br /><div class=\"hide\" id=\"javaModel\">");
Expand Down Expand Up @@ -1733,6 +1766,13 @@ else if( model_info().size() > 100000 ) {
layers[i] = neurons[i].units;
JCodeGen.toStaticVar(sb, "NEURONS", layers, "Number of neurons for each layer.");

if (get_params().autoencoder) {
sb.i(1).p("@Override public int getPredsSize() { return " + model_info.units[model_info.units.length-1] + "; }").nl();
sb.i(1).p("@Override public boolean isAutoEncoder() { return true; }").nl();
sb.i(1).p("@Override public String getHeader() { return \"" + getHeader() + "\"; }").nl();

}

// activation storage
sb.i(1).p("// Storage for neuron activation values.").nl();
sb.i(1).p("public static final float[][] ACTIVATION = new float[][] {").nl();
Expand Down Expand Up @@ -1826,6 +1866,7 @@ else if( model_info().size() > 100000 ) {
}
bodySb.i(0).p("}").nl();
}
bodySb.i().p("java.util.Arrays.fill(ACTIVATION[0],0);").nl();
if (cats > 0) {
bodySb.i().p("for (i=0; i<ncats; ++i) ACTIVATION[0][CATS[i]] = 1f;").nl();
}
Expand All @@ -1839,6 +1880,8 @@ else if( model_info().size() > 100000 ) {
boolean relu=(get_params().activation == DeepLearning.Activation.Rectifier || get_params().activation == DeepLearning.Activation.RectifierWithDropout);
boolean maxout=(get_params().activation == DeepLearning.Activation.Maxout || get_params().activation == DeepLearning.Activation.MaxoutWithDropout);

final String stopping = get_params().autoencoder ? "(i<=ACTIVATION.length-1)" : "(i<ACTIVATION.length-1)";

// make prediction: forward propagation
bodySb.i().p("for (i=1; i<ACTIVATION.length; ++i) {").nl();
bodySb.i(1).p("java.util.Arrays.fill(ACTIVATION[i],0f);").nl();
Expand All @@ -1854,20 +1897,20 @@ else if( model_info().size() > 100000 ) {
if (!maxout) {
bodySb.i(3).p("ACTIVATION[i][r] += ACTIVATION[i-1][c] * WEIGHT[i][r*cols+c];").nl();
} else {
bodySb.i(3).p("if (i<ACTIVATION.length-1) cmax = Math.max(ACTIVATION[i-1][c] * WEIGHT[i][r*cols+c], cmax);").nl();
bodySb.i(3).p("if " + stopping + " cmax = Math.max(ACTIVATION[i-1][c] * WEIGHT[i][r*cols+c], cmax);").nl();
bodySb.i(3).p("else ACTIVATION[i][r] += ACTIVATION[i-1][c] * WEIGHT[i][r*cols+c];").nl();
}
bodySb.i(2).p("}").nl();
if (maxout) {
bodySb.i(2).p("if (i<ACTIVATION.length-1) ACTIVATION[i][r] = Float.isInfinite(cmax) ? 0f : cmax;").nl();
bodySb.i(2).p("if "+ stopping +" ACTIVATION[i][r] = Float.isInfinite(cmax) ? 0f : cmax;").nl();
}
bodySb.i(2).p("ACTIVATION[i][r] += BIAS[i][r];").nl();
if (maxout) {
bodySb.i(2).p("if (i<ACTIVATION.length-1) rmax = Math.max(rmax, ACTIVATION[i][r]);").nl();
bodySb.i(2).p("if " + stopping + " rmax = Math.max(rmax, ACTIVATION[i][r]);").nl();
}
bodySb.i(1).p("}").nl();

if (!maxout) bodySb.i(1).p("if (i<ACTIVATION.length-1) {").nl();
if (!maxout) bodySb.i(1).p("if " + stopping + " {").nl();
bodySb.i(2).p("for (int r=0; r<ACTIVATION[i].length; ++r) {").nl();
if (tanh) {
bodySb.i(3).p("ACTIVATION[i][r] = 1f - 2f / (1f + (float)Math.exp(2*ACTIVATION[i][r]));").nl();
Expand All @@ -1877,7 +1920,7 @@ else if( model_info().size() > 100000 ) {
bodySb.i(3).p("if (rmax > 1 ) ACTIVATION[i][r] /= rmax;").nl();
}
if (get_params().hidden_dropout_ratios != null) {
if (maxout) bodySb.i(1).p("if (i<ACTIVATION.length-1) {").nl();
if (maxout) bodySb.i(1).p("if " + stopping + " {").nl();
bodySb.i(3).p("ACTIVATION[i][r] *= HIDDEN_DROPOUT_RATIOS[i-1];").nl();
if (maxout) bodySb.i(1).p("}").nl();
}
Expand All @@ -1903,18 +1946,40 @@ else if( model_info().size() > 100000 ) {
bodySb.i(2).p("}").nl();
bodySb.i(1).p("}").nl();
bodySb.i().p("}").nl();
} else {
} else if (!get_params().autoencoder) { //Regression
bodySb.i().p("}").nl();
bodySb.i().p("if (i == ACTIVATION.length-1) {").nl();
// regression: set preds[1], FillPreds0 will put it into preds[0]
if (model_info().data_info()._normRespMul != null) {
bodySb.i().p("preds[1] = (float) (ACTIVATION[ACTIVATION.length-1][0] / NORMRESPMUL[0] + NORMRESPSUB[0]);").nl();
bodySb.i().p("preds[1] = (float) (ACTIVATION[i][0] / NORMRESPMUL[0] + NORMRESPSUB[0]);").nl();
}
else {
bodySb.i().p("preds[1] = ACTIVATION[ACTIVATION.length-1][0];").nl();
bodySb.i().p("preds[1] = ACTIVATION[i][0];").nl();
}
bodySb.i().p("if (Float.isNaN(preds[1])) throw new RuntimeException(\"Predicted regression target NaN!\");").nl();
bodySb.i().p("}").nl();
} else { //AutoEncoder
bodySb.i(1).p("if (i == ACTIVATION.length-1) {").nl();
bodySb.i(2).p("for (int r=0; r<ACTIVATION[i].length; r++) {").nl();
bodySb.i(3).p("if (Float.isNaN(ACTIVATION[i][r]))").nl();
bodySb.i(4).p("throw new RuntimeException(\"Numerical instability, reconstructed NaN.\");").nl();
bodySb.i(3).p("preds[r] = ACTIVATION[i][r];").nl();
bodySb.i(2).p("}").nl();
if (model_info().data_info()._nums > 0) {
int ns = model_info().data_info().numStart();
bodySb.i(2).p("for (int k=" + ns + "; k<" + model_info().data_info().fullN() + "; ++k) {").nl();
bodySb.i(3).p("preds[k] = preds[k] / (float)NORMMUL[k-" + ns + "] + (float)NORMSUB[k-" + ns + "];").nl();
bodySb.i(2).p("}").nl();
}
bodySb.i(1).p("}").nl();
bodySb.i().p("}").nl();
// DEBUGGING
// bodySb.i().p("System.out.println(java.util.Arrays.toString(data));").nl();
// bodySb.i().p("System.out.println(java.util.Arrays.toString(ACTIVATION[0]));").nl();
// bodySb.i().p("System.out.println(java.util.Arrays.toString(ACTIVATION[ACTIVATION.length-1]));").nl();
// bodySb.i().p("System.out.println(java.util.Arrays.toString(preds));").nl();
// bodySb.i().p("System.out.println(\"\");").nl();
}

fileCtxSb.p(model);
toJavaUnifyPreds(bodySb);
toJavaFillPreds0(bodySb);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/water/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ protected SB toJavaSuper( SB sb ) {
return sb;
}
private SB toJavaNAMES( SB sb ) { return JCodeGen.toStaticVar(sb, "NAMES", _names, "Names of columns used by model."); }
private SB toJavaNCLASSES( SB sb ) { return isClassifier() ? JCodeGen.toStaticVar(sb, "NCLASSES", nclasses(), "Number of output classes included in training data response column.") : sb; }
protected SB toJavaNCLASSES( SB sb ) { return isClassifier() ? JCodeGen.toStaticVar(sb, "NCLASSES", nclasses(), "Number of output classes included in training data response column.") : sb; }
private SB toJavaDOMAINS( SB sb, SB fileContextSB ) {
sb.nl();
sb.ii(1);
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/water/api/DeepLearningModelView.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ public static Response redirect(Request req, Key modelKey) {
}
@Override protected String serveJava() {
deeplearning_model = UKV.get(_modelKey);
if (deeplearning_model!=null
&& !deeplearning_model.get_params().autoencoder) //not yet implemented
if (deeplearning_model!=null)
return deeplearning_model.toJava();
else
return "";
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/water/genmodel/GeneratedModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public static boolean grpContains(byte[] gcmp, int offset, int idx) {
return (gcmp[idx >> 3] & ((byte)1 << (idx % 8))) != 0;
}

public String getHeader() { return null; }
public boolean isAutoEncoder() { return false; }
@Override public int getColIdx(String name) {
String[] names = getNames();
for (int i=0; i<names.length; i++) if (names[i].equals(name)) return i;
Expand Down

0 comments on commit 6505a1f

Please sign in to comment.