Skip to content

Commit

Permalink
fixed path bug and regression metrics correction (#3504)
Browse files Browse the repository at this point in the history
  • Loading branch information
srsaggam authored Apr 22, 2019
1 parent da3a403 commit c832e27
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 46 deletions.
28 changes: 14 additions & 14 deletions src/mlnet/Templates/Console/ModelBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,21 @@ public static string GetAbsolutePath(string relativePath)
"s)\r\n {\r\n var L1 = crossValidationResults.Select(r => r.Metrics" +
".MeanAbsoluteError);\r\n var L2 = crossValidationResults.Select(r => r." +
"Metrics.MeanSquaredError);\r\n var RMS = crossValidationResults.Select(" +
"r => r.Metrics.MeanAbsoluteError);\r\n var lossFunction = crossValidati" +
"onResults.Select(r => r.Metrics.LossFunction);\r\n var R2 = crossValida" +
"tionResults.Select(r => r.Metrics.RSquared);\r\n\r\n Console.WriteLine($\"" +
"********************************************************************************" +
"*****************************\");\r\n Console.WriteLine($\"* Metric" +
"s for Regression model \");\r\n Console.WriteLine($\"*--------------" +
"r => r.Metrics.RootMeanSquaredError);\r\n var lossFunction = crossValid" +
"ationResults.Select(r => r.Metrics.LossFunction);\r\n var R2 = crossVal" +
"idationResults.Select(r => r.Metrics.RSquared);\r\n\r\n Console.WriteLine" +
"($\"*****************************************************************************" +
"********************************\");\r\n Console.WriteLine($\"* Met" +
"rics for Regression model \");\r\n Console.WriteLine($\"*-----------" +
"--------------------------------------------------------------------------------" +
"--------------\");\r\n Console.WriteLine($\"* Average L1 Loss: {" +
"L1.Average():0.###} \");\r\n Console.WriteLine($\"* Average L2 Loss" +
": {L2.Average():0.###} \");\r\n Console.WriteLine($\"* Average " +
"RMS: {RMS.Average():0.###} \");\r\n Console.WriteLine($\"* " +
" Average Loss Function: {lossFunction.Average():0.###} \");\r\n Consol" +
"e.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");\r\n " +
"Console.WriteLine($\"************************************************************" +
"*************************************************\");\r\n }\r\n");
"-----------------\");\r\n Console.WriteLine($\"* Average L1 Loss: " +
" {L1.Average():0.###} \");\r\n Console.WriteLine($\"* Average L" +
"2 Loss: {L2.Average():0.###} \");\r\n Console.WriteLine($\"* " +
" Average RMS: {RMS.Average():0.###} \");\r\n Console.WriteLin" +
"e($\"* Average Loss Function: {lossFunction.Average():0.###} \");\r\n " +
" Console.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");" +
"\r\n Console.WriteLine($\"**********************************************" +
"***************************************************************\");\r\n }\r\n");
} if("BinaryClassification".Equals(TaskType)){
this.Write(" public static void PrintBinaryClassificationMetrics(BinaryClassificationM" +
"etrics metrics)\r\n {\r\n Console.WriteLine($\"********************" +
Expand Down
10 changes: 5 additions & 5 deletions src/mlnet/Templates/Console/ModelBuilder.tt
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,18 @@ else{#>
{
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Regression model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} ");
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
Console.WriteLine($"*************************************************************************************************************");
}
<# } if("BinaryClassification".Equals(TaskType)){ #>
Expand Down
51 changes: 31 additions & 20 deletions src/mlnet/Templates/Console/PredictProgram.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,19 @@ public virtual string TransformText()
//*****************************************************************************************
using System;
using System.IO;
using System.Linq;
using Microsoft.ML;
using ");

#line 17 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));

#line default
#line hidden
this.Write(".Model.DataModels;\r\n\r\n\r\nnamespace ");

#line 20 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 21 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));

#line default
Expand All @@ -57,35 +58,35 @@ public virtual string TransformText()
"d and use for predictions\r\n private const string MODEL_FILEPATH = @\"MLMod" +
"el.zip\";\r\n\r\n //Dataset to use for predictions \r\n");

#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 29 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
if(string.IsNullOrEmpty(TestDataPath)){

#line default
#line hidden
this.Write(" private const string DATA_FILEPATH = @\"");

#line 29 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 30 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(TrainDataPath));

#line default
#line hidden
this.Write("\";\r\n");

#line 30 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
} else{

#line default
#line hidden
this.Write(" private const string DATA_FILEPATH = @\"");

#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 32 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(TestDataPath));

#line default
#line hidden
this.Write("\";\r\n");

#line 32 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 33 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
}

#line default
Expand All @@ -98,7 +99,7 @@ static void Main(string[] args)
// Training code used by ML.NET CLI and AutoML to generate the model
//ModelBuilder.CreateModel();
ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema);
ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
// Create sample data to do a single prediction with it
Expand All @@ -109,50 +110,50 @@ static void Main(string[] args)
");

#line 50 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 51 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
if("BinaryClassification".Equals(TaskType)){

#line default
#line hidden
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");

#line 51 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 52 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));

#line default
#line hidden
this.Write("} | Predicted value: {predictionResult.Prediction}\");\r\n");

#line 52 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 53 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
}else if("Regression".Equals(TaskType)){

#line default
#line hidden
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");

#line 53 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 54 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));

#line default
#line hidden
this.Write("} | Predicted value: {predictionResult.Score}\");\r\n");

#line 54 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 55 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
} else if("MulticlassClassification".Equals(TaskType)){

#line default
#line hidden
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");

#line 55 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 56 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));

#line default
#line hidden
this.Write("} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.J" +
"oin(\",\", predictionResult.Score)}]\");\r\n");

#line 56 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 57 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
}

#line default
Expand All @@ -171,28 +172,28 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
path: dataFilePath,
hasHeader : ");

#line 69 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 70 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));

#line default
#line hidden
this.Write(",\r\n separatorChar : \'");

#line 70 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 71 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));

#line default
#line hidden
this.Write("\',\r\n allowQuoting : ");

#line 71 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 72 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant()));

#line default
#line hidden
this.Write(",\r\n allowSparse: ");

#line 72 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 73 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));

#line default
Expand All @@ -204,13 +205,23 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
.First();
return sampleForPrediction;
}
public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
return fullPath;
}
}
}
");
return this.GenerationEnvironment.ToString();
}

#line 81 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
#line 92 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"

public string TaskType {get;set;}
public string Namespace {get;set;}
Expand Down
13 changes: 12 additions & 1 deletion src/mlnet/Templates/Console/PredictProgram.tt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//*****************************************************************************************

using System;
using System.IO;
using System.Linq;
using Microsoft.ML;
using <#= Namespace #>.Model.DataModels;
Expand All @@ -38,7 +39,7 @@ namespace <#= Namespace #>.ConsoleApp
// Training code used by ML.NET CLI and AutoML to generate the model
//ModelBuilder.CreateModel();

ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema);
ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);

// Create sample data to do a single prediction with it
Expand Down Expand Up @@ -76,6 +77,16 @@ namespace <#= Namespace #>.ConsoleApp
.First();
return sampleForPrediction;
}

public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;

string fullPath = Path.Combine(assemblyFolderPath, relativePath);

return fullPath;
}
}
}
<#+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ namespace TestNamespace.ConsoleApp
{
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Regression model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} ");
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
Console.WriteLine($"*************************************************************************************************************");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//*****************************************************************************************

using System;
using System.IO;
using System.Linq;
using Microsoft.ML;
using TestNamespace.Model.DataModels;
Expand All @@ -27,7 +28,7 @@ namespace TestNamespace.ConsoleApp
// Training code used by ML.NET CLI and AutoML to generate the model
//ModelBuilder.CreateModel();

ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema);
ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);

// Create sample data to do a single prediction with it
Expand Down Expand Up @@ -59,5 +60,15 @@ namespace TestNamespace.ConsoleApp
.First();
return sampleForPrediction;
}

public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;

string fullPath = Path.Combine(assemblyFolderPath, relativePath);

return fullPath;
}
}
}

0 comments on commit c832e27

Please sign in to comment.