Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EvaluatorUtils to handle label column of type key without text key values #394

Merged
merged 11 commits into from
Jun 27, 2018
Merged
111 changes: 83 additions & 28 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -552,25 +552,27 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
}
}

private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
out int[] indices, out Dictionary<DvText, int> reconciledKeyNames)
private static int[][] MapKeys<T>(ISchema[] schemas, string columnName, bool isVec,
int[] indices, Dictionary<DvText, int> reconciledKeyNames)
{
Contracts.AssertValue(indices);
Contracts.AssertValue(reconciledKeyNames);

var dvCount = schemas.Length;
var keyValueMappers = new int[dvCount][];
var keyNamesCur = default(VBuffer<DvText>);
indices = new int[dvCount];
reconciledKeyNames = new Dictionary<DvText, int>();
var keyNamesCur = default(VBuffer<T>);
for (int i = 0; i < dvCount; i++)
{
var schema = schemas[i];
if (!schema.TryGetColumnIndex(columnName, out indices[i]))
throw Contracts.Except($"Schema number {i} does not contain column '{columnName}'");

var type = schema.GetColumnType(indices[i]);
var keyValueType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, indices[i]);
if (type.IsVector != isVec)
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type");
if (!schema.HasKeyNames(indices[i], type.ItemType.KeyCount))
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have text key values");
if (keyValueType == null || keyValueType.ItemType.RawType != typeof(T))
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type of key values");
if (!type.ItemType.IsKey || type.ItemType.RawKind != DataKind.U4)
throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{type.ItemType}'");

Expand All @@ -580,7 +582,7 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
foreach (var kvp in keyNamesCur.Items(true))
{
var key = kvp.Key;
var name = kvp.Value;
var name = new DvText(kvp.Value.ToString());
if (!reconciledKeyNames.ContainsKey(name))
reconciledKeyNames[name] = reconciledKeyNames.Count;
keyValueMappers[i][key] = reconciledKeyNames[name];
Expand All @@ -595,17 +597,18 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
/// corresponding to the key value in the original column.
/// </summary>
public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));

var dvCount = views.Length;

Dictionary<DvText, int> keyNames;
int[] indices;
// Create mappings from the original key types to the reconciled key type.
var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, false, out indices, out keyNames);
var indices = new int[dvCount];
var keyNames = new Dictionary<DvText, int>();
// We use MarshalInvoke so that we can call MapKeys with the correct generic: keyValueType.RawType.
var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, false, indices, keyNames);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.MarshalInvoke [](start = 34, length = 19)

can you add comment what you use MarshalInvoke to pass keyValueType.RawType as generic type to function?

var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
var keyNamesVBuffer = new VBuffer<DvText>(keyNames.Count, keyNames.Keys.ToArray());
ValueGetter<VBuffer<DvText>> keyValueGetter =
Expand All @@ -629,20 +632,51 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s
}
}

/// <summary>
/// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
/// First, we find the union set of the key values in the different data views. Next we define a new key column for each
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
/// corresponding to the key value in the original column.
/// </summary>
public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView[] views, string columnName, int keyCount)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));

var keyType = new KeyType(DataKind.U4, 0, keyCount);

// For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
for (int i = 0; i < views.Length; i++)
{
if (!views[i].Schema.TryGetColumnIndex(columnName, out var index))
throw env.Except($"Data view {i} doesn't contain a column '{columnName}'");
ValueMapper<uint, uint> mapper =
(ref uint src, ref uint dst) =>
{
if (src == 0 || src > keyCount)
dst = 0;
else
dst = src + 1;
};
views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
views[i].Schema.GetColumnType(index), keyType, mapper);
}
}

/// <summary>
/// This method is similar to <see cref="ReconcileKeyValues"/>, but it reconciles the key values over vector
/// input columns.
/// </summary>
public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));

var dvCount = views.Length;

Dictionary<DvText, int> keyNames;
int[] columnIndices;
var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, true, out columnIndices, out keyNames);
var keyNames = new Dictionary<DvText, int>();
var columnIndices = new int[dvCount];
var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, true, columnIndices, keyNames);
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
var keyNamesVBuffer = new VBuffer<DvText>(keyNames.Count, keyNames.Keys.ToArray());
ValueGetter<VBuffer<DvText>> keyValueGetter =
Expand Down Expand Up @@ -736,7 +770,7 @@ public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env,
var foldDataViews = perInstance.Select(getPerInstance).ToArray();
if (collate)
{
var combined = AppendPerInstanceDataViews(env, foldDataViews, out variableSizeVectorColumnNames);
var combined = AppendPerInstanceDataViews(env, perInstance[0].Schema.Label?.Name, foldDataViews, out variableSizeVectorColumnNames);
return new[] { combined };
}
else
Expand Down Expand Up @@ -767,7 +801,8 @@ public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataVie
return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray());
}

private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string labelColName,
IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
{
Contracts.AssertValue(env);
env.AssertValue(foldDataViews);
Expand All @@ -776,7 +811,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
// This is a dictionary from the column name to its vector size.
var vectorSizes = new Dictionary<string, int>();
var firstDvSlotNames = new Dictionary<string, VBuffer<DvText>>();
var firstDvKeyColumns = new List<string>();
ColumnType labelColKeyValuesType = null;
var firstDvKeyWithNamesColumns = new List<string>();
var firstDvKeyNoNamesColumns = new Dictionary<string, int>();
var firstDvVectorKeyColumns = new List<string>();
var variableSizeVectorColumnNamesList = new List<string>();
var list = new List<IDataView>();
Expand Down Expand Up @@ -822,10 +859,20 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
else
vectorSizes.Add(name, type.VectorSize);
}
else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
else if (dvNumber == 0 && name == labelColName)
{
// The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
firstDvKeyColumns.Add(name);
labelColKeyValuesType = dv.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i);
}
else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
firstDvKeyWithNamesColumns.Add(name);
else if (type.KeyCount > 0 && name != labelColName)
{
// For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4.
if (!firstDvKeyNoNamesColumns.ContainsKey(name))
firstDvKeyNoNamesColumns[name] = type.KeyCount;
if (firstDvKeyNoNamesColumns[name] < type.KeyCount)
firstDvKeyNoNamesColumns[name] = type.KeyCount;
}
}
var idv = dv;
Expand All @@ -839,26 +886,34 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
list.Add(idv);
dvNumber++;
}

variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray();
if (variableSizeVectorColumnNamesList.Count == 0 && firstDvKeyColumns.Count == 0)
return AppendRowsDataView.Create(env, null, list.ToArray());

var views = list.ToArray();
foreach (var keyCol in firstDvKeyColumns)
ReconcileKeyValues(env, views, keyCol);
foreach (var keyCol in firstDvKeyWithNamesColumns)
ReconcileKeyValues(env, views, keyCol, TextType.Instance);
if (labelColKeyValuesType != null)
ReconcileKeyValues(env, views, labelColName, labelColKeyValuesType.ItemType);
foreach (var keyCol in firstDvKeyNoNamesColumns)
ReconcileKeyValuesWithNoNames(env, views, keyCol.Key, keyCol.Value);
foreach (var vectorKeyCol in firstDvVectorKeyColumns)
ReconcileVectorKeyValues(env, views, vectorKeyCol);
ReconcileVectorKeyValues(env, views, vectorKeyCol, TextType.Instance);

Func<IDataView, int, IDataView> keyToValue =
(idv, i) =>
{
foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns))
foreach (var keyCol in firstDvVectorKeyColumns.Prepend(labelColName))
{
if (keyCol == labelColName && labelColKeyValuesType == null)
continue;
idv = new KeyToValueTransform(env, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv);
var hidden = FindHiddenColumns(idv.Schema, keyCol);
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
}
foreach (var keyCol in firstDvKeyNoNamesColumns)
{
var hidden = FindHiddenColumns(idv.Schema, keyCol.Key);
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
}
return idv;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
maml.exe CV tr=FastRankRanking{t=1} strat=Strat threads=- norm=Warn prexf=rangefilter{col=Label min=20 max=25} prexf=term{col=Strat:Label} dout=%Output% loader=text{col=Features:R4:10-14 col=Label:R4:9 col=GroupId:TX:1} data=%Data% out=%Output% xf=term{col=Label} xf=hash{col=GroupId}
Bad value at line 9 in column Label
Processed 501 rows with 1 bad values and 0 format errors
Bad value at line 9 in column Label
Processed 501 rows with 1 bad values and 0 format errors
Not adding a normalizer.
Making per-feature arrays
Changing data from row-wise to column-wise
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Processed 40 instances
Binning and forming Feature objects
Reserved memory for tree learner: 10764 bytes
Starting to train ...
Not training a calibrator because it is not needed.
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Bad value at line 9 in column Label
Processed 501 rows with 1 bad values and 0 format errors
Not adding a normalizer.
Making per-feature arrays
Changing data from row-wise to column-wise
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Processed 32 instances
Binning and forming Feature objects
Reserved memory for tree learner: 6396 bytes
Starting to train ...
Not training a calibrator because it is not needed.
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
NDCG@1: 0.000000
NDCG@2: 0.000000
NDCG@3: 0.000000
DCG@1: 0.000000
DCG@2: 0.000000
DCG@3: 0.000000
NDCG@1: 0.000000
NDCG@2: 0.000000
NDCG@3: 0.000000
DCG@1: 0.000000
DCG@2: 0.000000
DCG@3: 0.000000

OVERALL RESULTS
---------------------------------------
NDCG@1: 0.000000 (0.0000)
NDCG@2: 0.000000 (0.0000)
NDCG@3: 0.000000 (0.0000)
DCG@1: 0.000000 (0.0000)
DCG@2: 0.000000 (0.0000)
DCG@3: 0.000000 (0.0000)

---------------------------------------
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Bad value at line 9 in column Features at slot 0
Bad value at line 9 in column Features at slot 1
Bad value at line 9 in column Features at slot 2
Bad value at line 9 in column Features at slot 3
Bad value at line 9 in column Features at slot 4
Bad value at line 9 in column Label
Processed 501 rows with 6 bad values and 0 format errors
Physical memory usage(MB): %Number%
Virtual memory usage(MB): %Number%
%DateTime% Time elapsed(s): %Number%

Loading