-
Notifications
You must be signed in to change notification settings - Fork 1
/
AdaBoost.java
134 lines (116 loc) · 4.25 KB
/
AdaBoost.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/**
* @author Jonathan Lack and Eric Denovitzer
* COS 402 P6
* AdaBoost.java
*/
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.*;
/**
* This class implements the AdaBoost algorithm, with decision stumps serving as
* the weak learning algorithm.
*/
public class AdaBoost implements Classifier {
private static final int NUM_ROUNDS = 100;
private double weight[];
private double alpha[] = new double[NUM_ROUNDS];
private DecisionStump hypotheses[] = new DecisionStump[NUM_ROUNDS];
private String author = "Jonathan Lack and Eric Denovitzer";
private String description = "The AdaBoost algorithm with decision stumps " +
"as the weak learning algorithm.";
/**
*
*/
public AdaBoost(BinaryDataSet d) {
double negError = 0, sameError = 0;
double minError = 1;
DecisionStump best = null;
weight = new double[d.numTrainExs];
for (int i = 0; i < d.numTrainExs; i++)
weight[i] = (1 / d.numTrainExs);
for (int t = 0; t < NUM_ROUNDS; t++) {
// get best hypothesis
for (int i = 0; i < d.numAttrs; i++) {
DecisionStump stump = new DecisionStump(i, true);
DecisionStump negStump = new DecisionStump(i, false);
sameError = stumpError(stump, d);
negError = stumpError(negStump, d);
if (sameError < minError) {
minError = sameError;
best = stump;
}
if (negError < minError) {
minError = negError;
best = negStump;
}
}
// store hypothesis and its weight
hypotheses[t] = best;
alpha[t] = 0.5 * Math.log((1 - minError) / minError);
// update weights
for (int i = 0; i < d.numTrainExs; i++) {
if (best.predict(d.trainEx[i]) != d.trainLabel[i])
weight[i] *= Math.pow(Math.E, alpha[t]);
else
weight[i] *= Math.pow(Math.E, -alpha[t]);
}
weight = normalize(weight);
}
}
/**
* Predicts the result favored by majority of hypotheses (decision stumps)
*/
public int predict(int[] ex) {
double weightTrue = 0, weightFalse = 0;
for (int i = 0; i < hypotheses.length; i++) {
if (hypotheses[i].predict(ex) == 0)
weightFalse += alpha[i];
else
weightTrue += alpha[i];
}
if (weightTrue > weightFalse)
return 1;
return 0;
}
/** This method returns a description of the learning algorithm. */
public String algorithmDescription() {
return description;
}
/** This method returns the author of this program. */
public String author() {
return author;
}
private double[] normalize(double[] N) {
double[] weights = new double[N.length];
double total = 0;
for (int i = 0; i < N.length; i++)
total += N[i];
for (int i = 0; i < N.length; i++)
weights[i] = N[i] / total;
return weights;
}
private double stumpError(DecisionStump stump, BinaryDataSet d) {
double error = 0;
for (int i = 0; i < d.numTrainExs; i++) {
if (stump.predict(d.trainEx[i]) != d.trainLabel[i])
error += weight[i];
}
return error;
}
/**
* A simple main for testing this algorithm. This main reads a filestem from
* the command line, runs the learning algorithm on this dataset, and prints
* the test predictions to filestem.testout.
*/
public static void main(String argv[]) throws FileNotFoundException,
IOException {
if (argv.length < 1) {
System.err.println("argument: filestem");
return;
}
String filestem = argv[0];
BinaryDataSet d = new BinaryDataSet(filestem);
Classifier c = new AdaBoost(d);
d.printTestPredictions(c, filestem);
}
}