From 64f5ba1d0ae2a3e201cc34fe3d0a52dedaf6a1e8 Mon Sep 17 00:00:00 2001 From: Srujan Saggam <41802116+srsaggam@users.noreply.github.com> Date: Wed, 27 Mar 2019 14:42:58 -0700 Subject: [PATCH] Ova Multi class codegen support (#321) * dummy * multiova implementation * fix tests * remove inclusion list * fix tests and console helper --- ....ConsoleHelperFileContentTest.approved.txt | 20 ++- ...inProgramCSFileContentOvaTest.approved.txt | 110 ++++++++++++++ .../ConsoleCodeGeneratorTests.cs | 77 +++++++++- src/mlnet.Test/CodeGenTests.cs | 6 +- src/mlnet.Test/TrainerGeneratorTests.cs | 36 ++--- src/mlnet.Test/TransformGeneratorTests.cs | 12 +- src/mlnet/AutoML/AutoMLEngine.cs | 24 +-- .../CodeGenerator/CSharp/CodeGenerator.cs | 40 +++-- .../CSharp/TrainerGeneratorBase.cs | 25 ++- .../CSharp/TrainerGeneratorFactory.cs | 12 +- .../CodeGenerator/CSharp/TrainerGenerators.cs | 71 +++++++-- .../CSharp/TransformGeneratorBase.cs | 4 +- .../CSharp/TransformGeneratorFactory.cs | 2 +- .../CSharp/TransformGenerators.cs | 22 +-- src/mlnet/Templates/Console/ConsoleHelper.cs | 143 ++++++++++-------- src/mlnet/Templates/Console/ConsoleHelper.tt | 20 ++- src/mlnet/Templates/Console/TrainProgram.cs | 4 +- src/mlnet/Templates/Console/TrainProgram.tt | 2 +- 18 files changed, 452 insertions(+), 178 deletions(-) create mode 100644 src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentOvaTest.approved.txt diff --git a/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt b/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt index 1053b7d591..a98a044ab2 100644 --- a/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt +++ b/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt @@ -58,8 +58,7 @@ namespace TestNamespace.Train } - public static void PrintBinaryClassificationFoldsAverageMetrics( - TrainCatalogBase.CrossValidationResult[] crossValResults) + public static void PrintBinaryClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult[] crossValResults) { var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); @@ -77,8 +76,21 @@ namespace TestNamespace.Train } - public static void PrintMulticlassClassificationFoldsAverageMetrics( - TrainCatalogBase.CrossValidationResult[] crossValResults) + public static void PrintMultiClassClassificationMetrics(MultiClassClassifierMetrics metrics) + { + Console.WriteLine($"************************************************************"); + Console.WriteLine($"* Metrics for multi-class classification model "); + Console.WriteLine($"*-----------------------------------------------------------"); + Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); + Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better"); + Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better"); + Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better"); + Console.WriteLine($"************************************************************"); + } + + public static void PrintMulticlassClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult[] crossValResults) { var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); diff --git a/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentOvaTest.approved.txt b/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentOvaTest.approved.txt new file mode 100644 index 0000000000..7935b383a2 --- /dev/null +++ b/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentOvaTest.approved.txt @@ -0,0 +1,110 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.IO; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.Data.DataView; +using TestNamespace.Model.DataModels; + +namespace TestNamespace.Train +{ + class Program + { + private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv"; + private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv"; + private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip"; + + static void Main(string[] args) + { + // Create MLContext to be shared across the model creation workflow objects + // Set a random seed for repeatable/deterministic results across multiple trainings. + MLContext mlContext = new MLContext(seed: 1); + + // Load Data + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + path: TRAIN_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + + IDataView testDataView = mlContext.Data.LoadFromTextFile( + path: TEST_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + // Build training pipeline + IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); + + // Train Model + ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); + + // Evaluate quality of Model + EvaluateModel(mlContext, mlModel, testDataView); + + // Save model + SaveModel(mlContext, mlModel, MODEL_FILEPATH); + + Console.WriteLine("=============== End of process, hit any key to finish ==============="); + Console.ReadKey(); + } + + public static IEstimator BuildTrainingPipeline(MLContext mlContext) + { + // Data process configuration with pipeline data transformations + var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" }) + .AppendCacheCheckpoint(mlContext); + + // Set the training algorithm + var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(numLeaves: 2, labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label"); + var trainingPipeline = dataProcessPipeline.Append(trainer); + + return trainingPipeline; + } + + public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + Console.WriteLine("=============== Training model ==============="); + + ITransformer model = trainingPipeline.Fit(trainingDataView); + + Console.WriteLine("=============== End of training process ==============="); + return model; + } + + private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) + { + // Evaluate the model and show accuracy stats + Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); + IDataView predictions = mlModel.Transform(testDataView); + var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score"); + ConsoleHelper.PrintMultiClassClassificationMetrics(metrics); + } + private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath) + { + // Save/persist the trained model to a .ZIP file + Console.WriteLine($"=============== Saving the model ==============="); + using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) + mlContext.Model.Save(mlModel, fs); + + Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + } +} diff --git a/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs b/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs index 56168d620b..6eec3474d6 100644 --- a/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs +++ b/src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using ApprovalTests; using ApprovalTests.Reporters; @@ -18,7 +19,8 @@ namespace mlnet.Test [UseReporter(typeof(DiffReporter))] public class ConsoleCodeGeneratorTests { - private Pipeline pipeline; + private Pipeline mockedPipeline; + private Pipeline mockedOvaPipeline; private ColumnInferenceResults columnInference = default; private string namespaceValue = "TestNamespace"; @@ -46,6 +48,29 @@ public void ConsoleHelperFileContentTest() Approvals.Verify(result.Item3); } + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void TrainProgramCSFileContentOvaTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedOvaPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.MulticlassClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateTrainProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.Item1); + } + [TestMethod] [UseReporter(typeof(DiffReporter))] [MethodImpl(MethodImplOptions.NoInlining)] @@ -70,6 +95,7 @@ public void TrainProgramCSFileContentTest() } + [TestMethod] [UseReporter(typeof(DiffReporter))] [MethodImpl(MethodImplOptions.NoInlining)] @@ -211,7 +237,7 @@ public void PredictProjectFileContentTest() private (Pipeline, ColumnInferenceResults) GetMockedPipelineAndInference() { - if (pipeline == null) + if (mockedPipeline == null) { MLContext context = new MLContext(); // same learners with different hyperparams @@ -224,7 +250,48 @@ public void PredictProjectFileContentTest() var inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, true); var inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); - this.pipeline = inferredPipeline1.ToPipeline(); + this.mockedPipeline = inferredPipeline1.ToPipeline(); + var textLoaderArgs = new TextLoader.Options() + { + Columns = new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("col1", DataKind.Single, 1), + new TextLoader.Column("col2", DataKind.Single, 0), + new TextLoader.Column("col3", DataKind.String, 0), + new TextLoader.Column("col4", DataKind.Int32, 0), + new TextLoader.Column("col5", DataKind.UInt32, 0), + }, + AllowQuoting = true, + AllowSparse = true, + HasHeader = true, + Separators = new[] { ',' } + }; + + this.columnInference = new ColumnInferenceResults() + { + TextLoaderOptions = textLoaderArgs, + ColumnInformation = new ColumnInformation() { LabelColumn = "Label" } + }; + } + return (mockedPipeline, columnInference); + } + + private (Pipeline, ColumnInferenceResults) GetMockedOvaPipelineAndInference() + { + if (mockedOvaPipeline == null) + { + MLContext context = new MLContext(); + // same learners with different hyperparams + var hyperparams1 = new Microsoft.ML.Auto.ParameterSet(new List() { new LongParameterValue("NumLeaves", 2) }); + var hyperparams2 = new Microsoft.ML.Auto.ParameterSet(new List() { new LongParameterValue("NumLeaves", 6) }); + var trainer1 = new SuggestedTrainer(context, new FastForestOvaExtension(), new ColumnInformation(), hyperparams1); + var trainer2 = new SuggestedTrainer(context, new FastForestOvaExtension(), new ColumnInformation(), hyperparams2); + var transforms1 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + var transforms2 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + var inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, true); + var inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); + + this.mockedOvaPipeline = inferredPipeline1.ToPipeline(); var textLoaderArgs = new TextLoader.Options() { Columns = new[] { @@ -241,13 +308,15 @@ public void PredictProjectFileContentTest() Separators = new[] { ',' } }; + this.columnInference = new ColumnInferenceResults() { TextLoaderOptions = textLoaderArgs, ColumnInformation = new ColumnInformation() { LabelColumn = "Label" } }; + } - return (pipeline, columnInference); + return (mockedOvaPipeline, columnInference); } } } diff --git a/src/mlnet.Test/CodeGenTests.cs b/src/mlnet.Test/CodeGenTests.cs index 456b8516ac..79ce1827ac 100644 --- a/src/mlnet.Test/CodeGenTests.cs +++ b/src/mlnet.Test/CodeGenTests.cs @@ -51,7 +51,7 @@ public void TrainerGeneratorBasicAdvancedParameterTest() string expectedTrainer = "LightGbm(new Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; string expectedUsing = "using Microsoft.ML.LightGBM;\r\n"; Assert.AreEqual(expectedTrainer, actual.Item1); - Assert.AreEqual(expectedUsing, actual.Item2); + Assert.AreEqual(expectedUsing, actual.Item2[0]); } [TestMethod] @@ -80,7 +80,7 @@ public void TransformGeneratorUsingTest() string expectedTransform = "Categorical.OneHotEncoding(new []{new OneHotEncodingEstimator.ColumnOptions(\"Label\",\"Label\")})"; var expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } [TestMethod] @@ -130,7 +130,7 @@ public void TrainerComplexParameterTest() string expectedTrainer = "LightGbm(new Options(){Booster=new TreeBooster(){},LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; var expectedUsings = "using Microsoft.ML.LightGBM;\r\n"; Assert.AreEqual(expectedTrainer, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } } } diff --git a/src/mlnet.Test/TrainerGeneratorTests.cs b/src/mlnet.Test/TrainerGeneratorTests.cs index fef0a87227..58faddb268 100644 --- a/src/mlnet.Test/TrainerGeneratorTests.cs +++ b/src/mlnet.Test/TrainerGeneratorTests.cs @@ -52,7 +52,7 @@ public void LightGbmBinaryAdvancedParameterTest() string expectedTrainerString = "LightGbm(new Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; string expectedUsings = "using Microsoft.ML.LightGBM;\r\n"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -90,7 +90,7 @@ public void SymSgdBinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers.HalLearners;\r\n"; string expectedTrainerString = "SymbolicStochasticGradientDescent(new SymSgdClassificationTrainer.Options(){LearningRate=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -127,7 +127,7 @@ public void StochasticGradientDescentBinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "StochasticGradientDescent(new SgdBinaryTrainer.Options(){Shuffle=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -164,7 +164,7 @@ public void SDCABinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "StochasticDualCoordinateAscent(new SdcaBinaryTrainer.Options(){BiasLearningRate=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -201,7 +201,7 @@ public void SDCAMultiAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "StochasticDualCoordinateAscent(new SdcaMultiClassTrainer.Options(){BiasLearningRate=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -238,7 +238,7 @@ public void SDCARegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "StochasticDualCoordinateAscent(new SdcaRegressionTrainer.Options(){BiasLearningRate=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -275,7 +275,7 @@ public void PoissonRegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "PoissonRegression(new PoissonRegression.Options(){MaxIterations=1,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -312,7 +312,7 @@ public void OrdinaryLeastSquaresRegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers.HalLearners;\r\n"; string expectedTrainerString = "OrdinaryLeastSquares(new OlsLinearRegressionTrainer.Options(){L2Weight=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -349,7 +349,7 @@ public void OnlineGradientDescentRegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "OnlineGradientDescent(new OnlineGradientDescentTrainer.Options(){RecencyGainMulti=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -386,7 +386,7 @@ public void LogisticRegressionBinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "LogisticRegression(new LogisticRegression.Options(){DenseOptimizer=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -423,7 +423,7 @@ public void LogisticRegressionMultiAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "LogisticRegression(new MulticlassLogisticRegression.Options(){DenseOptimizer=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -460,7 +460,7 @@ public void LinearSvmBinaryParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n "; string expectedTrainerString = "LinearSupportVectorMachines(new LinearSvmTrainer.Options(){NoBias=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -498,7 +498,7 @@ public void FastTreeTweedieRegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; string expectedTrainerString = "OnlineGradientDescent(new OnlineGradientDescentTrainer.Options(){Shrinkage=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -536,7 +536,7 @@ public void FastTreeRegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; string expectedTrainerString = "FastTree(new FastTreeRegressionTrainer.Options(){Shrinkage=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -574,7 +574,7 @@ public void FastTreeBinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; string expectedTrainerString = "FastTree(new FastTreeBinaryClassificationTrainer.Options(){Shrinkage=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -611,7 +611,7 @@ public void FastForestRegressionAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; string expectedTrainerString = "FastForest(new FastForestRegression.Options(){Shrinkage=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -649,7 +649,7 @@ public void FastForestBinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; string expectedTrainerString = "FastForest(new FastForestClassification.Options(){Shrinkage=0.1f,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } @@ -687,7 +687,7 @@ public void AveragedPerceptronBinaryAdvancedParameterTest() var expectedUsings = "using Microsoft.ML.Trainers;\r\n "; string expectedTrainerString = "AveragedPerceptron(new AveragedPerceptronTrainer.Options(){Shuffle=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"})"; Assert.AreEqual(expectedTrainerString, actual.Item1); - Assert.AreEqual(expectedUsings, actual.Item2); + Assert.AreEqual(expectedUsings, actual.Item2[0]); } } diff --git a/src/mlnet.Test/TransformGeneratorTests.cs b/src/mlnet.Test/TransformGeneratorTests.cs index eedfececc3..7865a8fbb5 100644 --- a/src/mlnet.Test/TransformGeneratorTests.cs +++ b/src/mlnet.Test/TransformGeneratorTests.cs @@ -21,7 +21,7 @@ public void MissingValueReplacingTest() var expectedTransform = "ReplaceMissingValues(new []{new MissingValueReplacingEstimator.ColumnOptions(\"categorical_column_1\",\"categorical_column_1\")})"; string expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } [TestMethod] @@ -36,7 +36,7 @@ public void OneHotEncodingTest() string expectedTransform = "Categorical.OneHotEncoding(new []{new OneHotEncodingEstimator.ColumnOptions(\"categorical_column_1\",\"categorical_column_1\")})"; var expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } [TestMethod] @@ -96,7 +96,7 @@ public void KeyToValueMappingTest() string expectedTransform = "Conversion.MapKeyToValue(\"Label\",\"Label\")"; var expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } [TestMethod] @@ -126,7 +126,7 @@ public void OneHotHashEncodingTest() string expectedTransform = "Categorical.OneHotHashEncoding(new []{new OneHotHashEncodingEstimator.ColumnOptions(\"Categorical_column_1\",\"Categorical_column_1\")})"; var expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } [TestMethod] @@ -156,7 +156,7 @@ public void TypeConvertingTest() string expectedTransform = "Conversion.ConvertType(new []{new TypeConvertingEstimator.ColumnOptions(\"R4_column_1\",DataKind.Single,\"I4_column_1\")})"; string expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } [TestMethod] @@ -171,7 +171,7 @@ public void ValueToKeyMappingTest() string expectedTransform = "Conversion.MapValueToKey(\"Label\",\"Label\")"; var expectedUsings = "using Microsoft.ML.Transforms;\r\n"; Assert.AreEqual(expectedTransform, actual[0].Item1); - Assert.AreEqual(expectedUsings, actual[0].Item2); + Assert.AreEqual(expectedUsings, actual[0].Item2[0]); } } diff --git a/src/mlnet/AutoML/AutoMLEngine.cs b/src/mlnet/AutoML/AutoMLEngine.cs index 733bb72f10..2f08bbddf7 100644 --- a/src/mlnet/AutoML/AutoMLEngine.cs +++ b/src/mlnet/AutoML/AutoMLEngine.cs @@ -89,24 +89,14 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation { var optimizationMetric = new MulticlassExperimentSettings().OptimizingMetric; var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric); - - var experimentSettings = new MulticlassExperimentSettings() - { - MaxExperimentTimeInSeconds = settings.MaxExplorationTime, - ProgressHandler = progressReporter, - EnableCaching = this.enableCaching, - OptimizingMetric = optimizationMetric - }; - - // Inclusion list for currently supported learners. Need to remove once we have codegen support for all other learners. - experimentSettings.Trainers.Clear(); - experimentSettings.Trainers.Add(MulticlassClassificationTrainer.LightGbm); - experimentSettings.Trainers.Add(MulticlassClassificationTrainer.LogisticRegression); - experimentSettings.Trainers.Add(MulticlassClassificationTrainer.StochasticDualCoordinateAscent); - var result = context.Auto() - .CreateMulticlassClassificationExperiment(experimentSettings) - .Execute(trainData, validationData, columnInformation); + .CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings() + { + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, + ProgressHandler = progressReporter, + EnableCaching = this.enableCaching, + OptimizingMetric = optimizationMetric + }).Execute(trainData, validationData, columnInformation); logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline); var bestIteration = result.Best(); pipeline = bestIteration.Pipeline; diff --git a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs index 3c8f4ab8e9..8672e68e86 100644 --- a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs +++ b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs @@ -79,7 +79,7 @@ public void GenerateOutput() { var result = GenerateTransformsAndTrainers(); - var trainProgramCSFileContent = GenerateTrainProgramCSFileContent(result.Usings, result.Trainer, result.PreTrainerTransforms, result.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name); + var trainProgramCSFileContent = GenerateTrainProgramCSFileContent(result.Usings, result.TrainerMethod, result.PreTrainerTransforms, result.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name); trainProgramCSFileContent = Utils.FormatCode(trainProgramCSFileContent); var trainProjectFileContent = GeneratTrainProjectFileContent(namespaceValue); @@ -108,11 +108,10 @@ public void GenerateOutput() return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent); } - internal (string Usings, string Trainer, List PreTrainerTransforms, List PostTrainerTransforms) GenerateTransformsAndTrainers() + internal (string Usings, string TrainerMethod, List PreTrainerTransforms, List PostTrainerTransforms) GenerateTransformsAndTrainers() { StringBuilder usingsBuilder = new StringBuilder(); var usings = new List(); - var trainerAndUsings = this.GenerateTrainerAndUsings(); // Get pre-trainer transforms var nodes = pipeline.Nodes.TakeWhile(t => t.NodeType == PipelineNodeType.Transform); @@ -125,14 +124,17 @@ public void GenerateOutput() var postTrainerTransformsAndUsings = this.GenerateTransformsAndUsings(nodes); //Get trainer code and its associated usings. - var trainer = trainerAndUsings.Item1; - usings.Add(trainerAndUsings.Item2); + (string trainerMethod, string[] trainerUsings) = this.GenerateTrainerAndUsings(); + if (trainerUsings != null) + { + usings.AddRange(trainerUsings); + } //Get transforms code and its associated (unique) usings. - var preTrainerTransforms = preTrainerTransformsAndUsings.Select(t => t.Item1).ToList(); - var postTrainerTransforms = postTrainerTransformsAndUsings.Select(t => t.Item1).ToList(); - usings.AddRange(preTrainerTransformsAndUsings.Select(t => t.Item2)); - usings.AddRange(postTrainerTransformsAndUsings.Select(t => t.Item2)); + var preTrainerTransforms = preTrainerTransformsAndUsings?.Select(t => t.Item1).ToList(); + var postTrainerTransforms = postTrainerTransformsAndUsings?.Select(t => t.Item1).ToList(); + usings.AddRange(preTrainerTransformsAndUsings.Where(t => t.Item2 != null).SelectMany(t => t.Item2)); + usings.AddRange(postTrainerTransformsAndUsings.Where(t => t.Item2 != null).SelectMany(t => t.Item2)); usings = usings.Distinct().ToList(); //Combine all using statements to actual text. @@ -143,14 +145,14 @@ public void GenerateOutput() usingsBuilder.Append(t); }); - return (usingsBuilder.ToString(), trainer, preTrainerTransforms, postTrainerTransforms); + return (usingsBuilder.ToString(), trainerMethod, preTrainerTransforms, postTrainerTransforms); } - internal IList<(string, string)> GenerateTransformsAndUsings(IEnumerable nodes) + internal IList<(string, string[])> GenerateTransformsAndUsings(IEnumerable nodes) { //var nodes = pipeline.Nodes.TakeWhile(t => t.NodeType == PipelineNodeType.Transform); //var nodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Transform); - var results = new List<(string, string)>(); + var results = new List<(string, string[])>(); foreach (var node in nodes) { ITransformGenerator generator = TransformGeneratorFactory.GetInstance(node); @@ -160,9 +162,15 @@ public void GenerateOutput() return results; } - internal (string, string) GenerateTrainerAndUsings() + internal (string, string[]) GenerateTrainerAndUsings() { - ITrainerGenerator generator = TrainerGeneratorFactory.GetInstance(pipeline); + if (pipeline == null) + throw new ArgumentNullException(nameof(pipeline)); + var node = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer).First(); + if (node == null) + throw new ArgumentException($"The trainer was not found."); + + ITrainerGenerator generator = TrainerGeneratorFactory.GetInstance(node); var trainerString = generator.GenerateTrainer(); var trainerUsings = generator.GenerateUsings(); return (trainerString, trainerUsings); @@ -229,7 +237,7 @@ internal IList GenerateClassLabels() #region Train Project private string GenerateTrainProgramCSFileContent(string usings, - string trainer, + string trainerMethod, List preTrainerTransforms, List postTrainerTransforms, string namespaceValue, @@ -244,7 +252,7 @@ private string GenerateTrainProgramCSFileContent(string usings, Separator = columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(), AllowQuoting = columnInferenceResult.TextLoaderOptions.AllowQuoting, AllowSparse = columnInferenceResult.TextLoaderOptions.AllowSparse, - Trainer = trainer, + Trainer = trainerMethod, GeneratedUsings = usings, Path = settings.TrainDataset, TestPath = settings.TestDataset, diff --git a/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs index 56285828af..27273d16d5 100644 --- a/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs +++ b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs @@ -25,7 +25,7 @@ internal abstract class TrainerGeneratorBase : ITrainerGenerator internal abstract string OptionsName { get; } internal abstract string MethodName { get; } internal abstract IDictionary NamedParameters { get; } - internal abstract string Usings { get; } + internal abstract string[] Usings { get; } /// /// Generates an instance of TrainerGenerator @@ -39,7 +39,10 @@ protected TrainerGeneratorBase(PipelineNode node) private void Initialize(PipelineNode node) { this.node = node; - hasAdvancedSettings = node.Properties.Keys.Any(t => !NamedParameters.ContainsKey(t)); + if (NamedParameters != null) + { + hasAdvancedSettings = node.Properties.Keys.Any(t => !NamedParameters.ContainsKey(t)); + } seperator = hasAdvancedSettings ? "=" : ":"; if (!node.Properties.ContainsKey("LabelColumn")) { @@ -87,11 +90,19 @@ private void Initialize(PipelineNode node) } //more special cases to handle - arguments.Add(hasAdvancedSettings ? kv.Key : NamedParameters[kv.Key], value); + if (NamedParameters != null) + { + arguments.Add(hasAdvancedSettings ? kv.Key : NamedParameters[kv.Key], value); + } + else + { + arguments.Add(kv.Key, value); + } + } } - private static string BuildComplexParameter(string paramName, IDictionary arguments, string seperator) + internal static string BuildComplexParameter(string paramName, IDictionary arguments, string seperator) { StringBuilder sb = new StringBuilder(); sb.Append("new "); @@ -102,7 +113,7 @@ private static string BuildComplexParameter(string paramName, IDictionary arguments, string seperator) + internal static string AppendArguments(IDictionary arguments, string seperator) { if (arguments.Count == 0) return string.Empty; @@ -122,7 +133,7 @@ private static string AppendArguments(IDictionary arguments, str return sb.ToString(); } - public string GenerateTrainer() + public virtual string GenerateTrainer() { StringBuilder sb = new StringBuilder(); sb.Append(MethodName); @@ -140,7 +151,7 @@ public string GenerateTrainer() return sb.ToString(); } - public string GenerateUsings() + public virtual string[] GenerateUsings() { if (hasAdvancedSettings) return Usings; diff --git a/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs index 6d9d500735..fb29594c9a 100644 --- a/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs +++ b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs @@ -12,18 +12,14 @@ namespace Microsoft.ML.CLI.CodeGenerator.CSharp internal interface ITrainerGenerator { string GenerateTrainer(); - string GenerateUsings(); + + string[] GenerateUsings(); } internal static class TrainerGeneratorFactory { - internal static ITrainerGenerator GetInstance(Pipeline pipeline) + internal static ITrainerGenerator GetInstance(PipelineNode node) { - if (pipeline == null) - throw new ArgumentNullException(nameof(pipeline)); - var node = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer).First(); - if (node == null) - throw new ArgumentException($"The trainer was not found."); if (Enum.TryParse(node.Name, out TrainerName trainer)) { switch (trainer) @@ -66,6 +62,8 @@ internal static ITrainerGenerator GetInstance(Pipeline pipeline) return new StochasticGradientDescentClassification(node); case TrainerName.SymSgdBinary: return new SymbolicStochasticGradientDescent(node); + case TrainerName.Ova: + return new OneVersusAll(node); default: throw new ArgumentException($"The trainer '{trainer}' is not handled currently."); } diff --git a/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs b/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs index 3b6a194e1c..6b6d99f524 100644 --- a/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs +++ b/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.Linq; +using System.Text; using Microsoft.ML.Auto; namespace Microsoft.ML.CLI.CodeGenerator.CSharp @@ -36,7 +38,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.LightGBM;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.LightGBM;\r\n" }; public LightGbm(PipelineNode node) : base(node) { @@ -70,7 +72,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n "; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n " }; public AveragedPerceptron(PipelineNode node) : base(node) { @@ -80,7 +82,7 @@ public AveragedPerceptron(PipelineNode node) : base(node) #region FastTree internal abstract class FastTreeBase : TrainerGeneratorBase { - internal override string Usings => "using Microsoft.ML.Trainers.FastTree;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers.FastTree;\r\n" }; //The named parameters to the trainer. internal override IDictionary NamedParameters @@ -196,7 +198,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n "; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n " }; public LinearSvm(PipelineNode node) : base(node) { @@ -230,7 +232,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; public LogisticRegressionBase(PipelineNode node) : base(node) { @@ -285,7 +287,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; public OnlineGradientDescentRegression(PipelineNode node) : base(node) { @@ -315,7 +317,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers.HalLearners;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers.HalLearners;\r\n" }; public OrdinaryLeastSquaresRegression(PipelineNode node) : base(node) { @@ -350,7 +352,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; public PoissonRegression(PipelineNode node) : base(node) { @@ -382,7 +384,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; public StochasticDualCoordinateAscentBase(PipelineNode node) : base(node) { @@ -447,7 +449,7 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; public StochasticGradientDescentClassification(PipelineNode node) : base(node) { @@ -477,12 +479,59 @@ internal override IDictionary NamedParameters } } - internal override string Usings => "using Microsoft.ML.Trainers.HalLearners;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers.HalLearners;\r\n" }; public SymbolicStochasticGradientDescent(PipelineNode node) : base(node) { + } } + internal class OneVersusAll : TrainerGeneratorBase + { + private PipelineNode node; + private string[] binaryTrainerUsings = null; + + //ClassName of the trainer + internal override string MethodName => "OneVersusAll"; + + //ClassName of the options to trainer + internal override string OptionsName => null; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters => null; + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public OneVersusAll(PipelineNode node) : base(node) + { + this.node = node; + } + + public override string GenerateTrainer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + sb.Append("mlContext.BinaryClassification.Trainers."); // This is dependent on the name of the MLContext object in template. + var trainerGenerator = TrainerGeneratorFactory.GetInstance((PipelineNode)this.node.Properties["BinaryTrainer"]); + binaryTrainerUsings = trainerGenerator.GenerateUsings(); + sb.Append(trainerGenerator.GenerateTrainer()); + sb.Append(","); + sb.Append("labelColumnName:"); + sb.Append("\""); + sb.Append(node.Properties["LabelColumn"]); + sb.Append("\""); + sb.Append(")"); + return sb.ToString(); + } + + public override string[] GenerateUsings() + { + return binaryTrainerUsings; + } + + } + } } diff --git a/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs index 3498ae8461..c55342fc78 100644 --- a/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs +++ b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs @@ -15,7 +15,7 @@ internal abstract class TransformGeneratorBase : ITransformGenerator //abstract properties internal abstract string MethodName { get; } - internal abstract string Usings { get; } + internal abstract string[] Usings { get; } protected string[] inputColumns; @@ -49,7 +49,7 @@ private void Initialize(PipelineNode node) public abstract string GenerateTransformer(); - public string GenerateUsings() + public string[] GenerateUsings() { return Usings; } diff --git a/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs index 2b83f6267f..70500b091b 100644 --- a/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs +++ b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs @@ -12,7 +12,7 @@ internal interface ITransformGenerator { string GenerateTransformer(); - string GenerateUsings(); + string[] GenerateUsings(); } internal static class TransformGeneratorFactory diff --git a/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs b/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs index 4612241abd..cc5fbd5ec6 100644 --- a/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs +++ b/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs @@ -17,7 +17,7 @@ public Normalizer(PipelineNode node) : base(node) internal override string MethodName => "Normalize"; - internal override string Usings => null; + internal override string[] Usings => null; public override string GenerateTransformer() { @@ -42,7 +42,7 @@ public OneHotEncoding(PipelineNode node) : base(node) internal override string MethodName => "Categorical.OneHotEncoding"; - internal override string Usings => "using Microsoft.ML.Transforms;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Transforms;\r\n" }; private string ArgumentsName = "OneHotEncodingEstimator.ColumnOptions"; @@ -79,7 +79,7 @@ public ColumnConcat(PipelineNode node) : base(node) internal override string MethodName => "Concatenate"; - internal override string Usings => null; + internal override string[] Usings => null; public override string GenerateTransformer() { @@ -111,7 +111,7 @@ public ColumnCopying(PipelineNode node) : base(node) internal override string MethodName => "CopyColumns"; - internal override string Usings => null; + internal override string[] Usings => null; public override string GenerateTransformer() { @@ -136,7 +136,7 @@ public KeyToValueMapping(PipelineNode node) : base(node) internal override string MethodName => "Conversion.MapKeyToValue"; - internal override string Usings => "using Microsoft.ML.Transforms;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Transforms;\r\n" }; public override string GenerateTransformer() { @@ -161,7 +161,7 @@ public MissingValueIndicator(PipelineNode node) : base(node) internal override string MethodName => "IndicateMissingValues"; - internal override string Usings => null; + internal override string[] Usings => null; private string ArgumentsName = "ColumnOptions"; @@ -200,7 +200,7 @@ public MissingValueReplacer(PipelineNode node) : base(node) internal override string MethodName => "ReplaceMissingValues"; private string ArgumentsName = "MissingValueReplacingEstimator.ColumnOptions"; - internal override string Usings => "using Microsoft.ML.Transforms;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Transforms;\r\n" }; public override string GenerateTransformer() { @@ -235,7 +235,7 @@ public OneHotHashEncoding(PipelineNode node) : base(node) internal override string MethodName => "Categorical.OneHotHashEncoding"; - internal override string Usings => "using Microsoft.ML.Transforms;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Transforms;\r\n" }; private string ArgumentsName = "OneHotHashEncodingEstimator.ColumnOptions"; @@ -272,7 +272,7 @@ public TextFeaturizing(PipelineNode node) : base(node) internal override string MethodName => "Text.FeaturizeText"; - internal override string Usings => null; + internal override string[] Usings => null; public override string GenerateTransformer() { @@ -297,7 +297,7 @@ public TypeConverting(PipelineNode node) : base(node) internal override string MethodName => "Conversion.ConvertType"; - internal override string Usings => "using Microsoft.ML.Transforms;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Transforms;\r\n" }; private string ArgumentsName = "TypeConvertingEstimator.ColumnOptions"; @@ -336,7 +336,7 @@ public ValueToKeyMapping(PipelineNode node) : base(node) internal override string MethodName => "Conversion.MapValueToKey"; - internal override string Usings => "using Microsoft.ML.Transforms;\r\n"; + internal override string[] Usings => new string[] { "using Microsoft.ML.Transforms;\r\n" }; public override string GenerateTransformer() { diff --git a/src/mlnet/Templates/Console/ConsoleHelper.cs b/src/mlnet/Templates/Console/ConsoleHelper.cs index 813ad14ac9..1a2668cce5 100644 --- a/src/mlnet/Templates/Console/ConsoleHelper.cs +++ b/src/mlnet/Templates/Console/ConsoleHelper.cs @@ -78,72 +78,87 @@ namespace "); "\r\n Console.WriteLine($\"* Auc: {metrics.Auc:P2}\");\r\n " + " Console.WriteLine($\"*******************************************************" + "*****\");\r\n }\r\n\r\n\r\n public static void PrintBinaryClassificationFol" + - "dsAverageMetrics(\r\n TrainCatalogBase.Cro" + - "ssValidationResult[] crossValResults)\r\n {\r\n " + - " var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);\r" + - "\n\r\n var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accurac" + - "y);\r\n var AccuracyAverage = AccuracyValues.Average();\r\n va" + - "r AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues);\r\n " + - " var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyV" + - "alues);\r\n\r\n\r\n Console.WriteLine($\"***********************************" + - "**************************************************************************\");\r\n " + - " Console.WriteLine($\"* Metrics for Binary Classification model " + - " \");\r\n Console.WriteLine($\"*--------------------------------------" + - "----------------------------------------------------------------------\");\r\n " + - " Console.WriteLine($\"* Average Accuracy: {AccuracyAverage:0.###} " + - " - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 9" + - "5%: ({AccuraciesConfidenceInterval95:#.###})\");\r\n Console.WriteLine($" + - "\"*******************************************************************************" + - "******************************\");\r\n\r\n }\r\n\r\n public static void Pri" + - "ntMulticlassClassificationFoldsAverageMetrics(\r\n " + - " TrainCatalogBase.CrossValidationResult[] c" + - "rossValResults)\r\n {\r\n var metricsInMultipleFolds = crossValRes" + - "ults.Select(r => r.Metrics);\r\n\r\n var microAccuracyValues = metricsInM" + - "ultipleFolds.Select(m => m.AccuracyMicro);\r\n var microAccuracyAverage" + - " = microAccuracyValues.Average();\r\n var microAccuraciesStdDeviation =" + - " CalculateStandardDeviation(microAccuracyValues);\r\n var microAccuraci" + - "esConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);\r\n\r\n" + - " var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.Accur" + - "acyMacro);\r\n var macroAccuracyAverage = macroAccuracyValues.Average()" + - ";\r\n var macroAccuraciesStdDeviation = CalculateStandardDeviation(macr" + - "oAccuracyValues);\r\n var macroAccuraciesConfidenceInterval95 = Calcula" + - "teConfidenceInterval95(macroAccuracyValues);\r\n\r\n var logLossValues = " + - "metricsInMultipleFolds.Select(m => m.LogLoss);\r\n var logLossAverage =" + - " logLossValues.Average();\r\n var logLossStdDeviation = CalculateStanda" + - "rdDeviation(logLossValues);\r\n var logLossConfidenceInterval95 = Calcu" + - "lateConfidenceInterval95(logLossValues);\r\n\r\n var logLossReductionValu" + - "es = metricsInMultipleFolds.Select(m => m.LogLossReduction);\r\n var lo" + - "gLossReductionAverage = logLossReductionValues.Average();\r\n var logLo" + - "ssReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);\r\n " + - " var logLossReductionConfidenceInterval95 = CalculateConfidenceInterva" + - "l95(logLossReductionValues);\r\n\r\n Console.WriteLine($\"****************" + + "dsAverageMetrics(TrainCatalogBase.CrossValidationResult[] crossValResults)\r\n {\r\n var metricsInMultipleFolds = cro" + + "ssValResults.Select(r => r.Metrics);\r\n\r\n var AccuracyValues = metrics" + + "InMultipleFolds.Select(m => m.Accuracy);\r\n var AccuracyAverage = Accu" + + "racyValues.Average();\r\n var AccuraciesStdDeviation = CalculateStandar" + + "dDeviation(AccuracyValues);\r\n var AccuraciesConfidenceInterval95 = Ca" + + "lculateConfidenceInterval95(AccuracyValues);\r\n\r\n\r\n Console.WriteLine(" + + "$\"******************************************************************************" + + "*******************************\");\r\n Console.WriteLine($\"* Metr" + + "ics for Binary Classification model \");\r\n Console.WriteLine($\"*-" + + "--------------------------------------------------------------------------------" + + "---------------------------\");\r\n Console.WriteLine($\"* Average " + + "Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDevia" + + "tion:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###}" + + ")\");\r\n Console.WriteLine($\"******************************************" + + "*******************************************************************\");\r\n\r\n " + + " }\r\n\r\n public static void PrintMultiClassClassificationMetrics(MultiClas" + + "sClassifierMetrics metrics)\r\n {\r\n Console.WriteLine($\"********" + + "****************************************************\");\r\n Console.Wri" + + "teLine($\"* Metrics for multi-class classification model \");\r\n Co" + + "nsole.WriteLine($\"*-----------------------------------------------------------\")" + + ";\r\n Console.WriteLine($\" AccuracyMacro = {metrics.AccuracyMacro:0." + + "####}, a value between 0 and 1, the closer to 1, the better\");\r\n Cons" + + "ole.WriteLine($\" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value betw" + + "een 0 and 1, the closer to 1, the better\");\r\n Console.WriteLine($\" " + + " LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better\");\r\n " + + " Console.WriteLine($\" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.###" + + "#}, the closer to 0, the better\");\r\n Console.WriteLine($\" LogLoss " + + "for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better\")" + + ";\r\n Console.WriteLine($\" LogLoss for class 3 = {metrics.PerClassLo" + + "gLoss[2]:0.####}, the closer to 0, the better\");\r\n Console.WriteLine(" + + "$\"************************************************************\");\r\n }\r\n\r\n" + + " public static void PrintMulticlassClassificationFoldsAverageMetrics(Trai" + + "nCatalogBase.CrossValidationResult[] crossValResult" + + "s)\r\n {\r\n var metricsInMultipleFolds = crossValResults.Select(r" + + " => r.Metrics);\r\n\r\n var microAccuracyValues = metricsInMultipleFolds." + + "Select(m => m.AccuracyMicro);\r\n var microAccuracyAverage = microAccur" + + "acyValues.Average();\r\n var microAccuraciesStdDeviation = CalculateSta" + + "ndardDeviation(microAccuracyValues);\r\n var microAccuraciesConfidenceI" + + "nterval95 = CalculateConfidenceInterval95(microAccuracyValues);\r\n\r\n v" + + "ar macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);\r\n " + + " var macroAccuracyAverage = macroAccuracyValues.Average();\r\n " + + " var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValu" + + "es);\r\n var macroAccuraciesConfidenceInterval95 = CalculateConfidenceI" + + "nterval95(macroAccuracyValues);\r\n\r\n var logLossValues = metricsInMult" + + "ipleFolds.Select(m => m.LogLoss);\r\n var logLossAverage = logLossValue" + + "s.Average();\r\n var logLossStdDeviation = CalculateStandardDeviation(l" + + "ogLossValues);\r\n var logLossConfidenceInterval95 = CalculateConfidenc" + + "eInterval95(logLossValues);\r\n\r\n var logLossReductionValues = metricsI" + + "nMultipleFolds.Select(m => m.LogLossReduction);\r\n var logLossReductio" + + "nAverage = logLossReductionValues.Average();\r\n var logLossReductionSt" + + "dDeviation = CalculateStandardDeviation(logLossReductionValues);\r\n va" + + "r logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossRe" + + "ductionValues);\r\n\r\n Console.WriteLine($\"*****************************" + "********************************************************************************" + - "*************\");\r\n Console.WriteLine($\"* Metrics for Multi-clas" + - "s Classification model \");\r\n Console.WriteLine($\"*--------------" + + "\");\r\n Console.WriteLine($\"* Metrics for Multi-class Classificat" + + "ion model \");\r\n Console.WriteLine($\"*---------------------------" + "--------------------------------------------------------------------------------" + - "--------------\");\r\n Console.WriteLine($\"* Average MicroAccuracy" + - ": {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDev" + - "iation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95" + - ":#.###})\");\r\n Console.WriteLine($\"* Average MacroAccuracy: {" + - "macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation" + - ":#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###" + - "})\");\r\n Console.WriteLine($\"* Average LogLoss: {logLos" + - "sAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confiden" + - "ce Interval 95%: ({logLossConfidenceInterval95:#.###})\");\r\n Console.W" + - "riteLine($\"* Average LogLossReduction: {logLossReductionAverage:#.###} - " + - "Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interva" + - "l 95%: ({logLossReductionConfidenceInterval95:#.###})\");\r\n Console.Wr" + - "iteLine($\"**********************************************************************" + - "***************************************\");\r\n\r\n }\r\n\r\n public static" + - " double CalculateStandardDeviation(IEnumerable values)\r\n {\r\n " + - " double average = values.Average();\r\n double sumOfSquaresOfDiff" + - "erences = values.Select(val => (val - average) * (val - average)).Sum();\r\n " + - " double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.C" + - "ount() - 1));\r\n return standardDeviation;\r\n }\r\n\r\n publi" + - "c static double CalculateConfidenceInterval95(IEnumerable values)\r\n " + - " {\r\n double confidenceInterval95 = 1.96 * CalculateStandardDeviatio" + - "n(values) / Math.Sqrt((values.Count() - 1));\r\n return confidenceInter" + - "val95;\r\n }\r\n }\r\n}\r\n"); + "-\");\r\n Console.WriteLine($\"* Average MicroAccuracy: {microAc" + + "curacyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}" + + ") - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})\");\r\n" + + " Console.WriteLine($\"* Average MacroAccuracy: {macroAccuracy" + + "Average:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - C" + + "onfidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})\");\r\n " + + " Console.WriteLine($\"* Average LogLoss: {logLossAverage:#.##" + + "#} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 9" + + "5%: ({logLossConfidenceInterval95:#.###})\");\r\n Console.WriteLine($\"* " + + " Average LogLossReduction: {logLossReductionAverage:#.###} - Standard devi" + + "ation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logL" + + "ossReductionConfidenceInterval95:#.###})\");\r\n Console.WriteLine($\"***" + + "********************************************************************************" + + "**************************\");\r\n\r\n }\r\n\r\n public static double Calcu" + + "lateStandardDeviation(IEnumerable values)\r\n {\r\n double" + + " average = values.Average();\r\n double sumOfSquaresOfDifferences = val" + + "ues.Select(val => (val - average) * (val - average)).Sum();\r\n double " + + "standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));" + + "\r\n return standardDeviation;\r\n }\r\n\r\n public static doub" + + "le CalculateConfidenceInterval95(IEnumerable values)\r\n {\r\n " + + " double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / M" + + "ath.Sqrt((values.Count() - 1));\r\n return confidenceInterval95;\r\n " + + " }\r\n }\r\n}\r\n"); return this.GenerationEnvironment.ToString(); } diff --git a/src/mlnet/Templates/Console/ConsoleHelper.tt b/src/mlnet/Templates/Console/ConsoleHelper.tt index 61cdfb5443..15c00f04eb 100644 --- a/src/mlnet/Templates/Console/ConsoleHelper.tt +++ b/src/mlnet/Templates/Console/ConsoleHelper.tt @@ -63,8 +63,7 @@ namespace <#= Namespace #>.Train } - public static void PrintBinaryClassificationFoldsAverageMetrics( - TrainCatalogBase.CrossValidationResult[] crossValResults) + public static void PrintBinaryClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult[] crossValResults) { var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); @@ -82,8 +81,21 @@ namespace <#= Namespace #>.Train } - public static void PrintMulticlassClassificationFoldsAverageMetrics( - TrainCatalogBase.CrossValidationResult[] crossValResults) + public static void PrintMultiClassClassificationMetrics(MultiClassClassifierMetrics metrics) + { + Console.WriteLine($"************************************************************"); + Console.WriteLine($"* Metrics for multi-class classification model "); + Console.WriteLine($"*-----------------------------------------------------------"); + Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); + Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better"); + Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better"); + Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better"); + Console.WriteLine($"************************************************************"); + } + + public static void PrintMulticlassClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult[] crossValResults) { var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); diff --git a/src/mlnet/Templates/Console/TrainProgram.cs b/src/mlnet/Templates/Console/TrainProgram.cs index 6034645127..4ddf790cd6 100644 --- a/src/mlnet/Templates/Console/TrainProgram.cs +++ b/src/mlnet/Templates/Console/TrainProgram.cs @@ -182,8 +182,8 @@ public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDat this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); this.Write(".Evaluate(predictions, \""); this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); - this.Write("\", \"Score\");\r\n ConsoleHelper.PrintBinaryClassificationMetrics(metrics)" + - ";\r\n"); + this.Write("\", \"Score\");\r\n ConsoleHelper.PrintMultiClassClassificationMetrics(metr" + + "ics);\r\n"); }if("Regression".Equals(TaskType)){ this.Write(" var metrics = mlContext."); this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); diff --git a/src/mlnet/Templates/Console/TrainProgram.tt b/src/mlnet/Templates/Console/TrainProgram.tt index 92e52dea29..98391d6d7c 100644 --- a/src/mlnet/Templates/Console/TrainProgram.tt +++ b/src/mlnet/Templates/Console/TrainProgram.tt @@ -131,7 +131,7 @@ else{#> ConsoleHelper.PrintBinaryClassificationMetrics(metrics); <#} if("MulticlassClassification".Equals(TaskType)){ #> var metrics = mlContext.<#= TaskType #>.Evaluate(predictions, "<#= LabelName #>", "Score"); - ConsoleHelper.PrintBinaryClassificationMetrics(metrics); + ConsoleHelper.PrintMultiClassClassificationMetrics(metrics); <#}if("Regression".Equals(TaskType)){ #> var metrics = mlContext.<#= TaskType #>.Evaluate(predictions, "<#= LabelName #>", "Score"); ConsoleHelper.PrintRegressionMetrics(metrics);