Skip to content

Commit

Permalink
User input column type validation (dotnet#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored and Dmitry-A committed Aug 22, 2019
1 parent bf4ece8 commit 4e15684
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 17 deletions.
72 changes: 55 additions & 17 deletions src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ namespace Microsoft.ML.Auto
{
internal static class UserInputValidationUtil
{
// column purpose names
private const string LabelColumnPurposeName = "label";
private const string WeightColumnPurposeName = "weight";
private const string NumericColumnPurposeName = "numeric";
private const string CategoricalColumnPurposeName = "categorical";
private const string TextColumnPurposeName = "text";
private const string IgnoredColumnPurposeName = "ignored";

public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
IDataView validationData)
{
Expand Down Expand Up @@ -55,22 +63,25 @@ private static void ValidateTrainData(IDataView trainData)
private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation)
{
ValidateColumnInformation(columnInformation);
ValidateTrainDataColumnExists(trainData, columnInformation.LabelColumn);
ValidateTrainDataColumnExists(trainData, columnInformation.WeightColumn);
ValidateTrainDataColumnsExist(trainData, columnInformation.CategoricalColumns);
ValidateTrainDataColumnsExist(trainData, columnInformation.NumericColumns);
ValidateTrainDataColumnsExist(trainData, columnInformation.TextColumns);
ValidateTrainDataColumnsExist(trainData, columnInformation.IgnoredColumns);
ValidateTrainDataColumn(trainData, columnInformation.LabelColumn, LabelColumnPurposeName);
ValidateTrainDataColumn(trainData, columnInformation.WeightColumn, WeightColumnPurposeName);
ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumns, CategoricalColumnPurposeName,
new DataViewType[] { NumberDataViewType.Single, TextDataViewType.Instance });
ValidateTrainDataColumns(trainData, columnInformation.NumericColumns, NumericColumnPurposeName,
new DataViewType[] { NumberDataViewType.Single, BooleanDataViewType.Instance });
ValidateTrainDataColumns(trainData, columnInformation.TextColumns, TextColumnPurposeName,
new DataViewType[] { TextDataViewType.Instance });
ValidateTrainDataColumns(trainData, columnInformation.IgnoredColumns, IgnoredColumnPurposeName);
}

private static void ValidateColumnInformation(ColumnInformation columnInformation)
{
ValidateLabelColumn(columnInformation.LabelColumn);

ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumns, "categorical");
ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumns, "numeric");
ValidateColumnInfoEnumerationProperty(columnInformation.TextColumns, "text");
ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumns, "ignored");
ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumns, CategoricalColumnPurposeName);
ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumns, NumericColumnPurposeName);
ValidateColumnInfoEnumerationProperty(columnInformation.TextColumns, TextColumnPurposeName);
ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumns, IgnoredColumnPurposeName);

// keep a list of all columns, to detect duplicates
var allColumns = new List<string>();
Expand All @@ -88,11 +99,11 @@ private static void ValidateColumnInformation(ColumnInformation columnInformatio
}
}

private static void ValidateColumnInfoEnumerationProperty(IEnumerable<string> columns, string propertyName)
private static void ValidateColumnInfoEnumerationProperty(IEnumerable<string> columns, string columnPurpose)
{
if (columns?.Contains(null) == true)
{
throw new ArgumentException($"Null column string was specified as {propertyName} in column information");
throw new ArgumentException($"Null column string was specified as {columnPurpose} in column information");
}
}

Expand Down Expand Up @@ -155,7 +166,8 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
}
}

private static void ValidateTrainDataColumnsExist(IDataView trainData, IEnumerable<string> columnNames)
private static void ValidateTrainDataColumns(IDataView trainData, IEnumerable<string> columnNames, string columnPurpose,
IEnumerable<DataViewType> allowedTypes = null)
{
if (columnNames == null)
{
Expand All @@ -164,15 +176,41 @@ private static void ValidateTrainDataColumnsExist(IDataView trainData, IEnumerab

foreach (var columnName in columnNames)
{
ValidateTrainDataColumnExists(trainData, columnName);
ValidateTrainDataColumn(trainData, columnName, columnPurpose, allowedTypes);
}
}

private static void ValidateTrainDataColumnExists(IDataView trainData, string columnName)
private static void ValidateTrainDataColumn(IDataView trainData, string columnName, string columnPurpose, IEnumerable<DataViewType> allowedTypes = null)
{
if (columnName != null && trainData.Schema.GetColumnOrNull(columnName) == null)
if (columnName == null)
{
return;
}

var nullableColumn = trainData.Schema.GetColumnOrNull(columnName);
if (nullableColumn == null)
{
throw new ArgumentException($"Provided {columnPurpose} column {columnName} '{columnName}' not found in training data.");
}

if(allowedTypes == null)
{
throw new ArgumentException($"Provided column '{columnName}' not found in training data.");
return;
}
var column = nullableColumn.Value;
var itemType = column.Type.GetItemType();
if (!allowedTypes.Contains(itemType))
{
if (allowedTypes.Count() == 1)
{
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " +
$"but only type {allowedTypes.First()} is allowed.");
}
else
{
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " +
$"but only types {string.Join(", ", allowedTypes)} are allowed.");
}
}
}

Expand Down
16 changes: 16 additions & 0 deletions src/Test/UserInputValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,21 @@ public void ValidateFeaturesColInvalidType()
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null);
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateTextColumnNotText()
{
const string TextPurposeColName = "TextColumn";
var schemaBuilder = new SchemaBuilder();
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
schemaBuilder.AddColumn(TextPurposeColName, NumberDataViewType.Double);
var schema = schemaBuilder.GetSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView,
new ColumnInformation() { TextColumns = new[] { TextPurposeColName } },
null);
}
}
}

0 comments on commit 4e15684

Please sign in to comment.