Skip to content

Commit

Permalink
Ova Multi class codegen support (dotnet#321)
Browse files Browse the repository at this point in the history
* dummy

* multiova implementation

* fix tests

* remove inclusion list

* fix tests and console helper
  • Loading branch information
srsaggam authored Mar 27, 2019
1 parent 73d141b commit 64f5ba1
Show file tree
Hide file tree
Showing 18 changed files with 452 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ namespace TestNamespace.Train
}


public static void PrintBinaryClassificationFoldsAverageMetrics(
TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>[] crossValResults)
public static void PrintBinaryClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>[] crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

Expand All @@ -77,8 +76,21 @@ namespace TestNamespace.Train

}

public static void PrintMulticlassClassificationFoldsAverageMetrics(
TrainCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[] crossValResults)
public static void PrintMultiClassClassificationMetrics(MultiClassClassifierMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
Console.WriteLine($"************************************************************");
}

public static void PrintMulticlassClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[] crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//*****************************************************************************************
//* *
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
//* *
//*****************************************************************************************

using System;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.Data.DataView;
using TestNamespace.Model.DataModels;

namespace TestNamespace.Train
{
class Program
{
private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv";
private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv";
private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip";

static void Main(string[] args)
{
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
MLContext mlContext = new MLContext(seed: 1);

// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: true);

IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
path: TEST_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: true);
// Build training pipeline
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);

// Train Model
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);

// Evaluate quality of Model
EvaluateModel(mlContext, mlModel, testDataView);

// Save model
SaveModel(mlContext, mlModel, MODEL_FILEPATH);

Console.WriteLine("=============== End of process, hit any key to finish ===============");
Console.ReadKey();
}

public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
{
// Data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" })
.AppendCacheCheckpoint(mlContext);

// Set the training algorithm
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(numLeaves: 2, labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label");
var trainingPipeline = dataProcessPipeline.Append(trainer);

return trainingPipeline;
}

public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
Console.WriteLine("=============== Training model ===============");

ITransformer model = trainingPipeline.Fit(trainingDataView);

Console.WriteLine("=============== End of training process ===============");
return model;
}

private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView)
{
// Evaluate the model and show accuracy stats
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
IDataView predictions = mlModel.Transform(testDataView);
var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score");
ConsoleHelper.PrintMultiClassClassificationMetrics(metrics);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath)
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model ===============");
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(mlModel, fs);

Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}

public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;

string fullPath = Path.Combine(assemblyFolderPath, relativePath);

return fullPath;
}
}
}
77 changes: 73 additions & 4 deletions src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using ApprovalTests;
using ApprovalTests.Reporters;
Expand All @@ -18,7 +19,8 @@ namespace mlnet.Test
[UseReporter(typeof(DiffReporter))]
public class ConsoleCodeGeneratorTests
{
private Pipeline pipeline;
private Pipeline mockedPipeline;
private Pipeline mockedOvaPipeline;
private ColumnInferenceResults columnInference = default;
private string namespaceValue = "TestNamespace";

Expand Down Expand Up @@ -46,6 +48,29 @@ public void ConsoleHelperFileContentTest()
Approvals.Verify(result.Item3);
}

[TestMethod]
[UseReporter(typeof(DiffReporter))]
[MethodImpl(MethodImplOptions.NoInlining)]
public void TrainProgramCSFileContentOvaTest()
{
(Pipeline pipeline,
ColumnInferenceResults columnInference) = GetMockedOvaPipelineAndInference();

var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings()
{
MlTask = TaskKind.MulticlassClassification,
OutputBaseDir = null,
OutputName = "MyNamespace",
TrainDataset = "x:\\dummypath\\dummy_train.csv",
TestDataset = "x:\\dummypath\\dummy_test.csv",
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateTrainProjectContents(namespaceValue, typeof(float));

Approvals.Verify(result.Item1);
}

[TestMethod]
[UseReporter(typeof(DiffReporter))]
[MethodImpl(MethodImplOptions.NoInlining)]
Expand All @@ -70,6 +95,7 @@ public void TrainProgramCSFileContentTest()
}



[TestMethod]
[UseReporter(typeof(DiffReporter))]
[MethodImpl(MethodImplOptions.NoInlining)]
Expand Down Expand Up @@ -211,7 +237,7 @@ public void PredictProjectFileContentTest()

private (Pipeline, ColumnInferenceResults) GetMockedPipelineAndInference()
{
if (pipeline == null)
if (mockedPipeline == null)
{
MLContext context = new MLContext();
// same learners with different hyperparams
Expand All @@ -224,7 +250,48 @@ public void PredictProjectFileContentTest()
var inferredPipeline1 = new SuggestedPipeline(transforms1, new List<SuggestedTransform>(), trainer1, context, true);
var inferredPipeline2 = new SuggestedPipeline(transforms2, new List<SuggestedTransform>(), trainer2, context, false);

this.pipeline = inferredPipeline1.ToPipeline();
this.mockedPipeline = inferredPipeline1.ToPipeline();
var textLoaderArgs = new TextLoader.Options()
{
Columns = new[] {
new TextLoader.Column("Label", DataKind.Boolean, 0),
new TextLoader.Column("col1", DataKind.Single, 1),
new TextLoader.Column("col2", DataKind.Single, 0),
new TextLoader.Column("col3", DataKind.String, 0),
new TextLoader.Column("col4", DataKind.Int32, 0),
new TextLoader.Column("col5", DataKind.UInt32, 0),
},
AllowQuoting = true,
AllowSparse = true,
HasHeader = true,
Separators = new[] { ',' }
};

this.columnInference = new ColumnInferenceResults()
{
TextLoaderOptions = textLoaderArgs,
ColumnInformation = new ColumnInformation() { LabelColumn = "Label" }
};
}
return (mockedPipeline, columnInference);
}

private (Pipeline, ColumnInferenceResults) GetMockedOvaPipelineAndInference()
{
if (mockedOvaPipeline == null)
{
MLContext context = new MLContext();
// same learners with different hyperparams
var hyperparams1 = new Microsoft.ML.Auto.ParameterSet(new List<Microsoft.ML.Auto.IParameterValue>() { new LongParameterValue("NumLeaves", 2) });
var hyperparams2 = new Microsoft.ML.Auto.ParameterSet(new List<Microsoft.ML.Auto.IParameterValue>() { new LongParameterValue("NumLeaves", 6) });
var trainer1 = new SuggestedTrainer(context, new FastForestOvaExtension(), new ColumnInformation(), hyperparams1);
var trainer2 = new SuggestedTrainer(context, new FastForestOvaExtension(), new ColumnInformation(), hyperparams2);
var transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
var transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
var inferredPipeline1 = new SuggestedPipeline(transforms1, new List<SuggestedTransform>(), trainer1, context, true);
var inferredPipeline2 = new SuggestedPipeline(transforms2, new List<SuggestedTransform>(), trainer2, context, false);

this.mockedOvaPipeline = inferredPipeline1.ToPipeline();
var textLoaderArgs = new TextLoader.Options()
{
Columns = new[] {
Expand All @@ -241,13 +308,15 @@ public void PredictProjectFileContentTest()
Separators = new[] { ',' }
};


this.columnInference = new ColumnInferenceResults()
{
TextLoaderOptions = textLoaderArgs,
ColumnInformation = new ColumnInformation() { LabelColumn = "Label" }
};

}
return (pipeline, columnInference);
return (mockedOvaPipeline, columnInference);
}
}
}
6 changes: 3 additions & 3 deletions src/mlnet.Test/CodeGenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void TrainerGeneratorBasicAdvancedParameterTest()
string expectedTrainer = "LightGbm(new Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})";
string expectedUsing = "using Microsoft.ML.LightGBM;\r\n";
Assert.AreEqual(expectedTrainer, actual.Item1);
Assert.AreEqual(expectedUsing, actual.Item2);
Assert.AreEqual(expectedUsing, actual.Item2[0]);
}

[TestMethod]
Expand Down Expand Up @@ -80,7 +80,7 @@ public void TransformGeneratorUsingTest()
string expectedTransform = "Categorical.OneHotEncoding(new []{new OneHotEncodingEstimator.ColumnOptions(\"Label\",\"Label\")})";
var expectedUsings = "using Microsoft.ML.Transforms;\r\n";
Assert.AreEqual(expectedTransform, actual[0].Item1);
Assert.AreEqual(expectedUsings, actual[0].Item2);
Assert.AreEqual(expectedUsings, actual[0].Item2[0]);
}

[TestMethod]
Expand Down Expand Up @@ -130,7 +130,7 @@ public void TrainerComplexParameterTest()
string expectedTrainer = "LightGbm(new Options(){Booster=new TreeBooster(){},LabelColumn=\"Label\",FeatureColumn=\"Features\"})";
var expectedUsings = "using Microsoft.ML.LightGBM;\r\n";
Assert.AreEqual(expectedTrainer, actual.Item1);
Assert.AreEqual(expectedUsings, actual.Item2);
Assert.AreEqual(expectedUsings, actual.Item2[0]);
}
}
}
Loading

0 comments on commit 64f5ba1

Please sign in to comment.