Skip to content

Commit

Permalink
Extended contexts to regression and multiclass, added FFM pigstension
Browse files Browse the repository at this point in the history
  • Loading branch information
Zruty0 authored and TomFinley committed Sep 24, 2018
1 parent b88cc09 commit eb26489
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 128 deletions.
88 changes: 86 additions & 2 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ public static BinaryClassifierEvaluator.CalibratedResult Evaluate<T>(
}

/// <summary>
/// Evaluates scored binary classification data.
/// Evaluates scored binary classification data, if the predictions are not calibrated.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <param name="ctx">The binary classification context.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
/// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier.
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
public static BinaryClassifierEvaluator.Result Evaluate<T>(
Expand All @@ -83,5 +83,89 @@ public static BinaryClassifierEvaluator.Result Evaluate<T>(
var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { });
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName);
}

/// <summary>
/// Evaluates scored multiclass classification data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <typeparam name="TKey">The value type for the key label.</typeparam>
/// <param name="ctx">The multiclass classification context.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="pred">The index delegate for columns from the prediction of a multiclass classifier.
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
/// <param name="topK">If given a positive value, the <see cref="MultiClassClassifierEvaluator.Result.TopKAccuracy"/> will be filled with
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
/// the top-K values as being stored "correctly."</param>
/// <returns>The evaluation metrics.</returns>
public static MultiClassClassifierEvaluator.Result Evaluate<T, TKey>(
this MulticlassClassificationContext ctx,
DataView<T> data,
Func<T, Key<uint, TKey>> label,
Func<T, (Vector<float> score, Key<uint, TKey> predictedLabel)> pred,
int topK = 0)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(pred, nameof(pred));
env.CheckParam(topK >= 0, nameof(topK), "Must not be negative.");

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
(var scoreCol, var predCol) = pred(indexer.Indices);
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
string scoreName = indexer.Get(scoreCol);
string predName = indexer.Get(predCol);

var args = new MultiClassClassifierEvaluator.Arguments() { };
if (topK > 0)
args.OutputTopKAcc = topK;

var eval = new MultiClassClassifierEvaluator(env, args);
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName);
}

private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactory
{
private readonly IRegressionLoss _loss;
public TrivialRegressionLossFactory(IRegressionLoss loss) => _loss = loss;
public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss;
}

/// <summary>
/// Evaluates scored multiclass classification data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <param name="ctx">The regression context.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="score">The index delegate for predicted score column.</param>
/// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param>
/// <returns>The evaluation metrics.</returns>
public static RegressionEvaluator.Result Evaluate<T>(
this RegressionContext ctx,
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Scalar<float>> score,
IRegressionLoss loss = null)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(score, nameof(score));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
string scoreName = indexer.Get(score(indexer.Indices));

var args = new RegressionEvaluator.Arguments() { };
if (loss != null)
args.LossFunction = new TrivialRegressionLossFactory(loss);
return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName);
}
}
}
67 changes: 21 additions & 46 deletions src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -598,62 +598,37 @@ internal Result(IExceptionContext ectx, IRow overallResult, int topK)
}

/// <summary>
/// Evaluates scored regression data.
/// Evaluates scored multiclass classification data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <typeparam name="TKey">The value type for the key label.</typeparam>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="pred">The index delegate for columns from prediction of a multi-class classifier.
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
/// <param name="topK">If given a positive value, the <see cref="Result.TopKAccuracy"/> will be filled with
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
/// the top-K values as being stored "correctly."</param>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these outputs.</returns>
public static Result Evaluate<T, TKey>(
DataView<T> data,
Func<T, Key<uint, TKey>> label,
Func<T, (Vector<float> score, Key<uint, TKey> predictedLabel)> pred,
int topK = 0)
public Result Evaluate(IDataView data, string label, string score, string predictedLabel)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(pred, nameof(pred));
env.CheckParam(topK >= 0, nameof(topK), "Must not be negative.");

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
(var scoreCol, var predCol) = pred(indexer.Indices);
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
string scoreName = indexer.Get(scoreCol);
string predName = indexer.Get(predCol);

var args = new Arguments() { };
if (topK > 0)
args.OutputTopKAcc = topK;

var eval = new MultiClassClassifierEvaluator(env, args);

var roles = new RoleMappedData(data.AsDynamic, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(labelName),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predName));

var resultDict = eval.Evaluate(roles);
env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));

var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));

var resultDict = Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

Result result;
using (var cursor = overall.GetRowCursor(i => true))
{
var moved = cursor.MoveNext();
env.Assert(moved);
result = new Result(env, cursor, topK);
Host.Assert(moved);
result = new Result(Host, cursor, _outputTopKAcc ?? 0);
moved = cursor.MoveNext();
env.Assert(!moved);
Host.Assert(!moved);
}
return result;
}
Expand Down
62 changes: 20 additions & 42 deletions src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -219,65 +219,43 @@ internal Result(IExceptionContext ectx, IRow overallResult)
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
L1 = Fetch(RegressionEvaluator.L1);
L2 = Fetch(RegressionEvaluator.L2);
Rms= Fetch(RegressionEvaluator.Rms);
Rms = Fetch(RegressionEvaluator.Rms);
LossFn = Fetch(RegressionEvaluator.Loss);
RSquared = Fetch(RegressionEvaluator.RSquared);
}
}

private sealed class TrivialLossFactory : ISupportRegressionLossFactory
{
private readonly IRegressionLoss _loss;
public TrivialLossFactory(IRegressionLoss loss) => _loss = loss;
public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss;
}

/// <summary>
/// Evaluates scored regression data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="score">The index delegate for the predicted score column.</param>
/// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param>
/// <returns>The evaluation results for these outputs.</returns>
public static Result Evaluate<T>(
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Scalar<float>> score,
IRegressionLoss loss = null)
/// <param name="label">The name of the label column.</param>
/// <param name="score">The name of the predicted score column.</param>
/// <returns>The evaluation metrics for these outputs.</returns>
public Result Evaluate(
IDataView data,
string label,
string score)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(score, nameof(score));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
string scoreName = indexer.Get(score(indexer.Indices));

var args = new Arguments() { };
if (loss != null)
args.LossFunction = new TrivialLossFactory(loss);
var eval = new RegressionEvaluator(env, args);

var roles = new RoleMappedData(data.AsDynamic, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(labelName),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName));

var resultDict = eval.Evaluate(roles);
env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score));

var resultDict = Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

Result result;
using (var cursor = overall.GetRowCursor(i => true))
{
var moved = cursor.MoveNext();
env.Assert(moved);
result = new Result(env, cursor);
Host.Assert(moved);
result = new Result(Host, cursor);
moved = cursor.MoveNext();
env.Assert(!moved);
Host.Assert(!moved);
}
return result;
}
Expand Down
Loading

0 comments on commit eb26489

Please sign in to comment.