Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pipelineitem for Ova #363

Merged
merged 4 commits into from
Jun 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Microsoft.ML/Models/CrossValidator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using Microsoft.ML.Runtime;
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand Down
74 changes: 74 additions & 0 deletions src/Microsoft.ML/Models/OneVersusAll.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using static Microsoft.ML.Runtime.EntryPoints.CommonInputs;

namespace Microsoft.ML.Models
{
public sealed partial class OneVersusAll
{
/// <summary>
/// Create OneVersusAll multiclass trainer.
/// </summary>
/// <param name="trainer">Underlying binary trainer</param>
/// <param name="useProbabilities">"Use probabilities (vs. raw outputs) to identify top-score category</param>
public static ILearningPipelineItem With(ITrainerInputWithLabel trainer, bool useProbabilities = true)
{
return new OvaPipelineItem(trainer, useProbabilities);
}

private class OvaPipelineItem : ILearningPipelineItem
{
private Var<IDataView> _data;
private ITrainerInputWithLabel _trainer;
private bool _useProbabilities;

public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities)
{
_trainer = trainer;
_useProbabilities = useProbabilities;
}

public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
{
using (var env = new TlcEnvironment())
{
var subgraph = env.CreateExperiment();
subgraph.Add(_trainer);
var ova = new OneVersusAll();
if (previousStep != null)
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}

_data = dataStep.Data;
ova.TrainingData = dataStep.Data;
ova.UseProbabilities = _useProbabilities;
ova.Nodes = subgraph;
}
Output output = experiment.Add(ova);
return new OvaPipelineStep(output);
}
}

public Var<IDataView> GetInputData() => _data;
}

private class OvaPipelineStep : ILearningPipelinePredictorStep
{
public OvaPipelineStep(Output output)
{
Model = output.PredictorModel;
}

public Var<IPredictorModel> Model { get; }
}
}
}
22 changes: 13 additions & 9 deletions src/Microsoft.ML/Models/TrainTestEvaluator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using Microsoft.ML.Runtime;
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand Down Expand Up @@ -110,7 +114,7 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
Inputs.Data = firstTransform.GetInputData();
Outputs.PredictorModel = null;
Outputs.TransformModel = lastTransformModel;
var crossValidateOutput = experiment.Add(this);
var trainTestNodeOutput = experiment.Add(this);
experiment.Compile();
foreach (ILearningPipelineLoader loader in loaders)
loader.SetInput(environment, experiment);
Expand All @@ -124,35 +128,35 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
{
trainTestOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics(
environment,
experiment.GetOutput(crossValidateOutput.OverallMetrics),
experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault();
experiment.GetOutput(trainTestNodeOutput.OverallMetrics),
experiment.GetOutput(trainTestNodeOutput.ConfusionMatrix)).FirstOrDefault();
}
else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
{
trainTestOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics(
environment,
experiment.GetOutput(crossValidateOutput.OverallMetrics),
experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault();
experiment.GetOutput(trainTestNodeOutput.OverallMetrics),
experiment.GetOutput(trainTestNodeOutput.ConfusionMatrix)).FirstOrDefault();
}
else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer)
{
trainTestOutput.RegressionMetrics = RegressionMetrics.FromOverallMetrics(
environment,
experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault();
experiment.GetOutput(trainTestNodeOutput.OverallMetrics)).FirstOrDefault();
}
else if (Kind == MacroUtilsTrainerKinds.SignatureClusteringTrainer)
{
trainTestOutput.ClusterMetrics = ClusterMetrics.FromOverallMetrics(
environment,
experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault();
experiment.GetOutput(trainTestNodeOutput.OverallMetrics)).FirstOrDefault();
}
else
{
//Implement metrics for ranking, clustering and anomaly detection.
throw Contracts.Except($"{Kind.ToString()} is not supported at the moment.");
}

ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel);
ITransformModel model = experiment.GetOutput(trainTestNodeOutput.TransformModel);
BatchPredictionEngine<TInput, TOutput> predictor;
using (var memoryStream = new MemoryStream())
{
Expand Down
35 changes: 33 additions & 2 deletions test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public void TrainAndPredictIrisModelTest()
{
string dataPath = GetDataPath("iris.txt");

var pipeline = new LearningPipeline(seed:1, conc:1);
var pipeline = new LearningPipeline(seed: 1, conc: 1);

pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false));
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
Expand All @@ -33,7 +33,7 @@ public void TrainAndPredictIrisModelTest()
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth= 5.1f,
PetalWidth = 5.1f,
});

Assert.Equal(1, prediction.PredictedLabels[0], 2);
Expand Down Expand Up @@ -136,6 +136,37 @@ public class IrisPrediction
[ColumnName("Score")]
public float[] PredictedLabels;
}

[Fact]
public void TrainOva()
{
string dataPath = GetDataPath("iris.txt");

var pipeline = new LearningPipeline(seed: 1, conc: 1);
pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false));
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));

pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier()));

var model = pipeline.Train<IrisData, IrisPrediction>();

var testData = new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false);
var evaluator = new ClassificationEvaluator();
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
CheckMetrics(metrics);

var trainTest = new TrainTestEvaluator() { Kind = MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer }.TrainTestEvaluate<IrisData, IrisPrediction>(pipeline, testData);
CheckMetrics(trainTest.ClassificationMetrics);
}

private void CheckMetrics(ClassificationMetrics metrics)
{
Assert.Equal(.96, metrics.AccuracyMacro, 2);
Assert.Equal(.96, metrics.AccuracyMicro, 2);
Assert.Equal(.19, metrics.LogLoss, 1);
Assert.InRange(metrics.LogLossReduction, 80, 84);
}
}
}