diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv
index 7fc82434b4..da227b16ff 100644
--- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv
+++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv
@@ -3,6 +3,7 @@ Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runt
Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData TextLoader Microsoft.ML.Runtime.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output
+Data.TransformModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelOutput
Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Runtime.Data.Evaluate AnomalyDetection Microsoft.ML.Runtime.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput
Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate Binary Microsoft.ML.Runtime.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput
Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output]
diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json
index d8b3fdccec..ab2771065c 100644
--- a/ZBaselines/Common/EntryPoints/core_manifest.json
+++ b/ZBaselines/Common/EntryPoints/core_manifest.json
@@ -469,6 +469,35 @@
"ILearningPipelineLoader"
]
},
+ {
+ "Name": "Data.TransformModelArrayConverter",
+ "Desc": "Create and array variable",
+ "FriendlyName": null,
+ "ShortName": null,
+ "Inputs": [
+ {
+ "Name": "TransformModel",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "TransformModel"
+ },
+ "Desc": "The models",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputModel",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "TransformModel"
+ },
+ "Desc": "The model array"
+ }
+ ]
+ },
{
"Name": "Models.AnomalyDetectionEvaluator",
"Desc": "Evaluates an anomaly detection scored dataset.",
@@ -1411,9 +1440,28 @@
"Name": "Model",
"Type": "PredictorModel",
"Desc": "The model",
- "Required": true,
+ "Required": false,
"SortOrder": 1.0,
- "IsNullable": false
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "TransformModel",
+ "Type": "TransformModel",
+ "Desc": "The transform model",
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "UseTransformModel",
+ "Type": "Bool",
+ "Desc": "Indicates to use transform model instead of predictor model.",
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": false
}
]
},
@@ -1476,6 +1524,14 @@
},
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
},
+ {
+ "Name": "TransformModel",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "TransformModel"
+ },
+ "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
+ },
{
"Name": "Warnings",
"Type": "DataView",
@@ -3002,9 +3058,28 @@
"Name": "Model",
"Type": "PredictorModel",
"Desc": "The model",
- "Required": true,
+ "Required": false,
"SortOrder": 1.0,
- "IsNullable": false
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "TransformModel",
+ "Type": "TransformModel",
+ "Desc": "Transform model",
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "UseTransformModel",
+ "Type": "Bool",
+ "Desc": "Indicates to use transform model instead of predictor model.",
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": false
}
]
},
@@ -3058,6 +3133,11 @@
"Type": "PredictorModel",
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
},
+ {
+ "Name": "TransformModel",
+ "Type": "TransformModel",
+ "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
+ },
{
"Name": "Warnings",
"Type": "DataView",
diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs
index ec249ce768..ccb65d43ab 100644
--- a/src/Microsoft.ML.Core/Data/ITransformModel.cs
+++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs
@@ -25,8 +25,6 @@ public interface ITransformModel
///
ISchema InputSchema { get; }
- IDataView Data { get; }
-
///
/// Apply the transform(s) in the model to the given input data.
///
diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
index acbce34b24..b840529e77 100644
--- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
@@ -44,11 +44,6 @@ public ISchema InputSchema
get { return _schemaRoot; }
}
- public IDataView Data
- {
- get { return _chain; }
- }
-
///
/// Create a TransformModel containing the transforms from "result" back to "input".
///
diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs
index 8ef1ae4206..10c078a891 100644
--- a/src/Microsoft.ML/CSharpApi.cs
+++ b/src/Microsoft.ML/CSharpApi.cs
@@ -11582,11 +11582,6 @@ public sealed class Output
///
public Var OutputModel { get; set; } = new Var();
- ///
- /// Data
- ///
- public Var Data { get; set; } = new Var();
-
}
}
}
diff --git a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs
index e0a4eae826..b916a9ab57 100644
--- a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs
+++ b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Transforms;
+using System.Collections.Generic;
namespace Microsoft.ML.Models
{
@@ -23,7 +24,7 @@ public sealed partial class BinaryClassificationEvaluator
///
/// A BinaryClassificationMetrics instance that describes how well the model performed against the test data.
///
- public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
+ public List Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
{
diff --git a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs
index aa3a94f3a9..fe95eac312 100644
--- a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs
+++ b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System;
+using System.Collections.Generic;
namespace Microsoft.ML.Models
{
@@ -18,7 +19,7 @@ private BinaryClassificationMetrics()
{
}
- internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
+ internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
{
Contracts.AssertValue(env);
env.AssertValue(overallMetrics);
@@ -31,28 +32,37 @@ internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, ID
throw env.Except("The overall RegressionMetrics didn't have any rows.");
}
- SerializationClass metrics = enumerator.Current;
-
- if (enumerator.MoveNext())
- {
- throw env.Except("The overall RegressionMetrics contained more than 1 row.");
- }
-
- return new BinaryClassificationMetrics()
+ List metrics = new List();
+ var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
+ do
{
- Auc = metrics.Auc,
- Accuracy = metrics.Accuracy,
- PositivePrecision = metrics.PositivePrecision,
- PositiveRecall = metrics.PositiveRecall,
- NegativePrecision = metrics.NegativePrecision,
- NegativeRecall = metrics.NegativeRecall,
- LogLoss = metrics.LogLoss,
- LogLossReduction = metrics.LogLossReduction,
- Entropy = metrics.Entropy,
- F1Score = metrics.F1Score,
- Auprc = metrics.Auprc,
- ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix),
- };
+ SerializationClass metric = enumerator.Current;
+
+ if (!confusionMatrices.MoveNext())
+ {
+ throw env.Except("Confusion matrices didn't have enough matrices.");
+ }
+
+ metrics.Add(
+ new BinaryClassificationMetrics()
+ {
+ Auc = metric.Auc,
+ Accuracy = metric.Accuracy,
+ PositivePrecision = metric.PositivePrecision,
+ PositiveRecall = metric.PositiveRecall,
+ NegativePrecision = metric.NegativePrecision,
+ NegativeRecall = metric.NegativeRecall,
+ LogLoss = metric.LogLoss,
+ LogLossReduction = metric.LogLossReduction,
+ Entropy = metric.Entropy,
+ F1Score = metric.F1Score,
+ Auprc = metric.Auprc,
+ ConfusionMatrix = confusionMatrices.Current,
+ });
+
+ } while (enumerator.MoveNext());
+
+ return metrics;
}
///
diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs
index c8bec8642f..33799c0c78 100644
--- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs
+++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs
@@ -5,6 +5,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Transforms;
+using System.Collections.Generic;
namespace Microsoft.ML.Models
{
@@ -23,7 +24,7 @@ public sealed partial class ClassificationEvaluator
///
/// A ClassificationMetrics instance that describes how well the model performed against the test data.
///
- public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
+ public List Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
{
diff --git a/src/Microsoft.ML/Models/ClassificationMetrics.cs b/src/Microsoft.ML/Models/ClassificationMetrics.cs
index 81c0f91d7b..0fbbba602e 100644
--- a/src/Microsoft.ML/Models/ClassificationMetrics.cs
+++ b/src/Microsoft.ML/Models/ClassificationMetrics.cs
@@ -5,6 +5,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
+using System.Collections.Generic;
namespace Microsoft.ML.Models
{
@@ -17,7 +18,7 @@ private ClassificationMetrics()
{
}
- internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
+ internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
{
Contracts.AssertValue(env);
env.AssertValue(overallMetrics);
@@ -29,24 +30,32 @@ internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataVie
{
throw env.Except("The overall RegressionMetrics didn't have any rows.");
}
-
- SerializationClass metrics = enumerator.Current;
-
- if (enumerator.MoveNext())
- {
- throw env.Except("The overall RegressionMetrics contained more than 1 row.");
- }
-
- return new ClassificationMetrics()
+
+ List metrics = new List();
+ var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
+ do
{
- AccuracyMicro = metrics.AccuracyMicro,
- AccuracyMacro = metrics.AccuracyMacro,
- LogLoss = metrics.LogLoss,
- LogLossReduction = metrics.LogLossReduction,
- TopKAccuracy = metrics.TopKAccuracy,
- PerClassLogLoss = metrics.PerClassLogLoss,
- ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix)
- };
+ if (!confusionMatrices.MoveNext())
+ {
+ throw env.Except("Confusion matrices didn't have enough matrices.");
+ }
+
+ SerializationClass metric = enumerator.Current;
+ metrics.Add(
+ new ClassificationMetrics()
+ {
+ AccuracyMicro = metric.AccuracyMicro,
+ AccuracyMacro = metric.AccuracyMacro,
+ LogLoss = metric.LogLoss,
+ LogLossReduction = metric.LogLossReduction,
+ TopKAccuracy = metric.TopKAccuracy,
+ PerClassLogLoss = metric.PerClassLogLoss,
+ ConfusionMatrix = confusionMatrices.Current
+ });
+
+ } while (enumerator.MoveNext());
+
+ return metrics;
}
///
diff --git a/src/Microsoft.ML/Models/ConfusionMatrix.cs b/src/Microsoft.ML/Models/ConfusionMatrix.cs
index 2040fc8331..72aa5061dc 100644
--- a/src/Microsoft.ML/Models/ConfusionMatrix.cs
+++ b/src/Microsoft.ML/Models/ConfusionMatrix.cs
@@ -41,7 +41,7 @@ private ConfusionMatrix(double[,] elements, string[] classNames)
});
}
- internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusionMatrix)
+ internal static List Create(IHostEnvironment env, IDataView confusionMatrix)
{
Contracts.AssertValue(env);
env.AssertValue(confusionMatrix);
@@ -51,18 +51,28 @@ internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusion
env.Except($"ConfusionMatrix data view did not contain a {nameof(MetricKinds.ColumnNames.Count)} column.");
}
+ IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn);
+ var slots = default(VBuffer);
+ confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots);
+ string[] classNames = new string[slots.Count];
+ for (int i = 0; i < slots.Count; i++)
+ {
+ classNames[i] = slots.Values[i].ToString();
+ }
+
ColumnType type = confusionMatrix.Schema.GetColumnType(countColumn);
env.Assert(type.IsVector);
-
- double[,] elements = new double[type.VectorSize, type.VectorSize];
-
- IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn);
ValueGetter> countGetter = cursor.GetGetter>(countColumn);
VBuffer countValues = default;
-
+ List confusionMatrices = new List();
+
int valuesRowIndex = 0;
+ double[,] elements = null;
while (cursor.MoveNext())
{
+ if(valuesRowIndex == 0)
+ elements = new double[type.VectorSize, type.VectorSize];
+
countGetter(ref countValues);
for (int i = 0; i < countValues.Length; i++)
{
@@ -70,17 +80,15 @@ internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusion
}
valuesRowIndex++;
- }
- var slots = default(VBuffer);
- confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots);
- string[] classNames = new string[slots.Count];
- for (int i = 0; i < slots.Count; i++)
- {
- classNames[i] = slots.Values[i].ToString();
+ if(valuesRowIndex == type.VectorSize)
+ {
+ valuesRowIndex = 0;
+ confusionMatrices.Add(new ConfusionMatrix(elements, classNames));
+ }
}
- return new ConfusionMatrix(elements, classNames);
+ return confusionMatrices;
}
///
diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs
index c6aa52fb32..2ef99993cb 100644
--- a/src/Microsoft.ML/Models/CrossValidator.cs
+++ b/src/Microsoft.ML/Models/CrossValidator.cs
@@ -11,7 +11,7 @@ namespace Microsoft.ML.Models
{
public sealed partial class CrossValidator
{
- public CrossValidationOutput[] CrossValidate(LearningPipeline pipeline)
+ public CrossValidationOutput CrossValidate(LearningPipeline pipeline)
where TInput : class
where TOutput : class, new()
{
@@ -109,31 +109,35 @@ public CrossValidationOutput[] CrossValidate(L
experiment.Run();
- CrossValidationOutput[] cvo = new CrossValidationOutput[NumFolds];
+ CrossValidationOutput cvo = new CrossValidationOutput();
+ cvo.PredictorModels = new PredictionModel[NumFolds];
for (int Index = 0; Index < NumFolds; Index++)
{
- cvo[Index] = new CrossValidationOutput();
if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer)
{
- cvo[Index].BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics(
+ cvo.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics(
environment,
- experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]),
- experiment.GetOutput(crossValidateOutput.ConfusionMatrix[Index]));
+ experiment.GetOutput(crossValidateOutput.OverallMetrics),
+ experiment.GetOutput(crossValidateOutput.ConfusionMatrix));
}
else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
{
- cvo[Index].ClassificationMetrics = ClassificationMetrics.FromMetrics(
+ cvo.ClassificationMetrics = ClassificationMetrics.FromMetrics(
environment,
- experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]),
- experiment.GetOutput(crossValidateOutput.ConfusionMatrix[Index]));
+ experiment.GetOutput(crossValidateOutput.OverallMetrics),
+ experiment.GetOutput(crossValidateOutput.ConfusionMatrix));
}
else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer)
{
- cvo[Index].RegressionMetrics = RegressionMetrics.FromOverallMetrics(
+ cvo.RegressionMetrics = RegressionMetrics.FromOverallMetrics(
environment,
- experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]));
+ experiment.GetOutput(crossValidateOutput.OverallMetrics));
+ }
+ else
+ {
+ //Implement metrics for ranking, clustering and anomaly detection.
}
ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel[Index]);
@@ -146,7 +150,7 @@ public CrossValidationOutput[] CrossValidate(L
predictor = environment.CreateBatchPredictionEngine(memoryStream);
- cvo[Index].PredictorModel = new PredictionModel(predictor, memoryStream);
+ cvo.PredictorModels[Index] = new PredictionModel(predictor, memoryStream);
}
}
@@ -159,11 +163,12 @@ public class CrossValidationOutput
where TInput : class
where TOutput : class, new()
{
- public BinaryClassificationMetrics BinaryClassificationMetrics;
- public ClassificationMetrics ClassificationMetrics;
- public RegressionMetrics RegressionMetrics;
- public PredictionModel PredictorModel;
-
- //REVIEW: Add warnings and per instance results.
+ public List BinaryClassificationMetrics;
+ public List ClassificationMetrics;
+ public List RegressionMetrics;
+ public PredictionModel[] PredictorModels;
+
+ //REVIEW: Add warnings and per instance results and implement
+ //metrics for ranking, clustering and anomaly detection.
}
}
diff --git a/src/Microsoft.ML/Models/RegressionEvaluator.cs b/src/Microsoft.ML/Models/RegressionEvaluator.cs
index 8c2daa53f0..c55f4f3335 100644
--- a/src/Microsoft.ML/Models/RegressionEvaluator.cs
+++ b/src/Microsoft.ML/Models/RegressionEvaluator.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Transforms;
+using System.Collections.Generic;
namespace Microsoft.ML.Models
{
@@ -23,7 +24,7 @@ public sealed partial class RegressionEvaluator
///
/// A RegressionMetrics instance that describes how well the model performed against the test data.
///
- public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
+ public List Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
{
diff --git a/src/Microsoft.ML/Models/RegressionMetrics.cs b/src/Microsoft.ML/Models/RegressionMetrics.cs
index f5a5122242..43eabc3bec 100644
--- a/src/Microsoft.ML/Models/RegressionMetrics.cs
+++ b/src/Microsoft.ML/Models/RegressionMetrics.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System;
+using System.Collections.Generic;
namespace Microsoft.ML.Models
{
@@ -18,7 +19,7 @@ private RegressionMetrics()
{
}
- internal static RegressionMetrics FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
+ internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
{
Contracts.AssertValue(env);
env.AssertValue(overallMetrics);
@@ -30,21 +31,22 @@ internal static RegressionMetrics FromOverallMetrics(IHostEnvironment env, IData
throw env.Except("The overall RegressionMetrics didn't have any rows.");
}
- SerializationClass metrics = enumerator.Current;
-
- if (enumerator.MoveNext())
- {
- throw env.Except("The overall RegressionMetrics contained more than 1 row.");
- }
-
- return new RegressionMetrics()
+ List metrics = new List();
+ do
{
- L1 = metrics.L1,
- L2 = metrics.L2,
- Rms = metrics.Rms,
- LossFn = metrics.LossFn,
- RSquared = metrics.RSquared,
- };
+ SerializationClass metric = enumerator.Current;
+ metrics.Add(new RegressionMetrics()
+ {
+ L1 = metric.L1,
+ L2 = metric.L2,
+ Rms = metric.Rms,
+ LossFn = metric.LossFn,
+ RSquared = metric.RSquared,
+ });
+
+ } while (enumerator.MoveNext());
+
+ return metrics;
}
///
diff --git a/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs b/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs
index 9f7cbb727b..fa34cfd7ac 100644
--- a/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs
+++ b/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs
@@ -23,9 +23,6 @@ public sealed class CombineTransformModelsOutput
{
[TlcModule.Output(Desc = "Combined model", SortOrder = 1)]
public ITransformModel OutputModel;
-
- [TlcModule.Output(Desc = "Data", SortOrder = 2)]
- public IDataView Data;
}
public sealed class PredictorModelInput
@@ -92,7 +89,7 @@ public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironme
for (int i = input.Models.Length - 2; i >= 0; i--)
model = model.Apply(env, input.Models[i]);
- return new CombineTransformModelsOutput { OutputModel = model, Data = model.Data };
+ return new CombineTransformModelsOutput { OutputModel = model };
}
[TlcModule.EntryPoint(Name = "Transforms.ManyHeterogeneousModelCombiner", Desc = "Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel.")]
diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
index adfa42e50d..e31f6311cd 100644
--- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
+++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
@@ -53,7 +53,7 @@ public void Setup()
var testData = new TextLoader(s_dataPath).CreateFrom(useHeader: true);
var evaluator = new ClassificationEvaluator();
- s_metrics = evaluator.Evaluate(s_trainedModel, testData);
+ s_metrics = evaluator.Evaluate(s_trainedModel, testData).FirstOrDefault();
s_batches = new IrisData[s_batchSizes.Length][];
for (int i = 0; i < s_batches.Length; i++)
diff --git a/test/Microsoft.ML.Tests/CSharpCodeGen.cs b/test/Microsoft.ML.Tests/CSharpCodeGen.cs
index 316d7eab55..c647110702 100644
--- a/test/Microsoft.ML.Tests/CSharpCodeGen.cs
+++ b/test/Microsoft.ML.Tests/CSharpCodeGen.cs
@@ -15,8 +15,7 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output)
{
}
- //[Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")]
- [Fact]
+ [Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")]
public void GenerateCSharpAPI()
{
var cSharpAPIPath = Path.Combine(RootDir, @"src\\Microsoft.ML\\CSharpApi.cs");
diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs
index 31fc4fdd6d..85955b1c06 100644
--- a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs
@@ -8,6 +8,7 @@
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
+using System.Linq;
using Xunit;
using Xunit.Abstractions;
@@ -65,7 +66,7 @@ public void TrainAndPredictHousePriceModelTest()
var testData = new TextLoader(testDataPath).CreateFrom(useHeader: true, separator: ',');
var evaluator = new RegressionEvaluator();
- RegressionMetrics metrics = evaluator.Evaluate(model, testData);
+ RegressionMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault();
Assert.InRange(metrics.L1, 85_000, 89_000);
Assert.InRange(metrics.L2, 17_000_000_000, 19_000_000_000);
Assert.InRange(metrics.Rms, 130_500, 135_000);
diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
index 5dcbf3a588..cb6cad9548 100644
--- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
@@ -7,6 +7,7 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
+using System.Linq;
using Xunit;
namespace Microsoft.ML.Scenarios
@@ -71,7 +72,7 @@ public void TrainAndPredictIrisModelTest()
var evaluator = new ClassificationEvaluator();
evaluator.OutputTopKAcc = 3;
- ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
+ ClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault();
Assert.Equal(.98, metrics.AccuracyMacro);
Assert.Equal(.98, metrics.AccuracyMicro, 2);
diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
index ebddc33b03..10d23b62b7 100644
--- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
@@ -7,6 +7,7 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
+using System.Linq;
using Xunit;
namespace Microsoft.ML.Scenarios
@@ -74,7 +75,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
var evaluator = new ClassificationEvaluator();
evaluator.OutputTopKAcc = 3;
- ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
+ ClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); ;
Assert.Equal(.98, metrics.AccuracyMacro);
Assert.Equal(.98, metrics.AccuracyMicro, 2);
diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs
index a27a85488a..165d89f372 100644
--- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs
@@ -65,7 +65,7 @@ public void TrainAndPredictSentimentModelTest()
pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 });
pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" });
-
+ //var cv = new CrossValidator().CrossValidate(pipeline);
PredictionModel model = pipeline.Train();
IEnumerable sentiments = new[]
{
@@ -111,7 +111,7 @@ public void TrainAndPredictSentimentModelTest()
}
};
var evaluator = new BinaryClassificationEvaluator();
- BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData);
+ BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault();
Assert.Equal(.5556, metrics.Accuracy, 4);
Assert.Equal(.8, metrics.Auc, 1);