Skip to content

Commit

Permalink
Create CalibratedPredictor instead of SchemaBindableCalibratedPredict…
Browse files Browse the repository at this point in the history
…or (dotnet#338)

`CalibratorUtils.TrainCalibrator` and `TrainCalibratorIfNeeded` now creates `CalibratedPredictor` instead of `SchemaBindableCalibratedPredictor` whenever the predictor implements `IValueMapper`.
  • Loading branch information
yaeldMS authored and eerhardt committed Jul 27, 2018
1 parent 6abc988 commit 2023d09
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 13 deletions.
18 changes: 6 additions & 12 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,10 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat
/// <param name="trainer">The trainer used to train the predictor.</param>
/// <param name="predictor">The predictor that needs calibration.</param>
/// <param name="data">The examples to used for calibrator training.</param>
/// <param name="needValueMapper">Indicates whether the predictor returned needs to be an <see cref="IValueMapper"/>.
/// This parameter is needed for OVA that uses the predictors as <see cref="IValueMapper"/>s. If it is false,
/// The predictor returned is an an <see cref="ISchemaBindableMapper"/>.</param>
/// <returns>The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.</returns>
public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
Expand All @@ -763,7 +760,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema))
return predictor;

return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data, needValueMapper);
return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data);
}

/// <summary>
Expand All @@ -775,13 +772,10 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
/// <param name="maxRows">The maximum rows to use for calibrator training.</param>
/// <param name="predictor">The predictor that needs calibration.</param>
/// <param name="data">The examples to used for calibrator training.</param>
/// <param name="needValueMapper">Indicates whether the predictor returned needs to be an <see cref="IValueMapper"/>.
/// This parameter is needed for OVA that uses the predictors as <see cref="IValueMapper"/>s. If it is false,
/// The predictor returned is an an <see cref="ISchemaBindableMapper"/>.</param>
/// <returns>The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.</returns>
public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
int maxRows, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
int maxRows, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
Expand Down Expand Up @@ -834,10 +828,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal
}
}
var cali = caliTrainer.FinishTraining(ch);
return CreateCalibratedPredictor(env, (IPredictorProducing<Float>)predictor, cali, needValueMapper);
return CreateCalibratedPredictor(env, (IPredictorProducing<Float>)predictor, cali);
}

public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Float> predictor, ICalibrator cali, bool needValueMapper = false)
public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Float> predictor, ICalibrator cali)
{
Contracts.Assert(predictor != null);
if (cali == null)
Expand All @@ -853,7 +847,7 @@ public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironm
var predWithFeatureScores = predictor as IPredictorWithFeatureWeights<Float>;
if (predWithFeatureScores != null && predictor is IParameterMixer<Float> && cali is IParameterMixer)
return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali);
if (needValueMapper)
if (predictor is IValueMapper)
return new CalibratedPredictor(env, predictor, cali);
return new SchemaBindableCalibratedPredictor(env, predictor, cali);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
else
calibrator = Args.Calibrator.CreateInstance(Host);
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
trainer, predictor, td, true);
trainer, predictor, td);
predictor = res as TScalarPredictor;
Host.Check(predictor != null, "Calibrated predictor does not implement the expected interface");
}
Expand Down
59 changes: 59 additions & 0 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -798,5 +798,64 @@ public void TestOvaMacro()
}
}
}

[Fact]
public void TestOvaMacroWithUncalibratedLearner()
{
var dataPath = GetDataPath(@"iris.txt");
using (var env = new TlcEnvironment(42))
{
// Specify subgraph for OVA
var subGraph = env.CreateExperiment();
var learnerInput = new Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false };
var learnerOutput = subGraph.Add(learnerInput);
// Create pipeline with OVA and multiclass scoring.
var experiment = env.CreateExperiment();
var importInput = new ML.Data.TextLoader(dataPath);
importInput.Arguments.Column = new TextLoaderColumn[]
{
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
};
var importOutput = experiment.Add(importInput);
var oneVersusAll = new Models.OneVersusAll
{
TrainingData = importOutput.Data,
Nodes = subGraph,
UseProbabilities = true,
};
var ovaOutput = experiment.Add(oneVersusAll);
var scoreInput = new ML.Transforms.DatasetScorer
{
Data = importOutput.Data,
PredictorModel = ovaOutput.PredictorModel
};
var scoreOutput = experiment.Add(scoreInput);
var evalInput = new ML.Models.ClassificationEvaluator
{
Data = scoreOutput.ScoredData
};
var evalOutput = experiment.Add(evalInput);
experiment.Compile();
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
experiment.Run();

var data = experiment.GetOutput(evalOutput.OverallMetrics);
var schema = data.Schema;
var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
Assert.True(b);
using (var cursor = data.GetRowCursor(col => col == accCol))
{
var getter = cursor.GetGetter<double>(accCol);
b = cursor.MoveNext();
Assert.True(b);
double acc = 0;
getter(ref acc);
Assert.Equal(0.71, acc, 2);
b = cursor.MoveNext();
Assert.False(b);
}
}
}
}
}

0 comments on commit 2023d09

Please sign in to comment.