From 2b55653e68a26946cedbc855bf7eb5d5fe3138de Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed <38438266+zeahmed@users.noreply.github.com> Date: Fri, 11 May 2018 11:09:40 -0700 Subject: [PATCH] Fixed exception: "InvalidOperationException: Source column 'Label' is required but not found." (#121) * Checking for both ColumnAttribute and ColumnNameAttribute when creating schema in CreateBatchPredictionEngine. * Addressed reviewers' comments. --- src/Microsoft.ML.Api/SchemaDefinition.cs | 5 +- ...PlantClassificationWithStringLabelTests.cs | 136 ++++++++++++++++++ 2 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index c93403b59eb..5f847126250 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -332,8 +332,9 @@ public static SchemaDefinition Create(Type userType) if (fieldInfo.GetCustomAttribute() != null) continue; - var mappingAttr = fieldInfo.GetCustomAttribute(); - var name = mappingAttr == null ? fieldInfo.Name : (mappingAttr.Name ?? fieldInfo.Name); + var mappingAttr = fieldInfo.GetCustomAttribute(); + var mappingNameAttr = fieldInfo.GetCustomAttribute(); + string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name; // Disallow duplicate names, because the field enumeration order is not actually // well defined, so we are not gauranteed to have consistent "hiding" from run to // run, across different .NET versions. diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs new file mode 100644 index 00000000000..79cc2fc1377 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Models; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + [Fact] + public void TrainAndPredictIrisModelWithStringLabelTest() + { + string dataPath = GetDataPath("iris.data"); + + var pipeline = new LearningPipeline(); + + pipeline.Add(new TextLoader(dataPath, useHeader: false, separator: ",")); + + pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field. + + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + + PredictionModel model = pipeline.Train(); + + IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + Assert.Equal(1, prediction.PredictedLabels[0], 2); + Assert.Equal(0, prediction.PredictedLabels[1], 2); + Assert.Equal(0, prediction.PredictedLabels[2], 2); + + prediction = model.Predict(new IrisDataWithStringLabel() + { + SepalLength = 3.1f, + SepalWidth = 5.5f, + PetalLength = 2.2f, + PetalWidth = 6.4f, + }); + + Assert.Equal(0, prediction.PredictedLabels[0], 2); + Assert.Equal(0, prediction.PredictedLabels[1], 2); + Assert.Equal(1, prediction.PredictedLabels[2], 2); + + prediction = model.Predict(new IrisDataWithStringLabel() + { + SepalLength = 3.1f, + SepalWidth = 2.5f, + PetalLength = 1.2f, + PetalWidth = 4.4f, + }); + + Assert.Equal(.2, prediction.PredictedLabels[0], 1); + Assert.Equal(.8, prediction.PredictedLabels[1], 1); + Assert.Equal(0, prediction.PredictedLabels[2], 2); + + // Note: Testing against the same data set as a simple way to test evaluation. + // This isn't appropriate in real-world scenarios. + string testDataPath = GetDataPath("iris.data"); + var testData = new TextLoader(testDataPath, useHeader: false, separator: ","); + + var evaluator = new ClassificationEvaluator(); + evaluator.OutputTopKAcc = 3; + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + + Assert.Equal(.98, metrics.AccuracyMacro); + Assert.Equal(.98, metrics.AccuracyMicro, 2); + Assert.Equal(.06, metrics.LogLoss, 2); + Assert.InRange(metrics.LogLossReduction, 94, 96); + Assert.Equal(1, metrics.TopKAccuracy); + + Assert.Equal(3, metrics.PerClassLogLoss.Length); + Assert.Equal(0, metrics.PerClassLogLoss[0], 1); + Assert.Equal(.1, metrics.PerClassLogLoss[1], 1); + Assert.Equal(.1, metrics.PerClassLogLoss[2], 1); + + ConfusionMatrix matrix = metrics.ConfusionMatrix; + Assert.Equal(3, matrix.Order); + Assert.Equal(3, matrix.ClassNames.Count); + Assert.Equal("Iris-setosa", matrix.ClassNames[0]); + Assert.Equal("Iris-versicolor", matrix.ClassNames[1]); + Assert.Equal("Iris-virginica", matrix.ClassNames[2]); + + Assert.Equal(50, matrix[0, 0]); + Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]); + Assert.Equal(0, matrix[0, 1]); + Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]); + Assert.Equal(0, matrix[0, 2]); + Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]); + + Assert.Equal(0, matrix[1, 0]); + Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]); + Assert.Equal(48, matrix[1, 1]); + Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]); + Assert.Equal(2, matrix[1, 2]); + Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]); + + Assert.Equal(0, matrix[2, 0]); + Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]); + Assert.Equal(1, matrix[2, 1]); + Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]); + Assert.Equal(49, matrix[2, 2]); + Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]); + } + + public class IrisDataWithStringLabel + { + [Column("0")] + public float PetalWidth; + + [Column("1")] + public float SepalLength; + + [Column("2")] + public float SepalWidth; + + [Column("3")] + public float PetalLength; + + [Column("4", name: "Label")] + public string IrisPlantType; + } + } +}