Skip to content

Commit

Permalink
Merge master.
Browse files Browse the repository at this point in the history
  • Loading branch information
codemzs committed Sep 12, 2018
1 parent 8c80530 commit d45bc2c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 60 deletions.
46 changes: 23 additions & 23 deletions src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -313,50 +313,50 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo)
{
if (typeNames != null)
{
MetadataUtils.MetadataGetter<VBuffer<DvText>> getter = (int col, ref VBuffer<DvText> dst) =>
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> getter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
InputSchema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref dst);
};
var info = new MetadataInfo<VBuffer<DvText>>(typeNames, getter);
var info = new MetadataInfo<VBuffer<ReadOnlyMemory<char>>>(typeNames, getter);
colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info);
}
}
else
{
if (typeNames != null && _types[i].IsKnownSizeVector)
{
MetadataUtils.MetadataGetter<VBuffer<DvText>> getter = (int col, ref VBuffer<DvText> dst) =>
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> getter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
GetSlotNames(i, ref dst);
};
var info = new MetadataInfo<VBuffer<DvText>>(new VectorType(TextType.Instance, _types[i]), getter);
var info = new MetadataInfo<VBuffer<ReadOnlyMemory<char>>>(new VectorType(TextType.Instance, _types[i]), getter);
colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info);
}
}

if (!_parent._columns[i].Bag && srcType.ValueCount > 0)
{
MetadataUtils.MetadataGetter<VBuffer<DvInt4>> getter = (int col, ref VBuffer<DvInt4> dst) =>
MetadataUtils.MetadataGetter<VBuffer<int>> getter = (int col, ref VBuffer<int> dst) =>
{
GetCategoricalSlotRanges(i, ref dst);
};
var info = new MetadataInfo<VBuffer<DvInt4>>(MetadataUtils.GetCategoricalType(_infos[i].TypeSrc.ValueCount), getter);
var info = new MetadataInfo<VBuffer<int>>(MetadataUtils.GetCategoricalType(_infos[i].TypeSrc.ValueCount), getter);
colMetaInfo.Add(MetadataUtils.Kinds.CategoricalSlotRanges, info);
}

if (!_parent._columns[i].Bag || srcType.ValueCount == 1)
{
MetadataUtils.MetadataGetter<DvBool> getter = (int col, ref DvBool dst) =>
MetadataUtils.MetadataGetter<bool> getter = (int col, ref bool dst) =>
{
dst = true;
};
var info = new MetadataInfo<DvBool>(BoolType.Instance, getter);
var info = new MetadataInfo<bool>(BoolType.Instance, getter);
colMetaInfo.Add(MetadataUtils.Kinds.IsNormalized, info);
}
}

// Combines source key names and slot names to produce final slot names.
private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
private void GetSlotNames(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
{
Host.Assert(0 <= iinfo && iinfo < _infos.Length);
Host.Assert(_types[iinfo].IsKnownSizeVector);
Expand All @@ -367,7 +367,7 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
Host.Assert(typeSrc.VectorSize > 1);

// Get the source slot names, defaulting to empty text.
var namesSlotSrc = default(VBuffer<DvText>);
var namesSlotSrc = default(VBuffer<ReadOnlyMemory<char>>);
InputSchema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol);
Host.Assert(srcCol >= 0);
var typeSlotSrc = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, srcCol);
Expand All @@ -377,31 +377,31 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
Host.Check(namesSlotSrc.Length == typeSrc.VectorSize);
}
else
namesSlotSrc = VBufferUtils.CreateEmpty<DvText>(typeSrc.VectorSize);
namesSlotSrc = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(typeSrc.VectorSize);

int keyCount = typeSrc.ItemType.ItemType.KeyCount;
int slotLim = _types[iinfo].VectorSize;
Host.Assert(slotLim == (long)typeSrc.VectorSize * keyCount);

// Get the source key names, in an array (since we will use them multiple times).
var namesKeySrc = default(VBuffer<DvText>);
var namesKeySrc = default(VBuffer<ReadOnlyMemory<char>>);
InputSchema.GetMetadata(MetadataUtils.Kinds.KeyValues, srcCol, ref namesKeySrc);
Host.Check(namesKeySrc.Length == keyCount);
var keys = new DvText[keyCount];
var keys = new ReadOnlyMemory<char>[keyCount];
namesKeySrc.CopyTo(keys);

var values = dst.Values;
if (Utils.Size(values) < slotLim)
values = new DvText[slotLim];
values = new ReadOnlyMemory<char>[slotLim];

var sb = new StringBuilder();
int slot = 0;
foreach (var kvpSlot in namesSlotSrc.Items(all: true))
{
Contracts.Assert(slot == (long)kvpSlot.Key * keyCount);
sb.Clear();
if (kvpSlot.Value.HasChars)
kvpSlot.Value.AddToStringBuilder(sb);
if (!kvpSlot.Value.IsEmpty)
ReadOnlyMemoryUtils.AddToStringBuilder(kvpSlot.Value, sb);
else
sb.Append('[').Append(kvpSlot.Key).Append(']');
sb.Append('.');
Expand All @@ -410,24 +410,24 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
foreach (var key in keys)
{
sb.Length = len;
key.AddToStringBuilder(sb);
values[slot++] = new DvText(sb.ToString());
ReadOnlyMemoryUtils.AddToStringBuilder(key, sb);
values[slot++] = sb.ToString().AsMemory();
}
}
Host.Assert(slot == slotLim);

dst = new VBuffer<DvText>(slotLim, values, dst.Indices);
dst = new VBuffer<ReadOnlyMemory<char>>(slotLim, values, dst.Indices);
}

private void GetCategoricalSlotRanges(int iinfo, ref VBuffer<DvInt4> dst)
private void GetCategoricalSlotRanges(int iinfo, ref VBuffer<int> dst)
{
Host.Assert(0 <= iinfo && iinfo < _infos.Length);

var info = _infos[iinfo];

Host.Assert(info.TypeSrc.ValueCount > 0);

DvInt4[] ranges = new DvInt4[info.TypeSrc.ValueCount * 2];
int[] ranges = new int[info.TypeSrc.ValueCount * 2];
int size = info.TypeSrc.ItemType.KeyCount;

ranges[0] = 0;
Expand All @@ -438,7 +438,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer<DvInt4> dst)
ranges[i + 1] = ranges[i] + size - 1;
}

dst = new VBuffer<DvInt4>(ranges.Length, ranges);
dst = new VBuffer<int>(ranges.Length, ranges);
}

protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
Expand Down Expand Up @@ -834,7 +834,7 @@ public OutVectorColumn(Vector<Key<TKey>> input, bool bag)
Input = input;
Bag = bag;
}

public OutVectorColumn(VarVector<Key<TKey>> input)
: base(Reconciler.Inst, input)
{
Expand Down
38 changes: 19 additions & 19 deletions src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,46 +246,46 @@ private void AddMetadata(int i, ColumnMetadataInfo colMetaInfo)
{
if (typeNames != null)
{
MetadataUtils.MetadataGetter<VBuffer<DvText>> getter = (int col, ref VBuffer<DvText> dst) =>
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> getter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
GenerateBitSlotName(i, ref dst);
};
var info = new MetadataInfo<VBuffer<DvText>>(new VectorType(TextType.Instance, _types[i]), getter);
var info = new MetadataInfo<VBuffer<ReadOnlyMemory<char>>>(new VectorType(TextType.Instance, _types[i]), getter);
colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info);
}
MetadataUtils.MetadataGetter<DvBool> normalizeGetter = (int col, ref DvBool dst) =>
MetadataUtils.MetadataGetter<bool> normalizeGetter = (int col, ref bool dst) =>
{
dst = true;
};
var normalizeInfo = new MetadataInfo<DvBool>(BoolType.Instance, normalizeGetter);
var normalizeInfo = new MetadataInfo<bool>(BoolType.Instance, normalizeGetter);
colMetaInfo.Add(MetadataUtils.Kinds.IsNormalized, normalizeInfo);
}
else
{
if (typeNames != null && _types[i].IsKnownSizeVector)
{
MetadataUtils.MetadataGetter<VBuffer<DvText>> getter = (int col, ref VBuffer<DvText> dst) =>
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> getter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
GetSlotNames(i, ref dst);
};
var info = new MetadataInfo<VBuffer<DvText>>(new VectorType(TextType.Instance, _types[i]), getter);
var info = new MetadataInfo<VBuffer<ReadOnlyMemory<char>>>(new VectorType(TextType.Instance, _types[i]), getter);
colMetaInfo.Add(MetadataUtils.Kinds.SlotNames, info);
}
}
}

private void GenerateBitSlotName(int iinfo, ref VBuffer<DvText> dst)
private void GenerateBitSlotName(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
{
const string slotNamePrefix = "Bit";
var bldr = new BufferBuilder<DvText>(TextCombiner.Instance);
var bldr = new BufferBuilder<ReadOnlyMemory<char>>(TextCombiner.Instance);
bldr.Reset(_bitsPerKey[iinfo], true);
for (int i = 0; i < _bitsPerKey[iinfo]; i++)
bldr.AddFeature(i, new DvText(slotNamePrefix + (_bitsPerKey[iinfo] - i - 1)));
bldr.AddFeature(i, (slotNamePrefix + (_bitsPerKey[iinfo] - i - 1)).AsMemory());

bldr.GetResult(ref dst);
}

private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
private void GetSlotNames(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
{
Host.Assert(0 <= iinfo && iinfo < _infos.Length);
Host.Assert(_types[iinfo].IsKnownSizeVector);
Expand All @@ -295,7 +295,7 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
Host.Assert(typeSrc.VectorSize > 1);

// Get the source slot names, defaulting to empty text.
var namesSlotSrc = default(VBuffer<DvText>);
var namesSlotSrc = default(VBuffer<ReadOnlyMemory<char>>);
InputSchema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol);
Host.Assert(srcCol >= 0);
var typeSlotSrc = InputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, srcCol);
Expand All @@ -305,25 +305,25 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
Host.Check(namesSlotSrc.Length == typeSrc.VectorSize);
}
else
namesSlotSrc = VBufferUtils.CreateEmpty<DvText>(typeSrc.VectorSize);
namesSlotSrc = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(typeSrc.VectorSize);

int slotLim = _types[iinfo].VectorSize;
Host.Assert(slotLim == (long)typeSrc.VectorSize * _bitsPerKey[iinfo]);

var values = dst.Values;
if (Utils.Size(values) < slotLim)
values = new DvText[slotLim];
values = new ReadOnlyMemory<char>[slotLim];

var sb = new StringBuilder();
int slot = 0;
VBuffer<DvText> bits = default;
VBuffer<ReadOnlyMemory<char>> bits = default;
GenerateBitSlotName(iinfo, ref bits);
foreach (var kvpSlot in namesSlotSrc.Items(all: true))
{
Contracts.Assert(slot == (long)kvpSlot.Key * _bitsPerKey[iinfo]);
sb.Clear();
if (kvpSlot.Value.HasChars)
kvpSlot.Value.AddToStringBuilder(sb);
if (!kvpSlot.Value.IsEmpty)
ReadOnlyMemoryUtils.AddToStringBuilder(kvpSlot.Value, sb);
else
sb.Append('[').Append(kvpSlot.Key).Append(']');
sb.Append('.');
Expand All @@ -332,13 +332,13 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
foreach (var key in bits.Values)
{
sb.Length = len;
key.AddToStringBuilder(sb);
values[slot++] = new DvText(sb.ToString());
ReadOnlyMemoryUtils.AddToStringBuilder(key, sb);
values[slot++] = sb.ToString().AsMemory();
}
}
Host.Assert(slot == slotLim);

dst = new VBuffer<DvText>(slotLim, values, dst.Indices);
dst = new VBuffer<ReadOnlyMemory<char>>(slotLim, values, dst.Indices);
}

protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Tools;
using System;
using System.IO;
using System.Linq;
using Xunit;
Expand Down Expand Up @@ -124,8 +125,8 @@ private void ValidateMetadata(IDataView result)
Assert.True(result.Schema.TryGetColumnIndex("CatD", out int colD));
var types = result.Schema.GetMetadataTypes(colA);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.SlotNames });
VBuffer<DvText> slots = default;
DvBool normalized = default;
VBuffer<ReadOnlyMemory<char>> slots = default;
bool normalized = default;
result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colA, ref slots);
Assert.True(slots.Length == 6);
Assert.Equal(slots.Values.Select(x => x.ToString()), new string[6] { "[0].Bit2", "[0].Bit1", "[0].Bit0", "[1].Bit2", "[1].Bit1", "[1].Bit0" });
Expand All @@ -136,15 +137,15 @@ private void ValidateMetadata(IDataView result)
Assert.True(slots.Length == 2);
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "Bit1", "Bit0" });
result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colB, ref normalized);
Assert.True(normalized.IsTrue);
Assert.True(normalized);

types = result.Schema.GetMetadataTypes(colC);
Assert.Equal(types.Select(x => x.Key), new string[0]);

types = result.Schema.GetMetadataTypes(colD);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.IsNormalized });
result.Schema.GetMetadata(MetadataUtils.Kinds.IsNormalized, colD, ref normalized);
Assert.True(normalized.IsTrue);
Assert.True(normalized);
}

[Fact]
Expand Down
Loading

0 comments on commit d45bc2c

Please sign in to comment.