Skip to content

Commit

Permalink
validate AutoFit 'Features' column must be of type R4 (dotnet#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored and Dmitry-A committed Aug 22, 2019
1 parent 3ddd5c4 commit 355366f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.IO;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.Auto
{
Expand Down Expand Up @@ -45,6 +46,11 @@ private static void ValidateTrainData(IDataView trainData)
{
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
}

if (trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType() != NumberType.R4)
{
throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type Single", nameof(trainData));
}
}

private static void ValidateLabel(IDataView trainData, string label)
Expand Down Expand Up @@ -174,4 +180,4 @@ private static string FindFirstDuplicate(IEnumerable<string> values)
return groups.FirstOrDefault(g => g.Count() > 1)?.Key;
}
}
}
}
12 changes: 12 additions & 0 deletions src/Test/UserInputValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,17 @@ public void ValidateInferColsPath()
{
UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.DownloadUciAdultDataset());
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateFeaturesColInvalidType()
{
var schemaBuilder = new SchemaBuilder();
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberType.R8);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberType.R4);
var schema = schemaBuilder.GetSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateAutoFitArgs(dataView, DefaultColumnNames.Label, null, null, null);
}
}
}

0 comments on commit 355366f

Please sign in to comment.