Skip to content

Commit

Permalink
Included cross-validation procedure
Browse files Browse the repository at this point in the history
Now, it is able to execute cross-validation procedure if no test set is provided. It needs the number of folds to execute.
  • Loading branch information
i02momuj committed Feb 27, 2019
1 parent b95ae06 commit 43c86af
Show file tree
Hide file tree
Showing 27 changed files with 522 additions and 225 deletions.
37 changes: 34 additions & 3 deletions src/executeMulan/ExecuteMulanAlgorithm.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mulan.data.MultiLabelInstances;
import mulan.evaluation.Evaluation;
import mulan.evaluation.Evaluator;
import mulan.evaluation.MultipleEvaluation;
import mulan.evaluation.measure.*;
import mulan.evaluation.measures.regression.example.ExampleBasedRMaxSE;
import mulan.evaluation.measures.regression.macro.MacroMAE;
Expand All @@ -24,16 +25,25 @@ public class ExecuteMulanAlgorithm {
public MultiLabelInstances testSet = null;
public Evaluator eval = new Evaluator();
public Evaluation results;
public MultipleEvaluation mResults;
public List<Measure> measures = new ArrayList<Measure>();
public long time_in, time_fin, total_time;
public int nFolds;

public void prepareExecution(String tvalue, String Tvalue, String xvalue, String ovalue) {
public void prepareExecution(String tvalue, String Tvalue, String xvalue, String ovalue, int fvalue) {
try {
trainingSet = new MultiLabelInstances(tvalue, xvalue);
testSet = new MultiLabelInstances(Tvalue, xvalue);

pw = new PrintWriter(new FileWriter(ovalue, true));
if(fvalue <= 0) {
testSet = new MultiLabelInstances(Tvalue, xvalue);
}
else {
testSet = null;
}
nFolds = fvalue;

pw = new PrintWriter(new FileWriter(ovalue, true));

} catch (Exception e) {
e.printStackTrace();
}
Expand Down Expand Up @@ -77,6 +87,27 @@ public void printResults(String tvalue, boolean lvalue, String algorithm) throws
pw.println();
}

public void printResultsCV(String tvalue, boolean lvalue, String algorithm) throws Exception {
String [] p = tvalue.split("\\/");
String datasetName = p[p.length-1].split("\\.")[0];
pw.print(algorithm + "_" + datasetName + ";");

for(Measure m : measures) {
pw.print(mResults.getMean(m.getName()) + ";");

if((lvalue) && (m.getClass().getName().contains("Macro")))
{
for(int l=0; l<trainingSet.getNumLabels(); l++)
{
pw.print(mResults.getMean(m.getName(), l) + ";");
}
}
}

pw.print(total_time + ";");
pw.println();
}

public void execute(String tvalue, String Tvalue, String xvalue, String ovalue, boolean lvalue) {
System.out.println("Method not implemented");
System.exit(-1);
Expand Down
82 changes: 51 additions & 31 deletions src/executeMulan/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ public static void showUse()
System.out.println("Parameters:");

//Files
System.out.println("\t -t Train .arff file");
System.out.println("\t -T Test .arff file");
System.out.println("\t -t Train .arff file (full data in case of CV)");
System.out.println("\t -T Test .arff file (do not use if CV)");
System.out.println("\t -x Labels .xml file");

//Algorithm
Expand Down Expand Up @@ -47,6 +47,9 @@ public static void showUse()
System.out.println("\t\t\tST -> Single Target");
System.out.println("\t\t\tSST -> Stacked ST");

//Number of folds if cross-validation
System.out.println("\t -f Number of folds for cross-validation. If not set, do not perform CV.");

//Number of seed numbers == Number of executions
System.out.println("\t -i Number of random seeds");

Expand All @@ -59,12 +62,14 @@ public static void showUse()

public static void main(String [] args)
{
String tvalue=null, Tvalue=null, xvalue=null, avalue=null, ovalue=null, lvalueStr=null, ivalueStr=null;
String tvalue=null, Tvalue=null, xvalue=null, avalue=null, ovalue=null, lvalueStr=null, ivalueStr=null, fvalueStr=null;

//By default, macro measures are not shown for each label
boolean lvalue = false;
//By default, 10 random seeds are used
int ivalue = 10;
//By default, cross-validation is not performed
int fvalue = -1;

try {
tvalue = Utils.getOption("t", args);
Expand All @@ -73,14 +78,11 @@ public static void main(String [] args)
avalue = Utils.getOption("a", args);
ovalue = Utils.getOption("o", args);

if((tvalue.length() == 0) || (Tvalue.length() == 0) || (xvalue.length() == 0) || (avalue.length() == 0) || (ovalue.length() == 0))
if((tvalue.length() == 0) || (xvalue.length() == 0) || (avalue.length() == 0) || (ovalue.length() == 0))
{
if(tvalue.length() == 0) {
System.out.println("Please enter the train dataset filename.");
}
if(Tvalue.length() == 0) {
System.out.println("Please enter the test dataset filename.");
}
if(xvalue.length() == 0) {
System.out.println("Please enter the xml dataset filename.");
}
Expand Down Expand Up @@ -109,6 +111,24 @@ public static void main(String [] args)
if(ivalueStr.length() != 0) {
ivalue = Integer.parseInt(ivalueStr);
}

fvalueStr = Utils.getOption("f", args);
if(fvalueStr.length() != 0) {
fvalue = Integer.parseInt(fvalueStr);
}

if(fvalue > 0 && (Tvalue.length() > 0)) {
System.out.println("Including both test file and CV procedure is not allowed.");
showUse();
System.exit(1);
}
else if(fvalue <= 0 && (Tvalue.length() == 0)) {
System.out.println("Please set test file or number of folds for CV procedure.");
showUse();
System.exit(1);
}


} catch (Exception e) {
showUse();
System.exit(1);
Expand All @@ -118,102 +138,102 @@ public static void main(String [] args)
if(avalue.equalsIgnoreCase("AdaBoostMH"))
{
ExecuteAdaBoostMH a = new ExecuteAdaBoostMH();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("BPMLL"))
{
ExecuteBPMLL a = new ExecuteBPMLL();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("BR"))
{
ExecuteBR a = new ExecuteBR();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("CC"))
{
ExecuteCC a = new ExecuteCC();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("CLR"))
{
ExecuteCLR a = new ExecuteCLR();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("EBR"))
{
ExecuteEBR a = new ExecuteEBR();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("ECC"))
{
ExecuteECC a = new ExecuteECC();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("EPS"))
{
ExecuteEPS a = new ExecuteEPS();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("EPS_1"))
{
ExecuteEPS_1 a = new ExecuteEPS_1();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("CDE"))
{
ExecuteCDE a = new ExecuteCDE();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("HOMER"))
{
ExecuteHOMER a = new ExecuteHOMER();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("IBLR"))
{
ExecuteIBLR a = new ExecuteIBLR();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("LP"))
{
ExecuteLP a = new ExecuteLP();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("LPBR"))
{
ExecuteLPBR a = new ExecuteLPBR();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("MLKNN"))
{
ExecuteMLkNN a = new ExecuteMLkNN();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("MLS"))
{
ExecuteMLS a = new ExecuteMLS();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("PS"))
{
ExecutePS a = new ExecutePS();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.toUpperCase().equalsIgnoreCase("RAKEL"))
{
ExecuteRAkEL a = new ExecuteRAkEL();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("RFPCT"))
{
ExecuteRFPCT a = new ExecuteRFPCT();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("ERC"))
{
ExecuteERC a = new ExecuteERC();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
// else if(avalue.equalsIgnoreCase("MORF"))
// {
Expand All @@ -223,22 +243,22 @@ else if(avalue.equalsIgnoreCase("ERC"))
else if(avalue.equalsIgnoreCase("RC"))
{
ExecuteRC a = new ExecuteRC();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("RLC"))
{
ExecuteRLC a = new ExecuteRLC();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, ivalue, fvalue);
}
else if(avalue.equalsIgnoreCase("ST"))
{
ExecuteST a = new ExecuteST();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else if(avalue.equalsIgnoreCase("SST"))
{
ExecuteSST a = new ExecuteSST();
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue);
a.execute(tvalue, Tvalue, xvalue, ovalue, lvalue, fvalue);
}
else
{
Expand Down
25 changes: 18 additions & 7 deletions src/executeMulan/mlc/ExecuteAdaBoostMH.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

public class ExecuteAdaBoostMH extends ExecuteMulanAlgorithm {

public void execute (String tvalue, String Tvalue, String xvalue, String ovalue, boolean lvalue)
public void execute (String tvalue, String Tvalue, String xvalue, String ovalue, boolean lvalue, int fvalue)
{
try{
prepareExecution(tvalue, Tvalue, xvalue, ovalue);
prepareExecution(tvalue, Tvalue, xvalue, ovalue, fvalue);

AdaBoostMH learner = null;

Expand All @@ -16,10 +16,16 @@ public void execute (String tvalue, String Tvalue, String xvalue, String ovalue,
time_in = System.currentTimeMillis();

learner = new AdaBoostMH();
learner.build(trainingSet);

measures = prepareMeasuresClassification(trainingSet);
results = eval.evaluate(learner, testSet, measures);

measures = prepareMeasuresClassification(trainingSet);

if(nFolds > 0) {
mResults = eval.crossValidate(learner, trainingSet, measures, nFolds);
}
else {
learner.build(trainingSet);
results = eval.evaluate(learner, testSet, measures);
}

time_fin = System.currentTimeMillis();

Expand All @@ -28,7 +34,12 @@ public void execute (String tvalue, String Tvalue, String xvalue, String ovalue,
System.out.println("Execution time (ms): " + total_time);

printHeader(lvalue);
printResults(Tvalue, lvalue, "AdaBoost.MH");
if(nFolds <= 0) {
printResults(Tvalue, lvalue, "AdaBoost.MH");
}
else {
printResultsCV(tvalue, lvalue, "AdaBoost.MH");
}
}
catch(Exception e1)
{
Expand Down
24 changes: 17 additions & 7 deletions src/executeMulan/mlc/ExecuteBPMLL.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

public class ExecuteBPMLL extends ExecuteMulanAlgorithm {

public void execute (String tvalue, String Tvalue, String xvalue, String ovalue, boolean lvalue, int nIter)
public void execute (String tvalue, String Tvalue, String xvalue, String ovalue, boolean lvalue, int nIter, int fvalue)
{
try{
prepareExecution(tvalue, Tvalue, xvalue, ovalue);
prepareExecution(tvalue, Tvalue, xvalue, ovalue, fvalue);

BPMLL learner = null;

Expand All @@ -18,10 +18,15 @@ public void execute (String tvalue, String Tvalue, String xvalue, String ovalue,

learner = new BPMLL(i*10);

learner.build(trainingSet);

measures = prepareMeasuresClassification(trainingSet);
results = eval.evaluate(learner, testSet, measures);
measures = prepareMeasuresClassification(trainingSet);

if(nFolds > 0) {
mResults = eval.crossValidate(learner, trainingSet, measures, nFolds);
}
else {
learner.build(trainingSet);
results = eval.evaluate(learner, testSet, measures);
}

time_fin = System.currentTimeMillis();

Expand All @@ -34,7 +39,12 @@ public void execute (String tvalue, String Tvalue, String xvalue, String ovalue,
printHeader(lvalue);
}

printResults(Tvalue, lvalue, "BPMLL");
if(nFolds <= 0) {
printResults(Tvalue, lvalue, "BPMLL");
}
else {
printResultsCV(tvalue, lvalue, "BPMLL");
}

}//End for

Expand Down
Loading

0 comments on commit 43c86af

Please sign in to comment.