Skip to content

Commit

Permalink
WIP #3
Browse files Browse the repository at this point in the history
  • Loading branch information
eerhardt committed Jan 10, 2019
1 parent e56114f commit 204a1b5
Show file tree
Hide file tree
Showing 15 changed files with 167 additions and 147 deletions.
22 changes: 10 additions & 12 deletions src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,18 @@ private static void PrintSchema(TextWriter writer, Arguments args, Schema schema

if (!args.ShowSlots)
continue;
if (!type.IsKnownSizeVector)
if (!type.IsKnownSizeVector())
continue;
ColumnType typeNames;
if ((typeNames = schema[col].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type) == null)
continue;
if (typeNames.VectorSize != type.VectorSize || !(typeNames.ItemType is TextType))
if (typeNames.VectorSize() != type.VectorSize() || !(typeNames.ItemType() is TextType))
{
Contracts.Assert(false, "Unexpected slot names type");
continue;
}
schema[col].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names);
if (names.Length != type.VectorSize)
if (names.Length != type.VectorSize())
{
Contracts.Assert(false, "Unexpected length of slot names vector");
continue;
Expand Down Expand Up @@ -193,10 +193,10 @@ private static void ShowMetadata(IndentedTextWriter itw, Schema schema, int col,
itw.Write("Metadata '{0}': {1}", metaColumn.Name, type);
if (showVals)
{
if (!type.IsVector)
if (!(type is VectorType vectorType))
ShowMetadataValue(itw, schema, col, metaColumn.Name, type);
else
ShowMetadataValueVec(itw, schema, col, metaColumn.Name, type);
ShowMetadataValueVec(itw, schema, col, metaColumn.Name, vectorType);
}
itw.WriteLine();
}
Expand All @@ -210,7 +210,7 @@ private static void ShowMetadataValue(IndentedTextWriter itw, Schema schema, int
Contracts.Assert(0 <= col && col < schema.Count);
Contracts.AssertNonEmpty(kind);
Contracts.AssertValue(type);
Contracts.Assert(!type.IsVector);
Contracts.Assert(!(type is VectorType));

if (!type.IsStandardScalar() && !(type is KeyType))
{
Expand All @@ -230,7 +230,7 @@ private static void ShowMetadataValue<T>(IndentedTextWriter itw, Schema schema,
Contracts.Assert(0 <= col && col < schema.Count);
Contracts.AssertNonEmpty(kind);
Contracts.AssertValue(type);
Contracts.Assert(!type.IsVector);
Contracts.Assert(!(type is VectorType));
Contracts.Assert(type.RawType == typeof(T));

var conv = Conversions.Instance.GetStringConversion<T>(type);
Expand All @@ -243,34 +243,32 @@ private static void ShowMetadataValue<T>(IndentedTextWriter itw, Schema schema,
itw.Write(": '{0}'", sb);
}

private static void ShowMetadataValueVec(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type)
private static void ShowMetadataValueVec(IndentedTextWriter itw, Schema schema, int col, string kind, VectorType type)
{
Contracts.AssertValue(itw);
Contracts.AssertValue(schema);
Contracts.Assert(0 <= col && col < schema.Count);
Contracts.AssertNonEmpty(kind);
Contracts.AssertValue(type);
Contracts.Assert(type.IsVector);

if (!type.ItemType.IsStandardScalar() && !(type.ItemType is KeyType))
{
itw.Write(": Can't display value of this type");
return;
}

Action<IndentedTextWriter, Schema, int, string, ColumnType> del = ShowMetadataValueVec<int>;
Action<IndentedTextWriter, Schema, int, string, VectorType> del = ShowMetadataValueVec<int>;
var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType);
meth.Invoke(null, new object[] { itw, schema, col, kind, type });
}

private static void ShowMetadataValueVec<T>(IndentedTextWriter itw, Schema schema, int col, string kind, ColumnType type)
private static void ShowMetadataValueVec<T>(IndentedTextWriter itw, Schema schema, int col, string kind, VectorType type)
{
Contracts.AssertValue(itw);
Contracts.AssertValue(schema);
Contracts.Assert(0 <= col && col < schema.Count);
Contracts.AssertNonEmpty(kind);
Contracts.AssertValue(type);
Contracts.Assert(type.IsVector);
Contracts.Assert(type.ItemType.RawType == typeof(T));

var conv = Conversions.Instance.GetStringConversion<T>(type.ItemType);
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/Data/RowCursorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ public static Delegate GetVecGetterAs(PrimitiveType typeDst, Row row, int col)
Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col));
Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active");

var typeSrc = row.Schema[col].Type;
Contracts.Check(typeSrc.IsVector, "Source column type must be vector");
var typeSrc = row.Schema[col].Type as VectorType;
Contracts.Check(typeSrc != null, "Source column type must be vector");

Func<VectorType, PrimitiveType, GetterFactory, ValueGetter<VBuffer<int>>> del = GetVecGetterAsCore<int, int>;
var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.ItemType.RawType, typeDst.RawType);
Expand All @@ -170,8 +170,8 @@ public static ValueGetter<VBuffer<TDst>> GetVecGetterAs<TDst>(PrimitiveType type
Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col));
Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active");

var typeSrc = row.Schema[col].Type;
Contracts.Check(typeSrc.IsVector, "Source column type must be vector");
var typeSrc = row.Schema[col].Type as VectorType;
Contracts.Check(typeSrc != null, "Source column type must be vector");

Func<VectorType, PrimitiveType, GetterFactory, ValueGetter<VBuffer<TDst>>> del = GetVecGetterAsCore<int, TDst>;
var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.ItemType.RawType, typeof(TDst));
Expand Down Expand Up @@ -259,7 +259,7 @@ private static ValueGetter<VBuffer<TDst>> GetVecGetterAsCore<TSrc, TDst>(VectorT
return (ValueGetter<VBuffer<TDst>>)(Delegate)getter;
}

int size = typeSrc.VectorSize;
int size = typeSrc.Size;
var src = default(VBuffer<TSrc>);
return (ref VBuffer<TDst> dst) =>
{
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ private ColInfo(string name, ColumnType colType, Segment[] segs, int isegVar, in
Contracts.Assert(isegVar >= -1);

Name = name;
Kind = colType.ItemType.RawKind;
Kind = colType.ItemType().RawKind;
Contracts.Assert(Kind != 0);
ColType = colType;
Segments = segs;
Expand Down Expand Up @@ -845,7 +845,7 @@ public void Save(ModelSaveContext ctx)
{
var info = Infos[iinfo];
ctx.SaveNonEmptyString(info.Name);
var type = info.ColType.ItemType;
var type = info.ColType.ItemType();
Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind);
ctx.Writer.Write((byte)type.RawKind);
ctx.Writer.WriteBoolByte(type is KeyType);
Expand Down Expand Up @@ -899,7 +899,7 @@ public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
var names = _slotNames[col];
if (names.Length > 0)
{
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
Contracts.Assert(Infos[col].ColType.VectorSize() == names.Length);
yield return MetadataUtils.GetSlotNamesPair(names.Length);
}
}
Expand All @@ -915,7 +915,7 @@ public ColumnType GetMetadataTypeOrNull(string kind, int col)
var names = _slotNames[col];
if (names.Length == 0)
return null;
Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
Contracts.Assert(Infos[col].ColType.VectorSize() == names.Length);
return MetadataUtils.GetNamesType(names.Length);

default:
Expand Down Expand Up @@ -947,7 +947,7 @@ private void GetSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
if (names.Length == 0)
throw MetadataUtils.ExceptGetMetadata();

Contracts.Assert(Infos[col].ColType.VectorSize == names.Length);
Contracts.Assert(Infos[col].ColType.VectorSize() == names.Length);
names.CopyTo(ref dst);
}
}
Expand Down
19 changes: 11 additions & 8 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -666,18 +666,21 @@ public Parser(TextLoader parent)
_creator[i] = cache.GetCreatorOne(keyType);
continue;
}
else if (info.ColType is VectorType vectorType && vectorType.ItemType is KeyType vectorKeyType)

VectorType vectorType = info.ColType as VectorType;
if (vectorType?.ItemType is KeyType vectorKeyType)
{
_creator[i] = cache.GetCreatorVec(vectorKeyType);
continue;
}

DataKind kind = info.ColType.ItemType.RawKind;
ColumnType itemType = vectorType?.ItemType ?? info.ColType;
DataKind kind = itemType.RawKind;
Contracts.Assert(kind != 0);
var map = info.ColType.IsVector ? mapVec : mapOne;
var map = vectorType != null ? mapVec : mapOne;
if (!map.TryGetValue(info.Kind, out _creator[i]))
{
var fn = info.ColType.IsVector ?
var fn = vectorType != null ?
cache.GetCreatorVec(info.Kind) :
cache.GetCreatorOne(info.Kind);
map.Add(info.Kind, fn);
Expand Down Expand Up @@ -750,7 +753,7 @@ public static void ParseSlotNames(TextLoader parent, ReadOnlyMemory<char> textHe
for (int iinfo = 0; iinfo < infos.Length; iinfo++)
{
var info = infos[iinfo];
if (!info.ColType.IsKnownSizeVector)
if (!info.ColType.IsKnownSizeVector())
continue;
bldr.Reset(info.SizeBase, false);
int ivDst = 0;
Expand Down Expand Up @@ -1282,7 +1285,7 @@ private void ProcessItems(RowSet rows, int irow, bool[] active, FieldSet fields,
var v = rows.Pipes[iinfo];
Contracts.Assert(v != null);

if (!info.ColType.IsVector)
if (!(info.ColType is VectorType))
ProcessOne(fields, info, v, irow, line);
else
ProcessVec(srcLim, fields, info, v, irow, line);
Expand All @@ -1292,7 +1295,7 @@ private void ProcessItems(RowSet rows, int irow, bool[] active, FieldSet fields,
private void ProcessVec(int srcLim, FieldSet fields, ColInfo info, ColumnPipe v, int irow, long line)
{
Contracts.Assert(srcLim >= 0);
Contracts.Assert(info.ColType.IsVector);
Contracts.Assert(info.ColType is VectorType);
Contracts.Assert(info.SizeBase > 0 || info.IsegVariable >= 0);

int sizeVar = 0;
Expand Down Expand Up @@ -1348,7 +1351,7 @@ private void ProcessVec(int srcLim, FieldSet fields, ColInfo info, ColumnPipe v,

private void ProcessOne(FieldSet vs, ColInfo info, ColumnPipe v, int irow, long line)
{
Contracts.Assert(!info.ColType.IsVector);
Contracts.Assert(!(info.ColType is VectorType));
Contracts.Assert(Utils.Size(info.Segments) == 1);
Contracts.Assert(info.Segments[0].Lim == info.Segments[0].Min + 1);

Expand Down
44 changes: 23 additions & 21 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ public static ValueWriter Create(RowCursor cursor, int col, char sep)

ColumnType type = cursor.Schema[col].Type;
Type writePipeType;
if (type.IsVector)
writePipeType = typeof(VecValueWriter<>).MakeGenericType(type.ItemType.RawType);
if (type is VectorType vectorType)
writePipeType = typeof(VecValueWriter<>).MakeGenericType(vectorType.ItemType.RawType);
else
writePipeType = typeof(ValueWriter<>).MakeGenericType(type.RawType);

Expand Down Expand Up @@ -151,15 +151,15 @@ public VecValueWriter(RowCursor cursor, VectorType type, int source, char sep)
: base(type.ItemType, source, sep)
{
_getSrc = cursor.GetGetter<VBuffer<T>>(source);
ColumnType typeNames;
if (type.IsKnownSizeVector &&
(typeNames = cursor.Schema[source].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type) != null &&
typeNames.VectorSize == type.VectorSize && typeNames.ItemType is TextType)
VectorType typeNames;
if (type.IsKnownSize
&& (typeNames = cursor.Schema[source].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type as VectorType) != null
&& typeNames.Size == type.Size && typeNames.ItemType is TextType)
{
cursor.Schema[source].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames);
Contracts.Check(_slotNames.Length == typeNames.VectorSize, "Unexpected slot names length");
Contracts.Check(_slotNames.Length == typeNames.Size, "Unexpected slot names length");
}
_slotCount = type.VectorSize;
_slotCount = type.Size;
}

public override void WriteData(Action<StringBuilder, int> appendItem, out int length)
Expand Down Expand Up @@ -313,7 +313,7 @@ public static string SeparatorCharToString(char separator)

public bool IsColumnSavable(ColumnType type)
{
var item = type.ItemType;
var item = type.ItemType();
return item.IsStandardScalar() || item is KeyType;
}

Expand Down Expand Up @@ -387,7 +387,7 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data,
for (int i = 0; i < cols.Length; i++)
{
ch.Check(0 <= cols[i] && cols[i] < active.Length);
ch.Check(data.Schema[cols[i]].Type.ItemType.RawKind != 0);
ch.Check(data.Schema[cols[i]].Type.ItemType().RawKind != 0);
active[cols[i]] = true;
}

Expand All @@ -399,15 +399,15 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data,
if (hasHeader)
continue;
var type = data.Schema[cols[i]].Type;
if (!type.IsVector)
if (!(type is VectorType vectorType))
{
hasHeader = true;
continue;
}
if (!type.IsKnownSizeVector)
if (!vectorType.IsKnownSize)
continue;
var typeNames = data.Schema[cols[i]].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type;
if (typeNames != null && typeNames.VectorSize == type.VectorSize && typeNames.ItemType is TextType)
var typeNames = data.Schema[cols[i]].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type as VectorType;
if (typeNames != null && typeNames.Size == vectorType.Size && typeNames.ItemType is TextType)
hasHeader = true;
}
}
Expand Down Expand Up @@ -471,13 +471,13 @@ private string CreateLoaderArguments(Schema schema, ValueWriter[] pipes, bool ha
var settings = CmdParser.GetSettings(_host, column, new TextLoader.Column());
CmdQuoter.QuoteValue(settings, sb, true);
}
if (type.IsVector && !type.IsKnownSizeVector && i != pipes.Length - 1)
if (type is VectorType vectorType && !vectorType.IsKnownSize && i != pipes.Length - 1)
{
ch.Warning("Column '{0}' is variable length, so it must be the last, or the file will be unreadable. Consider switching to binary format or use xf=Choose to make '{0}' the last column.", name);
index = null;
}

index += type.ValueCount;
index += type.ValueCount();
}

return sb.ToString();
Expand All @@ -487,7 +487,9 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
{
DataKind? kind;
KeyRange keyRange = null;
if (type.ItemType is KeyType key)
VectorType vectorType = type as VectorType;
ColumnType itemType = vectorType?.ItemType ?? type;
if (itemType is KeyType key)
{
if (!key.Contiguous)
keyRange = new KeyRange(key.Min, contiguous: false);
Expand All @@ -501,15 +503,15 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
kind = key.RawKind;
}
else
kind = type.ItemType.RawKind;
kind = itemType.RawKind;

TextLoader.Range[] source = null;

TextLoader.Range range = null;
int minValue = start ?? -1;
if (type.IsKnownSizeVector)
range = new TextLoader.Range { Min = minValue, Max = minValue + type.ValueCount - 1, ForceVector = true };
else if (type.IsVector)
if (vectorType?.IsKnownSize == true)
range = new TextLoader.Range { Min = minValue, Max = minValue + vectorType.Size - 1, ForceVector = true };
else if (vectorType != null)
range = new TextLoader.Range { Min = minValue, VariableEnd = true };
else
range = new TextLoader.Range { Min = minValue };
Expand Down
Loading

0 comments on commit 204a1b5

Please sign in to comment.