Skip to content

Commit

Permalink
Enabling Ranking Cross Validation (#5263)
Browse files Browse the repository at this point in the history
** Enabling CrossValidation in ML.NET and ranking compatibility with the AutoML API
  • Loading branch information
Lynx1820 authored Jul 10, 2020
1 parent bc10d60 commit f87a3bb
Show file tree
Hide file tree
Showing 26 changed files with 256 additions and 89 deletions.
3 changes: 2 additions & 1 deletion src/Microsoft.ML.AutoML/API/ColumnInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ public sealed class ColumnInformation
public string UserIdColumnName { get; set; }

/// <summary>
/// The dataset column to use as a group ID for computation.
/// The dataset column to use as a group ID for computation in a Ranking Task.
/// If a SamplingKeyColumnName is provided, then it should be the same as this column.
/// </summary>
public string GroupIdColumnName { get; set; }

Expand Down
58 changes: 46 additions & 12 deletions src/Microsoft.ML.AutoML/API/ExperimentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,24 @@ internal ExperimentBase(MLContext context,
public ExperimentResult<TMetrics> Execute(IDataView trainData, string labelColumnName = DefaultColumnNames.Label,
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
{
var columnInformation = new ColumnInformation()
ColumnInformation columnInformation;
if (_task == TaskKind.Ranking)
{
LabelColumnName = labelColumnName,
SamplingKeyColumnName = samplingKeyColumn
};
columnInformation = new ColumnInformation()
{
LabelColumnName = labelColumnName,
SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId,
GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
};
}
else
{
columnInformation = new ColumnInformation()
{
LabelColumnName = labelColumnName,
SamplingKeyColumnName = samplingKeyColumn
};
}
return Execute(trainData, columnInformation, preFeaturizer, progressHandler);
}

Expand Down Expand Up @@ -102,19 +115,28 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation
const int crossValRowCountThreshold = 15000;

var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);
var samplingKeyColumnName = GetSamplingKey(columnInformation?.GroupIdColumnName, columnInformation?.SamplingKeyColumnName);
if (rowCount < crossValRowCountThreshold)
{
const int numCrossValFolds = 10;
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumnName);
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, samplingKeyColumnName);
return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
}
else
{
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, samplingKeyColumnName);
return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler);
}
}

private string GetSamplingKey(string groupIdColumnName, string samplingKeyColumnName)
{
UserInputValidationUtil.ValidateSamplingKey(samplingKeyColumnName, groupIdColumnName, _task);
if (_task == TaskKind.Ranking)
return groupIdColumnName ?? DefaultColumnNames.GroupId;
return samplingKeyColumnName;
}

/// <summary>
/// Executes an AutoML experiment.
/// </summary>
Expand All @@ -136,7 +158,10 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation
/// </remarks>
public ExperimentResult<TMetrics> Execute(IDataView trainData, IDataView validationData, string labelColumnName = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
{
var columnInformation = new ColumnInformation() { LabelColumnName = labelColumnName };
var columnInformation = (_task == TaskKind.Ranking) ?
new ColumnInformation() { LabelColumnName = labelColumnName, GroupIdColumnName = DefaultColumnNames.GroupId } :
new ColumnInformation() { LabelColumnName = labelColumnName };

return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
}

Expand Down Expand Up @@ -194,7 +219,8 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData, ui
IProgress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
{
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName);
var samplingKeyColumnName = GetSamplingKey(columnInformation?.GroupIdColumnName, columnInformation?.SamplingKeyColumnName);
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, samplingKeyColumnName);
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
}

Expand Down Expand Up @@ -223,7 +249,15 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData,
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null,
Progress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
{
var columnInformation = new ColumnInformation()
var columnInformation = (_task == TaskKind.Ranking) ?
new ColumnInformation()
{
LabelColumnName = labelColumnName,
SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId,
GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
}
:
new ColumnInformation()
{
LabelColumnName = labelColumnName,
SamplingKeyColumnName = samplingKeyColumn
Expand Down Expand Up @@ -253,7 +287,7 @@ private ExperimentResult<TMetrics> ExecuteTrainValidate(IDataView trainData,
validationData = preprocessorTransform.Transform(validationData);
}

var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumnName, MetricsAgent,
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.GroupIdColumnName, columnInfo.LabelColumnName, MetricsAgent,
preFeaturizer, preprocessorTransform, _logger);
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
Expand All @@ -273,7 +307,7 @@ private CrossValidationExperimentResult<TMetrics> ExecuteCrossVal(IDataView[] tr
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);

var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, MetricsAgent, preFeaturizer,
preprocessorTransforms, columnInfo.LabelColumnName, _logger);
preprocessorTransforms, columnInfo.GroupIdColumnName, columnInfo.LabelColumnName, _logger);
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);

// Execute experiment & get all pipelines run
Expand All @@ -300,7 +334,7 @@ private ExperimentResult<TMetrics> ExecuteCrossValSummary(IDataView[] trainDatas
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);

var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, MetricsAgent, preFeaturizer,
preprocessorTransforms, columnInfo.LabelColumnName, OptimizingMetricInfo, _logger);
preprocessorTransforms, columnInfo.GroupIdColumnName, columnInfo.LabelColumnName, OptimizingMetricInfo, _logger);
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
}
Expand Down
19 changes: 5 additions & 14 deletions src/Microsoft.ML.AutoML/API/RankingExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// <value>The default value is <see cref="RankingMetric" />.</value>
public RankingMetric OptimizingMetric { get; set; }

/// <summary>
/// Name for the GroupId column.
/// </summary>
/// <value>The default value is GroupId.</value>
public string GroupIdColumnName { get; set; }

/// <summary>
/// Collection of trainers the AutoML experiment can leverage.
/// </summary>
Expand All @@ -34,7 +28,6 @@ public sealed class RankingExperimentSettings : ExperimentSettings
public ICollection<RankingTrainer> Trainers { get; }
public RankingExperimentSettings()
{
GroupIdColumnName = "GroupId";
OptimizingMetric = RankingMetric.Ndcg;
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
}
Expand Down Expand Up @@ -75,11 +68,10 @@ public static class RankingExperimentResultExtensions
/// </summary>
/// <param name="results">Enumeration of AutoML experiment run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
/// <returns>The best experiment run.</returns>
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
{
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var metricsAgent = new RankingMetricsAgent(null, metric);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -89,11 +81,10 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
/// </summary>
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
/// <returns>The best experiment run.</returns>
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
{
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var metricsAgent = new RankingMetricsAgent(null, metric);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -112,7 +103,7 @@ public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingEx
{
internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
: base(context,
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.GroupIdColumnName),
new RankingMetricsAgent(context, settings.OptimizingMetric),
new OptimizingMetricInfo(settings.OptimizingMetric),
settings,
TaskKind.Ranking,
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.AutoML/ColumnInference/ColumnInformationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ internal static class ColumnInformationUtil
return ColumnPurpose.Weight;
}

if (columnName == columnInfo.GroupIdColumnName)
{
return ColumnPurpose.GroupId;
}

if (columnName == columnInfo.SamplingKeyColumnName)
{
return ColumnPurpose.SamplingKey;
Expand Down Expand Up @@ -51,11 +56,6 @@ internal static class ColumnInformationUtil
return ColumnPurpose.UserId;
}

if (columnName == columnInfo.GroupIdColumnName)
{
return ColumnPurpose.GroupId;
}

if (columnName == columnInfo.ItemIdColumnName)
{
return ColumnPurpose.ItemId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public bool IsModelPerfect(double score)
}
}

public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
{
return _mlContext.BinaryClassification.EvaluateNonCalibrated(data, labelColumn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ internal interface IMetricsAgent<T>

bool IsModelPerfect(double score);

T EvaluateMetrics(IDataView data, string labelColumn);
// GroupId is a parameter used only in RankingMetricsAgent
T EvaluateMetrics(IDataView data, string labelColumn, string groupId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public bool IsModelPerfect(double score)
}
}

public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
{
return _mlContext.MulticlassClassification.Evaluate(data, labelColumn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
{
private readonly MLContext _mlContext;
private readonly RankingMetric _optimizingMetric;
private readonly string _groupIdColumnName;

public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric, string groupIdColumnName)
public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric)
{
_mlContext = mlContext;
_optimizingMetric = optimizingMetric;
_groupIdColumnName = groupIdColumnName;
}

// Optimizing metric used: NDCG@10 and DCG@10
Expand Down Expand Up @@ -59,9 +57,9 @@ public bool IsModelPerfect(double score)
}
}

public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn)
public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
{
return _mlContext.Ranking.Evaluate(data, labelColumn, _groupIdColumnName);
return _mlContext.Ranking.Evaluate(data, labelColumn, groupIdColumn);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public bool IsModelPerfect(double score)
}
}

public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn)
public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
{
return _mlContext.Regression.Evaluate(data, labelColumn);
}
Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.AutoML/Experiment/Runners/CrossValRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal class CrossValRunner<TMetrics> : IRunner<CrossValidationRunDetail<TMetr
private readonly IMetricsAgent<TMetrics> _metricsAgent;
private readonly IEstimator<ITransformer> _preFeaturizer;
private readonly ITransformer[] _preprocessorTransforms;
private readonly string _groupIdColumn;
private readonly string _labelColumn;
private readonly IChannel _logger;
private readonly DataViewSchema _modelInputSchema;
Expand All @@ -29,6 +30,7 @@ public CrossValRunner(MLContext context,
IMetricsAgent<TMetrics> metricsAgent,
IEstimator<ITransformer> preFeaturizer,
ITransformer[] preprocessorTransforms,
string groupIdColumn,
string labelColumn,
IChannel logger)
{
Expand All @@ -38,6 +40,7 @@ public CrossValRunner(MLContext context,
_metricsAgent = metricsAgent;
_preFeaturizer = preFeaturizer;
_preprocessorTransforms = preprocessorTransforms;
_groupIdColumn = groupIdColumn;
_labelColumn = labelColumn;
_logger = logger;
_modelInputSchema = trainDatasets[0].Schema;
Expand All @@ -52,7 +55,7 @@ public CrossValRunner(MLContext context,
{
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
_labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
_groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
trainResults.Add(new SuggestedPipelineTrainResult<TMetrics>(trainResult.model, trainResult.metrics, trainResult.exception, trainResult.score));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal class CrossValSummaryRunner<TMetrics> : IRunner<RunDetail<TMetrics>>
private readonly IMetricsAgent<TMetrics> _metricsAgent;
private readonly IEstimator<ITransformer> _preFeaturizer;
private readonly ITransformer[] _preprocessorTransforms;
private readonly string _groupIdColumn;
private readonly string _labelColumn;
private readonly OptimizingMetricInfo _optimizingMetricInfo;
private readonly IChannel _logger;
Expand All @@ -31,6 +32,7 @@ public CrossValSummaryRunner(MLContext context,
IMetricsAgent<TMetrics> metricsAgent,
IEstimator<ITransformer> preFeaturizer,
ITransformer[] preprocessorTransforms,
string groupIdColumn,
string labelColumn,
OptimizingMetricInfo optimizingMetricInfo,
IChannel logger)
Expand All @@ -41,6 +43,7 @@ public CrossValSummaryRunner(MLContext context,
_metricsAgent = metricsAgent;
_preFeaturizer = preFeaturizer;
_preprocessorTransforms = preprocessorTransforms;
_groupIdColumn = groupIdColumn;
_labelColumn = labelColumn;
_optimizingMetricInfo = optimizingMetricInfo;
_logger = logger;
Expand All @@ -56,7 +59,7 @@ public CrossValSummaryRunner(MLContext context,
{
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
_labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
_groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
_logger);
trainResults.Add(trainResult);
}
Expand Down
Loading

0 comments on commit f87a3bb

Please sign in to comment.