Skip to content

Commit

Permalink
remove unused methods in consolehelper and nit picks in generated code (
Browse files Browse the repository at this point in the history
dotnet#261)

* nit picks

* change in console helper

* fix tests

* add space

* fix tests
  • Loading branch information
srsaggam authored Mar 4, 2019
1 parent 3b1d0ac commit 3acd887
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 485 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
using System;
//*****************************************************************************************
//* *
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
//* *
//*****************************************************************************************

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Core.Data;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace MyNamespace
Expand Down Expand Up @@ -47,32 +52,15 @@ namespace MyNamespace
Console.WriteLine($"************************************************************");
}

public static void PrintMultiClassClassificationMetrics(string name, MultiClassClassifierMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
Console.WriteLine($"************************************************************");
}


public static void PrintRegressionFoldsAverageMetrics(string algorithmName,
(RegressionMetrics metrics,
ITransformer model,
IDataView scoredTestData)[] crossValidationResults
TrainCatalogBase.CrossValidationResult<RegressionMetrics>[] crossValidationResults
)
{
var L1 = crossValidationResults.Select(r => r.metrics.L1);
var L2 = crossValidationResults.Select(r => r.metrics.L2);
var RMS = crossValidationResults.Select(r => r.metrics.L1);
var lossFunction = crossValidationResults.Select(r => r.metrics.LossFn);
var R2 = crossValidationResults.Select(r => r.metrics.RSquared);
var L1 = crossValidationResults.Select(r => r.Metrics.L1);
var L2 = crossValidationResults.Select(r => r.Metrics.L2);
var RMS = crossValidationResults.Select(r => r.Metrics.L1);
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFn);
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for {algorithmName} Regression model ");
Expand All @@ -87,12 +75,9 @@ namespace MyNamespace

public static void PrintBinaryClassificationFoldsAverageMetrics(
string algorithmName,
(BinaryClassificationMetrics metrics,
ITransformer model,
IDataView scoredTestData)[] crossValResults
)
TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>[] crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.metrics);
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
var AccuracyAverage = AccuracyValues.Average();
Expand All @@ -108,45 +93,6 @@ namespace MyNamespace

}

public static void PrintMulticlassClassificationFoldsAverageMetrics(
string algorithmName,
(MultiClassClassifierMetrics metrics,
ITransformer model,
IDataView scoredTestData)[] crossValResults
)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.metrics);

var microAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMicro);
var microAccuracyAverage = microAccuracyValues.Average();
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);

var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);
var macroAccuracyAverage = macroAccuracyValues.Average();
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);

var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
var logLossAverage = logLossValues.Average();
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);

var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
var logLossReductionAverage = logLossReductionValues.Average();
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);

Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for {algorithmName} Multi-class Classification model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");

}

public static double CalculateStandardDeviation(IEnumerable<double> values)
{
Expand All @@ -162,16 +108,6 @@ namespace MyNamespace
return confidenceInterval95;
}

public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
{
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Metrics for {name} clustering model ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* AvgMinScore: {metrics.AvgMinScore}");
Console.WriteLine($"* DBI is: {metrics.Dbi}");
Console.WriteLine($"*************************************************");
}

public static void ConsoleWriteHeader(params string[] lines)
{
var defaultColor = Console.ForegroundColor;
Expand All @@ -185,59 +121,5 @@ namespace MyNamespace
Console.WriteLine(new string('#', maxLength));
Console.ForegroundColor = defaultColor;
}

public static void ConsoleWriterSection(params string[] lines)
{
var defaultColor = Console.ForegroundColor;
Console.ForegroundColor = ConsoleColor.Blue;
Console.WriteLine(" ");
foreach (var line in lines)
{
Console.WriteLine(line);
}
var maxLength = lines.Select(x => x.Length).Max();
Console.WriteLine(new string('-', maxLength));
Console.ForegroundColor = defaultColor;
}

public static void ConsolePressAnyKey()
{
var defaultColor = Console.ForegroundColor;
Console.ForegroundColor = ConsoleColor.Green;
Console.WriteLine(" ");
Console.WriteLine("Press any key to finish.");
Console.ReadKey();
}

public static void ConsoleWriteException(params string[] lines)
{
var defaultColor = Console.ForegroundColor;
Console.ForegroundColor = ConsoleColor.Red;
const string exceptionTitle = "EXCEPTION";
Console.WriteLine(" ");
Console.WriteLine(exceptionTitle);
Console.WriteLine(new string('#', exceptionTitle.Length));
Console.ForegroundColor = defaultColor;
foreach (var line in lines)
{
Console.WriteLine(line);
}
}

public static void ConsoleWriteWarning(params string[] lines)
{
var defaultColor = Console.ForegroundColor;
Console.ForegroundColor = ConsoleColor.DarkMagenta;
const string warningTitle = "WARNING";
Console.WriteLine(" ");
Console.WriteLine(warningTitle);
Console.WriteLine(new string('#', warningTitle.Length));
Console.ForegroundColor = defaultColor;
foreach (var line in lines)
{
Console.WriteLine(line);
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,26 @@ namespace MyNamespace
private static string TestDataPath = @"x:\dummypath\dummy_test.csv";
private static string ModelPath = @"x:\models\model.zip";

// Set this flag to enable the training process.
private static bool EnableTraining = false;

static void Main(string[] args)
{
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
var mlContext = new MLContext(seed: 1);
var mlContext = new MLContext();

if (EnableTraining)
{
// Create, Train, Evaluate and Save a model
BuildTrainEvaluateAndSaveModel(mlContext);
ConsoleHelper.ConsoleWriteHeader("=============== End of training process ===============");
}
else
{
ConsoleHelper.ConsoleWriteHeader("Skipping the training process. Please set the flag : 'EnableTraining' to 'true' to enable the training process.");
}
// (Optional step) Create, Train, Evaluate and Save the model.zip file
TrainEvaluateAndSaveModel(mlContext);

// Make a single test prediction loading the model from .ZIP file
TestSinglePrediction(mlContext);
// Make a single test prediction loading the model from model.zip file
Predict(mlContext);

ConsoleHelper.ConsoleWriteHeader("=============== End of process, hit any key to finish ===============");
Console.ReadKey();

}

private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
private static ITransformer TrainEvaluateAndSaveModel(MLContext mlContext)
{
// Data loading
// Load data
Console.WriteLine("=============== Loading data ===============");
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
path: TrainDataPath,
hasHeader: true,
Expand Down Expand Up @@ -88,12 +77,13 @@ namespace MyNamespace
mlContext.Model.Save(trainedModel, fs);

Console.WriteLine("The model is saved to {0}", ModelPath);
ConsoleHelper.ConsoleWriteHeader("=============== End of training process ===============");

return trainedModel;
}

// (OPTIONAL) Try/test a single prediction by loading the model from the file, first.
private static void TestSinglePrediction(MLContext mlContext)
// Try/test a single prediction by loading the model from the file, first.
private static void Predict(MLContext mlContext)
{
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
Expand Down
Loading

0 comments on commit 3acd887

Please sign in to comment.