Skip to content

Commit

Permalink
EvaluatorUtils to handle label column of type key without text key va…
Browse files Browse the repository at this point in the history
…lues (#394)

* Fix EvaluatorUtils to handle label column of type key without text key values.
  • Loading branch information
yaeldMS authored and shauheen committed Jun 27, 2018
1 parent 6c4470f commit bca008b
Show file tree
Hide file tree
Showing 10 changed files with 888 additions and 650 deletions.
111 changes: 83 additions & 28 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -552,25 +552,27 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
}
}

private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
out int[] indices, out Dictionary<DvText, int> reconciledKeyNames)
private static int[][] MapKeys<T>(ISchema[] schemas, string columnName, bool isVec,
int[] indices, Dictionary<DvText, int> reconciledKeyNames)
{
Contracts.AssertValue(indices);
Contracts.AssertValue(reconciledKeyNames);

var dvCount = schemas.Length;
var keyValueMappers = new int[dvCount][];
var keyNamesCur = default(VBuffer<DvText>);
indices = new int[dvCount];
reconciledKeyNames = new Dictionary<DvText, int>();
var keyNamesCur = default(VBuffer<T>);
for (int i = 0; i < dvCount; i++)
{
var schema = schemas[i];
if (!schema.TryGetColumnIndex(columnName, out indices[i]))
throw Contracts.Except($"Schema number {i} does not contain column '{columnName}'");

var type = schema.GetColumnType(indices[i]);
var keyValueType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, indices[i]);
if (type.IsVector != isVec)
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type");
if (!schema.HasKeyNames(indices[i], type.ItemType.KeyCount))
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have text key values");
if (keyValueType == null || keyValueType.ItemType.RawType != typeof(T))
throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type of key values");
if (!type.ItemType.IsKey || type.ItemType.RawKind != DataKind.U4)
throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{type.ItemType}'");

Expand All @@ -580,7 +582,7 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
foreach (var kvp in keyNamesCur.Items(true))
{
var key = kvp.Key;
var name = kvp.Value;
var name = new DvText(kvp.Value.ToString());
if (!reconciledKeyNames.ContainsKey(name))
reconciledKeyNames[name] = reconciledKeyNames.Count;
keyValueMappers[i][key] = reconciledKeyNames[name];
Expand All @@ -595,17 +597,18 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
/// corresponding to the key value in the original column.
/// </summary>
public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));

var dvCount = views.Length;

Dictionary<DvText, int> keyNames;
int[] indices;
// Create mappings from the original key types to the reconciled key type.
var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, false, out indices, out keyNames);
var indices = new int[dvCount];
var keyNames = new Dictionary<DvText, int>();
// We use MarshalInvoke so that we can call MapKeys with the correct generic: keyValueType.RawType.
var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, false, indices, keyNames);
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
var keyNamesVBuffer = new VBuffer<DvText>(keyNames.Count, keyNames.Keys.ToArray());
ValueGetter<VBuffer<DvText>> keyValueGetter =
Expand All @@ -629,20 +632,51 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s
}
}

/// <summary>
/// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
/// First, we find the union set of the key values in the different data views. Next we define a new key column for each
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
/// corresponding to the key value in the original column.
/// </summary>
public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView[] views, string columnName, int keyCount)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));

var keyType = new KeyType(DataKind.U4, 0, keyCount);

// For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
for (int i = 0; i < views.Length; i++)
{
if (!views[i].Schema.TryGetColumnIndex(columnName, out var index))
throw env.Except($"Data view {i} doesn't contain a column '{columnName}'");
ValueMapper<uint, uint> mapper =
(ref uint src, ref uint dst) =>
{
if (src == 0 || src > keyCount)
dst = 0;
else
dst = src + 1;
};
views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
views[i].Schema.GetColumnType(index), keyType, mapper);
}
}

/// <summary>
/// This method is similar to <see cref="ReconcileKeyValues"/>, but it reconciles the key values over vector
/// input columns.
/// </summary>
public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName)
public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName, ColumnType keyValueType)
{
Contracts.CheckNonEmpty(views, nameof(views));
Contracts.CheckNonEmpty(columnName, nameof(columnName));

var dvCount = views.Length;

Dictionary<DvText, int> keyNames;
int[] columnIndices;
var keyValueMappers = MapKeys(views.Select(view => view.Schema).ToArray(), columnName, true, out columnIndices, out keyNames);
var keyNames = new Dictionary<DvText, int>();
var columnIndices = new int[dvCount];
var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, true, columnIndices, keyNames);
var keyType = new KeyType(DataKind.U4, 0, keyNames.Count);
var keyNamesVBuffer = new VBuffer<DvText>(keyNames.Count, keyNames.Keys.ToArray());
ValueGetter<VBuffer<DvText>> keyValueGetter =
Expand Down Expand Up @@ -736,7 +770,7 @@ public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env,
var foldDataViews = perInstance.Select(getPerInstance).ToArray();
if (collate)
{
var combined = AppendPerInstanceDataViews(env, foldDataViews, out variableSizeVectorColumnNames);
var combined = AppendPerInstanceDataViews(env, perInstance[0].Schema.Label?.Name, foldDataViews, out variableSizeVectorColumnNames);
return new[] { combined };
}
else
Expand Down Expand Up @@ -767,7 +801,8 @@ public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataVie
return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray());
}

private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string labelColName,
IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
{
Contracts.AssertValue(env);
env.AssertValue(foldDataViews);
Expand All @@ -776,7 +811,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
// This is a dictionary from the column name to its vector size.
var vectorSizes = new Dictionary<string, int>();
var firstDvSlotNames = new Dictionary<string, VBuffer<DvText>>();
var firstDvKeyColumns = new List<string>();
ColumnType labelColKeyValuesType = null;
var firstDvKeyWithNamesColumns = new List<string>();
var firstDvKeyNoNamesColumns = new Dictionary<string, int>();
var firstDvVectorKeyColumns = new List<string>();
var variableSizeVectorColumnNamesList = new List<string>();
var list = new List<IDataView>();
Expand Down Expand Up @@ -822,10 +859,20 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
else
vectorSizes.Add(name, type.VectorSize);
}
else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
else if (dvNumber == 0 && name == labelColName)
{
// The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
firstDvKeyColumns.Add(name);
labelColKeyValuesType = dv.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i);
}
else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
firstDvKeyWithNamesColumns.Add(name);
else if (type.KeyCount > 0 && name != labelColName)
{
// For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4.
if (!firstDvKeyNoNamesColumns.ContainsKey(name))
firstDvKeyNoNamesColumns[name] = type.KeyCount;
if (firstDvKeyNoNamesColumns[name] < type.KeyCount)
firstDvKeyNoNamesColumns[name] = type.KeyCount;
}
}
var idv = dv;
Expand All @@ -839,26 +886,34 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
list.Add(idv);
dvNumber++;
}

variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray();
if (variableSizeVectorColumnNamesList.Count == 0 && firstDvKeyColumns.Count == 0)
return AppendRowsDataView.Create(env, null, list.ToArray());

var views = list.ToArray();
foreach (var keyCol in firstDvKeyColumns)
ReconcileKeyValues(env, views, keyCol);
foreach (var keyCol in firstDvKeyWithNamesColumns)
ReconcileKeyValues(env, views, keyCol, TextType.Instance);
if (labelColKeyValuesType != null)
ReconcileKeyValues(env, views, labelColName, labelColKeyValuesType.ItemType);
foreach (var keyCol in firstDvKeyNoNamesColumns)
ReconcileKeyValuesWithNoNames(env, views, keyCol.Key, keyCol.Value);
foreach (var vectorKeyCol in firstDvVectorKeyColumns)
ReconcileVectorKeyValues(env, views, vectorKeyCol);
ReconcileVectorKeyValues(env, views, vectorKeyCol, TextType.Instance);

Func<IDataView, int, IDataView> keyToValue =
(idv, i) =>
{
foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns))
foreach (var keyCol in firstDvVectorKeyColumns.Prepend(labelColName))
{
if (keyCol == labelColName && labelColKeyValuesType == null)
continue;
idv = new KeyToValueTransform(env, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv);
var hidden = FindHiddenColumns(idv.Schema, keyCol);
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
}
foreach (var keyCol in firstDvKeyNoNamesColumns)
{
var hidden = FindHiddenColumns(idv.Schema, keyCol.Key);
idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv);
}
return idv;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
maml.exe CV tr=FastRankRanking{t=1} strat=Strat threads=- norm=Warn prexf=rangefilter{col=Label min=20 max=25} prexf=term{col=Strat:Label} dout=%Output% loader=text{col=Features:R4:10-14 col=Label:R4:9 col=GroupId:TX:1 header+} data=%Data% out=%Output% xf=term{col=Label} xf=hash{col=GroupId}
Not adding a normalizer.
Making per-feature arrays
Changing data from row-wise to column-wise
Processed 40 instances
Binning and forming Feature objects
Reserved memory for tree learner: 10764 bytes
Starting to train ...
Not training a calibrator because it is not needed.
Not adding a normalizer.
Making per-feature arrays
Changing data from row-wise to column-wise
Processed 32 instances
Binning and forming Feature objects
Reserved memory for tree learner: 6396 bytes
Starting to train ...
Not training a calibrator because it is not needed.
NDCG@1: 0.000000
NDCG@2: 0.000000
NDCG@3: 0.000000
DCG@1: 0.000000
DCG@2: 0.000000
DCG@3: 0.000000
NDCG@1: 0.000000
NDCG@2: 0.000000
NDCG@3: 0.000000
DCG@1: 0.000000
DCG@2: 0.000000
DCG@3: 0.000000

OVERALL RESULTS
---------------------------------------
NDCG@1: 0.000000 (0.0000)
NDCG@2: 0.000000 (0.0000)
NDCG@3: 0.000000 (0.0000)
DCG@1: 0.000000 (0.0000)
DCG@2: 0.000000 (0.0000)
DCG@3: 0.000000 (0.0000)

---------------------------------------
Physical memory usage(MB): %Number%
Virtual memory usage(MB): %Number%
%DateTime% Time elapsed(s): %Number%

Loading

0 comments on commit bca008b

Please sign in to comment.