Skip to content

Commit

Permalink
InferColumns API: Validate all columns specified in column info exist…
Browse files Browse the repository at this point in the history
… in inferred data view (#3599)
  • Loading branch information
daholste authored Apr 26, 2019
1 parent fe77854 commit 4dc47aa
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
var textLoader = context.Data.CreateTextLoader(typedLoaderOptions);
var dataView = textLoader.Load(path);

// Validate all columns specified in column info exist in inferred data view
ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView);

var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo);

// start building result objects
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;

namespace Microsoft.ML.Auto
{
internal static class ColumnInferenceValidationUtil
{
/// <summary>
/// Validate all columns specified in column info exist in inferred data view.
/// </summary>
public static void ValidateSpecifiedColumnsExist(ColumnInformation columnInfo,
IDataView dataView)
{
var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo);
foreach (var columnName in columnNames)
{
if (dataView.Schema.GetColumnOrNull(columnName) == null)
{
throw new ArgumentException($"Specified column {columnName} " +
$"is not found in the dataset.");
}
}
}
}
}
33 changes: 33 additions & 0 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,38 @@ public static ColumnInformation BuildColumnInfo(IEnumerable<DatasetColumnInfo> c
{
return BuildColumnInfo(columns.Select(c => (c.Name, c.Purpose)));
}

/// <summary>
/// Get all column names that are in <paramref name="columnInformation"/>.
/// </summary>
/// <param name="columnInformation">Column information.</param>
public static IEnumerable<string> GetColumnNames(ColumnInformation columnInformation)
{
var columnNames = new List<string>();
AddStringToListIfNotNull(columnNames, columnInformation.LabelColumnName);
AddStringToListIfNotNull(columnNames, columnInformation.ExampleWeightColumnName);
AddStringToListIfNotNull(columnNames, columnInformation.SamplingKeyColumnName);
AddStringsToListIfNotNull(columnNames, columnInformation.CategoricalColumnNames);
AddStringsToListIfNotNull(columnNames, columnInformation.IgnoredColumnNames);
AddStringsToListIfNotNull(columnNames, columnInformation.NumericColumnNames);
AddStringsToListIfNotNull(columnNames, columnInformation.TextColumnNames);
return columnNames;
}

private static void AddStringsToListIfNotNull(List<string> list, IEnumerable<string> strings)
{
foreach (var str in strings)
{
AddStringToListIfNotNull(list, str);
}
}

private static void AddStringToListIfNotNull(List<string> list, string str)
{
if (str != null)
{
list.Add(str);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.ML.Auto.Test
{
[TestClass]
public class ColumnInferenceValidationUtilTests
{
[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateColumnNotContainedInData()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
var columnInfo = new ColumnInformation();
columnInfo.CategoricalColumnNames.Add("Categorical");
ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView);
}
}
}
23 changes: 22 additions & 1 deletion test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.ML.Auto.Test
Expand Down Expand Up @@ -32,5 +33,25 @@ public void GetColumnPurpose()
Assert.AreEqual(ColumnPurpose.Ignore, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Ignored"));
Assert.AreEqual(null, ColumnInformationUtil.GetColumnPurpose(columnInfo, "NonExistent"));
}

[TestMethod]
public void GetColumnNamesTest()
{
var columnInfo = new ColumnInformation()
{
LabelColumnName = "Label",
SamplingKeyColumnName = "SamplingKey",
};
columnInfo.CategoricalColumnNames.Add("Cat1");
columnInfo.CategoricalColumnNames.Add("Cat2");
columnInfo.NumericColumnNames.Add("Num");
var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo);
Assert.AreEqual(5, columnNames.Count());
Assert.IsTrue(columnNames.Contains("Label"));
Assert.IsTrue(columnNames.Contains("SamplingKey"));
Assert.IsTrue(columnNames.Contains("Cat1"));
Assert.IsTrue(columnNames.Contains("Cat2"));
Assert.IsTrue(columnNames.Contains("Num"));
}
}
}
}

0 comments on commit 4dc47aa

Please sign in to comment.