forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ova Multi class codegen support (dotnet#321)
* dummy * multiova implementation * fix tests * remove inclusion list * fix tests and console helper
- Loading branch information
Showing
18 changed files
with
452 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
...est/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentOvaTest.approved.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
//***************************************************************************************** | ||
//* * | ||
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * | ||
//* * | ||
//***************************************************************************************** | ||
|
||
using System; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using Microsoft.Data.DataView; | ||
using TestNamespace.Model.DataModels; | ||
|
||
namespace TestNamespace.Train | ||
{ | ||
class Program | ||
{ | ||
private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv"; | ||
private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv"; | ||
private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip"; | ||
|
||
static void Main(string[] args) | ||
{ | ||
// Create MLContext to be shared across the model creation workflow objects | ||
// Set a random seed for repeatable/deterministic results across multiple trainings. | ||
MLContext mlContext = new MLContext(seed: 1); | ||
|
||
// Load Data | ||
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>( | ||
path: TRAIN_DATA_FILEPATH, | ||
hasHeader: true, | ||
separatorChar: ',', | ||
allowQuoting: true, | ||
allowSparse: true); | ||
|
||
IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>( | ||
path: TEST_DATA_FILEPATH, | ||
hasHeader: true, | ||
separatorChar: ',', | ||
allowQuoting: true, | ||
allowSparse: true); | ||
// Build training pipeline | ||
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext); | ||
|
||
// Train Model | ||
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); | ||
|
||
// Evaluate quality of Model | ||
EvaluateModel(mlContext, mlModel, testDataView); | ||
|
||
// Save model | ||
SaveModel(mlContext, mlModel, MODEL_FILEPATH); | ||
|
||
Console.WriteLine("=============== End of process, hit any key to finish ==============="); | ||
Console.ReadKey(); | ||
} | ||
|
||
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext) | ||
{ | ||
// Data process configuration with pipeline data transformations | ||
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" }) | ||
.AppendCacheCheckpoint(mlContext); | ||
|
||
// Set the training algorithm | ||
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(numLeaves: 2, labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label"); | ||
var trainingPipeline = dataProcessPipeline.Append(trainer); | ||
|
||
return trainingPipeline; | ||
} | ||
|
||
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline) | ||
{ | ||
Console.WriteLine("=============== Training model ==============="); | ||
|
||
ITransformer model = trainingPipeline.Fit(trainingDataView); | ||
|
||
Console.WriteLine("=============== End of training process ==============="); | ||
return model; | ||
} | ||
|
||
private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) | ||
{ | ||
// Evaluate the model and show accuracy stats | ||
Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); | ||
IDataView predictions = mlModel.Transform(testDataView); | ||
var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score"); | ||
ConsoleHelper.PrintMultiClassClassificationMetrics(metrics); | ||
} | ||
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath) | ||
{ | ||
// Save/persist the trained model to a .ZIP file | ||
Console.WriteLine($"=============== Saving the model ==============="); | ||
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) | ||
mlContext.Model.Save(mlModel, fs); | ||
|
||
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); | ||
} | ||
|
||
public static string GetAbsolutePath(string relativePath) | ||
{ | ||
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); | ||
string assemblyFolderPath = _dataRoot.Directory.FullName; | ||
|
||
string fullPath = Path.Combine(assemblyFolderPath, relativePath); | ||
|
||
return fullPath; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.