diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv index 22c2767d7a..7fc82434b4 100644 --- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv +++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv @@ -8,6 +8,7 @@ Models.BinaryClassificationEvaluator Evaluates a binary classification scored da Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output] Models.ClassificationEvaluator Evaluates a multi class classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate MultiClass Microsoft.ML.Runtime.Data.MultiClassMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput Models.ClusterEvaluator Evaluates a clustering scored dataset. Microsoft.ML.Runtime.Data.Evaluate Clustering Microsoft.ML.Runtime.Data.ClusteringMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput +Models.CrossValidationResultsCombiner Combine the metric data views returned from cross validation. Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro CombineMetrics Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+CombineMetricsInput Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+CombinedOutput Models.CrossValidator Cross validation for general learning Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro CrossValidate Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+Output] Models.CrossValidatorDatasetSplitter Split the dataset into the specified number of cross-validation folds (train and test sets) Microsoft.ML.Runtime.EntryPoints.CVSplit Split Microsoft.ML.Runtime.EntryPoints.CVSplit+Input Microsoft.ML.Runtime.EntryPoints.CVSplit+Output Models.DatasetTransformer Applies a TransformModel to a dataset. Microsoft.ML.Runtime.EntryPoints.ModelOperations Apply Microsoft.ML.Runtime.EntryPoints.ModelOperations+ApplyTransformModelInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+ApplyTransformModelOutput diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index d705b26010..05b7965842 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -1238,6 +1238,116 @@ "IEvaluatorOutput" ] }, + { + "Name": "Models.CrossValidationResultsCombiner", + "Desc": "Combine the metric data views returned from cross validation.", + "FriendlyName": null, + "ShortName": null, + "Inputs": [ + { + "Name": "OverallMetrics", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Overall metrics datasets", + "Required": false, + "SortOrder": 1.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "PerInstanceMetrics", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Per instance metrics datasets", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "ConfusionMatrix", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Confusion matrix datasets", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Warnings", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Warning datasets", + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "The label column name", + "Aliases": [ + "Label" + ], + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "Kind", + "Type": { + "Kind": "Enum", + "Values": [ + "SignatureBinaryClassifierTrainer", + "SignatureMultiClassClassifierTrainer", + "SignatureRankerTrainer", + "SignatureRegressorTrainer", + "SignatureMultiOutputRegressorTrainer", + "SignatureAnomalyDetectorTrainer", + "SignatureClusteringTrainer" + ] + }, + "Desc": "Specifies the trainer kind, which determines the evaluator to be used.", + "Required": true, + "SortOrder": 7.0, + "IsNullable": false, + "Default": "SignatureBinaryClassifierTrainer" + } + ], + "Outputs": [ + { + "Name": "Warnings", + "Type": "DataView", + "Desc": "Warning dataset" + }, + { + "Name": "OverallMetrics", + "Type": "DataView", + "Desc": "Overall metrics dataset" + }, + { + "Name": "PerInstanceMetrics", + "Type": "DataView", + "Desc": "Per instance metrics dataset" + }, + { + "Name": "ConfusionMatrix", + "Type": "DataView", + "Desc": "Confusion matrix dataset" + } + ] + }, { "Name": "Models.CrossValidator", "Desc": "Cross validation for general learning", @@ -1368,34 +1478,22 @@ }, { "Name": "Warnings", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Warning dataset" }, { "Name": "OverallMetrics", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Overall metrics dataset" }, { "Name": "PerInstanceMetrics", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Per instance metrics dataset" }, { "Name": "ConfusionMatrix", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Confusion matrix dataset" } ] diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index c430abf85a..fc78e72c53 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -211,38 +211,34 @@ private void RunCore(IChannel ch, string cmd) } // Print the overall results. - eval.PrintOverallResults(ch, Args.SummaryFilename, tasks.Select(t => t.Result.Metrics).ToArray()); + if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList)) + throw ch.Except("No overall metrics found"); + + var overall = eval.GetOverallResults(overallList.ToArray()); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds); + eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray()); Dictionary[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray(); SendTelemetryMetric(metricValues); // Save the per-instance results. if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { - Func, int, IDataView> getPerInstance = - (task, i) => - { - if (!Args.OutputExampleFoldIndex) - return task.Result.PerInstanceResults; - - // If the fold index is requested, add a column containing it. We use the first column in the data view - // as an input column to the LambdaColumnMapper, because it must have an input. - var inputColName = task.Result.PerInstanceResults.Schema.GetColumnName(0); - var inputColType = task.Result.PerInstanceResults.Schema.GetColumnType(0); - return Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn, inputColType.RawType, Host, - task.Result.PerInstanceResults, inputColName, MetricKinds.ColumnNames.FoldIndex, - inputColType, Args.NumFolds, i + 1, "FoldIndex", default(ValueGetter>)); - }; - - var foldDataViews = tasks.Select(getPerInstance).ToArray(); + var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics, + Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames); + if (variableSizeVectorColumnNames.Length > 0) + { + ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", + string.Join(", ", variableSizeVectorColumnNames)); + } if (Args.CollateMetrics) { - var perInst = AppendPerInstanceDataViews(foldDataViews, ch); - MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInst); + ch.Assert(perInstance.Length == 1); + MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]); } else { int i = 0; - foreach (var idv in foldDataViews) + foreach (var idv in perInstance) { MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv); i++; @@ -251,166 +247,6 @@ private void RunCore(IChannel ch, string cmd) } } - private IDataView AppendPerInstanceDataViews(IEnumerable foldDataViews, IChannel ch) - { - // Make sure there are no variable size vector columns. - // This is a dictionary from the column name to its vector size. - var vectorSizes = new Dictionary(); - var firstDvSlotNames = new Dictionary>(); - var firstDvKeyColumns = new List(); - var firstDvVectorKeyColumns = new List(); - var variableSizeVectorColumnNames = new List(); - var list = new List(); - int dvNumber = 0; - foreach (var dv in foldDataViews) - { - var hidden = new List(); - for (int i = 0; i < dv.Schema.ColumnCount; i++) - { - if (dv.Schema.IsHidden(i)) - { - hidden.Add(i); - continue; - } - - var type = dv.Schema.GetColumnType(i); - var name = dv.Schema.GetColumnName(i); - if (type.IsVector) - { - if (dvNumber == 0) - { - if (dv.Schema.HasKeyNames(i, type.ItemType.KeyCount)) - firstDvVectorKeyColumns.Add(name); - // Store the slot names of the 1st idv and use them as baseline. - if (dv.Schema.HasSlotNames(i, type.VectorSize)) - { - VBuffer slotNames = default(VBuffer); - dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); - firstDvSlotNames.Add(name, slotNames); - } - } - - int cachedSize; - if (vectorSizes.TryGetValue(name, out cachedSize)) - { - VBuffer slotNames; - // In the event that no slot names were recorded here, then slotNames will be - // the default, length 0 vector. - firstDvSlotNames.TryGetValue(name, out slotNames); - if (!VerifyVectorColumnsMatch(cachedSize, i, dv, type, ref slotNames)) - variableSizeVectorColumnNames.Add(name); - } - else - vectorSizes.Add(name, type.VectorSize); - } - else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount)) - { - // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform. - firstDvKeyColumns.Add(name); - } - } - var idv = dv; - if (hidden.Count > 0) - { - var args = new ChooseColumnsByIndexTransform.Arguments(); - args.Drop = true; - args.Index = hidden.ToArray(); - idv = new ChooseColumnsByIndexTransform(Host, args, idv); - } - list.Add(idv); - dvNumber++; - } - - if (variableSizeVectorColumnNames.Count == 0 && firstDvKeyColumns.Count == 0) - return AppendRowsDataView.Create(Host, null, list.ToArray()); - - var views = list.ToArray(); - foreach (var keyCol in firstDvKeyColumns) - EvaluateUtils.ReconcileKeyValues(Host, views, keyCol); - foreach (var vectorKeyCol in firstDvVectorKeyColumns) - EvaluateUtils.ReconcileVectorKeyValues(Host, views, vectorKeyCol); - - Func keyToValue = - (idv, i) => - { - foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns)) - { - idv = new KeyToValueTransform(Host, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv); - var hidden = FindHiddenColumns(idv.Schema, keyCol); - idv = new ChooseColumnsByIndexTransform(Host, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv); - } - return idv; - }; - - Func selectDropNonVarLenthCol = - (idv) => - { - foreach (var variableSizeVectorColumnName in variableSizeVectorColumnNames) - { - int index; - idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index); - var type = idv.Schema.GetColumnType(index); - - idv = Utils.MarshalInvoke(AddVarLengthColumn, type.ItemType.RawType, Host, idv, - variableSizeVectorColumnName, type); - - // Drop the old column that does not have variable length. - idv = new DropColumnsTransform(Host, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv); - } - return idv; - }; - - if (variableSizeVectorColumnNames.Count > 0) - ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", string.Join(", ", variableSizeVectorColumnNames)); - return AppendRowsDataView.Create(Host, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray()); - } - - private static IEnumerable FindHiddenColumns(ISchema schema, string colName) - { - for (int i = 0; i < schema.ColumnCount; i++) - { - if (schema.IsHidden(i) && schema.GetColumnName(i) == colName) - yield return i; - } - } - - private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView dv, - ColumnType type, ref VBuffer firstDvSlotNames) - { - if (cachedSize != type.VectorSize) - return false; - - // If we detect mismatch it a sign that slots reshuffling has happened. - if (dv.Schema.HasSlotNames(col, type.VectorSize)) - { - // Verify that slots match with slots from 1st idv. - VBuffer currSlotNames = default(VBuffer); - dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref currSlotNames); - - if (currSlotNames.Length != firstDvSlotNames.Length) - return false; - else - { - var result = true; - VBufferUtils.ForEachEitherDefined(ref currSlotNames, ref firstDvSlotNames, - (slot, val1, val2) => result = result && DvText.Identical(val1, val2)); - return result; - } - } - else - { - // If we don't have slot names, then the first dataview should not have had slot names either. - return firstDvSlotNames.Length == 0; - } - } - - private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc) - { - return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName, - variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive), - (ref VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); - } - /// /// Callback from the CV method to apply the transforms from the train data to the test and/or validation data. /// @@ -504,16 +340,32 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output return stratificationColumn; } + private bool TryGetOverallMetrics(Dictionary[] metrics, out List overallList) + { + Host.AssertNonEmpty(metrics); + + overallList = new List(); + for (int i = 0; i < metrics.Length; i++) + { + var dict = metrics[i]; + IDataView idv; + if (!dict.TryGetValue(MetricKinds.OverallMetrics, out idv)) + return false; + overallList.Add(idv); + } + return true; + } + private sealed class FoldHelper { public struct FoldResult { public readonly Dictionary Metrics; public readonly ISchema ScoreSchema; - public readonly IDataView PerInstanceResults; + public readonly RoleMappedData PerInstanceResults; public readonly RoleMappedSchema TrainSchema; - public FoldResult(Dictionary metrics, ISchema scoreSchema, IDataView perInstance, RoleMappedSchema trainSchema) + public FoldResult(Dictionary metrics, ISchema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema) { Metrics = metrics; ScoreSchema = scoreSchema; @@ -735,12 +587,11 @@ private FoldResult RunFold(int fold) var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames()); var dict = eval.Evaluate(dataEval); - IDataView perInstance = null; + RoleMappedData perInstance = null; if (_savePerInstance) { var perInst = eval.GetPerInstanceMetrics(dataEval); - var perInstData = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames()); - perInstance = eval.GetPerInstanceDataViewToSave(perInstData); + perInstance = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames()); } ch.Done(); return new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema); diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs index 533dfdcc41..d0e066d789 100644 --- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs +++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs @@ -240,7 +240,11 @@ private void RunCore(IChannel ch) var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); - evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); + if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) + throw ch.Except("No overall metrics found"); + overall = evaluator.GetOverallResults(overall); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + evaluator.PrintAdditionalMetrics(ch, metrics); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index 73ca98f66e..79e7bd5458 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -118,7 +118,11 @@ private void RunCore(IChannel ch) var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); - evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); + if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) + throw ch.Except("No overall metrics found"); + overall = evaluator.GetOverallResults(overall); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index dda298e6b2..f6ffa772f9 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -194,7 +194,11 @@ private void RunCore(IChannel ch, string cmd) var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); - evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); + if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) + throw ch.Except("No overall metrics found"); + overall = evaluator.GetOverallResults(overall); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index adf4ee4fba..39a5f31c38 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -746,14 +746,8 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView GetOverallResultsCore(IDataView overall) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - var args = new DropColumnsTransform.Arguments(); args.Column = new[] { @@ -762,8 +756,7 @@ protected override void PrintOverallResultsCore(IChannel ch, string filename, Di AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP, AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos }; - overall = new DropColumnsTransform(Host, args, overall); - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return new DropColumnsTransform(Host, args, overall); } protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 673d497073..90078da9ee 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -1171,18 +1171,16 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView GetOverallResultsCore(IDataView overall) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - var args = new DropColumnsTransform.Arguments(); args.Column = new[] { BinaryClassifierEvaluator.Entropy }; - overall = new DropColumnsTransform(Host, args, overall); - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return new DropColumnsTransform(Host, args, overall); + } + + protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + { + ch.AssertNonEmpty(metrics); if (!string.IsNullOrEmpty(_prFileName)) { @@ -1228,14 +1226,7 @@ private bool TryGetPrMetrics(Dictionary[] metrics, out IDataV if (!dict.TryGetValue(BinaryClassifierEvaluator.PrCurve, out idv)) return false; if (metrics.Length != 1) - { - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn, inputColType.RawType, Host, idv, - inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, metrics.Length, i + 1, "FoldIndex", - default(ValueGetter>)); - } + idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length); else pr = idv; prList.Add(idv); diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 9cc25761f3..0e2de21530 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -10,7 +10,6 @@ using System.Text; using System.Threading; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data.Conversion; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; @@ -18,6 +17,13 @@ namespace Microsoft.ML.Runtime.Data { public static class EvaluateUtils { + public struct AggregatedMetric + { + public double Sum; + public double SumSq; + public string Name; + } + private static class DefaultEvaluatorTable { private static volatile Dictionary _knownEvaluatorLoadNames; @@ -200,7 +206,7 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, ISchem return null; } - public static bool IsScoreColumnKind(IExceptionContext ectx, ISchema schema, int col, string kind) + private static bool IsScoreColumnKind(IExceptionContext ectx, ISchema schema, int col, string kind) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(schema, nameof(schema)); @@ -359,7 +365,7 @@ public static IEnumerable> GetMetrics(IDataView met } } - public static IDataView AddTextColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, + private static IDataView AddTextColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, ColumnType typeSrc, string value, string registrationName) { Contracts.Check(typeSrc.RawType == typeof(TSrc)); @@ -367,7 +373,33 @@ public static IDataView AddTextColumn(IHostEnvironment env, IDataView inpu (ref TSrc src, ref DvText dst) => dst = new DvText(value)); } - public static IDataView AddKeyColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, + /// + /// Add a text column containing a fold index to a data view. + /// + /// The host environment. + /// The data view to which we add the column + /// The current fold this data view belongs to. + /// The input data view with an additional text column containing the current fold index. + public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int curFold) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(input, nameof(input)); + env.CheckParam(curFold >= 0, nameof(curFold)); + + // We use the first column in the data view as an input column to the LambdaColumnMapper, + // because it must have an input. + int inputCol = 0; + while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol)) + inputCol++; + env.Assert(inputCol < input.Schema.ColumnCount); + + var inputColName = input.Schema.GetColumnName(0); + var inputColType = input.Schema.GetColumnType(0); + return Utils.MarshalInvoke(AddTextColumn, inputColType.RawType, env, + input, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, $"Fold {curFold}", "FoldName"); + } + + private static IDataView AddKeyColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, ColumnType typeSrc, int keyCount, int value, string registrationName, ValueGetter> keyValueGetter) { Contracts.Check(typeSrc.RawType == typeof(TSrc)); @@ -381,6 +413,35 @@ public static IDataView AddKeyColumn(IHostEnvironment env, IDataView input }, keyValueGetter); } + /// + /// Add a key type column containing a fold index to a data view. + /// + /// The host environment. + /// The data view to which we add the column + /// The current fold this data view belongs to. + /// The total number of folds. + /// The input data view with an additional key type column containing the current fold index. + public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int curFold, int numFolds) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(input, nameof(input)); + env.CheckParam(curFold >= 0, nameof(curFold)); + env.CheckParam(numFolds > 0, nameof(numFolds)); + + // We use the first column in the data view as an input column to the LambdaColumnMapper, + // because it must have an input. + int inputCol = 0; + while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol)) + inputCol++; + env.Assert(inputCol < input.Schema.ColumnCount); + + var inputColName = input.Schema.GetColumnName(inputCol); + var inputColType = input.Schema.GetColumnType(inputCol); + return Utils.MarshalInvoke(AddKeyColumn, inputColType.RawType, env, + input, inputColName, MetricKinds.ColumnNames.FoldIndex, + inputColType, numFolds, curFold + 1, "FoldIndex", default(ValueGetter>)); + } + /// /// This method takes an array of data views and a specified input vector column, and adds a new output column to each of the data views. /// First, we find the union set of the slot names in the different data views. Next we define a new vector column for each @@ -639,6 +700,584 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi } } + /// + /// This method gets the per-instance metrics from multiple scored data views and either returns them as an + /// array or combines them into a single data view, based on user specifications. + /// + /// A host environment. + /// The evaluator to use for getting the per-instance metrics. + /// If true, data views are combined into a single data view. Otherwise, data views + /// are returned as an array. + /// If true, a column containing the fold index is added to the returned data views. + /// The array of scored data views to evaluate. These are passed as + /// so that the evaluator can know the role mappings it needs. + /// A list of column names that are not included in the combined data view + /// since their types do not match. + /// + public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env, IMamlEvaluator eval, bool collate, bool outputFoldIndex, RoleMappedData[] perInstance, out string[] variableSizeVectorColumnNames) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(eval, nameof(eval)); + env.CheckNonEmpty(perInstance, nameof(perInstance)); + + Func getPerInstance = + (rmd, i) => + { + var perInst = eval.GetPerInstanceDataViewToSave(rmd); + + if (!outputFoldIndex) + return perInst; + + // If the fold index is requested, add a column containing it. We use the first column in the data view + // as an input column to the LambdaColumnMapper, because it must have an input. + return AddFoldIndex(env, perInst, i, perInstance.Length); + }; + + var foldDataViews = perInstance.Select(getPerInstance).ToArray(); + if (collate) + { + var combined = AppendPerInstanceDataViews(env, foldDataViews, out variableSizeVectorColumnNames); + return new[] { combined }; + } + else + { + variableSizeVectorColumnNames = new string[0]; + return foldDataViews.ToArray(); + } + } + + /// + /// Create an output data view that is the vertical concatenation of the metric data views. + /// + public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataView[] metrics) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckNonEmpty(metrics, nameof(metrics)); + + if (metrics.Length == 1) + return metrics[0]; + + var overallList = new List(); + for (int i = 0; i < metrics.Length; i++) + { + // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file. + var idv = AddFoldIndex(env, metrics[i], i); + overallList.Add(idv); + } + return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray()); + } + + private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnumerable foldDataViews, out string[] variableSizeVectorColumnNames) + { + Contracts.AssertValue(env); + env.AssertValue(foldDataViews); + + // Make sure there are no variable size vector columns. + // This is a dictionary from the column name to its vector size. + var vectorSizes = new Dictionary(); + var firstDvSlotNames = new Dictionary>(); + var firstDvKeyColumns = new List(); + var firstDvVectorKeyColumns = new List(); + var variableSizeVectorColumnNamesList = new List(); + var list = new List(); + int dvNumber = 0; + foreach (var dv in foldDataViews) + { + var hidden = new List(); + for (int i = 0; i < dv.Schema.ColumnCount; i++) + { + if (dv.Schema.IsHidden(i)) + { + hidden.Add(i); + continue; + } + + var type = dv.Schema.GetColumnType(i); + var name = dv.Schema.GetColumnName(i); + if (type.IsVector) + { + if (dvNumber == 0) + { + if (dv.Schema.HasKeyNames(i, type.ItemType.KeyCount)) + firstDvVectorKeyColumns.Add(name); + // Store the slot names of the 1st idv and use them as baseline. + if (dv.Schema.HasSlotNames(i, type.VectorSize)) + { + VBuffer slotNames = default(VBuffer); + dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); + firstDvSlotNames.Add(name, slotNames); + } + } + + int cachedSize; + if (vectorSizes.TryGetValue(name, out cachedSize)) + { + VBuffer slotNames; + // In the event that no slot names were recorded here, then slotNames will be + // the default, length 0 vector. + firstDvSlotNames.TryGetValue(name, out slotNames); + if (!VerifyVectorColumnsMatch(cachedSize, i, dv, type, ref slotNames)) + variableSizeVectorColumnNamesList.Add(name); + } + else + vectorSizes.Add(name, type.VectorSize); + } + else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount)) + { + // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform. + firstDvKeyColumns.Add(name); + } + } + var idv = dv; + if (hidden.Count > 0) + { + var args = new ChooseColumnsByIndexTransform.Arguments(); + args.Drop = true; + args.Index = hidden.ToArray(); + idv = new ChooseColumnsByIndexTransform(env, args, idv); + } + list.Add(idv); + dvNumber++; + } + + variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray(); + if (variableSizeVectorColumnNamesList.Count == 0 && firstDvKeyColumns.Count == 0) + return AppendRowsDataView.Create(env, null, list.ToArray()); + + var views = list.ToArray(); + foreach (var keyCol in firstDvKeyColumns) + ReconcileKeyValues(env, views, keyCol); + foreach (var vectorKeyCol in firstDvVectorKeyColumns) + ReconcileVectorKeyValues(env, views, vectorKeyCol); + + Func keyToValue = + (idv, i) => + { + foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns)) + { + idv = new KeyToValueTransform(env, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv); + var hidden = FindHiddenColumns(idv.Schema, keyCol); + idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv); + } + return idv; + }; + + Func selectDropNonVarLenthCol = + (idv) => + { + foreach (var variableSizeVectorColumnName in variableSizeVectorColumnNamesList) + { + int index; + idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index); + var type = idv.Schema.GetColumnType(index); + + idv = Utils.MarshalInvoke(AddVarLengthColumn, type.ItemType.RawType, env, idv, + variableSizeVectorColumnName, type); + + // Drop the old column that does not have variable length. + idv = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv); + } + return idv; + }; + + return AppendRowsDataView.Create(env, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray()); + } + + private static IEnumerable FindHiddenColumns(ISchema schema, string colName) + { + for (int i = 0; i < schema.ColumnCount; i++) + { + if (schema.IsHidden(i) && schema.GetColumnName(i) == colName) + yield return i; + } + } + + private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView dv, + ColumnType type, ref VBuffer firstDvSlotNames) + { + if (cachedSize != type.VectorSize) + return false; + + // If we detect mismatch it a sign that slots reshuffling has happened. + if (dv.Schema.HasSlotNames(col, type.VectorSize)) + { + // Verify that slots match with slots from 1st idv. + VBuffer currSlotNames = default(VBuffer); + dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref currSlotNames); + + if (currSlotNames.Length != firstDvSlotNames.Length) + return false; + else + { + var result = true; + VBufferUtils.ForEachEitherDefined(ref currSlotNames, ref firstDvSlotNames, + (slot, val1, val2) => result = result && DvText.Identical(val1, val2)); + return result; + } + } + else + { + // If we don't have slot names, then the first dataview should not have had slot names either. + return firstDvSlotNames.Length == 0; + } + } + + private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc) + { + return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName, + variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive), + (ref VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); + } + + private static List GetMetricNames(IChannel ch, ISchema schema, IRow row, Func ignoreCol, + ValueGetter[] getters, ValueGetter>[] vBufferGetters) + { + ch.AssertValue(schema); + ch.AssertValue(row); + ch.Assert(Utils.Size(getters) == schema.ColumnCount); + ch.Assert(Utils.Size(vBufferGetters) == schema.ColumnCount); + + // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns + // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. + VBuffer names = default(VBuffer); + int metricCount = 0; + var metricNames = new List(); + for (int i = 0; i < schema.ColumnCount; i++) + { + if (schema.IsHidden(i) || ignoreCol(i)) + continue; + + var type = schema.GetColumnType(i); + var metricName = row.Schema.GetColumnName(i); + if (type.IsNumber) + { + getters[i] = RowCursorUtils.GetGetterAs(NumberType.R8, row, i); + metricNames.Add(metricName); + metricCount++; + } + else if (type.IsVector && type.ItemType == NumberType.R8) + { + if (type.VectorSize == 0) + { + ch.Warning("Vector metric '{0}' has different lengths in different folds and will not be averaged for overall results.", metricName); + continue; + } + + vBufferGetters[i] = row.GetGetter>(i); + metricCount += type.VectorSize; + var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); + if (slotNamesType != null && slotNamesType.VectorSize == type.VectorSize && slotNamesType.ItemType.IsText) + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref names); + else + { + var namesArray = names.Values; + if (Utils.Size(namesArray) < type.VectorSize) + namesArray = new DvText[type.VectorSize]; + for (int j = 0; j < type.VectorSize; j++) + namesArray[j] = new DvText(string.Format("Label_{0}", j)); + names = new VBuffer(type.VectorSize, namesArray); + } + foreach (var name in names.Items(all: true)) + metricNames.Add(string.Format("{0}{1}", metricName, name.Value)); + } + } + ch.Assert(metricNames.Count == metricCount); + return metricNames; + } + + internal static IDataView GetOverallMetricsData(IHostEnvironment env, IDataView data, int numFolds, out AggregatedMetric[] agg, + out AggregatedMetric[] weightedAgg) + { + agg = ComputeMetricsSum(env, data, numFolds, out int isWeightedCol, out int stratCol, out int stratVal, out int foldCol, out weightedAgg); + + var nonAveragedCols = new List(); + var avgMetrics = GetAverageToDataView(env, data.Schema, agg, weightedAgg, numFolds, stratCol, stratVal, + isWeightedCol, foldCol, numFolds > 1, nonAveragedCols); + + var idvList = new List() { avgMetrics }; + + var hasStrat = stratCol >= 0; + if (numFolds > 1 || hasStrat) + { + if (Utils.Size(nonAveragedCols) > 0) + { + var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() }; + data = new DropColumnsTransform(env, dropArgs, data); + } + idvList.Add(data); + } + + var overall = AppendRowsDataView.Create(env, avgMetrics.Schema, idvList.ToArray()); + + // If there are stratified results, apply a KeyToValue transform to get the stratification column + // names from the key column. + if (hasStrat) + { + var args = new KeyToValueTransform.Arguments(); + args.Column = new[] { new KeyToValueTransform.Column() { Source = MetricKinds.ColumnNames.StratCol }, }; + overall = new KeyToValueTransform(env, args, overall); + } + return overall; + } + + internal static AggregatedMetric[] ComputeMetricsSum(IHostEnvironment env, IDataView data, int numFolds, out int isWeightedCol, + out int stratCol, out int stratVal, out int foldCol, out AggregatedMetric[] weightedAgg) + { + var hasWeighted = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out int wcol); + var hasStrats = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out int scol); + var hasStratVals = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out int svalcol); + env.Assert(hasStrats == hasStratVals); + var hasFoldCol = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out int fcol); + + isWeightedCol = hasWeighted ? wcol : -1; + stratCol = hasStrats ? scol : -1; + stratVal = hasStratVals ? svalcol : -1; + foldCol = hasFoldCol ? fcol : -1; + + // We currently have only double valued or vector of double valued metrics. + int colCount = data.Schema.ColumnCount; + var getters = new ValueGetter[colCount]; + var vBufferGetters = new ValueGetter>[colCount]; + int numResults = 0; + int numWeightedResults = 0; + AggregatedMetric[] agg; + using (var cursor = data.GetRowCursor(col => true)) + { + DvBool isWeighted = DvBool.False; + ValueGetter isWeightedGetter; + if (hasWeighted) + isWeightedGetter = cursor.GetGetter(isWeightedCol); + else + isWeightedGetter = (ref DvBool dst) => dst = DvBool.False; + + ValueGetter stratColGetter; + if (hasStrats) + { + var type = cursor.Schema.GetColumnType(stratCol); + stratColGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); + } + else + stratColGetter = (ref uint dst) => dst = 0; + + // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns + // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. + List metricNames; + using (var ch = env.Register("GetMetricsAsString").Start("Get Metric Names")) + { + metricNames = GetMetricNames(ch, data.Schema, cursor, + i => hasWeighted && i == wcol || hasStrats && (i == scol || i == svalcol) || + hasFoldCol && i == fcol, getters, vBufferGetters); + ch.Done(); + } + agg = new AggregatedMetric[metricNames.Count]; + + Double metricVal = 0; + VBuffer metricVals = default(VBuffer); + if (hasWeighted) + weightedAgg = new AggregatedMetric[metricNames.Count]; + else + weightedAgg = null; + uint strat = 0; + while (cursor.MoveNext()) + { + stratColGetter(ref strat); + // REVIEW: how to print stratified results? + if (strat > 0) + continue; + + isWeightedGetter(ref isWeighted); + if (isWeighted.IsTrue) + { + // If !average, we should have only one relevant row. + if (numWeightedResults > numFolds) + throw Contracts.Except("Multiple weighted rows found in metrics data view."); + + numWeightedResults++; + UpdateSums(isWeightedCol, stratCol, stratVal, weightedAgg, numFolds > 1, metricNames, hasWeighted, + hasStrats, colCount, getters, vBufferGetters, ref metricVal, ref metricVals); + } + else + { + // If !average, we should have only one relevant row. + if (numResults > numFolds) + throw Contracts.Except("Multiple unweighted rows found in metrics data view."); + + numResults++; + UpdateSums(isWeightedCol, stratCol, stratVal, agg, numFolds > 1, metricNames, hasWeighted, hasStrats, + colCount, getters, vBufferGetters, ref metricVal, ref metricVals); + } + + if (numResults == numFolds && (!hasWeighted || numWeightedResults == numFolds)) + break; + } + } + return agg; + } + + private static void UpdateSums(int isWeightedCol, int stratCol, int stratVal, AggregatedMetric[] aggregated, bool hasStdev, List metricNames, bool hasWeighted, bool hasStrats, int colCount, ValueGetter[] getters, ValueGetter>[] vBufferGetters, ref double metricVal, ref VBuffer metricVals) + { + int iMetric = 0; + for (int i = 0; i < colCount; i++) + { + if (hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal)) + continue; + + if (getters[i] == null && vBufferGetters[i] == null) + { + // REVIEW: What to do with metrics that are not doubles? + continue; + } + if (getters[i] != null) + { + getters[i](ref metricVal); + aggregated[iMetric].Sum += metricVal; + if (hasStdev) + aggregated[iMetric].SumSq += metricVal * metricVal; + aggregated[iMetric].Name = metricNames[iMetric]; + iMetric++; + } + else + { + Contracts.AssertValue(vBufferGetters[i]); + vBufferGetters[i](ref metricVals); + foreach (var metric in metricVals.Items(all: true)) + { + aggregated[iMetric].Sum += metric.Value; + if (hasStdev) + aggregated[iMetric].SumSq += metric.Value * metric.Value; + aggregated[iMetric].Name = metricNames[iMetric]; + iMetric++; + } + } + } + Contracts.Assert(iMetric == metricNames.Count); + } + + internal static IDataView GetAverageToDataView(IHostEnvironment env, ISchema schema, AggregatedMetric[] agg, AggregatedMetric[] weightedAgg, + int numFolds, int stratCol, int stratVal, int isWeightedCol, int foldCol, bool hasStdev, List nonAveragedCols = null) + { + Contracts.AssertValue(env); + + int colCount = schema.ColumnCount; + + var dvBldr = new ArrayDataViewBuilder(env); + var weightedDvBldr = isWeightedCol >= 0 ? new ArrayDataViewBuilder(env) : null; + + int iMetric = 0; + for (int i = 0; i < colCount; i++) + { + if (schema.IsHidden(i)) + continue; + + var type = schema.GetColumnType(i); + var name = schema.GetColumnName(i); + if (i == stratCol) + { + var keyValuesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i); + if (keyValuesType == null || !keyValuesType.ItemType.IsText || + keyValuesType.VectorSize != type.KeyCount) + { + throw env.Except("Column '{0}' must have key values metadata", + MetricKinds.ColumnNames.StratCol); + } + + ValueGetter> getKeyValues = + (ref VBuffer dst) => + { + schema.GetMetadata(MetadataUtils.Kinds.KeyValues, stratCol, ref dst); + Contracts.Assert(dst.IsDense); + }; + + var keys = foldCol >= 0 ? new uint[] { 0, 0 } : new uint[] { 0 }; + dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, 0, type.KeyCount, keys); + weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, 0, type.KeyCount, keys); + } + else if (i == stratVal) + { + var stratVals = foldCol >= 0 ? new[] { DvText.NA, DvText.NA } : new[] { DvText.NA }; + dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextType.Instance, stratVals); + weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.StratVal, TextType.Instance, stratVals); + } + else if (i == isWeightedCol) + { + env.AssertValue(weightedDvBldr); + dvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BoolType.Instance, foldCol >= 0 ? new[] { DvBool.False, DvBool.False } : new[] { DvBool.False }); + weightedDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BoolType.Instance, foldCol >= 0 ? new[] { DvBool.True, DvBool.True } : new[] { DvBool.True }); + } + else if (i == foldCol) + { + var foldVals = new[] { new DvText("Average"), new DvText("Standard Deviation") }; + dvBldr.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextType.Instance, foldVals); + weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextType.Instance, foldVals); + } + else if (type.IsNumber) + { + dvBldr.AddScalarColumn(schema, agg, hasStdev, numFolds, iMetric); + weightedDvBldr?.AddScalarColumn(schema, weightedAgg, hasStdev, numFolds, iMetric); + iMetric++; + } + else if (type.IsKnownSizeVector && type.ItemType == NumberType.R8) + { + dvBldr.AddVectorColumn(env, schema, agg, hasStdev, numFolds, iMetric, i, type, name); + weightedDvBldr?.AddVectorColumn(env, schema, weightedAgg, hasStdev, numFolds, iMetric, i, type, name); + iMetric += type.VectorSize; + } + else + nonAveragedCols?.Add(name); + } + var idv = dvBldr.GetDataView(); + if (weightedDvBldr != null) + idv = AppendRowsDataView.Create(env, idv.Schema, idv, weightedDvBldr.GetDataView()); + return idv; + } + + private static void AddVectorColumn(this ArrayDataViewBuilder dvBldr, IHostEnvironment env, ISchema schema, + AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric, int i, ColumnType type, string columnName) + { + var vectorMetrics = new double[type.VectorSize]; + env.Assert(vectorMetrics.Length > 0); + for (int j = 0; j < vectorMetrics.Length; j++) + vectorMetrics[j] = agg[iMetric + j].Sum / numFolds; + double[] vectorStdevMetrics = null; + if (hasStdev) + { + vectorStdevMetrics = new double[type.VectorSize]; + for (int j = 0; j < vectorStdevMetrics.Length; j++) + vectorStdevMetrics[j] = Math.Sqrt(agg[iMetric + j].SumSq / numFolds - vectorMetrics[j] * vectorMetrics[j]); + } + var names = new DvText[type.VectorSize]; + for (int j = 0; j < names.Length; j++) + names[j] = new DvText(agg[iMetric + j].Name); + var slotNames = new VBuffer(type.VectorSize, names); + ValueGetter> getSlotNames = (ref VBuffer dst) => dst = slotNames; + if (vectorStdevMetrics != null) + { + env.AssertValue(vectorStdevMetrics); + dvBldr.AddColumn(columnName, getSlotNames, NumberType.R8, new[] { vectorMetrics, vectorStdevMetrics }); + } + else + dvBldr.AddColumn(columnName, getSlotNames, NumberType.R8, new[] { vectorMetrics }); + } + + private static void AddScalarColumn(this ArrayDataViewBuilder dvBldr, ISchema schema, AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric) + { + Contracts.AssertValue(dvBldr); + + var avg = agg[iMetric].Sum / numFolds; + if (hasStdev) + dvBldr.AddColumn(agg[iMetric].Name, NumberType.R8, avg, Math.Sqrt(agg[iMetric].SumSq / numFolds - avg * avg)); + else + dvBldr.AddColumn(agg[iMetric].Name, NumberType.R8, avg); + } + + /// + /// Takes a data view containing one or more rows of metrics, and returns a data view containing additional + /// rows with the average and the standard deviation of the metrics in the input data view. + /// + public static IDataView CombineFoldMetricsDataViews(IHostEnvironment env, IDataView data, int numFolds) + { + return GetOverallMetricsData(env, data, numFolds, out var _, out var _); + } } public static class MetricWriter @@ -791,286 +1430,56 @@ private static double[][] GetConfusionTableAsArray(IDataView confusionDataView, /// metrics. Otherwise it is assigned null. public static string GetPerFoldResults(IHostEnvironment env, IDataView fold, out string weightedMetrics) { - IDataView avgMetrics; - int isWeightedCol; - if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol)) - weightedMetrics = GetMetricsAsString(env, fold, true, 1, out avgMetrics); - else - weightedMetrics = null; - return GetMetricsAsString(env, fold, false, 1, out avgMetrics); + return GetFoldMetricsAsString(env, fold, out weightedMetrics); } - // This method returns a string representation of a set of metrics. If there are stratification columns, it looks for columns named - // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks - // for a DvBool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column. - // If nonAveragedCols is non-null, it computes the average and standard deviation over all the relevant rows and populates - // nonAveragedCols with columns that are either hidden, or are not of a type that we can display (i.e., either a numeric column, - // or a known length vector of doubles). - // If average is false, no averaging is done, and instead we check that there is exactly one relevant row. Otherwise, we - // add the vector columns of variable length of the list of non-averagable columns if nonAveragedCols is not null. - private static string GetMetricsAsString(IHostEnvironment env, IDataView data, bool weighted, - int numFolds, out IDataView avgMetricsDataView, bool average = false, List nonAveragedCols = null) + private static string GetOverallMetricsAsString(double[] sumMetrics, double[] sumSqMetrics, int numFolds, bool weighted, bool average, List metricNames) { - int isWeightedCol; - bool hasWeighted = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol); - // If the IsWeighted column is not present, weighted must be false. - Contracts.Assert(hasWeighted || !weighted); - - int stratCol; - bool hasStrats = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol); - int stratVal; - bool hasStratVals = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal); - Contracts.Assert(hasStrats == hasStratVals); - - int foldCol; - bool hasFoldCol = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol); - - // We currently have only double valued or vector of double valued metrics. - var colCount = data.Schema.ColumnCount; - var getters = new ValueGetter[colCount]; - var vBufferGetters = new ValueGetter>[colCount]; - - double[] avgMetrics; - double[] sumSqMetrics; - List metricNames; - int numResults = 0; - using (var cursor = data.GetRowCursor(col => true)) - { - DvBool isWeighted = DvBool.False; - ValueGetter isWeightedGetter; - if (hasWeighted) - isWeightedGetter = cursor.GetGetter(isWeightedCol); - else - isWeightedGetter = (ref DvBool dst) => dst = DvBool.False; - - ValueGetter stratColGetter; - if (hasStrats) - { - var type = cursor.Schema.GetColumnType(stratCol); - stratColGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); - } - else - stratColGetter = (ref uint dst) => dst = 0; - - // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns - // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. - using (var ch = env.Register("GetMetricsAsString").Start("Get Metric Names")) - { - metricNames = GetMetricNames(ch, data.Schema, cursor, - i => hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal) || - hasFoldCol && i == foldCol, getters, vBufferGetters); - ch.Done(); - } - - Double metricVal = 0; - VBuffer metricVals = default(VBuffer); - avgMetrics = new double[metricNames.Count]; - sumSqMetrics = new double[metricNames.Count]; - uint strat = 0; - while (cursor.MoveNext()) - { - isWeightedGetter(ref isWeighted); - if (isWeighted.IsTrue != weighted) - continue; - - stratColGetter(ref strat); - // REVIEW: how to print stratified results? - if (strat > 0) - continue; - - // If !average, we should have only one relevant row. - if (!average && numResults > 0) - throw Contracts.Except("Multiple {0} rows found in metrics data view.", weighted ? "weighted" : "unweighted"); - - numResults++; - int iMetric = 0; - for (int i = 0; i < colCount; i++) - { - if (hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal)) - continue; - - // REVIEW: What to do with metrics that are not doubles? - if (getters[i] != null) - { - getters[i](ref metricVal); - avgMetrics[iMetric] += metricVal; - if (sumSqMetrics != null) - sumSqMetrics[iMetric] += metricVal * metricVal; - iMetric++; - } - else if (vBufferGetters[i] != null) - { - vBufferGetters[i](ref metricVals); - foreach (var metric in metricVals.Items(all: true)) - { - avgMetrics[iMetric] += metric.Value; - if (sumSqMetrics != null) - sumSqMetrics[iMetric] += metric.Value * metric.Value; - iMetric++; - } - } - } - Contracts.Assert(iMetric == metricNames.Count); - - if (numResults == numFolds) - break; - } - } - var sb = new StringBuilder(); for (int i = 0; i < metricNames.Count; i++) { - avgMetrics[i] /= numResults; + var avg = sumMetrics[i] / numFolds; sb.Append(string.Format("{0}{1}: ", weighted ? "Weighted " : "", metricNames[i]).PadRight(20)); - sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", avgMetrics[i])); + sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", avg)); if (average) { - Contracts.AssertValue(sumSqMetrics); - sb.AppendLine(string.Format(" ({0:N4})", numResults == 1 ? 0 : - Math.Sqrt(sumSqMetrics[i] / numResults - avgMetrics[i] * avgMetrics[i]))); + Contracts.Assert(sumSqMetrics != null || numFolds == 1); + sb.AppendLine(string.Format(" ({0:N4})", numFolds == 1 ? 0 : + Math.Sqrt(sumSqMetrics[i] / numFolds - avg * avg))); } else sb.AppendLine(); } - - if (average) - { - var dvBldr = new ArrayDataViewBuilder(env); - int iMetric = 0; - for (int i = 0; i < colCount; i++) - { - if (hasStrats && i == stratCol) - { - var type = data.Schema.GetColumnType(i); - var keyValuesType = data.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i); - if (keyValuesType == null || !keyValuesType.ItemType.IsText || - keyValuesType.VectorSize != type.KeyCount) - { - throw env.Except("Column '{0}' must have key values metadata", - MetricKinds.ColumnNames.StratCol); - } - - ValueGetter> getKeyValues = - (ref VBuffer dst) => - { - data.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, stratCol, ref dst); - Contracts.Assert(dst.IsDense); - }; - - dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, 0, type.KeyCount, (uint)0); - } - else if (hasStratVals && i == stratVal) - dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextType.Instance, DvText.NA); - else if (hasWeighted && i == isWeightedCol) - dvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BoolType.Instance, weighted ? DvBool.True : DvBool.False); - else if (hasFoldCol && i == foldCol) - { - var avg = new DvText("Average"); - dvBldr.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextType.Instance, avg); - } - else if (getters[i] != null) - { - dvBldr.AddColumn(data.Schema.GetColumnName(i), NumberType.R8, avgMetrics[iMetric]); - iMetric++; - } - else if (vBufferGetters[i] != null) - { - var type = data.Schema.GetColumnType(i); - var vectorMetrics = new double[type.VectorSize]; - env.Assert(vectorMetrics.Length > 0); - Array.Copy(avgMetrics, iMetric, vectorMetrics, 0, vectorMetrics.Length); - var slotNamesType = data.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); - var name = data.Schema.GetColumnName(i); - var slotNames = default(VBuffer); - if (slotNamesType != null && slotNamesType.ItemType.IsText && - slotNamesType.VectorSize == type.VectorSize) - { - data.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); - Contracts.Assert(slotNames.IsDense); - var values = slotNames.Values; - for (int j = 0; j < values.Length; j++) - values[j] = new DvText(name + values[j]); - slotNames = new VBuffer(slotNames.Length, values, slotNames.Indices); - } - else - { - var values = slotNames.Values; - if (Utils.Size(values) < type.VectorSize) - values = new DvText[type.VectorSize]; - for (int j = 0; j < type.VectorSize; j++) - values[j] = new DvText(name + j); - slotNames = new VBuffer(type.VectorSize, values, slotNames.Indices); - } - ValueGetter> getSlotNames = (ref VBuffer dst) => dst = slotNames; - dvBldr.AddColumn(name, getSlotNames, NumberType.R8, new[] { vectorMetrics }); - iMetric += vectorMetrics.Length; - } - else - nonAveragedCols?.Add(data.Schema.GetColumnName(i)); - } - Contracts.Assert(iMetric == metricNames.Count); - avgMetricsDataView = dvBldr.GetDataView(); - } - else - avgMetricsDataView = null; - return sb.ToString(); } - private static List GetMetricNames(IChannel ch, ISchema schema, IRow row, Func ignoreCol, - ValueGetter[] getters, ValueGetter>[] vBufferGetters) + // This method returns a string representation of a set of metrics. If there are stratification columns, it looks for columns named + // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks + // for a DvBool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column. + // If nonAveragedCols is non-null, it computes the average and standard deviation over all the relevant rows and populates + // nonAveragedCols with columns that are either hidden, or are not of a type that we can display (i.e., either a numeric column, + // or a known length vector of doubles). + // If average is false, no averaging is done, and instead we check that there is exactly one relevant row. Otherwise, we + // add the vector columns of variable length of the list of non-averagable columns if nonAveragedCols is not null. + private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView data, out string weightedMetricsString) { - Contracts.AssertValue(schema); - Contracts.AssertValue(row); - Contracts.Assert(Utils.Size(getters) == schema.ColumnCount); - Contracts.Assert(Utils.Size(vBufferGetters) == schema.ColumnCount); + var metrics = EvaluateUtils.ComputeMetricsSum(env, data, 1, out int isWeightedCol, out int stratCol, + out int stratVal, out int foldCol, out var weightedMetrics); - // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns - // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. - VBuffer names = default(VBuffer); - int metricCount = 0; - var metricNames = new List(); - for (int i = 0; i < schema.ColumnCount; i++) + var sb = new StringBuilder(); + var weightedSb = isWeightedCol >= 0 ? new StringBuilder() : null; + for (int i = 0; i < metrics.Length; i++) { - if (schema.IsHidden(i) || ignoreCol(i)) - continue; - - var type = schema.GetColumnType(i); - var metricName = row.Schema.GetColumnName(i); - if (type.IsNumber) - { - getters[i] = RowCursorUtils.GetGetterAs(NumberType.R8, row, i); - metricNames.Add(metricName); - metricCount++; - } - else if (type.IsVector && type.ItemType == NumberType.R8) - { - if (type.VectorSize == 0) - { - ch.Warning("Vector metric '{0}' has different lengths in different folds and will not be averaged for overall results.", metricName); - continue; - } - - vBufferGetters[i] = row.GetGetter>(i); - metricCount += type.VectorSize; - var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); - if (slotNamesType != null && slotNamesType.VectorSize == type.VectorSize && slotNamesType.ItemType.IsText) - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref names); - else - { - var namesArray = names.Values; - if (Utils.Size(namesArray) < type.VectorSize) - namesArray = new DvText[type.VectorSize]; - for (int j = 0; j < type.VectorSize; j++) - namesArray[j] = new DvText(string.Format("Label_{0}", j)); - names = new VBuffer(type.VectorSize, namesArray); - } - foreach (var name in names.Items(all: true)) - metricNames.Add(string.Format("{0} {1}", metricName, name.Value)); - } + sb.Append($"{metrics[i].Name}: ".PadRight(20)); + sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", metrics[i].Sum)); + weightedSb?.Append($"Weighted {weightedMetrics[i].Name}: ".PadRight(20)); + weightedSb?.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", weightedMetrics[i].Sum)); + sb.AppendLine(); + weightedSb?.AppendLine(); } - Contracts.Assert(metricNames.Count == metricCount); - return metricNames; + + weightedMetricsString = weightedSb?.ToString(); + return sb.ToString(); } // Get a string representation of a confusion table. @@ -1181,58 +1590,26 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl /// public static void PrintOverallMetrics(IHostEnvironment env, IChannel ch, string filename, IDataView overall, int numFolds) { + var overallWithAvg = EvaluateUtils.GetOverallMetricsData(env, overall, numFolds, out var agg, out var weightedAgg); + var sb = new StringBuilder(); sb.AppendLine(); sb.AppendLine("OVERALL RESULTS"); sb.AppendLine("---------------------------------------"); - int isWeighted; - IDataView weightedAvgMetrics = null; var nonAveragedCols = new List(); - if (overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeighted)) - sb.Append(GetMetricsAsString(env, overall, true, numFolds, out weightedAvgMetrics, true)); - IDataView avgMetrics; - sb.AppendLine(GetMetricsAsString(env, overall, false, numFolds, out avgMetrics, true, nonAveragedCols)); - env.AssertValue(avgMetrics); - sb.AppendLine("---------------------------------------"); + if (weightedAgg != null) + sb.Append(GetOverallMetricsAsString(weightedAgg.Select(x => x.Sum).ToArray(), weightedAgg.Select(x => x.SumSq).ToArray(), numFolds, true, true, weightedAgg.Select(x => x.Name).ToList())); + sb.Append(GetOverallMetricsAsString(agg.Select(x => x.Sum).ToArray(), agg.Select(x => x.SumSq).ToArray(), numFolds, false, true, agg.Select(x => x.Name).ToList())); + sb.AppendLine("\n---------------------------------------"); ch.Info(sb.ToString()); if (!string.IsNullOrEmpty(filename)) { using (var file = env.CreateOutputFile(filename)) { - // idvList will contain all the dataviews that should be appended with AppendRowsDataView. - // If numResults=1, then we just save the average metrics. Otherwise, we remove all the non-metric columns - // (except for the IsWeighted column and FoldIndex column if present), and append to the average results. - var idvList = new List() { avgMetrics }; - if (weightedAvgMetrics != null) - idvList.Add(weightedAvgMetrics); - - int stratCol; - var hasStrat = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol); - if (numFolds > 1 || hasStrat) - { - if (Utils.Size(nonAveragedCols) > 0) - { - var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() }; - overall = new DropColumnsTransform(env, dropArgs, overall); - } - idvList.Add(overall); - } - - var summary = AppendRowsDataView.Create(env, avgMetrics.Schema, idvList.ToArray()); - - // If there are stratified results, apply a KeyToValue transform to get the stratification column - // names from the key column. - if (hasStrat) - { - var args = new KeyToValueTransform.Arguments(); - args.Column = new[] { new KeyToValueTransform.Column() { Source = MetricKinds.ColumnNames.StratCol }, }; - summary = new KeyToValueTransform(env, args, summary); - } - var saverArgs = new TextSaver.Arguments() { Dense = true, Silent = true }; - DataSaverUtils.SaveDataView(ch, new TextSaver(env, saverArgs), summary, file); + DataSaverUtils.SaveDataView(ch, new TextSaver(env, saverArgs), overallWithAvg, file); } } } diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index ada69d2d47..bbb53ba631 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -6,7 +6,6 @@ using System.Linq; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; namespace Microsoft.ML.Runtime.Data { @@ -25,13 +24,18 @@ public interface IMamlEvaluator : IEvaluator void PrintFoldResults(IChannel ch, Dictionary metrics); /// - /// Combine the aggregate metrics from multiple folds and print them to the console. If filename is not null then - /// also save the results to the specified file. If results are from multiple folds, the file will contain - /// the average results first, and then each fold result. - /// Also handle any custom kinds of custom metrics, such as p/r curves for binary classification, or group summary results - /// for ranking. + /// Combine the overall metrics from multiple folds into a single data view. /// - void PrintOverallResults(IChannel ch, string filename, params Dictionary[] metrics); + /// + /// + IDataView GetOverallResults(params IDataView[] metrics); + + /// + /// Handles custom metrics (such as p/r curves for binary classification, or group summary results for ranking) from one + /// or more folds. Implementations of this method typically creates a single data view for the custom metric and saves it + /// to a user specified file. + /// + void PrintAdditionalMetrics(IChannel ch, params Dictionary[] metrics); /// /// Create a data view containing only the columns that are saved as per-instance results by Maml commands. @@ -162,57 +166,36 @@ protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + public IDataView GetOverallResults(params IDataView[] metrics) { - Host.CheckValue(ch, nameof(ch)); Host.CheckNonEmpty(metrics, nameof(metrics)); - PrintOverallResultsCore(ch, filename, metrics); + var overall = CombineOverallMetricsCore(metrics); + return GetOverallResultsCore(overall); } - /// - /// This method simply prints the overall metrics using EvaluateUtils.PrintOverallMetrics. - /// Override if something else is needed. - /// - protected virtual void PrintOverallResultsCore(IChannel ch, string filename, Dictionary[] metrics) + protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return EvaluateUtils.ConcatenateOverallMetrics(Host, metrics); } - protected bool TryGetOverallMetrics(Dictionary[] metrics, out IDataView overall) + protected virtual IDataView GetOverallResultsCore(IDataView overall) { - Host.AssertNonEmpty(metrics); + return overall; + } - if (metrics.Length == 1) - return metrics[0].TryGetValue(MetricKinds.OverallMetrics, out overall); + public void PrintAdditionalMetrics(IChannel ch, params Dictionary[] metrics) + { + Host.CheckValue(ch, nameof(ch)); + Host.CheckNonEmpty(metrics, nameof(metrics)); + PrintAdditionalMetricsCore(ch, metrics); + } - overall = null; - var overallList = new List(); - for (int i = 0; i < metrics.Length; i++) - { - var dict = metrics[i]; - IDataView idv; - if (!dict.TryGetValue(MetricKinds.OverallMetrics, out idv)) - return false; - - // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file. - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - // We use DvText.NA as the value of this column since for any stratified row the value will be non empty, so we can uniquely identify - // the overall row using this column. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddTextColumn, inputColType.RawType, Host, - idv, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, string.Format("Fold {0}", i), "FoldName"); - - overallList.Add(idv); - } - overall = AppendRowsDataView.Create(Host, overallList[0].Schema, overallList.ToArray()); - return true; + /// + /// This method simply prints the overall metrics using EvaluateUtils.PrintOverallMetrics. + /// Override if something else is needed. + /// + protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + { } public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData) diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 0251f0a0e7..5507176fb4 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -853,7 +853,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView CombineOverallMetricsCore(IDataView[] metrics) { - ch.AssertNonEmpty(metrics); - var overallList = new List(); for (int i = 0; i < metrics.Length; i++) { - var dict = metrics[i]; - if (!dict.TryGetValue(MetricKinds.OverallMetrics, out IDataView idv)) - throw ch.Except("No overall metrics found"); - - // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file. - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - // We use DvText.NA as the value of this column since for any stratified row the value will be non empty, so we can uniquely identify - // the overall row using this column. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddTextColumn, inputColType.RawType, Host, - idv, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, string.Format("Fold {0}", i), "FoldName"); - + var idv = metrics[i]; if (!_outputPerClass) idv = DropPerClassColumn(idv); @@ -925,14 +911,15 @@ protected override void PrintOverallResultsCore(IChannel ch, string filename, Di views[i] = idv; } } + return base.CombineOverallMetricsCore(views); + } - var overall = AppendRowsDataView.Create(Host, views[0].Schema, views.ToArray()); - + protected override IDataView GetOverallResultsCore(IDataView overall) + { // Change the name of the Top-k-accuracy column. if (_outputTopKAcc != null) overall = ChangeTopKAccColumnName(overall); - - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return overall; } private IDataView ChangeTopKAccColumnName(IDataView input) diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index fa450286f1..6d61f6b965 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -491,17 +491,9 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView GetOverallResultsCore(IDataView overall) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - - // Show only the metrics for the requested index. - overall = ExtractRelevantIndex(overall); - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return ExtractRelevantIndex(overall); } private IDataView ExtractRelevantIndex(IDataView data) @@ -516,6 +508,8 @@ private IDataView ExtractRelevantIndex(IDataView data) var index = _index ?? type.VectorSize / 2; output = LambdaColumnMapper.Create(Host, "Quantile Regression", output, name, name, type, NumberType.R8, (ref VBuffer src, ref Double dst) => dst = src.GetItemOrDefault(index)); + output = new ChooseColumnsByIndexTransform(Host, + new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { i } }, output); } } return output; diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 420f875cc6..a383f835fd 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -854,9 +854,10 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args) return cols.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, groupIdCol)); } - protected override void PrintOverallResultsCore(IChannel ch, string filename, Dictionary[] metrics) + protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { - base.PrintOverallResultsCore(ch, filename, metrics); + ch.AssertNonEmpty(metrics); + if (!string.IsNullOrEmpty(_groupSummaryFilename)) { IDataView gs; @@ -887,12 +888,7 @@ private bool TryGetGroupSummaryMetrics(Dictionary[] metrics, if (!metrics[i].TryGetValue(RankerEvaluator.GroupSummary, out idv)) return false; - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn, inputColType.RawType, Host, idv, - inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, metrics.Length, i + 1, "FoldIndex", - default(ValueGetter>)); + idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length); gsList.Add(idv); } gs = AppendRowsDataView.Create(Host, gsList[0].Schema, gsList.ToArray()); diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index f1ea0a2d4b..5bb4782599 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -142,6 +142,18 @@ public void Add(Microsoft.ML.Models.ClusterEvaluator input, Microsoft.ML.Models. _jsonNodes.Add(Serialize("Models.ClusterEvaluator", input, output)); } + public Microsoft.ML.Models.CrossValidationResultsCombiner.Output Add(Microsoft.ML.Models.CrossValidationResultsCombiner input) + { + var output = new Microsoft.ML.Models.CrossValidationResultsCombiner.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Models.CrossValidationResultsCombiner input, Microsoft.ML.Models.CrossValidationResultsCombiner.Output output) + { + _jsonNodes.Add(Serialize("Models.CrossValidationResultsCombiner", input, output)); + } + public Microsoft.ML.Models.CrossValidator.Output Add(Microsoft.ML.Models.CrossValidator input) { var output = new Microsoft.ML.Models.CrossValidator.Output(); @@ -2080,6 +2092,73 @@ public enum MacroUtilsTrainerKinds } + /// + /// Combine the metric data views returned from cross validation. + /// + public sealed partial class CrossValidationResultsCombiner + { + + + /// + /// Overall metrics datasets + /// + public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); + + /// + /// Per instance metrics datasets + /// + public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); + + /// + /// Confusion matrix datasets + /// + public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); + + /// + /// Warning datasets + /// + public ArrayVar Warnings { get; set; } = new ArrayVar(); + + /// + /// The label column name + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// Specifies the trainer kind, which determines the evaluator to be used. + /// + public Models.MacroUtilsTrainerKinds Kind { get; set; } = Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + + + public sealed class Output + { + /// + /// Warning dataset + /// + public Var Warnings { get; set; } = new Var(); + + /// + /// Overall metrics dataset + /// + public Var OverallMetrics { get; set; } = new Var(); + + /// + /// Per instance metrics dataset + /// + public Var PerInstanceMetrics { get; set; } = new Var(); + + /// + /// Confusion matrix dataset + /// + public Var ConfusionMatrix { get; set; } = new Var(); + + } + } + } + + namespace Models + { + public sealed partial class CrossValidationMacroSubGraphInput { /// @@ -2156,22 +2235,22 @@ public sealed class Output /// /// Warning dataset /// - public ArrayVar Warnings { get; set; } = new ArrayVar(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); + public Var PerInstanceMetrics { get; set; } = new Var(); /// /// Confusion matrix dataset /// - public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); + public Var ConfusionMatrix { get; set; } = new Var(); } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 1cb950f939..f39dd2ec3f 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; using System.Linq; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(void), typeof(CrossValidationMacro), null, typeof(SignatureEntryPointModule), "CrossValidationMacro")] @@ -83,16 +83,52 @@ public sealed class Output public IPredictorModel[] PredictorModel; [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] - public IDataView[] Warnings; + public IDataView Warnings; [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] - public IDataView[] OverallMetrics; + public IDataView OverallMetrics; [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] - public IDataView[] PerInstanceMetrics; + public IDataView PerInstanceMetrics; [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + public IDataView ConfusionMatrix; + } + + public sealed class CombineMetricsInput + { + [Argument(ArgumentType.Multiple, HelpText = "Overall metrics datasets", SortOrder = 1)] + public IDataView[] OverallMetrics; + + [Argument(ArgumentType.Multiple, HelpText = "Per instance metrics datasets", SortOrder = 2)] + public IDataView[] PerInstanceMetrics; + + [Argument(ArgumentType.Multiple, HelpText = "Confusion matrix datasets", SortOrder = 3)] public IDataView[] ConfusionMatrix; + + [Argument(ArgumentType.Multiple, HelpText = "Warning datasets", SortOrder = 4)] + public IDataView[] Warnings; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 6)] + public string LabelColumn = DefaultColumnNames.Label; + + [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 7)] + public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; + } + + public sealed class CombinedOutput + { + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + public IDataView Warnings; + + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + public IDataView OverallMetrics; + + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + public IDataView PerInstanceMetrics; + + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + public IDataView ConfusionMatrix; } [TlcModule.EntryPoint(Desc = "Cross validation for general learning", Name = "Models.CrossValidator")] @@ -206,6 +242,7 @@ public static CommonOutputs.MacroOutput CrossValidate( exp.Reset(); + // Convert predictors from all folds into an array of predictors. var outModels = new ML.Data.PredictorModelArrayConverter { Model = new ArrayVar(predModelVars) @@ -214,45 +251,159 @@ public static CommonOutputs.MacroOutput CrossValidate( outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); exp.Add(outModels, outModelsOutput); + // Convert warnings data views from all folds into an array of data views. var warnings = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(warningsVars) }; var warningsOutput = new ML.Data.IDataViewArrayConverter.Output(); - warningsOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.Warnings)); exp.Add(warnings, warningsOutput); + // Convert overall metrics data views from all folds into an array of data views. var overallMetrics = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(overallMetricsVars) }; var overallMetricsOutput = new ML.Data.IDataViewArrayConverter.Output(); - overallMetricsOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.OverallMetrics)); exp.Add(overallMetrics, overallMetricsOutput); + // Convert per instance data views from all folds into an array of data views. var instanceMetrics = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(instanceMetricsVars) }; var instanceMetricsOutput = new ML.Data.IDataViewArrayConverter.Output(); - instanceMetricsOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); exp.Add(instanceMetrics, instanceMetricsOutput); + ML.Data.IDataViewArrayConverter.Output confusionMatricesOutput = null; if (input.Kind == MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer || input.Kind == MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer) { + // Convert confusion matrix data views from all folds into an array of data views. var confusionMatrices = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(confusionMatrixVars) }; - var confusionMatricesOutput = new ML.Data.IDataViewArrayConverter.Output(); - confusionMatricesOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); + confusionMatricesOutput = new ML.Data.IDataViewArrayConverter.Output(); exp.Add(confusionMatrices, confusionMatricesOutput); } - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); + var combineArgs = new CombineMetricsInput(); + combineArgs.Kind = input.Kind; + + // Set the input bindings for the CombineMetrics entry point. + var combineInputBindingMap = new Dictionary>(); + var combineInputMap = new Dictionary(); + var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics)); + combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List { overallArray }); + combineInputMap.Add(overallArray, new SimpleVariableBinding(overallMetricsOutput.OutputData.VarName)); + var combinePerInstArray = new SimpleParameterBinding(nameof(combineArgs.PerInstanceMetrics)); + combineInputBindingMap.Add(nameof(combineArgs.PerInstanceMetrics), new List { combinePerInstArray }); + combineInputMap.Add(combinePerInstArray, new SimpleVariableBinding(instanceMetricsOutput.OutputData.VarName)); + if (confusionMatricesOutput != null) + { + var combineConfArray = new SimpleParameterBinding(nameof(combineArgs.ConfusionMatrix)); + combineInputBindingMap.Add(nameof(combineArgs.ConfusionMatrix), new List { combineConfArray }); + combineInputMap.Add(combineConfArray, new SimpleVariableBinding(confusionMatricesOutput.OutputData.VarName)); + } + + var combineOutputMap = new Dictionary(); + var combineWarningVar = new Var(); + combineWarningVar.VarName = node.GetOutputVariableName(nameof(Output.Warnings)); + combineOutputMap.Add(nameof(Output.Warnings), combineWarningVar.VarName); + var combineOverallMetric = new Var(); + combineOverallMetric.VarName = node.GetOutputVariableName(nameof(Output.OverallMetrics)); + combineOutputMap.Add(nameof(Output.OverallMetrics), combineOverallMetric.VarName); + var combineInstanceMetric = new Var(); + combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); + combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName); + var combineConfusionMatrix = new Var(); + combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); + combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); + subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); + subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Catalog, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap)); return new CommonOutputs.MacroOutput() { Nodes = subGraphNodes }; } + + [TlcModule.EntryPoint(Desc = "Combine the metric data views returned from cross validation.", Name = "Models.CrossValidationResultsCombiner")] + public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetricsInput input) + { + var eval = GetEvaluator(env, input.Kind); + var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select( + idv => RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn))).ToArray(), + out var variableSizeVectorColumnNames); + + var warnings = input.Warnings != null ? new List(input.Warnings) : new List(); + if (variableSizeVectorColumnNames.Length > 0) + { + var dvBldr = new ArrayDataViewBuilder(env); + var warn = $"Detected columns of variable length: {string.Join(", ", variableSizeVectorColumnNames)}." + + $" Consider setting collateMetrics- for meaningful per-Folds results."; + dvBldr.AddColumn(MetricKinds.ColumnNames.WarningText, TextType.Instance, new DvText(warn)); + warnings.Add(dvBldr.GetDataView()); + } + + env.Assert(Utils.Size(perInst) == 1); + + var overall = eval.GetOverallResults(input.OverallMetrics); + overall = EvaluateUtils.CombineFoldMetricsDataViews(env, overall, input.OverallMetrics.Length); + + IDataView conf = null; + if (Utils.Size(input.ConfusionMatrix) > 0) + { + EvaluateUtils.ReconcileSlotNames(env, input.ConfusionMatrix, MetricKinds.ColumnNames.Count, NumberType.R8); + + for (int i = 0; i < input.ConfusionMatrix.Length; i++) + { + var idv = input.ConfusionMatrix[i]; + // Find the old Count column and drop it. + for (int col = 0; col < idv.Schema.ColumnCount; col++) + { + if (idv.Schema.IsHidden(col) && + idv.Schema.GetColumnName(col).Equals(MetricKinds.ColumnNames.Count)) + { + input.ConfusionMatrix[i] = new ChooseColumnsByIndexTransform(env, + new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { col } }, idv); + break; + } + } + } + conf = EvaluateUtils.ConcatenateOverallMetrics(env, input.ConfusionMatrix); + } + + var warningsIdv = warnings.Count > 0 ? AppendRowsDataView.Create(env, warnings[0].Schema, warnings.ToArray()) : null; + + return new CombinedOutput() + { + PerInstanceMetrics = perInst[0], + OverallMetrics = overall, + ConfusionMatrix = conf, + Warnings = warningsIdv + }; + } + + private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.TrainerKinds kind) + { + switch (kind) + { + case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: + return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: + return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRegressorTrainer: + return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRankerTrainer: + return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: + return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureClusteringTrainer: + return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: + return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); + default: + throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); + } + } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 2feef7c60e..c7c199f2d1 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -2,24 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using ML = Microsoft.ML; -using Microsoft.ML.Runtime; +using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.TestFramework; using Xunit; using Xunit.Abstractions; -/*using Categorical = Microsoft.ML.Transforms; -using Commands = Microsoft.ML.Transforms; -using Evaluate = Microsoft.ML; -using ImportTextData = Microsoft.ML.Data; -using LogisticRegression = Microsoft.ML.Trainers; -using ModelOperations = Microsoft.ML.Transforms; -using Normalize = Microsoft.ML.Transforms; -using SchemaManipulation = Microsoft.ML.Transforms; -using ScoreModel = Microsoft.ML.Transforms; -using Sdca = Microsoft.ML.Trainers;*/ namespace Microsoft.ML.Runtime.RunTests { @@ -249,7 +238,7 @@ public void TestCrossValidationBinaryMacro() var crossValidateOutput = experiment.Add(crossValidateBinary); experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + importInput.SetInput(env, experiment); experiment.Run(); var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); @@ -274,7 +263,7 @@ public void TestCrossValidationBinaryMacro() public void TestCrossValidationMacro() { var dataPath = GetDataPath(TestDatasets.winequality.trainFilename); - using (var env = new TlcEnvironment()) + using (var env = new TlcEnvironment(42)) { var subGraph = env.CreateExperiment(); @@ -334,32 +323,186 @@ public void TestCrossValidationMacro() var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + importInput.SetInput(env, experiment); experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol)) + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) { var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. b = cursor.MoveNext(); Assert.True(b); + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(0.0013, stdev, 4); + + double sum = 0; double val = 0; - getter(ref val); - Assert.Equal(0.58, val, 1); + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + } + Assert.Equal(avg, sum / 2); b = cursor.MoveNext(); Assert.False(b); } } } + [Fact] + public void TestCrossValidationMacroWithMultiClass() + { + var dataPath = GetDataPath(@"Train-Tiny-28x28.txt"); + using (var env = new TlcEnvironment(42)) + { + var subGraph = env.CreateExperiment(); + + var nop = new ML.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); + + var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); + + var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(nopOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); + + var crossValidate = new ML.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + Kind = ML.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, + TransformModel = null + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + importInput.SetInput(env, experiment); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. + b = cursor.MoveNext(); + Assert.True(b); + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(0.025, stdev, 3); + + double sum = 0; + double val = 0; + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + } + Assert.Equal(avg, sum / 2); + b = cursor.MoveNext(); + Assert.False(b); + } + + var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix); + schema = confusion.Schema; + b = schema.TryGetColumnIndex("Count", out int countCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out foldCol); + Assert.True(b); + var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol); + Assert.True(type != null && type.ItemType.IsText && type.VectorSize == 10); + var slotNames = default(VBuffer); + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames); + Assert.True(slotNames.Values.Select((s, i) => s.EqualsStr(i.ToString())).All(x => x)); + using (var curs = confusion.GetRowCursor(col => true)) + { + var countGetter = curs.GetGetter>(countCol); + var foldGetter = curs.GetGetter(foldCol); + var confCount = default(VBuffer); + var foldIndex = default(DvText); + int rowCount = 0; + var foldCur = "Fold 0"; + while (curs.MoveNext()) + { + countGetter(ref confCount); + foldGetter(ref foldIndex); + rowCount++; + Assert.True(foldIndex.EqualsStr(foldCur)); + if (rowCount == 10) + { + rowCount = 0; + foldCur = "Fold 1"; + } + } + Assert.Equal(0, rowCount); + } + } + } + [Fact] public void TestCrossValidationMacroWithStratification() { var dataPath = GetDataPath(@"breast-cancer.txt"); - using (var env = new TlcEnvironment()) + using (var env = new TlcEnvironment(42)) { var subGraph = env.CreateExperiment(); @@ -400,23 +543,51 @@ public void TestCrossValidationMacroWithStratification() crossValidate.Inputs.Data = nop.Data; crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; var crossValidateOutput = experiment.Add(crossValidate); - experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex("AUC", out int metricCol); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol)) + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) { var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. b = cursor.MoveNext(); Assert.True(b); + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(0.00485, stdev, 5); + + double sum = 0; double val = 0; - getter(ref val); - Assert.Equal(0.99, val, 2); + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + } + Assert.Equal(avg, sum / 2); b = cursor.MoveNext(); Assert.False(b); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 0ef21b3ef8..34b9317176 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2366,9 +2366,9 @@ public void EntryPointChainedCrossValMacros() model = runner.GetOutput("model2"); Assert.NotNull(model[0]); - var metrics = runner.GetOutput("OverallMetrics"); - Assert.NotNull(metrics[0]); - using (var cursor = metrics[0].GetRowCursor(col => true)) + var metrics = runner.GetOutput("OverallMetrics"); + Assert.NotNull(metrics); + using (var cursor = metrics.GetRowCursor(col => true)) { Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); var aucGetter = cursor.GetGetter(aucCol); @@ -2378,9 +2378,9 @@ public void EntryPointChainedCrossValMacros() Assert.True(auc > 0.99); } - metrics = runner.GetOutput("OverallMetrics2"); - Assert.NotNull(metrics[0]); - using (var cursor = metrics[0].GetRowCursor(col => true)) + metrics = runner.GetOutput("OverallMetrics2"); + Assert.NotNull(metrics); + using (var cursor = metrics.GetRowCursor(col => true)) { Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); var aucGetter = cursor.GetGetter(aucCol); diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index bf7bf57052..f0042014ba 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -5,6 +5,7 @@ +