From 9ba4dabcdf2c4d092142d22726432b5b5b370b24 Mon Sep 17 00:00:00 2001 From: Gal Oshri Date: Tue, 7 Aug 2018 09:31:37 -0700 Subject: [PATCH 01/37] trigger build. --- docs/release-notes/0.4/release-0.4.md | 88 +++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 docs/release-notes/0.4/release-0.4.md diff --git a/docs/release-notes/0.4/release-0.4.md b/docs/release-notes/0.4/release-0.4.md new file mode 100644 index 0000000000..41c436f14e --- /dev/null +++ b/docs/release-notes/0.4/release-0.4.md @@ -0,0 +1,88 @@ +# ML.NET 0.4 Release Notes + +Today we are releasing ML.NET 0.4. During this release we have started +exploring new APIs for ML.NET that enable functionality that is missing from +the current APIs. We welcome feedback and contributions to the +conversation (relevant issues can be found [here](https://github.com/dotnet/machinelearning/projects/4)). While the +focus has been on designing the new APIs, we have also moved several +components from the internal codebase to ML.NET. + +### Installation + +ML.NET supports Windows, MacOS, and Linux. See [supported OS versions of .NET +Core +2.0](https://github.com/dotnet/core/blob/master/release-notes/2.0/2.0-supported-os.md) +for more details. + +You can install ML.NET NuGet from the CLI using: +``` +dotnet add package Microsoft.ML +``` + +From package manager: +``` +Install-Package Microsoft.ML +``` + +### Release Notes + +Below are some of the highlights from this release. + +* Added SymSGD learner for binary classification + ([#624](https://github.com/dotnet/machinelearning/pull/624)) + + * [SymSGD](https://arxiv.org/abs/1705.08030) is a technique for + parallelizing + [SGD](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) + (Stochastic Gradient Descent). This enables it to sometimes perform + faster than existing SGD implementations (e.g. [Hogwild + SGD](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.trainers.stochasticgradientdescentbinaryclassifier?view=ml-dotnet)). + * SymSGD is available for binary classification, but can be used in + multiclass classification with + [One-Versus-All](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.models.oneversusall?view=ml-dotnet) + * SymSGD requires adding the Microsoft.ML.HalLearners NuGet package to your project + * The current implementation in ML.NET does not yet have multi-threading + enabled due to build system limitations (tracked by + [#655](https://github.com/dotnet/machinelearning/issues/655)), but + SymSGD can still be helpful in scenarios where you want to try many + different learners and limit each of them to a single thread. + * Documentation can be found + [here](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.trainers.symsgdbinaryclassifier?view=ml-dotnet) + +* Added Word Embeddings Transform for text scenarios + ([#545](https://github.com/dotnet/machinelearning/pull/545)) + + * [Word embeddings](https://en.wikipedia.org/wiki/Word_embedding) is a + technique for mapping words or phrases to numeric vectors of relatively low + dimension (in comparison with the high dimensional n-gram extraction). + These numeric vectors are intended to capture some of the meaning of the + words so they can be used for training a better model. As an example, + SSWE (Sentiment-Specific Word Embedding) can be useful for sentiment + related tasks. + * This transform enables using pretrained models to get the embeddings + (i.e. the embeddings are already trained and available for use). + * Several options for pretrained embeddings are available: + [GloVe](https://nlp.stanford.edu/projects/glove/), + [fastText](https://en.wikipedia.org/wiki/FastText), and + [SSWE](http://anthology.aclweb.org/P/P14/P14-1146.pdf). The pretrained model is downloaded automatically on first use. + * Documentation can be found + [here](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.transforms.wordembeddings?view=ml-dotnet). + +* Improved support for F# by allowing use of property-based row classes ([#616](https://github.com/dotnet/machinelearning/pull/616)) + + * ML.NET now supports F# record types. + * The ML.NET samples repository is being updated to include F# samples as part of [#36](https://github.com/dotnet/machinelearning-samples/pull/36). + +Additional issues closed in this milestone can be found +[here](https://github.com/dotnet/machinelearning/milestone/3?closed=1). + +### Acknowledgements + +Shoutout to [dsyme](https://github.com/dsyme), +[SolyarA](https://github.com/SolyarA), +[dan-drews](https://github.com/dan-drews), +[bojanmisic](https://github.com/bojanmisic), +[jwood803](https://github.com/jwood803), +[sharwell](https://github.com/sharwell), +[JoshuaLight](https://github.com/JoshuaLight), and the ML.NET team for their +contributions as part of this release! \ No newline at end of file From f7a0fffb5a53b44c22f1166639f8fe2676226781 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 14 Aug 2018 16:45:58 -0700 Subject: [PATCH 02/37] Datakinds, ColumnTypes and NAHandle. --- src/Microsoft.ML.Core/Data/ColumnType.cs | 76 +++- src/Microsoft.ML.Core/Data/DataKind.cs | 77 +++- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 396 +++++++++--------- 3 files changed, 323 insertions(+), 226 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 96764d68f1..b754c0e767 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -385,6 +385,17 @@ public static NumberType I1 } } + private static volatile NumberType _instNI1; + public static NumberType NI1 + { + get + { + if (_instNI1 == null) + Interlocked.CompareExchange(ref _instNI1, new NumberType(DataKind.NI1, "NI1"), null); + return _instNI1; + } + } + private static volatile NumberType _instU1; public static NumberType U1 { @@ -407,6 +418,17 @@ public static NumberType I2 } } + private static volatile NumberType _instNI2; + public static NumberType NI2 + { + get + { + if (_instNI2 == null) + Interlocked.CompareExchange(ref _instNI2, new NumberType(DataKind.NI2, "NI2"), null); + return _instNI2; + } + } + private static volatile NumberType _instU2; public static NumberType U2 { @@ -429,6 +451,17 @@ public static NumberType I4 } } + private static volatile NumberType _instNI4; + public static NumberType NI4 + { + get + { + if (_instNI4 == null) + Interlocked.CompareExchange(ref _instNI4, new NumberType(DataKind.NI4, "NI4"), null); + return _instNI4; + } + } + private static volatile NumberType _instU4; public static NumberType U4 { @@ -451,6 +484,17 @@ public static NumberType I8 } } + private static volatile NumberType _instNI8; + public static NumberType NI8 + { + get + { + if (_instNI8 == null) + Interlocked.CompareExchange(ref _instNI8, new NumberType(DataKind.NI8, "NI8"), null); + return _instNI8; + } + } + private static volatile NumberType _instU8; public static NumberType U8 { @@ -506,18 +550,26 @@ public static NumberType Float { case DataKind.I1: return I1; + case DataKind.NI1: + return NI1; case DataKind.U1: return U1; case DataKind.I2: return I2; + case DataKind.NI2: + return NI2; case DataKind.U2: return U2; case DataKind.I4: return I4; + case DataKind.NI4: + return NI4; case DataKind.U4: return U4; case DataKind.I8: return I8; + case DataKind.NI8: + return NI8; case DataKind.U8: return U8; case DataKind.R4: @@ -567,14 +619,30 @@ public static BoolType Instance get { if (_instance == null) - Interlocked.CompareExchange(ref _instance, new BoolType(), null); + Interlocked.CompareExchange(ref _instance, new BoolType(DataKind.BL, "Bool"), null); return _instance; } } - private BoolType() - : base(typeof(DvBool), DataKind.BL) + private static volatile BoolType _ninstance; + public static BoolType NInstance + { + get + { + if (_ninstance == null) + Interlocked.CompareExchange(ref _ninstance, new BoolType(DataKind.NBL, "NBool"), null); + return _ninstance; + } + } + + private readonly string _name; + + private BoolType(DataKind kind, string name) + : base(kind.ToType(), kind) { + Contracts.AssertNonEmpty(name); + _name = name; + Contracts.Assert(IsNumber); } public override bool Equals(ColumnType other) @@ -587,7 +655,7 @@ public override bool Equals(ColumnType other) public override string ToString() { - return "Bool"; + return _name; } } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 32325f44a1..d8f3b1bef0 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -46,6 +46,13 @@ public enum DataKind : byte UG = 16, // Unsigned 16-byte integer. U16 = UG, +#pragma warning restore MSML_GeneralName + NI1 = 17, + NI2 = 18, + NI4 = 19, + NI8 = 20, +#pragma warning disable MSML_GeneralName + NBL = 21, #pragma warning restore MSML_GeneralName } @@ -55,7 +62,7 @@ public enum DataKind : byte public static class DataKindExtensions { public const DataKind KindMin = DataKind.I1; - public const DataKind KindLim = DataKind.UG + 1; + public const DataKind KindLim = DataKind.NBL + 1; public const int KindCount = KindLim - KindMin; /// @@ -84,18 +91,22 @@ public static ulong ToMaxInt(this DataKind kind) switch (kind) { case DataKind.I1: + case DataKind.NI1: return (ulong)sbyte.MaxValue; case DataKind.U1: return byte.MaxValue; case DataKind.I2: + case DataKind.NI2: return (ulong)short.MaxValue; case DataKind.U2: return ushort.MaxValue; case DataKind.I4: + case DataKind.NI4: return int.MaxValue; case DataKind.U4: return uint.MaxValue; case DataKind.I8: + case DataKind.NI8: return long.MaxValue; case DataKind.U8: return ulong.MaxValue; @@ -113,18 +124,22 @@ public static long ToMinInt(this DataKind kind) switch (kind) { case DataKind.I1: + case DataKind.NI1: return sbyte.MinValue; case DataKind.U1: return byte.MinValue; case DataKind.I2: + case DataKind.NI2: return short.MinValue; case DataKind.U2: return ushort.MinValue; case DataKind.I4: + case DataKind.NI4: return int.MinValue; case DataKind.U4: return uint.MinValue; case DataKind.I8: + case DataKind.NI8: return long.MinValue; case DataKind.U8: return 0; @@ -141,19 +156,27 @@ public static Type ToType(this DataKind kind) switch (kind) { case DataKind.I1: - return typeof(DvInt1); + return typeof(sbyte); + case DataKind.NI1: + return typeof(sbyte?); case DataKind.U1: return typeof(byte); case DataKind.I2: - return typeof(DvInt2); + return typeof(short); + case DataKind.NI2: + return typeof(short?); case DataKind.U2: return typeof(ushort); case DataKind.I4: - return typeof(DvInt4); + return typeof(Int32); + case DataKind.NI4: + return typeof(Int32?); case DataKind.U4: return typeof(uint); case DataKind.I8: - return typeof(DvInt8); + return typeof(Int64); + case DataKind.NI8: + return typeof(Int64?); case DataKind.U8: return typeof(ulong); case DataKind.R4: @@ -163,7 +186,9 @@ public static Type ToType(this DataKind kind) case DataKind.TX: return typeof(DvText); case DataKind.BL: - return typeof(DvBool); + return typeof(bool); + case DataKind.NBL: + return typeof(bool?); case DataKind.TS: return typeof(DvTimeSpan); case DataKind.DT: @@ -185,30 +210,40 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) Contracts.CheckValueOrNull(type); // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(DvInt1) || type == typeof(sbyte) || type == typeof(sbyte?)) + if (type == typeof(sbyte)) kind = DataKind.I1; + else if (type == typeof(sbyte?)) + kind = DataKind.NI1; else if (type == typeof(byte) || type == typeof(byte?)) kind = DataKind.U1; - else if (type == typeof(DvInt2)|| type== typeof(short) || type == typeof(short?)) + else if (type == typeof(short)) kind = DataKind.I2; - else if (type == typeof(ushort)|| type == typeof(ushort?)) + else if (type == typeof(short?)) + kind = DataKind.NI2; + else if (type == typeof(ushort) || type == typeof(ushort?)) kind = DataKind.U2; - else if (type == typeof(DvInt4) || type == typeof(int)|| type == typeof(int?)) + else if (type == typeof(int)) kind = DataKind.I4; - else if (type == typeof(uint)|| type == typeof(uint?)) + else if (type == typeof(int?)) + kind = DataKind.NI4; + else if (type == typeof(uint) || type == typeof(uint?)) kind = DataKind.U4; - else if (type == typeof(DvInt8) || type==typeof(long)|| type == typeof(long?)) + else if (type == typeof(long)) kind = DataKind.I8; - else if (type == typeof(ulong)|| type == typeof(ulong?)) + else if (type == typeof(long?)) + kind = DataKind.NI8; + else if (type == typeof(ulong) || type == typeof(ulong?)) kind = DataKind.U8; - else if (type == typeof(Single)|| type == typeof(Single?)) + else if (type == typeof(Single) || type == typeof(Single?)) kind = DataKind.R4; - else if (type == typeof(Double)|| type == typeof(Double?)) + else if (type == typeof(Double) || type == typeof(Double?)) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?)) + else if (type == typeof(bool)) kind = DataKind.BL; + else if (type == typeof(bool?)) + kind = DataKind.NBL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; else if (type == typeof(DvDateTime)) @@ -236,12 +271,20 @@ public static string GetString(this DataKind kind) { case DataKind.I1: return "I1"; + case DataKind.NI1: + return "NI1"; case DataKind.I2: return "I2"; + case DataKind.NI2: + return "NI2"; case DataKind.I4: return "I4"; + case DataKind.NI4: + return "NI4"; case DataKind.I8: return "I8"; + case DataKind.NI8: + return "NI8"; case DataKind.U1: return "U1"; case DataKind.U2: @@ -256,6 +299,8 @@ public static string GetString(this DataKind kind) return "R8"; case DataKind.BL: return "BL"; + case DataKind.NBL: + return "NBL"; case DataKind.TX: return "TX"; case DataKind.TS: diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index 2340f9b413..9756c34542 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -22,14 +22,14 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.RawKind) { - case DataKind.I1: + case DataKind.NI1: return new I1.MeanAggregatorOne(ch, cursor, col); - case DataKind.I2: + case DataKind.NI2: return new I2.MeanAggregatorOne(ch, cursor, col); - case DataKind.I4: + case DataKind.NI4: return new I4.MeanAggregatorOne(ch, cursor, col); - case DataKind.I8: - return new Long.MeanAggregatorOne(ch, type, cursor, col); + case DataKind.NI8: + return new Long.MeanAggregatorOne(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorOne(ch, cursor, col); case DataKind.R8: @@ -46,14 +46,14 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.RawKind) { - case DataKind.I1: + case DataKind.NI1: return new I1.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I2: + case DataKind.NI2: return new I2.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I4: + case DataKind.NI4: return new I4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I8: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); + case DataKind.NI8: + return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: @@ -78,14 +78,14 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: + case DataKind.NI1: return new I1.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.I2: + case DataKind.NI2: return new I2.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.I4: + case DataKind.NI4: return new I4.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.I8: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); + case DataKind.NI8: + return new Long.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R8: @@ -102,14 +102,14 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: + case DataKind.NI1: return new I1.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.I2: + case DataKind.NI2: return new I2.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.I4: + case DataKind.NI4: return new I4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.I8: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); + case DataKind.NI8: + return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: @@ -130,14 +130,14 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: + case DataKind.NI1: return new I1.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.I2: + case DataKind.NI2: return new I2.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.I4: + case DataKind.NI4: return new I4.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.I8: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); + case DataKind.NI8: + return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorAcrossSlots(ch, cursor, col); case DataKind.R8: @@ -154,14 +154,14 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: + case DataKind.NI1: return new I1.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I2: + case DataKind.NI2: return new I2.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I4: + case DataKind.NI4: return new I4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I8: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); + case DataKind.NI8: + return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: @@ -503,17 +503,17 @@ private void AssertValid(long valMax) Contracts.Assert(_cna >= 0); } - public void Update(long val, long valMax) + public void Update(long? val, long valMax) { AssertValid(valMax); - Contracts.Assert(-valMax - 1 <= val && val <= valMax); + Contracts.Assert(!val.HasValue || -valMax <= val && val <= valMax); - if (val >= 0) + if (!val.HasValue) + _cna++; + else if (val >= 0) IntUtils.Add(ref _sumHi, ref _sumLo, (ulong)val); - else if (val >= -valMax) - IntUtils.Sub(ref _sumHi, ref _sumLo, (ulong)(-val)); else - _cna++; + IntUtils.Sub(ref _sumHi, ref _sumLo, (ulong)(-val)); AssertValid(valMax); } @@ -935,73 +935,73 @@ private static class I1 private const long MaxVal = sbyte.MaxValue; - public sealed class MeanAggregatorOne : StatAggregator + public sealed class MeanAggregatorOne : StatAggregator { public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessRow(ref DvInt1 val) + protected override void ProcessRow(ref sbyte? val) { - Stat.Update(val.RawValue, MaxVal); + Stat.Update(val, MaxVal); } public override object GetStat() { long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt1)(sbyte)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + return (sbyte)val; } } - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots + public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessValue(ref DvInt1 val) + protected override void ProcessValue(ref sbyte? val) { - Stat.Update(val.RawValue, MaxVal); + Stat.Update(val, MaxVal); } public override object GetStat() { long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt1)(sbyte)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + return (sbyte)val; } } - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot + public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) : base(ch, type, cursor, col) { } - protected override void ProcessValue(ref DvInt1 val, int slot) + protected override void ProcessValue(ref sbyte? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val.RawValue, MaxVal); + Stat[slot].Update(val, MaxVal); } public override object GetStat() { - DvInt1[] stat = new DvInt1[Stat.Length]; + sbyte[] stat = new sbyte[Stat.Length]; for (int slot = 0; slot < stat.Length; slot++) { long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - stat[slot] = (DvInt1)(sbyte)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + stat[slot] = (sbyte)val; } return stat; } } - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne + public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1009,27 +1009,27 @@ public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnM Stat = (sbyte)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref DvInt1 val) + protected override void ProcessValueMin(ref sbyte? val) { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt1.RawNA) - Stat = raw; + var raw = val; + if (raw.HasValue && raw < Stat) + Stat = raw.Value; } - protected override void ProcessValueMax(ref DvInt1 val) + protected override void ProcessValueMax(ref sbyte? val) { - var raw = val.RawValue; + var raw = val; if (raw > Stat) - Stat = raw; + Stat = raw.Value; } public override object GetStat() { - return (DvInt1)Stat; + return (sbyte)Stat; } } - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots + public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1037,18 +1037,16 @@ public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool Stat = (sbyte)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref DvInt1 val) + protected override void ProcessValueMin(ref sbyte? val) { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt1.RawNA) - Stat = raw; + if (val.HasValue && val < Stat) + Stat = val.Value; } - protected override void ProcessValueMax(ref DvInt1 val) + protected override void ProcessValueMax(ref sbyte? val) { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; + if (val.HasValue && val > Stat) + Stat = val.Value; } public override object GetStat() @@ -1056,14 +1054,14 @@ public override object GetStat() // If sparsity occurred, fold in a zero. if (ValueCount > (ulong)ValuesProcessed) { - var def = default(DvInt1); + var def = default(sbyte?); ProcValueDelegate(ref def); } - return (DvInt1)Stat; + return Stat; } } - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot + public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) @@ -1073,34 +1071,32 @@ public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, i Stat[i] = bound; } - protected override void ProcessValueMin(ref DvInt1 val, int slot) + protected override void ProcessValueMin(ref sbyte? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw < Stat[slot] && raw != DvInt1.RawNA) - Stat[slot] = raw; + if (val.HasValue && val.Value < Stat[slot]) + Stat[slot] = val.Value; } - protected override void ProcessValueMax(ref DvInt1 val, int slot) + protected override void ProcessValueMax(ref sbyte? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw > Stat[slot]) - Stat[slot] = raw; + if (val.HasValue && val.Value > Stat[slot]) + Stat[slot] = val.Value; } public override object GetStat() { - DvInt1[] stat = new DvInt1[Stat.Length]; + sbyte[] stat = new sbyte[Stat.Length]; // Account for defaults resulting from sparsity. for (int slot = 0; slot < Stat.Length; slot++) { if (GetValuesProcessed(slot) < RowCount) { - var def = default(DvInt1); + var def = default(sbyte?); ProcValueDelegate(ref def, slot); } - stat[slot] = (DvInt1)Stat[slot]; + stat[slot] = Stat[slot]; } return stat; } @@ -1111,73 +1107,73 @@ private static class I2 { private const long MaxVal = short.MaxValue; - public sealed class MeanAggregatorOne : StatAggregator + public sealed class MeanAggregatorOne : StatAggregator { public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessRow(ref DvInt2 val) + protected override void ProcessRow(ref short? val) { - Stat.Update(val.RawValue, MaxVal); + Stat.Update(val, MaxVal); } public override object GetStat() { long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt2)(short)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + return (short)val; } } - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots + public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessValue(ref DvInt2 val) + protected override void ProcessValue(ref short? val) { - Stat.Update(val.RawValue, MaxVal); + Stat.Update(val, MaxVal); } public override object GetStat() { long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt2)(short)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + return (short)val; } } - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot + public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) : base(ch, type, cursor, col) { } - protected override void ProcessValue(ref DvInt2 val, int slot) + protected override void ProcessValue(ref short? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val.RawValue, MaxVal); + Stat[slot].Update(val, MaxVal); } public override object GetStat() { - DvInt2[] stat = new DvInt2[Stat.Length]; + short[] stat = new short[Stat.Length]; for (int slot = 0; slot < stat.Length; slot++) { long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - stat[slot] = (DvInt2)(short)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + stat[slot] = (short)val; } return stat; } } - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne + public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1185,27 +1181,25 @@ public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnM Stat = (short)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref DvInt2 val) + protected override void ProcessValueMin(ref short? val) { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt2.RawNA) - Stat = raw; + if (val.HasValue && val < Stat) + Stat = val.Value; } - protected override void ProcessValueMax(ref DvInt2 val) + protected override void ProcessValueMax(ref short? val) { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; + if (val.HasValue && val > Stat) + Stat = val.Value; } public override object GetStat() { - return (DvInt2)Stat; + return Stat; } } - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots + public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1213,18 +1207,16 @@ public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool Stat = (short)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref DvInt2 val) + protected override void ProcessValueMin(ref short? val) { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt2.RawNA) - Stat = raw; + if (val.HasValue && val < Stat) + Stat = val.Value; } - protected override void ProcessValueMax(ref DvInt2 val) + protected override void ProcessValueMax(ref short? val) { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; + if (val.HasValue && val > Stat) + Stat = val.Value; } public override object GetStat() @@ -1232,14 +1224,14 @@ public override object GetStat() // If sparsity occurred, fold in a zero. if (ValueCount > (ulong)ValuesProcessed) { - var def = default(DvInt2); + var def = default(short?); ProcValueDelegate(ref def); } - return (DvInt2)Stat; + return Stat; } } - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot + public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) @@ -1249,34 +1241,32 @@ public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, i Stat[i] = bound; } - protected override void ProcessValueMin(ref DvInt2 val, int slot) + protected override void ProcessValueMin(ref short? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw < Stat[slot] && raw != DvInt2.RawNA) - Stat[slot] = raw; + if (val.HasValue && val < Stat[slot]) + Stat[slot] = val.Value; } - protected override void ProcessValueMax(ref DvInt2 val, int slot) + protected override void ProcessValueMax(ref short? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw > Stat[slot]) - Stat[slot] = raw; + if (val.HasValue && val > Stat[slot]) + Stat[slot] = val.Value; } public override object GetStat() { - DvInt2[] stat = new DvInt2[Stat.Length]; + short[] stat = new short[Stat.Length]; // Account for defaults resulting from sparsity. for (int slot = 0; slot < Stat.Length; slot++) { if (GetValuesProcessed(slot) < RowCount) { - var def = default(DvInt2); + var def = default(short?); ProcValueDelegate(ref def, slot); } - stat[slot] = (DvInt2)Stat[slot]; + stat[slot] = Stat[slot]; } return stat; } @@ -1287,73 +1277,73 @@ private static class I4 { private const long MaxVal = int.MaxValue; - public sealed class MeanAggregatorOne : StatAggregator + public sealed class MeanAggregatorOne : StatAggregator { public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessRow(ref DvInt4 val) + protected override void ProcessRow(ref Int32? val) { - Stat.Update(val.RawValue, MaxVal); + Stat.Update(val, MaxVal); } public override object GetStat() { long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt4)(int)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + return (int)val; } } - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots + public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessValue(ref DvInt4 val) + protected override void ProcessValue(ref Int32? val) { - Stat.Update(val.RawValue, MaxVal); + Stat.Update(val, MaxVal); } public override object GetStat() { long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt4)(int)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + return (int)val; } } - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot + public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) : base(ch, type, cursor, col) { } - protected override void ProcessValue(ref DvInt4 val, int slot) + protected override void ProcessValue(ref Int32? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val.RawValue, MaxVal); + Stat[slot].Update(val, MaxVal); } public override object GetStat() { - DvInt4[] stat = new DvInt4[Stat.Length]; + Int32[] stat = new Int32[Stat.Length]; for (int slot = 0; slot < stat.Length; slot++) { long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - stat[slot] = (DvInt4)(int)val; + Ch.Assert(-MaxVal <= val && val <= MaxVal); + stat[slot] = (int)val; } return stat; } } - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne + public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1361,27 +1351,25 @@ public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnM Stat = (int)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref DvInt4 val) + protected override void ProcessValueMin(ref Int32? val) { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt4.RawNA) - Stat = raw; + if (val.HasValue && val < Stat) + Stat = val.Value; } - protected override void ProcessValueMax(ref DvInt4 val) + protected override void ProcessValueMax(ref Int32? val) { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; + if (val.HasValue && val > Stat) + Stat = val.Value; } public override object GetStat() { - return (DvInt4)Stat; + return Stat; } } - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots + public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1389,18 +1377,16 @@ public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool Stat = (int)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref DvInt4 val) + protected override void ProcessValueMin(ref Int32? val) { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt4.RawNA) - Stat = raw; + if (val.HasValue && val < Stat) + Stat = val.Value; } - protected override void ProcessValueMax(ref DvInt4 val) + protected override void ProcessValueMax(ref Int32? val) { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; + if (val.HasValue && val > Stat) + Stat = val.Value; } public override object GetStat() @@ -1408,14 +1394,14 @@ public override object GetStat() // If sparsity occurred, fold in a zero. if (ValueCount > (ulong)ValuesProcessed) { - var def = default(DvInt4); + var def = default(Int32?); ProcValueDelegate(ref def); } - return (DvInt4)Stat; + return Stat; } } - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot + public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) @@ -1425,20 +1411,18 @@ public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, i Stat[i] = bound; } - protected override void ProcessValueMin(ref DvInt4 val, int slot) + protected override void ProcessValueMin(ref Int32? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw < Stat[slot] && raw != DvInt4.RawNA) - Stat[slot] = raw; + if (val.HasValue && val < Stat[slot]) + Stat[slot] = val.Value; } - protected override void ProcessValueMax(ref DvInt4 val, int slot) + protected override void ProcessValueMax(ref Int32? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw > Stat[slot]) - Stat[slot] = raw; + if (val.HasValue && val > Stat[slot]) + Stat[slot] = val.Value; } public override object GetStat() @@ -1449,10 +1433,10 @@ public override object GetStat() { if (GetValuesProcessed(slot) < RowCount) { - var def = default(DvInt4); + var def = default(Int32?); ProcValueDelegate(ref def, slot); } - stat[slot] = (DvInt4)Stat[slot]; + stat[slot] = Stat[slot]; } return stat; } @@ -1550,15 +1534,15 @@ public MinMaxAggregatorOne(IChannel ch, ColumnType type, IRowCursor cursor, int protected override void ProcessValueMin(ref TItem val) { var raw = _converter.ToLong(val); - if (raw < Stat && -MaxVal <= raw) - Stat = raw; + if (raw.HasValue && raw < Stat) + Stat = raw.Value; } protected override void ProcessValueMax(ref TItem val) { var raw = _converter.ToLong(val); - if (raw > Stat) - Stat = raw; + if (raw.HasValue && raw > Stat) + Stat = raw.Value; } public override object GetStat() @@ -1581,15 +1565,15 @@ public MinMaxAggregatorAcrossSlots(IChannel ch, ColumnType type, IRowCursor curs protected override void ProcessValueMin(ref TItem val) { var raw = _converter.ToLong(val); - if (raw < Stat && -MaxVal <= raw) - Stat = raw; + if (raw.HasValue && raw < Stat) + Stat = raw.Value; } protected override void ProcessValueMax(ref TItem val) { var raw = _converter.ToLong(val); - if (raw > Stat) - Stat = raw; + if (raw.HasValue && raw > Stat) + Stat = raw.Value; } public override object GetStat() @@ -1622,16 +1606,16 @@ protected override void ProcessValueMin(ref TItem val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); var raw = _converter.ToLong(val); - if (raw < Stat[slot] && -MaxVal <= raw) - Stat[slot] = raw; + if (raw.HasValue && raw < Stat[slot]) + Stat[slot] = raw.Value; } protected override void ProcessValueMax(ref TItem val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); var raw = _converter.ToLong(val); - if (raw > Stat[slot]) - Stat[slot] = raw; + if (raw.HasValue && raw > Stat[slot]) + Stat[slot] = raw.Value; } public override object GetStat() @@ -1677,46 +1661,46 @@ private abstract class Converter private abstract class Converter : Converter { - public abstract long ToLong(T val); - public abstract T FromLong(long val); + public abstract long? ToLong(T val); + public abstract T FromLong(long? val); } - private sealed class I8Converter : Converter + private sealed class I8Converter : Converter { - public override long ToLong(DvInt8 val) + public override long? ToLong(long? val) { - return val.RawValue; + return val; } - public override DvInt8 FromLong(long val) + public override long? FromLong(long? val) { - Contracts.Assert(DvInt8.RawNA != val); - return (DvInt8)val; + Contracts.Assert(val.HasValue); + return val.Value; } } private sealed class TSConverter : Converter { - public override long ToLong(DvTimeSpan val) + public override long? ToLong(DvTimeSpan val) { return val.Ticks.RawValue; } - public override DvTimeSpan FromLong(long val) + public override DvTimeSpan FromLong(long? val) { - Contracts.Assert(DvInt8.RawNA != val); + Contracts.Assert(val.HasValue); return new DvTimeSpan(val); } } private sealed class DTConverter : Converter { - public override long ToLong(DvDateTime val) + public override long? ToLong(DvDateTime val) { return val.Ticks.RawValue; } - public override DvDateTime FromLong(long val) + public override DvDateTime FromLong(long? val) { Contracts.Assert(0 <= val && val <= DvDateTime.MaxTicks); return new DvDateTime(val); From 1fdfe0dabaa4313dd6f75852cc717a122ac48038 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 14 Aug 2018 17:47:26 -0700 Subject: [PATCH 03/37] Change metadata for categorical indices from DvInt4 to Int32. --- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 16 +++---- .../Transforms/ConcatTransform.cs | 8 ++-- .../Transforms/DropSlotsTransform.cs | 42 +++++++++---------- .../Transforms/KeyToVectorTransform.cs | 8 ++-- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 2 +- 5 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 116d521756..81882fa946 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -404,9 +404,9 @@ public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, return isValid; var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex); - if (type?.RawType == typeof(VBuffer)) + if (type?.RawType == typeof(VBuffer)) { - VBuffer catIndices = default(VBuffer); + VBuffer catIndices = default(VBuffer); schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices); VBufferUtils.Densify(ref catIndices); int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore; @@ -416,19 +416,19 @@ public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, isValid = true; for (int i = 0; i < catIndices.Values.Length; i += 2) { - if (catIndices.Values[i].RawValue > catIndices.Values[i + 1].RawValue || - catIndices.Values[i].RawValue <= previousEndIndex || - catIndices.Values[i].RawValue >= columnSlotsCount || - catIndices.Values[i + 1].RawValue >= columnSlotsCount) + if (catIndices.Values[i] > catIndices.Values[i + 1] || + catIndices.Values[i] <= previousEndIndex || + catIndices.Values[i] >= columnSlotsCount || + catIndices.Values[i + 1] >= columnSlotsCount) { isValid = false; break; } - previousEndIndex = catIndices.Values[i + 1].RawValue; + previousEndIndex = catIndices.Values[i + 1]; } if (isValid) - categoricalFeatures = catIndices.Values.Select(val => val.RawValue).ToArray(); + categoricalFeatures = catIndices.Values.Select(val => val).ToArray(); } } diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index b2024cc18c..c4ca1a0039 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -404,7 +404,7 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal if (_typesCategoricals[iinfo] == null) throw MetadataUtils.ExceptGetMetadata(); - MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); + MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); break; case MetadataUtils.Kinds.IsNormalized: if (!_isNormalized[iinfo]) @@ -417,9 +417,9 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal } } - private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) { - List allValues = new List(); + List allValues = new List(); int slotCount = 0; for (int i = 0; i < Infos[iiinfo].SrcIndices.Length; i++) { @@ -440,7 +440,7 @@ private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) Contracts.Assert(allValues.Count > 0); - dst = new VBuffer(allValues.Count, allValues.ToArray()); + dst = new VBuffer(allValues.Count, allValues.ToArray()); } private void IsNormalized(int iinfo, ref DvBool dst) diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 230cfbe680..dfe61cddbb 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -393,14 +393,14 @@ private void ComputeType(ISchema input, int[] slotsMin, int[] slotsMax, int iinf { if (MetadataUtils.TryGetCategoricalFeatureIndices(Source.Schema, Infos[iinfo].Source, out categoricalRanges)) { - VBuffer dst = default(VBuffer); + VBuffer dst = default(VBuffer); GetCategoricalSlotRangesCore(iinfo, slotDropper.SlotsMin, slotDropper.SlotsMax, categoricalRanges, ref dst); // REVIEW: cache dst as opposed to caculating it again. if (dst.Length > 0) { Contracts.Assert(dst.Length % 2 == 0); - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, + bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.GetCategoricalType(dst.Length / 2), GetCategoricalSlotRanges); } } @@ -443,7 +443,7 @@ private void GetSlotNames(int iinfo, ref VBuffer dst) infoEx.SlotDropper.DropSlots(ref names, ref dst); } - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) { if (_exes[iinfo].CategoricalRanges != null) { @@ -452,7 +452,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) } } - private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer dst) + private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.Assert(slotsMax != null && slotsMin != null); @@ -467,9 +467,9 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots int previousDropSlotsIndex = 0; int droppedSlotsCount = 0; bool combine = false; - DvInt4 min = -1; - DvInt4 max = -1; - List newCategoricalSlotRanges = new List(); + Int32 min = -1; + Int32 max = -1; + List newCategoricalSlotRanges = new List(); // Six possible ways a drop slot range interacts with categorical slots range. // @@ -498,7 +498,7 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots } else { - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); min = ranges[rangesIndex] - droppedSlotsCount; max = ranges[rangesIndex + 1] - droppedSlotsCount; } @@ -515,14 +515,14 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots rangesIndex += 2; if (combine) { - Contracts.Assert(min.RawValue >= 0 && min.RawValue <= max.RawValue); + Contracts.Assert(min >= 0 && min <= max); newCategoricalSlotRanges.Add(min); newCategoricalSlotRanges.Add(max); min = max = -1; combine = false; } - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); } else if (slotsMin[dropSlotsIndex] > ranges[rangesIndex] && @@ -535,7 +535,7 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots } else { - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); min = ranges[rangesIndex] - droppedSlotsCount; max = slotsMin[dropSlotsIndex] - 1 - droppedSlotsCount; @@ -576,28 +576,28 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots min = max = -1; } - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); for (int i = rangesIndex; i < ranges.Length; i++) newCategoricalSlotRanges.Add(ranges[i] - droppedSlotsCount); Contracts.Assert(newCategoricalSlotRanges.Count % 2 == 0); - Contracts.Assert(newCategoricalSlotRanges.TrueForAll(x => x.RawValue >= 0)); + Contracts.Assert(newCategoricalSlotRanges.TrueForAll(x => x >= 0)); Contracts.Assert(0 <= droppedSlotsCount && droppedSlotsCount <= slotsMax[slotsMax.Length - 1] + 1); if (newCategoricalSlotRanges.Count > 0) - dst = new VBuffer(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray()); + dst = new VBuffer(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray()); } private void CombineRanges( - DvInt4 minRange1, DvInt4 maxRange1, DvInt4 minRange2, DvInt4 maxRange2, - out DvInt4 newRangeMin, out DvInt4 newRangeMax) + Int32 minRange1, Int32 maxRange1, Int32 minRange2, Int32 maxRange2, + out Int32 newRangeMin, out Int32 newRangeMax) { - Contracts.Assert(minRange2.RawValue >= 0 && maxRange2.RawValue >= 0); - Contracts.Assert(minRange2.RawValue <= maxRange2.RawValue); - Contracts.Assert(minRange1.RawValue >= 0 && maxRange1.RawValue >= 0); - Contracts.Assert(minRange1.RawValue <= maxRange1.RawValue); - Contracts.Assert(maxRange1.RawValue + 1 == minRange2.RawValue); + Contracts.Assert(minRange2 >= 0 && maxRange2 >= 0); + Contracts.Assert(minRange2 <= maxRange2); + Contracts.Assert(minRange1 >= 0 && maxRange1 >= 0); + Contracts.Assert(minRange1 <= maxRange1); + Contracts.Assert(maxRange1 + 1 == minRange2); newRangeMin = minRange1; newRangeMax = maxRange2; diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 0f4b616a49..0a1f4ce283 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -319,7 +319,7 @@ private static void ComputeType(KeyToVectorTransform trans, ISchema input, int i if (!bag && info.TypeSrc.ValueCount > 0) { - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, + bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.GetCategoricalType(info.TypeSrc.ValueCount), trans.GetCategoricalSlotRanges); } @@ -334,7 +334,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _types[iinfo]; } - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); @@ -342,7 +342,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) Host.Assert(info.TypeSrc.ValueCount > 0); - DvInt4[] ranges = new DvInt4[info.TypeSrc.ValueCount * 2]; + Int32[] ranges = new Int32[info.TypeSrc.ValueCount * 2]; int size = info.TypeSrc.ItemType.KeyCount; ranges[0] = 0; @@ -353,7 +353,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) ranges[i + 1] = ranges[i] + size - 1; } - dst = new VBuffer(ranges.Length, ranges); + dst = new VBuffer(ranges.Length, ranges); } // Used for slot names when appropriate. diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index 9756c34542..c33e3e05f9 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -1427,7 +1427,7 @@ protected override void ProcessValueMax(ref Int32? val, int slot) public override object GetStat() { - DvInt4[] stat = new DvInt4[Stat.Length]; + Int32[] stat = new Int32[Stat.Length]; // Account for defaults resulting from sparsity. for (int slot = 0; slot < Stat.Length; slot++) { From 734540fbfeb6ad4c3cf60dc5bddce6cfe0a96925 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 15 Aug 2018 15:37:41 -0700 Subject: [PATCH 04/37] Conversions class. --- src/Microsoft.ML.Data/Data/Conversion.cs | 385 +++++++++--------- .../DataLoadSave/Text/TextLoaderParser.cs | 6 +- 2 files changed, 205 insertions(+), 186 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 0a9833064a..dbad9c7946 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -17,16 +17,16 @@ namespace Microsoft.ML.Runtime.Data.Conversion using BL = DvBool; using DT = DvDateTime; using DZ = DvDateTimeZone; - using I1 = DvInt1; - using I2 = DvInt2; - using I4 = DvInt4; - using I8 = DvInt8; + using NI1 = Nullable; + using NI2 = Nullable; + using NI4 = Nullable; + using NI8 = Nullable; using R4 = Single; using R8 = Double; - using RawI1 = SByte; - using RawI2 = Int16; - using RawI4 = Int32; - using RawI8 = Int64; + using I1 = SByte; + using I2 = Int16; + using I4 = Int32; + using I8 = Int64; using SB = StringBuilder; using TS = DvTimeSpan; using TX = DvText; @@ -119,37 +119,37 @@ private Conversions() // !!! WARNING !!!: Do NOT add any standard conversions without clearing from the IDV Type System // design committee. Any changes also require updating the IDV Type System Specification. - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); - - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); - - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); - - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); + + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); + + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); + + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); AddStd(Convert); AddStd(Convert); @@ -202,13 +202,13 @@ private Conversions() AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddStd(Convert); @@ -220,34 +220,34 @@ private Conversions() AddStd(Convert); AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddIsNA(IsNA); - AddIsNA(IsNA); - AddIsNA(IsNA); - AddIsNA(IsNA); + AddIsNA(IsNA); + AddIsNA(IsNA); + AddIsNA(IsNA); + AddIsNA(IsNA); AddIsNA(IsNA); AddIsNA(IsNA); AddIsNA(IsNA); @@ -256,10 +256,10 @@ private Conversions() AddIsNA
(IsNA); AddIsNA(IsNA); - AddGetNA(GetNA); - AddGetNA(GetNA); - AddGetNA(GetNA); - AddGetNA(GetNA); + AddGetNA(GetNA); + AddGetNA(GetNA); + AddGetNA(GetNA); + AddGetNA(GetNA); AddGetNA(GetNA); AddGetNA(GetNA); AddGetNA(GetNA); @@ -268,10 +268,10 @@ private Conversions() AddGetNA
(GetNA); AddGetNA(GetNA); - AddHasNA(HasNA); - AddHasNA(HasNA); - AddHasNA(HasNA); - AddHasNA(HasNA); + AddHasNA(HasNA); + AddHasNA(HasNA); + AddHasNA(HasNA); + AddHasNA(HasNA); AddHasNA(HasNA); AddHasNA(HasNA); AddHasNA(HasNA); @@ -280,10 +280,10 @@ private Conversions() AddHasNA
(HasNA); AddHasNA(HasNA); - AddIsDef(IsDefault); - AddIsDef(IsDefault); - AddIsDef(IsDefault); - AddIsDef(IsDefault); + AddIsDef(IsDefault); + AddIsDef(IsDefault); + AddIsDef(IsDefault); + AddIsDef(IsDefault); AddIsDef(IsDefault); AddIsDef(IsDefault); AddIsDef(IsDefault); @@ -302,10 +302,10 @@ private Conversions() AddHasZero(HasZero); AddHasZero(HasZero); - AddTryParse(TryParse); - AddTryParse(TryParse); - AddTryParse(TryParse); - AddTryParse(TryParse); + AddTryParse(TryParse); + AddTryParse(TryParse); + AddTryParse(TryParse); + AddTryParse(TryParse); AddTryParse(TryParse); AddTryParse(TryParse); AddTryParse(TryParse); @@ -846,10 +846,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) // The IsNA methods are for efficient delegates (instance instead of static). #region IsNA - private bool IsNA(ref I1 src) => src.IsNA; - private bool IsNA(ref I2 src) => src.IsNA; - private bool IsNA(ref I4 src) => src.IsNA; - private bool IsNA(ref I8 src) => src.IsNA; + private bool IsNA(ref NI1 src) => !src.HasValue; + private bool IsNA(ref NI2 src) => !src.HasValue; + private bool IsNA(ref NI4 src) => !src.HasValue; + private bool IsNA(ref NI8 src) => !src.HasValue; private bool IsNA(ref R4 src) => src.IsNA(); private bool IsNA(ref R8 src) => src.IsNA(); private bool IsNA(ref BL src) => src.IsNA; @@ -860,10 +860,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion IsNA #region HasNA - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } + private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } + private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } + private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } + private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA()) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA()) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } @@ -874,10 +874,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion HasNA #region IsDefault - private bool IsDefault(ref I1 src) => src.RawValue == 0; - private bool IsDefault(ref I2 src) => src.RawValue == 0; - private bool IsDefault(ref I4 src) => src.RawValue == 0; - private bool IsDefault(ref I8 src) => src.RawValue == 0; + private bool IsDefault(ref NI1 src) => src == default(NI1); + private bool IsDefault(ref NI2 src) => src == default(NI2); + private bool IsDefault(ref NI4 src) => src == default(NI4); + private bool IsDefault(ref NI8 src) => src == default(NI8); private bool IsDefault(ref R4 src) => src == 0; private bool IsDefault(ref R8 src) => src == 0; private bool IsDefault(ref TX src) => src.IsEmpty; @@ -900,10 +900,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion HasZero #region GetNA - private void GetNA(ref I1 value) => value = I1.NA; - private void GetNA(ref I2 value) => value = I2.NA; - private void GetNA(ref I4 value) => value = I4.NA; - private void GetNA(ref I8 value) => value = I8.NA; + private void GetNA(ref NI1 value) => value = default; + private void GetNA(ref NI2 value) => value = default; + private void GetNA(ref NI4 value) => value = default; + private void GetNA(ref NI8 value) => value = default; private void GetNA(ref R4 value) => value = R4.NaN; private void GetNA(ref R8 value) => value = R8.NaN; private void GetNA(ref BL value) => value = BL.NA; @@ -914,35 +914,35 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion GetNA #region ToI1 - public void Convert(ref I1 src, ref I1 dst) => dst = src; - public void Convert(ref I2 src, ref I1 dst) => dst = (I1)src; - public void Convert(ref I4 src, ref I1 dst) => dst = (I1)src; - public void Convert(ref I8 src, ref I1 dst) => dst = (I1)src; + public void Convert(ref NI1 src, ref NI1 dst) => dst = src; + public void Convert(ref NI2 src, ref NI1 dst) => dst = (NI1)src; + public void Convert(ref NI4 src, ref NI1 dst) => dst = (NI1)src; + public void Convert(ref NI8 src, ref NI1 dst) => dst = (NI1)src; #endregion ToI1 #region ToI2 - public void Convert(ref I1 src, ref I2 dst) => dst = src; - public void Convert(ref I2 src, ref I2 dst) => dst = src; - public void Convert(ref I4 src, ref I2 dst) => dst = (I2)src; - public void Convert(ref I8 src, ref I2 dst) => dst = (I2)src; + public void Convert(ref NI1 src, ref NI2 dst) => dst = src; + public void Convert(ref NI2 src, ref NI2 dst) => dst = src; + public void Convert(ref NI4 src, ref NI2 dst) => dst = (NI2)src; + public void Convert(ref NI8 src, ref NI2 dst) => dst = (NI2)src; #endregion ToI2 #region ToI4 - public void Convert(ref I1 src, ref I4 dst) => dst = src; - public void Convert(ref I2 src, ref I4 dst) => dst = src; - public void Convert(ref I4 src, ref I4 dst) => dst = src; - public void Convert(ref I8 src, ref I4 dst) => dst = (I4)src; + public void Convert(ref NI1 src, ref NI4 dst) => dst = src; + public void Convert(ref NI2 src, ref NI4 dst) => dst = src; + public void Convert(ref NI4 src, ref NI4 dst) => dst = src; + public void Convert(ref NI8 src, ref NI4 dst) => dst = (NI4)src; #endregion ToI4 #region ToI8 - public void Convert(ref I1 src, ref I8 dst) => dst = src; - public void Convert(ref I2 src, ref I8 dst) => dst = src; - public void Convert(ref I4 src, ref I8 dst) => dst = src; - public void Convert(ref I8 src, ref I8 dst) => dst = src; - - public void Convert(ref TS src, ref I8 dst) => dst = (I8)src.Ticks; - public void Convert(ref DT src, ref I8 dst) => dst = (I8)src.Ticks; - public void Convert(ref DZ src, ref I8 dst) => dst = (I8)src.UtcDateTime.Ticks; + public void Convert(ref NI1 src, ref NI8 dst) => dst = src; + public void Convert(ref NI2 src, ref NI8 dst) => dst = src; + public void Convert(ref NI4 src, ref NI8 dst) => dst = src; + public void Convert(ref NI8 src, ref NI8 dst) => dst = src; + + public void Convert(ref TS src, ref NI8 dst) => dst = (NI8)src.Ticks; + public void Convert(ref DT src, ref NI8 dst) => dst = (NI8)src.Ticks; + public void Convert(ref DZ src, ref NI8 dst) => dst = (NI8)src.UtcDateTime.Ticks; #endregion ToI8 #region ToU1 @@ -986,10 +986,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToUG #region ToR4 - public void Convert(ref I1 src, ref R4 dst) => dst = (R4)src; - public void Convert(ref I2 src, ref R4 dst) => dst = (R4)src; - public void Convert(ref I4 src, ref R4 dst) => dst = (R4)src; - public void Convert(ref I8 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref NI1 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref NI2 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref NI4 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref NI8 src, ref R4 dst) => dst = (R4)src; public void Convert(ref U1 src, ref R4 dst) => dst = src; public void Convert(ref U2 src, ref R4 dst) => dst = src; public void Convert(ref U4 src, ref R4 dst) => dst = src; @@ -1004,10 +1004,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToR4 #region ToR8 - public void Convert(ref I1 src, ref R8 dst) => dst = (R8)src; - public void Convert(ref I2 src, ref R8 dst) => dst = (R8)src; - public void Convert(ref I4 src, ref R8 dst) => dst = (R8)src; - public void Convert(ref I8 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref NI1 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref NI2 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref NI4 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref NI8 src, ref R8 dst) => dst = (R8)src; public void Convert(ref U1 src, ref R8 dst) => dst = src; public void Convert(ref U2 src, ref R8 dst) => dst = src; public void Convert(ref U4 src, ref R8 dst) => dst = src; @@ -1022,10 +1022,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToR8 #region ToStringBuilder - public void Convert(ref I1 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } - public void Convert(ref I2 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } - public void Convert(ref I4 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } - public void Convert(ref I8 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } + public void Convert(ref NI1 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } + public void Convert(ref NI2 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } + public void Convert(ref NI4 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } + public void Convert(ref NI8 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } public void Convert(ref U1 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(ref U2 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(ref U4 src, ref SB dst) => ClearDst(ref dst).Append(src); @@ -1303,13 +1303,16 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. ///
- public bool TryParse(ref TX src, out I1 dst) - { - long res; - bool f = TryParseSigned(RawI1.MaxValue, ref src, out res); - Contracts.Assert(f || res == I1.RawNA); - Contracts.Assert((RawI1)res == res); - dst = (RawI1)res; + public bool TryParse(ref TX src, out NI1 dst) + { + long? res; + bool f = TryParseSigned(I1.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + Contracts.Assert(!res.HasValue || (I1)res == res); + if (res.HasValue) + dst = (I1)res; + else + dst = null; return f; } @@ -1317,13 +1320,16 @@ public bool TryParse(ref TX src, out I1 dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. /// - public bool TryParse(ref TX src, out I2 dst) - { - long res; - bool f = TryParseSigned(RawI2.MaxValue, ref src, out res); - Contracts.Assert(f || res == I2.RawNA); - Contracts.Assert((RawI2)res == res); - dst = (RawI2)res; + public bool TryParse(ref TX src, out NI2 dst) + { + long? res; + bool f = TryParseSigned(I2.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + Contracts.Assert(!res.HasValue || (I2)res == res); + if (res.HasValue) + dst = (I2)res; + else + dst = null; return f; } @@ -1331,13 +1337,16 @@ public bool TryParse(ref TX src, out I2 dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. /// - public bool TryParse(ref TX src, out I4 dst) - { - long res; - bool f = TryParseSigned(RawI4.MaxValue, ref src, out res); - Contracts.Assert(f || res == I4.RawNA); - Contracts.Assert((RawI4)res == res); - dst = (RawI4)res; + public bool TryParse(ref TX src, out NI4 dst) + { + long? res; + bool f = TryParseSigned(I4.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + Contracts.Assert(!res.HasValue || (I4)res == res); + if (res.HasValue) + dst = (I4)res; + else + dst = null; return f; } @@ -1345,12 +1354,15 @@ public bool TryParse(ref TX src, out I4 dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. /// - public bool TryParse(ref TX src, out I8 dst) + public bool TryParse(ref TX src, out NI8 dst) { - long res; - bool f = TryParseSigned(RawI8.MaxValue, ref src, out res); - Contracts.Assert(f || res == I8.RawNA); - dst = res; + long? res; + bool f = TryParseSigned(I8.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + if (res.HasValue) + dst = (I8)res; + else + dst = null; return f; } @@ -1389,11 +1401,11 @@ private bool TryParseNonNegative(string text, int ich, int lim, out long result) /// /// This produces zero for empty. It returns false if the text is not parsable as a signed integer - /// or the result overflows. The min legal value is -max. The NA value is -max - 1. + /// or the result overflows. The min legal value is -max. The NA value null. /// When it returns false, result is set to the NA value. The result can be NA on true return, /// since some representations of NA are not considered parse failure. /// - private bool TryParseSigned(long max, ref TX span, out long result) + private bool TryParseSigned(long max, ref TX span, out long? result) { Contracts.Assert(max > 0); Contracts.Assert((max & (max + 1)) == 0); @@ -1401,7 +1413,7 @@ private bool TryParseSigned(long max, ref TX span, out long result) if (!span.HasChars) { if (span.IsNA) - result = -max - 1; + result = null; else result = 0; return true; @@ -1418,7 +1430,7 @@ private bool TryParseSigned(long max, ref TX span, out long result) !TryParseNonNegative(text, ichMin + 1, ichLim, out val) || val > max) { - result = -max - 1; + result = null; return false; } Contracts.Assert(val >= 0); @@ -1430,14 +1442,14 @@ private bool TryParseSigned(long max, ref TX span, out long result) if (!TryParseNonNegative(text, ichMin, ichLim, out val)) { // Check for acceptable NA forms: ? NaN NA and N/A. - result = -max - 1; + result = null; return IsStdMissing(ref span); } Contracts.Assert(val >= 0); if (val > max) { - result = -max - 1; + result = null; return false; } @@ -1530,41 +1542,48 @@ public bool TryParse(ref TX src, out DZ dst) return IsStdMissing(ref src); } - // These map unparsable and overflow values to "NA", which is the value Ix.MinValue. Note that this NA - // value is the "evil" value - the non-zero value, x, such that x == -x. Note also, that for I4, this - // matches R's representation of NA. - private I1 ParseI1(ref TX src) + // These map unparsable and overflow values to "NA", which is null. + private NI1 ParseI1(ref TX src) { - long res; - bool f = TryParseSigned(RawI1.MaxValue, ref src, out res); - Contracts.Assert(f || res == I1.RawNA); - Contracts.Assert((RawI1)res == res); - return (RawI1)res; + long? res; + bool f = TryParseSigned(I1.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + if (!res.HasValue) + return null; + + Contracts.Assert((I1)res == res); + return (I1)res; } - private I2 ParseI2(ref TX src) + private NI2 ParseI2(ref TX src) { - long res; - bool f = TryParseSigned(RawI2.MaxValue, ref src, out res); - Contracts.Assert(f || res == I2.RawNA); - Contracts.Assert((RawI2)res == res); - return (RawI2)res; + long? res; + bool f = TryParseSigned(I2.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + if (!res.HasValue) + return null; + + Contracts.Assert((I2)res == res); + return (I2)res; } - private I4 ParseI4(ref TX src) + private NI4 ParseI4(ref TX src) { - long res; - bool f = TryParseSigned(RawI4.MaxValue, ref src, out res); - Contracts.Assert(f || res == I4.RawNA); - Contracts.Assert((RawI4)res == res); - return (RawI4)res; + long? res; + bool f = TryParseSigned(I4.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); + if (!res.HasValue) + return null; + + Contracts.Assert((I4)res == res); + return (I4)res; } - private I8 ParseI8(ref TX src) + private NI8 ParseI8(ref TX src) { - long res; - bool f = TryParseSigned(RawI8.MaxValue, ref src, out res); - Contracts.Assert(f || res == I8.RawNA); + long? res; + bool f = TryParseSigned(I8.MaxValue, ref src, out res); + Contracts.Assert(f || !res.HasValue); return res; } @@ -1736,7 +1755,7 @@ private bool TryParse(ref TX src, out TX dst) return true; } - public void Convert(ref TX span, ref I1 value) + public void Convert(ref TX span, ref NI1 value) { value = ParseI1(ref span); } @@ -1744,7 +1763,7 @@ public void Convert(ref TX span, ref U1 value) { value = ParseU1(ref span); } - public void Convert(ref TX span, ref I2 value) + public void Convert(ref TX span, ref NI2 value) { value = ParseI2(ref span); } @@ -1752,7 +1771,7 @@ public void Convert(ref TX span, ref U2 value) { value = ParseU2(ref span); } - public void Convert(ref TX span, ref I4 value) + public void Convert(ref TX span, ref NI4 value) { value = ParseI4(ref span); } @@ -1760,7 +1779,7 @@ public void Convert(ref TX span, ref U4 value) { value = ParseU4(ref span); } - public void Convert(ref TX span, ref I8 value) + public void Convert(ref TX span, ref NI8 value) { value = ParseI8(ref span); } @@ -1822,10 +1841,10 @@ public void Convert(ref TX span, ref DZ value) #endregion FromTX #region FromBL - public void Convert(ref BL src, ref I1 dst) => dst = (I1)src; - public void Convert(ref BL src, ref I2 dst) => dst = (I2)src; - public void Convert(ref BL src, ref I4 dst) => dst = (I4)src; - public void Convert(ref BL src, ref I8 dst) => dst = (I8)src; + public void Convert(ref BL src, ref NI1 dst) => dst = (NI1)src; + public void Convert(ref BL src, ref NI2 dst) => dst = (NI2)src; + public void Convert(ref BL src, ref NI4 dst) => dst = (NI4)src; + public void Convert(ref BL src, ref NI8 dst) => dst = (NI8)src; public void Convert(ref BL src, ref R4 dst) => dst = (R4)src; public void Convert(ref BL src, ref R8 dst) => dst = (R8)src; public void Convert(ref BL src, ref BL dst) => dst = src; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index 582d81b546..b8962dcc59 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -993,15 +993,15 @@ public int GatherFields(DvText lineSpan, string path = null, long line = 0) var spanT = Fields.Spans[Fields.Count - 1]; // Note that Convert produces NA if the text is unparsable. - DvInt4 csrc = default(DvInt4); + int? csrc = default; Conversion.Conversions.Instance.Convert(ref spanT, ref csrc); - csrcSparse = csrc.RawValue; - if (csrcSparse <= 0) + if (!csrc.HasValue || csrc.Value <= 0) { _stats.LogBadFmt(ref scan, "Bad dimensionality or ambiguous sparse item. Use sparse=- for non-sparse file, and/or quote the value."); break; } + csrcSparse = csrc.Value; srcLimFixed = Fields.Indices[--Fields.Count]; if (csrcSparse >= SrcLim - srcLimFixed) csrcSparse = SrcLim - srcLimFixed - 1; From 10d753c7a1e7cffa5d84a99eff08bcec40a3225c Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 15 Aug 2018 22:16:29 -0700 Subject: [PATCH 05/37] peak, poke, codecs, transforms, evaluators. --- src/Microsoft.ML.Api/ApiUtils.cs | 3 +- .../DataViewConstructionUtils.cs | 64 +++++++++---------- src/Microsoft.ML.Api/TypedCursor.cs | 48 +++++++------- .../DataLoadSave/Binary/CodecFactory.cs | 8 +-- .../DataLoadSave/Binary/UnsafeTypeOps.cs | 56 ---------------- .../Evaluators/AnomalyDetectionEvaluator.cs | 6 +- .../MulticlassClassifierEvaluator.cs | 4 +- .../Training/TrainerUtils.cs | 2 +- .../Transforms/GenerateNumberTransform.cs | 4 +- src/Microsoft.ML.Parquet/ParquetLoader.cs | 18 +++--- .../Standard/ModelStatistics.cs | 2 +- .../MutualInformationFeatureSelection.cs | 50 ++++++++++++--- src/Microsoft.ML/Data/TextLoader.cs | 8 +-- 13 files changed, 125 insertions(+), 148 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 96e821f16e..5210e4267c 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -19,8 +19,7 @@ private static OpCode GetAssignmentOpCode(Type t) { // REVIEW: This should be a Dictionary based solution. // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128. - if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) || - t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray || + if (t == typeof(DvText) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index c50e48e16f..f3738bc3cd 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -134,42 +134,42 @@ private Delegate CreateGetter(int index) else if (outputType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(int?)) { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt4.NA); + Ch.Assert(colType.ItemType == NumberType.NI4); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(long?)) { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt8.NA); + Ch.Assert(colType.ItemType == NumberType.NI8); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(short?)) { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt2.NA); + Ch.Assert(colType.ItemType == NumberType.NI2); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(sbyte?)) { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt1.NA); + Ch.Assert(colType.ItemType == NumberType.NI1); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(bool)) { @@ -222,51 +222,51 @@ private Delegate CreateGetter(int index) } else if (outputType == typeof(int)) { - // int -> DvInt4 + // int -> int Ch.Assert(colType == NumberType.I4); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(int?)) { - // int? -> DvInt4 - Ch.Assert(colType == NumberType.I4); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt4.NA); + // int? -> int? + Ch.Assert(colType == NumberType.NI4); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(short)) { - // short -> DvInt2 + // short -> short Ch.Assert(colType == NumberType.I2); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(short?)) { - // short? -> DvInt2 - Ch.Assert(colType == NumberType.I2); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt2.NA); + // short? -> short? + Ch.Assert(colType == NumberType.NI2); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(long)) { - // long -> DvInt8 + // long -> long Ch.Assert(colType == NumberType.I8); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(long?)) { - // long? -> DvInt8 - Ch.Assert(colType == NumberType.I8); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt8.NA); + // long? -> long? + Ch.Assert(colType == NumberType.NI8); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(sbyte)) { - // sbyte -> DvInt1 + // sbyte -> sbyte Ch.Assert(colType == NumberType.I1); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(sbyte?)) { - // sbyte? -> DvInt1 - Ch.Assert(colType == NumberType.I1); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt1.NA); + // sbyte? -> sbyte? + Ch.Assert(colType == NumberType.NI1); + return CreateConvertingGetterDelegate(index, x => x); } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 19f9a7cf72..5fb4d6950d 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -292,42 +292,42 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit else if (fieldType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); } else if (fieldType.GetElementType() == typeof(int?)) { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int?)x); + Ch.Assert(colType.ItemType == NumberType.NI4); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(short?)) { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short?)x); + Ch.Assert(colType.ItemType == NumberType.NI2); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(long?)) { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long?)x); + Ch.Assert(colType.ItemType == NumberType.NI8); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(sbyte?)) { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte?)x); + Ch.Assert(colType.ItemType == NumberType.NI1); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } // VBuffer -> T[] @@ -373,49 +373,49 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (int)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(int?)) { - Ch.Assert(colType == NumberType.I4); + Ch.Assert(colType == NumberType.NI4); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (int?)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(short)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (short)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(short?)) { - Ch.Assert(colType == NumberType.I2); + Ch.Assert(colType == NumberType.NI2); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (short?)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(long)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (long)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(long?)) { - Ch.Assert(colType == NumberType.I8); + Ch.Assert(colType == NumberType.NI8); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (long?)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(sbyte)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (sbyte)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(sbyte?)) { - Ch.Assert(colType == NumberType.I1); + Ch.Assert(colType == NumberType.NI1); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (sbyte?)x); + return CreateConvertingActionSetter(input, index, poke, x => x); } // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs index d04adaf099..fef1c82a63 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs @@ -44,13 +44,13 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null) _loadNameToCodecCreator = new Dictionary(); _simpleCodecTypeMap = new Dictionary(); // Register the current codecs. - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs index 026228d6be..9930ce2974 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs @@ -33,16 +33,12 @@ static UnsafeTypeOpsFactory() { _type2ops = new Dictionary(); _type2ops[typeof(SByte)] = new SByteUnsafeTypeOps(); - _type2ops[typeof(DvInt1)] = new DvI1UnsafeTypeOps(); _type2ops[typeof(Byte)] = new ByteUnsafeTypeOps(); _type2ops[typeof(Int16)] = new Int16UnsafeTypeOps(); - _type2ops[typeof(DvInt2)] = new DvI2UnsafeTypeOps(); _type2ops[typeof(UInt16)] = new UInt16UnsafeTypeOps(); _type2ops[typeof(Int32)] = new Int32UnsafeTypeOps(); - _type2ops[typeof(DvInt4)] = new DvI4UnsafeTypeOps(); _type2ops[typeof(UInt32)] = new UInt32UnsafeTypeOps(); _type2ops[typeof(Int64)] = new Int64UnsafeTypeOps(); - _type2ops[typeof(DvInt8)] = new DvI8UnsafeTypeOps(); _type2ops[typeof(UInt64)] = new UInt64UnsafeTypeOps(); _type2ops[typeof(Single)] = new SingleUnsafeTypeOps(); _type2ops[typeof(Double)] = new DoubleUnsafeTypeOps(); @@ -67,19 +63,6 @@ public override unsafe void Apply(SByte[] array, Action func) public override SByte Read(BinaryReader reader) { return reader.ReadSByte(); } } - private sealed class DvI1UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(SByte); } } - public override unsafe void Apply(DvInt1[] array, Action func) - { - fixed (DvInt1* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt1 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt1 Read(BinaryReader reader) { return reader.ReadSByte(); } - } - private sealed class ByteUnsafeTypeOps : UnsafeTypeOps { public override int Size { get { return sizeof(Byte); } } @@ -104,19 +87,6 @@ public override unsafe void Apply(Int16[] array, Action func) public override Int16 Read(BinaryReader reader) { return reader.ReadInt16(); } } - private sealed class DvI2UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(Int16); } } - public override unsafe void Apply(DvInt2[] array, Action func) - { - fixed (DvInt2* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt2 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt2 Read(BinaryReader reader) { return reader.ReadInt16(); } - } - private sealed class UInt16UnsafeTypeOps : UnsafeTypeOps { public override int Size { get { return sizeof(UInt16); } } @@ -141,19 +111,6 @@ public override unsafe void Apply(Int32[] array, Action func) public override Int32 Read(BinaryReader reader) { return reader.ReadInt32(); } } - private sealed class DvI4UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(Int32); } } - public override unsafe void Apply(DvInt4[] array, Action func) - { - fixed (DvInt4* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt4 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt4 Read(BinaryReader reader) { return reader.ReadInt32(); } - } - private sealed class UInt32UnsafeTypeOps : UnsafeTypeOps { public override int Size { get { return sizeof(UInt32); } } @@ -178,19 +135,6 @@ public override unsafe void Apply(Int64[] array, Action func) public override Int64 Read(BinaryReader reader) { return reader.ReadInt64(); } } - private sealed class DvI8UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(Int64); } } - public override unsafe void Apply(DvInt8[] array, Action func) - { - fixed (DvInt8* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt8 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt8 Read(BinaryReader reader) { return reader.ReadInt64(); } - } - private sealed class UInt64UnsafeTypeOps : UnsafeTypeOps { public override int Size { get { return sizeof(UInt64); } } diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 8e4f3be56c..253619249e 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -136,7 +136,7 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A var thresholdAtK = new List(); var thresholdAtP = new List(); var thresholdAtNumAnomalies = new List(); - var numAnoms = new List(); + var numAnoms = new List(); var scores = new List(); var labels = new List(); @@ -678,11 +678,11 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary col == numAnomIndex || (hasStrat && col == stratCol))) { - var numAnomGetter = cursor.GetGetter(numAnomIndex); + var numAnomGetter = cursor.GetGetter(numAnomIndex); ValueGetter stratGetter = null; if (hasStrat) { diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index fd23e7c3b0..d97e4e6593 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -998,8 +998,8 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa if (labelType.IsKey && (!perInst.Schema.HasKeyNames(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4)) { perInst = LambdaColumnMapper.Create(Host, "ConvertToLong", perInst, schema.Label.Name, - schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.I8, - (ref uint src, ref DvInt8 dst) => dst = src == 0 ? DvInt8.NA : src - 1 + (long)labelType.AsKey.Min); + schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.NI8, + (ref uint src, ref Int64? dst) => dst = src == 0 ? null : src - 1 + (long?)labelType.AsKey.Min); } var perInstSchema = perInst.Schema; diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 33d3d1490d..b86daaace4 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -593,7 +593,7 @@ public void Signal(CursOpt opt) } /// - /// This supports Weight (Float), Group (ulong), and Id (DvInt8) columns. + /// This supports Weight (Float), Group (ulong), and Id (UInt128) columns. /// public class StandardScalarCursor : TrainingCursorBase { diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index cacd681141..3d8823d5a1 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -430,9 +430,9 @@ public ValueGetter GetGetter(int col) return fn; } - private ValueGetter MakeGetter() + private ValueGetter MakeGetter() { - return (ref DvInt8 value) => + return (ref Int64 value) => { Ch.Check(IsGood); value = Input.Position; diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 503debae65..792eafe17a 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -499,21 +499,21 @@ private Delegate CreateGetterDelegate(int col) case DataType.Byte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.SignedByte: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedByte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Short: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedShort: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int16: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedInt16: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int32: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int64: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int96: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.ByteArray: @@ -678,17 +678,17 @@ public ParquetConversions(IChannel channel) public void Conv(ref byte[] src, ref VBuffer dst) => dst = src != null ? new VBuffer(src.Length, src) : new VBuffer(0, new byte[0]); - public void Conv(ref sbyte? src, ref DvInt1 dst) => dst = src ?? DvInt1.NA; + public void Conv(ref sbyte? src, ref SByte? dst) => dst = src; public void Conv(ref byte src, ref byte dst) => dst = src; - public void Conv(ref short? src, ref DvInt2 dst) => dst = src ?? DvInt2.NA; + public void Conv(ref short? src, ref Int16? dst) => dst = src; public void Conv(ref ushort src, ref ushort dst) => dst = src; - public void Conv(ref int? src, ref DvInt4 dst) => dst = src ?? DvInt4.NA; + public void Conv(ref int? src, ref Int32? dst) => dst = src; - public void Conv(ref long? src, ref DvInt8 dst) => dst = src ?? DvInt8.NA; + public void Conv(ref long? src, ref Int64? dst) => dst = src; public void Conv(ref float? src, ref Single dst) => dst = src ?? Single.NaN; diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index 91874291b0..dd2a117051 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -414,7 +414,7 @@ public void AddStatsColumns(List list, LinearBinaryPredictor parent, Ro _env.AssertValueOrNull(parent); _env.AssertValue(schema); - DvInt8 count = _trainingExampleCount; + Int64 count = _trainingExampleCount; list.Add(RowColumnUtils.GetColumn("Count of training examples", NumberType.I8, ref count)); var dev = _deviance; list.Add(RowColumnUtils.GetColumn("Residual Deviance", NumberType.R4, ref dev)); diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 0af833a046..e0d6137a5c 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -407,7 +407,14 @@ private void GetLabels(Transposer trans, ColumnType labelType, int labelCol) // Note: NAs have their own separate bin. if (labelType == NumberType.I4) { - var tmp = default(VBuffer); + var tmp = default(VBuffer); + trans.GetSingleSlotValue(labelCol, ref tmp); + BinInts(ref tmp, ref labels, _numBins, out min, out lim); + _numLabels = lim - min; + } + else if (labelType == NumberType.NI4) + { + var tmp = default(VBuffer); trans.GetSingleSlotValue(labelCol, ref tmp); BinInts(ref tmp, ref labels, _numBins, out min, out lim); _numLabels = lim - min; @@ -486,7 +493,15 @@ private Single[] ComputeMutualInformation(Transposer trans, int col) if (type.ItemType == NumberType.I4) { return ComputeMutualInformation(trans, col, - (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => + (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => + { + BinInts(ref src, ref dst, _numBins, out min, out lim); + }); + } + else if (type.ItemType == NumberType.NI4) + { + return ComputeMutualInformation(trans, col, + (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => { BinInts(ref src, ref dst, _numBins, out min, out lim); }); @@ -674,9 +689,28 @@ private static ValueMapper, VBuffer> BinKeys(ColumnType colTy } /// - /// Maps from DvInt4 to ints. NaNs (and only NaNs) are mapped to the first bin. + /// Maps Ints. + /// + private void BinInts(ref VBuffer input, ref VBuffer output, + int numBins, out int min, out int lim) + { + Contracts.Assert(_singles.Count == 0); + + var bounds = _binFinder.FindBins(numBins, _singles, input.Length - input.Count); + min = -1 - bounds.FindIndexSorted(0); + lim = min + bounds.Length + 1; + int offset = min; + ValueMapper mapper = + (ref Int32 src, ref int dst) => + dst = offset + 1 + bounds.FindIndexSorted((Single)src); + mapper.MapVector(ref input, ref output); + _singles.Clear(); + } + + /// + /// Maps from Int32? to ints. NaNs (and only NaNs) are mapped to the first bin. /// - private void BinInts(ref VBuffer input, ref VBuffer output, + private void BinInts(ref VBuffer input, ref VBuffer output, int numBins, out int min, out int lim) { Contracts.Assert(_singles.Count == 0); @@ -685,7 +719,7 @@ private void BinInts(ref VBuffer input, ref VBuffer output, for (int i = 0; i < input.Count; i++) { var val = input.Values[i]; - if (!val.IsNA) + if (!val.HasValue) _singles.Add((Single)val); } } @@ -694,9 +728,9 @@ private void BinInts(ref VBuffer input, ref VBuffer output, min = -1 - bounds.FindIndexSorted(0); lim = min + bounds.Length + 1; int offset = min; - ValueMapper mapper = - (ref DvInt4 src, ref int dst) => - dst = src.IsNA ? offset : offset + 1 + bounds.FindIndexSorted((Single)src); + ValueMapper mapper = + (ref Int32? src, ref int dst) => + dst = !src.HasValue ? offset : offset + 1 + bounds.FindIndexSorted((Single)src); mapper.MapVector(ref input, ref output); _singles.Clear(); } diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 330412185e..f7eec8ac52 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -160,19 +160,19 @@ private static bool TryGetDataKind(Type type, out DataKind kind) Contracts.AssertValue(type); // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(DvInt1) || type == typeof(sbyte)) + if (type == typeof(sbyte)) kind = DataKind.I1; else if (type == typeof(byte) || type == typeof(char)) kind = DataKind.U1; - else if (type == typeof(DvInt2) || type == typeof(short)) + else if (type == typeof(short)) kind = DataKind.I2; else if (type == typeof(ushort)) kind = DataKind.U2; - else if (type == typeof(DvInt4) || type == typeof(int)) + else if ( type == typeof(int)) kind = DataKind.I4; else if (type == typeof(uint)) kind = DataKind.U4; - else if (type == typeof(DvInt8) || type == typeof(long)) + else if (type == typeof(long)) kind = DataKind.I8; else if (type == typeof(ulong)) kind = DataKind.U8; From 46299b87ca5ff536ff6458ced2c33a9efacdeb2d Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 15 Aug 2018 22:34:21 -0700 Subject: [PATCH 06/37] test. --- .../UnitTests/CoreBaseTestClass.cs | 16 +++---- .../UnitTests/DvTypes.cs | 42 ++++--------------- .../TestTransposer.cs | 28 ++++++------- .../DataPipe/TestDataPipeBase.cs | 16 +++---- .../TestSparseDataView.cs | 5 ++- 5 files changed, 40 insertions(+), 67 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs index 35859783ad..813e56e545 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs @@ -153,19 +153,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -196,19 +196,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs index a3f5d8231b..d332ade37c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs @@ -11,12 +11,12 @@ namespace Microsoft.ML.Runtime.RunTests public sealed class DvTypeTests { [Fact] - public void TestComparableDvInt4() + public void TestComparableInt32() { const int count = 100; var rand = RandomUtils.Create(42); - var values = new DvInt4[2 * count]; + var values = new Int32?[2 * count]; for (int i = 0; i < count; i++) { var v = values[i] = rand.Next(); @@ -28,41 +28,13 @@ public void TestComparableDvInt4() int iv2 = rand.Next(values.Length - 1); if (iv2 >= iv1) iv2++; - values[iv1] = DvInt4.NA; - values[iv2] = DvInt4.NA; + values[iv1] = null; + values[iv2] = null; Array.Sort(values); - Assert.True(values[0].IsNA); - Assert.True(values[1].IsNA); - Assert.True(!values[2].IsNA); - - Assert.True((values[0] == values[1]).IsNA); - Assert.True((values[0] != values[1]).IsNA); - Assert.True((values[0] <= values[1]).IsNA); - Assert.True(values[0].Equals(values[1])); - Assert.True(values[0].CompareTo(values[1]) == 0); - - Assert.True((values[1] == values[2]).IsNA); - Assert.True((values[1] != values[2]).IsNA); - Assert.True((values[1] <= values[2]).IsNA); - Assert.True(!values[1].Equals(values[2])); - Assert.True(values[1].CompareTo(values[2]) < 0); - - for (int i = 3; i < values.Length; i++) - { - DvBool eq = values[i - 1] == values[i]; - DvBool ne = values[i - 1] != values[i]; - DvBool le = values[i - 1] <= values[i]; - bool feq = values[i - 1].Equals(values[i]); - int cmp = values[i - 1].CompareTo(values[i]); - Assert.True(!eq.IsNA); - Assert.True(!ne.IsNA); - Assert.True(eq.IsTrue == ne.IsFalse); - Assert.True(le.IsTrue); - Assert.True(feq == eq.IsTrue); - Assert.True(cmp <= 0); - Assert.True(feq == (cmp == 0)); - } + Assert.True(!values[0].HasValue); + Assert.True(!values[1].HasValue); + Assert.True(values[2].HasValue); } [Fact] diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs index ed1780c6d7..35661bbb15 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs @@ -148,19 +148,19 @@ public void TransposerTest() ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env); // A is to check the splitting of a sparse-ish column. - var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (DvInt4)rgen.Next(), 50, 5, 10, 15); - dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. + var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (Int32)rgen.Next(), 50, 5, 10, 15); + dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. builder.AddColumn("A", NumberType.I4, dataA); // B is to check the splitting of a dense-ish column. builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49)); // C is to just have some column we do nothing with. - builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (DvInt2)1, 30, 3, 10, 24)); + builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (Int16)1, 30, 3, 10, 24)); // D is to check some column we don't have to split because it's sufficiently small. builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1)); // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (DvInt4)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (Int32)rgen.Next())); IDataView view = builder.GetDataView(); @@ -181,11 +181,11 @@ public void TransposerTest() } // Check the contents Assert.Null(trans.TransposeSchema.GetSlotType(2)); // C check to see that it's not transposable. - TransposeCheckHelper(view, 0, trans); // A check. + TransposeCheckHelper(view, 0, trans); // A check. TransposeCheckHelper(view, 1, trans); // B check. TransposeCheckHelper(view, 3, trans); // D check. TransposeCheckHelper(view, 4, trans); // E check. - TransposeCheckHelper(view, 5, trans); // F check. + TransposeCheckHelper(view, 5, trans); // F check. } // Force save. Recheck columns that would have previously been passthrough columns. @@ -200,7 +200,7 @@ public void TransposerTest() Assert.Null(trans.TransposeSchema.GetSlotType(2)); TransposeCheckHelper(view, 3, trans); // D check. TransposeCheckHelper(view, 4, trans); // E check. - TransposeCheckHelper(view, 5, trans); // F check. + TransposeCheckHelper(view, 5, trans); // F check. } } @@ -213,19 +213,19 @@ public void TransposerSaverLoaderTest() ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env); // A is to check the splitting of a sparse-ish column. - var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (DvInt4)rgen.Next(), 50, 5, 10, 15); - dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. + var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (Int32)rgen.Next(), 50, 5, 10, 15); + dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. builder.AddColumn("A", NumberType.I4, dataA); // B is to check the splitting of a dense-ish column. builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49)); // C is to just have some column we do nothing with. - builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (DvInt2)1, 30, 3, 10, 24)); + builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (Int16)1, 30, 3, 10, 24)); // D is to check some column we don't have to split because it's sufficiently small. builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1)); // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (DvInt4)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (Int32)rgen.Next())); IDataView view = builder.GetDataView(); @@ -240,12 +240,12 @@ public void TransposerSaverLoaderTest() // First check whether this as an IDataView yields the same values. CheckSameValues(view, loader); - TransposeCheckHelper(view, 0, loader); // A + TransposeCheckHelper(view, 0, loader); // A TransposeCheckHelper(view, 1, loader); // B - TransposeCheckHelper(view, 2, loader); // C + TransposeCheckHelper(view, 2, loader); // C TransposeCheckHelper(view, 3, loader); // D TransposeCheckHelper(view, 4, loader); // E - TransposeCheckHelper(view, 5, loader); // F + TransposeCheckHelper(view, 5, loader); // F Done(); } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index b53062c1a8..e53bc1e449 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -879,19 +879,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -922,19 +922,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: diff --git a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs index 08c9e17a28..f3bf775921 100644 --- a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs +++ b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using System; using Xunit; using Xunit.Abstractions; @@ -34,7 +35,7 @@ private class SparseExample public void SparseDataView() { GenericSparseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); - GenericSparseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 }); + GenericSparseDataView(new Int32[] { 1, 2, 3 }, new Int32[] { 1, 10, 100 }); GenericSparseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); GenericSparseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); GenericSparseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, @@ -76,7 +77,7 @@ private void GenericSparseDataView(T[] v1, T[] v2) public void DenseDataView() { GenericDenseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); - GenericDenseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 }); + GenericDenseDataView(new Int32[] { 1, 2, 3 }, new Int32[] { 1, 10, 100 }); GenericDenseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); GenericDenseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); GenericDenseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, From a99831b485f4e500b7f3ead2abe618b412c98834 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 15 Aug 2018 22:45:47 -0700 Subject: [PATCH 07/37] undo DvBool. --- src/Microsoft.ML.Core/Data/ColumnType.cs | 11 ----------- src/Microsoft.ML.Core/Data/DataKind.cs | 17 ++++------------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index b754c0e767..6e6c528560 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -624,17 +624,6 @@ public static BoolType Instance } } - private static volatile BoolType _ninstance; - public static BoolType NInstance - { - get - { - if (_ninstance == null) - Interlocked.CompareExchange(ref _ninstance, new BoolType(DataKind.NBL, "NBool"), null); - return _ninstance; - } - } - private readonly string _name; private BoolType(DataKind kind, string name) diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index d8f3b1bef0..c016cb2228 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -50,10 +50,7 @@ public enum DataKind : byte NI1 = 17, NI2 = 18, NI4 = 19, - NI8 = 20, -#pragma warning disable MSML_GeneralName - NBL = 21, -#pragma warning restore MSML_GeneralName + NI8 = 20 } /// @@ -62,7 +59,7 @@ public enum DataKind : byte public static class DataKindExtensions { public const DataKind KindMin = DataKind.I1; - public const DataKind KindLim = DataKind.NBL + 1; + public const DataKind KindLim = DataKind.NI8 + 1; public const int KindCount = KindLim - KindMin; /// @@ -186,9 +183,7 @@ public static Type ToType(this DataKind kind) case DataKind.TX: return typeof(DvText); case DataKind.BL: - return typeof(bool); - case DataKind.NBL: - return typeof(bool?); + return typeof(DvBool); case DataKind.TS: return typeof(DvTimeSpan); case DataKind.DT: @@ -240,10 +235,8 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(bool)) + else if (type == typeof(bool) || type == typeof(bool?)) kind = DataKind.BL; - else if (type == typeof(bool?)) - kind = DataKind.NBL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; else if (type == typeof(DvDateTime)) @@ -299,8 +292,6 @@ public static string GetString(this DataKind kind) return "R8"; case DataKind.BL: return "BL"; - case DataKind.NBL: - return "NBL"; case DataKind.TX: return "TX"; case DataKind.TS: From a0ceabb1e9c9533c950e4aa863f6ccda752d39e0 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 15 Aug 2018 22:48:50 -0700 Subject: [PATCH 08/37] undo DvBool. --- src/Microsoft.ML.Core/Data/DataKind.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index c016cb2228..2fa69507fc 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -235,7 +235,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(bool) || type == typeof(bool?)) + else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; From f30dd7ee2f27c6fb3fb1f8341f10fba44ac3cc62 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 16 Aug 2018 14:26:19 -0700 Subject: [PATCH 09/37] Replace DvInt* with .NET standard types and remove missing value support for ints. --- src/Microsoft.ML.Api/ApiUtils.cs | 2 +- .../DataViewConstructionUtils.cs | 44 --- src/Microsoft.ML.Api/TypedCursor.cs | 44 --- src/Microsoft.ML.Core/Data/ColumnType.cs | 53 ---- src/Microsoft.ML.Core/Data/DataKind.cs | 42 +-- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 4 +- src/Microsoft.ML.Core/Utilities/Stream.cs | 2 +- src/Microsoft.ML.Data/Data/Conversion.cs | 279 ++++++++---------- .../DataLoadSave/Binary/CodecFactory.cs | 2 +- .../DataLoadSave/Text/TextLoaderParser.cs | 8 +- .../MulticlassClassifierEvaluator.cs | 4 +- src/Microsoft.ML.Parquet/ParquetLoader.cs | 18 +- .../MutualInformationFeatureSelection.cs | 43 --- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 48 --- .../UnitTests/TestEntryPoints.cs | 14 +- .../CollectionDataSourceTests.cs | 72 +---- 16 files changed, 143 insertions(+), 536 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 5210e4267c..6061d7474c 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -19,7 +19,7 @@ private static OpCode GetAssignmentOpCode(Type t) { // REVIEW: This should be a Dictionary based solution. // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128. - if (t == typeof(DvText) || t == typeof(string) || t.IsArray || + if (t == typeof(DvText) || t == typeof(DvBool) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index f3738bc3cd..98c0128102 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -136,41 +136,21 @@ private Delegate CreateGetter(int index) Ch.Assert(colType.ItemType == NumberType.I4); return CreateConvertingArrayGetterDelegate(index, x => x); } - else if (outputType.GetElementType() == typeof(int?)) - { - Ch.Assert(colType.ItemType == NumberType.NI4); - return CreateConvertingArrayGetterDelegate(index, x => x); - } else if (outputType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); return CreateConvertingArrayGetterDelegate(index, x => x); } - else if (outputType.GetElementType() == typeof(long?)) - { - Ch.Assert(colType.ItemType == NumberType.NI8); - return CreateConvertingArrayGetterDelegate(index, x => x); - } else if (outputType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); return CreateConvertingArrayGetterDelegate(index, x => x); } - else if (outputType.GetElementType() == typeof(short?)) - { - Ch.Assert(colType.ItemType == NumberType.NI2); - return CreateConvertingArrayGetterDelegate(index, x => x); - } else if (outputType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); return CreateConvertingArrayGetterDelegate(index, x => x); } - else if (outputType.GetElementType() == typeof(sbyte?)) - { - Ch.Assert(colType.ItemType == NumberType.NI1); - return CreateConvertingArrayGetterDelegate(index, x => x); - } else if (outputType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); @@ -226,48 +206,24 @@ private Delegate CreateGetter(int index) Ch.Assert(colType == NumberType.I4); return CreateConvertingGetterDelegate(index, x => x); } - else if (outputType == typeof(int?)) - { - // int? -> int? - Ch.Assert(colType == NumberType.NI4); - return CreateConvertingGetterDelegate(index, x => x); - } else if (outputType == typeof(short)) { // short -> short Ch.Assert(colType == NumberType.I2); return CreateConvertingGetterDelegate(index, x => x); } - else if (outputType == typeof(short?)) - { - // short? -> short? - Ch.Assert(colType == NumberType.NI2); - return CreateConvertingGetterDelegate(index, x => x); - } else if (outputType == typeof(long)) { // long -> long Ch.Assert(colType == NumberType.I8); return CreateConvertingGetterDelegate(index, x => x); } - else if (outputType == typeof(long?)) - { - // long? -> long? - Ch.Assert(colType == NumberType.NI8); - return CreateConvertingGetterDelegate(index, x => x); - } else if (outputType == typeof(sbyte)) { // sbyte -> sbyte Ch.Assert(colType == NumberType.I1); return CreateConvertingGetterDelegate(index, x => x); } - else if (outputType == typeof(sbyte?)) - { - // sbyte? -> sbyte? - Ch.Assert(colType == NumberType.NI1); - return CreateConvertingGetterDelegate(index, x => x); - } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType)); diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 5fb4d6950d..437236e63c 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -294,41 +294,21 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(colType.ItemType == NumberType.I4); return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); } - else if (fieldType.GetElementType() == typeof(int?)) - { - Ch.Assert(colType.ItemType == NumberType.NI4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } else if (fieldType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } - else if (fieldType.GetElementType() == typeof(short?)) - { - Ch.Assert(colType.ItemType == NumberType.NI2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } else if (fieldType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } - else if (fieldType.GetElementType() == typeof(long?)) - { - Ch.Assert(colType.ItemType == NumberType.NI8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } else if (fieldType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } - else if (fieldType.GetElementType() == typeof(sbyte?)) - { - Ch.Assert(colType.ItemType == NumberType.NI1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } // VBuffer -> T[] if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) @@ -375,48 +355,24 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(peek == null); return CreateConvertingActionSetter(input, index, poke, x => x); } - else if (fieldType == typeof(int?)) - { - Ch.Assert(colType == NumberType.NI4); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } else if (fieldType == typeof(short)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); return CreateConvertingActionSetter(input, index, poke, x => x); } - else if (fieldType == typeof(short?)) - { - Ch.Assert(colType == NumberType.NI2); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } else if (fieldType == typeof(long)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); return CreateConvertingActionSetter(input, index, poke, x => x); } - else if (fieldType == typeof(long?)) - { - Ch.Assert(colType == NumberType.NI8); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } else if (fieldType == typeof(sbyte)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); return CreateConvertingActionSetter(input, index, poke, x => x); } - else if (fieldType == typeof(sbyte?)) - { - Ch.Assert(colType == NumberType.NI1); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType)); diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 6e6c528560..12d0509352 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -385,17 +385,6 @@ public static NumberType I1 } } - private static volatile NumberType _instNI1; - public static NumberType NI1 - { - get - { - if (_instNI1 == null) - Interlocked.CompareExchange(ref _instNI1, new NumberType(DataKind.NI1, "NI1"), null); - return _instNI1; - } - } - private static volatile NumberType _instU1; public static NumberType U1 { @@ -418,17 +407,6 @@ public static NumberType I2 } } - private static volatile NumberType _instNI2; - public static NumberType NI2 - { - get - { - if (_instNI2 == null) - Interlocked.CompareExchange(ref _instNI2, new NumberType(DataKind.NI2, "NI2"), null); - return _instNI2; - } - } - private static volatile NumberType _instU2; public static NumberType U2 { @@ -451,17 +429,6 @@ public static NumberType I4 } } - private static volatile NumberType _instNI4; - public static NumberType NI4 - { - get - { - if (_instNI4 == null) - Interlocked.CompareExchange(ref _instNI4, new NumberType(DataKind.NI4, "NI4"), null); - return _instNI4; - } - } - private static volatile NumberType _instU4; public static NumberType U4 { @@ -484,17 +451,6 @@ public static NumberType I8 } } - private static volatile NumberType _instNI8; - public static NumberType NI8 - { - get - { - if (_instNI8 == null) - Interlocked.CompareExchange(ref _instNI8, new NumberType(DataKind.NI8, "NI8"), null); - return _instNI8; - } - } - private static volatile NumberType _instU8; public static NumberType U8 { @@ -550,26 +506,18 @@ public static NumberType Float { case DataKind.I1: return I1; - case DataKind.NI1: - return NI1; case DataKind.U1: return U1; case DataKind.I2: return I2; - case DataKind.NI2: - return NI2; case DataKind.U2: return U2; case DataKind.I4: return I4; - case DataKind.NI4: - return NI4; case DataKind.U4: return U4; case DataKind.I8: return I8; - case DataKind.NI8: - return NI8; case DataKind.U8: return U8; case DataKind.R4: @@ -631,7 +579,6 @@ private BoolType(DataKind kind, string name) { Contracts.AssertNonEmpty(name); _name = name; - Contracts.Assert(IsNumber); } public override bool Equals(ColumnType other) diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 2fa69507fc..da0a4eaf7a 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -47,10 +47,6 @@ public enum DataKind : byte UG = 16, // Unsigned 16-byte integer. U16 = UG, #pragma warning restore MSML_GeneralName - NI1 = 17, - NI2 = 18, - NI4 = 19, - NI8 = 20 } /// @@ -59,7 +55,7 @@ public enum DataKind : byte public static class DataKindExtensions { public const DataKind KindMin = DataKind.I1; - public const DataKind KindLim = DataKind.NI8 + 1; + public const DataKind KindLim = DataKind.U16 + 1; public const int KindCount = KindLim - KindMin; /// @@ -88,22 +84,18 @@ public static ulong ToMaxInt(this DataKind kind) switch (kind) { case DataKind.I1: - case DataKind.NI1: return (ulong)sbyte.MaxValue; case DataKind.U1: return byte.MaxValue; case DataKind.I2: - case DataKind.NI2: return (ulong)short.MaxValue; case DataKind.U2: return ushort.MaxValue; case DataKind.I4: - case DataKind.NI4: return int.MaxValue; case DataKind.U4: return uint.MaxValue; case DataKind.I8: - case DataKind.NI8: return long.MaxValue; case DataKind.U8: return ulong.MaxValue; @@ -121,22 +113,18 @@ public static long ToMinInt(this DataKind kind) switch (kind) { case DataKind.I1: - case DataKind.NI1: return sbyte.MinValue; case DataKind.U1: return byte.MinValue; case DataKind.I2: - case DataKind.NI2: return short.MinValue; case DataKind.U2: return ushort.MinValue; case DataKind.I4: - case DataKind.NI4: return int.MinValue; case DataKind.U4: return uint.MinValue; case DataKind.I8: - case DataKind.NI8: return long.MinValue; case DataKind.U8: return 0; @@ -154,26 +142,18 @@ public static Type ToType(this DataKind kind) { case DataKind.I1: return typeof(sbyte); - case DataKind.NI1: - return typeof(sbyte?); case DataKind.U1: return typeof(byte); case DataKind.I2: return typeof(short); - case DataKind.NI2: - return typeof(short?); case DataKind.U2: return typeof(ushort); case DataKind.I4: - return typeof(Int32); - case DataKind.NI4: - return typeof(Int32?); + return typeof(int); case DataKind.U4: return typeof(uint); case DataKind.I8: - return typeof(Int64); - case DataKind.NI8: - return typeof(Int64?); + return typeof(long); case DataKind.U8: return typeof(ulong); case DataKind.R4: @@ -207,26 +187,18 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) // REVIEW: Make this more efficient. Should we have a global dictionary? if (type == typeof(sbyte)) kind = DataKind.I1; - else if (type == typeof(sbyte?)) - kind = DataKind.NI1; else if (type == typeof(byte) || type == typeof(byte?)) kind = DataKind.U1; else if (type == typeof(short)) kind = DataKind.I2; - else if (type == typeof(short?)) - kind = DataKind.NI2; else if (type == typeof(ushort) || type == typeof(ushort?)) kind = DataKind.U2; else if (type == typeof(int)) kind = DataKind.I4; - else if (type == typeof(int?)) - kind = DataKind.NI4; else if (type == typeof(uint) || type == typeof(uint?)) kind = DataKind.U4; else if (type == typeof(long)) kind = DataKind.I8; - else if (type == typeof(long?)) - kind = DataKind.NI8; else if (type == typeof(ulong) || type == typeof(ulong?)) kind = DataKind.U8; else if (type == typeof(Single) || type == typeof(Single?)) @@ -264,20 +236,12 @@ public static string GetString(this DataKind kind) { case DataKind.I1: return "I1"; - case DataKind.NI1: - return "NI1"; case DataKind.I2: return "I2"; - case DataKind.NI2: - return "NI2"; case DataKind.I4: return "I4"; - case DataKind.NI4: - return "NI4"; case DataKind.I8: return "I8"; - case DataKind.NI8: - return "NI8"; case DataKind.U1: return "U1"; case DataKind.U2: diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 81882fa946..7331db7416 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -404,9 +404,9 @@ public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, return isValid; var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex); - if (type?.RawType == typeof(VBuffer)) + if (type?.RawType == typeof(VBuffer)) { - VBuffer catIndices = default(VBuffer); + VBuffer catIndices = default(VBuffer); schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices); VBufferUtils.Densify(ref catIndices); int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore; diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 41c794e17f..832d8abfdf 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -449,7 +449,7 @@ public static long WriteByteStream(this BinaryWriter writer, IEnumerable e return c; } - public static long WriteIntStream(this BinaryWriter writer, IEnumerable e) + public static long WriteIntStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index dbad9c7946..16641275cd 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -17,10 +17,6 @@ namespace Microsoft.ML.Runtime.Data.Conversion using BL = DvBool; using DT = DvDateTime; using DZ = DvDateTimeZone; - using NI1 = Nullable; - using NI2 = Nullable; - using NI4 = Nullable; - using NI8 = Nullable; using R4 = Single; using R8 = Double; using I1 = SByte; @@ -119,37 +115,37 @@ private Conversions() // !!! WARNING !!!: Do NOT add any standard conversions without clearing from the IDV Type System // design committee. Any changes also require updating the IDV Type System Specification. - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); - - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); - - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); - - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddAux(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); + + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); + + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); + + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddAux(Convert); AddStd(Convert); AddStd(Convert); @@ -202,13 +198,13 @@ private Conversions() AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddStd(Convert); @@ -220,34 +216,30 @@ private Conversions() AddStd(Convert); AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); - AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddStd(Convert); + AddStd(Convert); AddStd(Convert); AddStd(Convert); AddAux(Convert); - AddIsNA(IsNA); - AddIsNA(IsNA); - AddIsNA(IsNA); - AddIsNA(IsNA); AddIsNA(IsNA); AddIsNA(IsNA); AddIsNA(IsNA); @@ -256,10 +248,6 @@ private Conversions() AddIsNA
(IsNA); AddIsNA(IsNA); - AddGetNA(GetNA); - AddGetNA(GetNA); - AddGetNA(GetNA); - AddGetNA(GetNA); AddGetNA(GetNA); AddGetNA(GetNA); AddGetNA(GetNA); @@ -268,10 +256,6 @@ private Conversions() AddGetNA
(GetNA); AddGetNA(GetNA); - AddHasNA(HasNA); - AddHasNA(HasNA); - AddHasNA(HasNA); - AddHasNA(HasNA); AddHasNA(HasNA); AddHasNA(HasNA); AddHasNA(HasNA); @@ -280,10 +264,10 @@ private Conversions() AddHasNA
(HasNA); AddHasNA(HasNA); - AddIsDef(IsDefault); - AddIsDef(IsDefault); - AddIsDef(IsDefault); - AddIsDef(IsDefault); + AddIsDef(IsDefault); + AddIsDef(IsDefault); + AddIsDef(IsDefault); + AddIsDef(IsDefault); AddIsDef(IsDefault); AddIsDef(IsDefault); AddIsDef(IsDefault); @@ -302,10 +286,10 @@ private Conversions() AddHasZero(HasZero); AddHasZero(HasZero); - AddTryParse(TryParse); - AddTryParse(TryParse); - AddTryParse(TryParse); - AddTryParse(TryParse); + AddTryParse(TryParse); + AddTryParse(TryParse); + AddTryParse(TryParse); + AddTryParse(TryParse); AddTryParse(TryParse); AddTryParse(TryParse); AddTryParse(TryParse); @@ -846,10 +830,6 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) // The IsNA methods are for efficient delegates (instance instead of static). #region IsNA - private bool IsNA(ref NI1 src) => !src.HasValue; - private bool IsNA(ref NI2 src) => !src.HasValue; - private bool IsNA(ref NI4 src) => !src.HasValue; - private bool IsNA(ref NI8 src) => !src.HasValue; private bool IsNA(ref R4 src) => src.IsNA(); private bool IsNA(ref R8 src) => src.IsNA(); private bool IsNA(ref BL src) => src.IsNA; @@ -860,10 +840,6 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion IsNA #region HasNA - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (!src.Values[i].HasValue) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA()) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA()) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } @@ -874,10 +850,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion HasNA #region IsDefault - private bool IsDefault(ref NI1 src) => src == default(NI1); - private bool IsDefault(ref NI2 src) => src == default(NI2); - private bool IsDefault(ref NI4 src) => src == default(NI4); - private bool IsDefault(ref NI8 src) => src == default(NI8); + private bool IsDefault(ref I1 src) => src == default(I1); + private bool IsDefault(ref I2 src) => src == default(I2); + private bool IsDefault(ref I4 src) => src == default(I4); + private bool IsDefault(ref I8 src) => src == default(I8); private bool IsDefault(ref R4 src) => src == 0; private bool IsDefault(ref R8 src) => src == 0; private bool IsDefault(ref TX src) => src.IsEmpty; @@ -900,10 +876,6 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion HasZero #region GetNA - private void GetNA(ref NI1 value) => value = default; - private void GetNA(ref NI2 value) => value = default; - private void GetNA(ref NI4 value) => value = default; - private void GetNA(ref NI8 value) => value = default; private void GetNA(ref R4 value) => value = R4.NaN; private void GetNA(ref R8 value) => value = R8.NaN; private void GetNA(ref BL value) => value = BL.NA; @@ -914,35 +886,35 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion GetNA #region ToI1 - public void Convert(ref NI1 src, ref NI1 dst) => dst = src; - public void Convert(ref NI2 src, ref NI1 dst) => dst = (NI1)src; - public void Convert(ref NI4 src, ref NI1 dst) => dst = (NI1)src; - public void Convert(ref NI8 src, ref NI1 dst) => dst = (NI1)src; + public void Convert(ref I1 src, ref I1 dst) => dst = src; + public void Convert(ref I2 src, ref I1 dst) => dst = (I1)src; + public void Convert(ref I4 src, ref I1 dst) => dst = (I1)src; + public void Convert(ref I8 src, ref I1 dst) => dst = (I1)src; #endregion ToI1 #region ToI2 - public void Convert(ref NI1 src, ref NI2 dst) => dst = src; - public void Convert(ref NI2 src, ref NI2 dst) => dst = src; - public void Convert(ref NI4 src, ref NI2 dst) => dst = (NI2)src; - public void Convert(ref NI8 src, ref NI2 dst) => dst = (NI2)src; + public void Convert(ref I1 src, ref I2 dst) => dst = src; + public void Convert(ref I2 src, ref I2 dst) => dst = src; + public void Convert(ref I4 src, ref I2 dst) => dst = (I2)src; + public void Convert(ref I8 src, ref I2 dst) => dst = (I2)src; #endregion ToI2 #region ToI4 - public void Convert(ref NI1 src, ref NI4 dst) => dst = src; - public void Convert(ref NI2 src, ref NI4 dst) => dst = src; - public void Convert(ref NI4 src, ref NI4 dst) => dst = src; - public void Convert(ref NI8 src, ref NI4 dst) => dst = (NI4)src; + public void Convert(ref I1 src, ref I4 dst) => dst = src; + public void Convert(ref I2 src, ref I4 dst) => dst = src; + public void Convert(ref I4 src, ref I4 dst) => dst = src; + public void Convert(ref I8 src, ref I4 dst) => dst = (I4)src; #endregion ToI4 #region ToI8 - public void Convert(ref NI1 src, ref NI8 dst) => dst = src; - public void Convert(ref NI2 src, ref NI8 dst) => dst = src; - public void Convert(ref NI4 src, ref NI8 dst) => dst = src; - public void Convert(ref NI8 src, ref NI8 dst) => dst = src; - - public void Convert(ref TS src, ref NI8 dst) => dst = (NI8)src.Ticks; - public void Convert(ref DT src, ref NI8 dst) => dst = (NI8)src.Ticks; - public void Convert(ref DZ src, ref NI8 dst) => dst = (NI8)src.UtcDateTime.Ticks; + public void Convert(ref I1 src, ref I8 dst) => dst = src; + public void Convert(ref I2 src, ref I8 dst) => dst = src; + public void Convert(ref I4 src, ref I8 dst) => dst = src; + public void Convert(ref I8 src, ref I8 dst) => dst = src; + + public void Convert(ref TS src, ref I8 dst) => dst = (I8)src.Ticks; + public void Convert(ref DT src, ref I8 dst) => dst = (I8)src.Ticks; + public void Convert(ref DZ src, ref I8 dst) => dst = (I8)src.UtcDateTime.Ticks; #endregion ToI8 #region ToU1 @@ -986,10 +958,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToUG #region ToR4 - public void Convert(ref NI1 src, ref R4 dst) => dst = (R4)src; - public void Convert(ref NI2 src, ref R4 dst) => dst = (R4)src; - public void Convert(ref NI4 src, ref R4 dst) => dst = (R4)src; - public void Convert(ref NI8 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref I1 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref I2 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref I4 src, ref R4 dst) => dst = (R4)src; + public void Convert(ref I8 src, ref R4 dst) => dst = (R4)src; public void Convert(ref U1 src, ref R4 dst) => dst = src; public void Convert(ref U2 src, ref R4 dst) => dst = src; public void Convert(ref U4 src, ref R4 dst) => dst = src; @@ -1004,10 +976,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToR4 #region ToR8 - public void Convert(ref NI1 src, ref R8 dst) => dst = (R8)src; - public void Convert(ref NI2 src, ref R8 dst) => dst = (R8)src; - public void Convert(ref NI4 src, ref R8 dst) => dst = (R8)src; - public void Convert(ref NI8 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref I1 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref I2 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref I4 src, ref R8 dst) => dst = (R8)src; + public void Convert(ref I8 src, ref R8 dst) => dst = (R8)src; public void Convert(ref U1 src, ref R8 dst) => dst = src; public void Convert(ref U2 src, ref R8 dst) => dst = src; public void Convert(ref U4 src, ref R8 dst) => dst = src; @@ -1022,10 +994,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToR8 #region ToStringBuilder - public void Convert(ref NI1 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } - public void Convert(ref NI2 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } - public void Convert(ref NI4 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } - public void Convert(ref NI8 src, ref SB dst) { ClearDst(ref dst); if (src.HasValue) dst.Append(src.Value); } + public void Convert(ref I1 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } + public void Convert(ref I2 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } + public void Convert(ref I4 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } + public void Convert(ref I8 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } public void Convert(ref U1 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(ref U2 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(ref U4 src, ref SB dst) => ClearDst(ref dst).Append(src); @@ -1303,16 +1275,13 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. ///
- public bool TryParse(ref TX src, out NI1 dst) + public bool TryParse(ref TX src, out I1 dst) { long? res; bool f = TryParseSigned(I1.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); - Contracts.Assert(!res.HasValue || (I1)res == res); - if (res.HasValue) - dst = (I1)res; - else - dst = null; + Contracts.Check(f && res.HasValue); + Contracts.Assert((I1)res == res); + dst = (I1)res; return f; } @@ -1320,16 +1289,13 @@ public bool TryParse(ref TX src, out NI1 dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. ///
- public bool TryParse(ref TX src, out NI2 dst) + public bool TryParse(ref TX src, out I2 dst) { long? res; bool f = TryParseSigned(I2.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); + Contracts.Check(f && res.HasValue); Contracts.Assert(!res.HasValue || (I2)res == res); - if (res.HasValue) - dst = (I2)res; - else - dst = null; + dst = (I2)res; return f; } @@ -1337,16 +1303,13 @@ public bool TryParse(ref TX src, out NI2 dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. ///
- public bool TryParse(ref TX src, out NI4 dst) + public bool TryParse(ref TX src, out I4 dst) { long? res; bool f = TryParseSigned(I4.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); + Contracts.Check(f && res.HasValue); Contracts.Assert(!res.HasValue || (I4)res == res); - if (res.HasValue) - dst = (I4)res; - else - dst = null; + dst = (I4)res; return f; } @@ -1354,15 +1317,12 @@ public bool TryParse(ref TX src, out NI4 dst) /// This produces zero for empty. It returns false if the text is not parsable or overflows. /// On failure, it sets dst to the NA value. ///
- public bool TryParse(ref TX src, out NI8 dst) + public bool TryParse(ref TX src, out I8 dst) { long? res; bool f = TryParseSigned(I8.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); - if (res.HasValue) - dst = (I8)res; - else - dst = null; + Contracts.Check(f && res.HasValue); + dst = (I8)res; return f; } @@ -1543,48 +1503,39 @@ public bool TryParse(ref TX src, out DZ dst) } // These map unparsable and overflow values to "NA", which is null. - private NI1 ParseI1(ref TX src) + private I1 ParseI1(ref TX src) { long? res; bool f = TryParseSigned(I1.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); - if (!res.HasValue) - return null; - + Contracts.Check(f && res.HasValue); Contracts.Assert((I1)res == res); return (I1)res; } - private NI2 ParseI2(ref TX src) + private I2 ParseI2(ref TX src) { long? res; bool f = TryParseSigned(I2.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); - if (!res.HasValue) - return null; - + Contracts.Check(f && res.HasValue); Contracts.Assert((I2)res == res); return (I2)res; } - private NI4 ParseI4(ref TX src) + private I4 ParseI4(ref TX src) { long? res; bool f = TryParseSigned(I4.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); - if (!res.HasValue) - return null; - + Contracts.Check(f && res.HasValue); Contracts.Assert((I4)res == res); return (I4)res; } - private NI8 ParseI8(ref TX src) + private I8 ParseI8(ref TX src) { long? res; bool f = TryParseSigned(I8.MaxValue, ref src, out res); Contracts.Assert(f || !res.HasValue); - return res; + return res.Value; } // These map unparsable and overflow values to zero. The unsigned integer types do not have an NA value. @@ -1755,7 +1706,7 @@ private bool TryParse(ref TX src, out TX dst) return true; } - public void Convert(ref TX span, ref NI1 value) + public void Convert(ref TX span, ref I1 value) { value = ParseI1(ref span); } @@ -1763,7 +1714,7 @@ public void Convert(ref TX span, ref U1 value) { value = ParseU1(ref span); } - public void Convert(ref TX span, ref NI2 value) + public void Convert(ref TX span, ref I2 value) { value = ParseI2(ref span); } @@ -1771,7 +1722,7 @@ public void Convert(ref TX span, ref U2 value) { value = ParseU2(ref span); } - public void Convert(ref TX span, ref NI4 value) + public void Convert(ref TX span, ref I4 value) { value = ParseI4(ref span); } @@ -1779,7 +1730,7 @@ public void Convert(ref TX span, ref U4 value) { value = ParseU4(ref span); } - public void Convert(ref TX span, ref NI8 value) + public void Convert(ref TX span, ref I8 value) { value = ParseI8(ref span); } @@ -1841,10 +1792,10 @@ public void Convert(ref TX span, ref DZ value) #endregion FromTX #region FromBL - public void Convert(ref BL src, ref NI1 dst) => dst = (NI1)src; - public void Convert(ref BL src, ref NI2 dst) => dst = (NI2)src; - public void Convert(ref BL src, ref NI4 dst) => dst = (NI4)src; - public void Convert(ref BL src, ref NI8 dst) => dst = (NI8)src; + public void Convert(ref BL src, ref I1 dst) => dst = (I1)src; + public void Convert(ref BL src, ref I2 dst) => dst = (I2)src; + public void Convert(ref BL src, ref I4 dst) => dst = (I4)src; + public void Convert(ref BL src, ref I8 dst) => dst = (I8)src; public void Convert(ref BL src, ref R4 dst) => dst = (R4)src; public void Convert(ref BL src, ref R8 dst) => dst = (R8)src; public void Convert(ref BL src, ref BL dst) => dst = src; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs index fef1c82a63..91470f5b6b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs @@ -48,7 +48,7 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null) RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index b8962dcc59..b0b3ce63b9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -992,16 +992,16 @@ public int GatherFields(DvText lineSpan, string path = null, long line = 0) } var spanT = Fields.Spans[Fields.Count - 1]; - // Note that Convert produces NA if the text is unparsable. - int? csrc = default; + // Note that Convert throws exception the text is unparsable. + int csrc = default; Conversion.Conversions.Instance.Convert(ref spanT, ref csrc); - if (!csrc.HasValue || csrc.Value <= 0) + if (csrc <= 0) { _stats.LogBadFmt(ref scan, "Bad dimensionality or ambiguous sparse item. Use sparse=- for non-sparse file, and/or quote the value."); break; } - csrcSparse = csrc.Value; + csrcSparse = csrc; srcLimFixed = Fields.Indices[--Fields.Count]; if (csrcSparse >= SrcLim - srcLimFixed) csrcSparse = SrcLim - srcLimFixed - 1; diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index d97e4e6593..a4a83fed28 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -998,8 +998,8 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa if (labelType.IsKey && (!perInst.Schema.HasKeyNames(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4)) { perInst = LambdaColumnMapper.Create(Host, "ConvertToLong", perInst, schema.Label.Name, - schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.NI8, - (ref uint src, ref Int64? dst) => dst = src == 0 ? null : src - 1 + (long?)labelType.AsKey.Min); + schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8, + (ref uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min); } var perInstSchema = perInst.Schema; diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 792eafe17a..2fe3f0ebe7 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -499,21 +499,21 @@ private Delegate CreateGetterDelegate(int col) case DataType.Byte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.SignedByte: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedByte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Short: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedShort: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int16: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedInt16: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int32: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int64: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int96: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.ByteArray: @@ -678,17 +678,17 @@ public ParquetConversions(IChannel channel) public void Conv(ref byte[] src, ref VBuffer dst) => dst = src != null ? new VBuffer(src.Length, src) : new VBuffer(0, new byte[0]); - public void Conv(ref sbyte? src, ref SByte? dst) => dst = src; + public void Conv(ref sbyte src, ref sbyte dst) => dst = src; public void Conv(ref byte src, ref byte dst) => dst = src; - public void Conv(ref short? src, ref Int16? dst) => dst = src; + public void Conv(ref short src, ref short dst) => dst = src; public void Conv(ref ushort src, ref ushort dst) => dst = src; - public void Conv(ref int? src, ref Int32? dst) => dst = src; + public void Conv(ref int src, ref int dst) => dst = src; - public void Conv(ref long? src, ref Int64? dst) => dst = src; + public void Conv(ref long src, ref long dst) => dst = src; public void Conv(ref float? src, ref Single dst) => dst = src ?? Single.NaN; diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index e0d6137a5c..dcd93d5039 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -412,13 +412,6 @@ private void GetLabels(Transposer trans, ColumnType labelType, int labelCol) BinInts(ref tmp, ref labels, _numBins, out min, out lim); _numLabels = lim - min; } - else if (labelType == NumberType.NI4) - { - var tmp = default(VBuffer); - trans.GetSingleSlotValue(labelCol, ref tmp); - BinInts(ref tmp, ref labels, _numBins, out min, out lim); - _numLabels = lim - min; - } else if (labelType == NumberType.R4) { var tmp = default(VBuffer); @@ -498,14 +491,6 @@ private Single[] ComputeMutualInformation(Transposer trans, int col) BinInts(ref src, ref dst, _numBins, out min, out lim); }); } - else if (type.ItemType == NumberType.NI4) - { - return ComputeMutualInformation(trans, col, - (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => - { - BinInts(ref src, ref dst, _numBins, out min, out lim); - }); - } if (type.ItemType == NumberType.R4) { return ComputeMutualInformation(trans, col, @@ -707,34 +692,6 @@ private void BinInts(ref VBuffer input, ref VBuffer output, _singles.Clear(); } - /// - /// Maps from Int32? to ints. NaNs (and only NaNs) are mapped to the first bin. - /// - private void BinInts(ref VBuffer input, ref VBuffer output, - int numBins, out int min, out int lim) - { - Contracts.Assert(_singles.Count == 0); - if (input.Values != null) - { - for (int i = 0; i < input.Count; i++) - { - var val = input.Values[i]; - if (!val.HasValue) - _singles.Add((Single)val); - } - } - - var bounds = _binFinder.FindBins(numBins, _singles, input.Length - input.Count); - min = -1 - bounds.FindIndexSorted(0); - lim = min + bounds.Length + 1; - int offset = min; - ValueMapper mapper = - (ref Int32? src, ref int dst) => - dst = !src.HasValue ? offset : offset + 1 + bounds.FindIndexSorted((Single)src); - mapper.MapVector(ref input, ref output); - _singles.Clear(); - } - /// /// Maps from Singles to ints. NaNs (and only NaNs) are mapped to the first bin. /// diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index c33e3e05f9..fd2ad40af5 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -22,14 +22,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.RawKind) { - case DataKind.NI1: - return new I1.MeanAggregatorOne(ch, cursor, col); - case DataKind.NI2: - return new I2.MeanAggregatorOne(ch, cursor, col); - case DataKind.NI4: - return new I4.MeanAggregatorOne(ch, cursor, col); - case DataKind.NI8: - return new Long.MeanAggregatorOne(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorOne(ch, cursor, col); case DataKind.R8: @@ -46,14 +38,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.RawKind) { - case DataKind.NI1: - return new I1.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI2: - return new I2.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI4: - return new I4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI8: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: @@ -78,14 +62,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.NI1: - return new I1.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.NI2: - return new I2.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.NI4: - return new I4.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.NI8: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R8: @@ -102,14 +78,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.NI1: - return new I1.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI2: - return new I2.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI4: - return new I4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI8: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: @@ -130,14 +98,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.NI1: - return new I1.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.NI2: - return new I2.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.NI4: - return new I4.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.NI8: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorAcrossSlots(ch, cursor, col); case DataKind.R8: @@ -154,14 +114,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.NI1: - return new I1.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI2: - return new I2.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI4: - return new I4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.NI8: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index c44678bf07..5e7f5b9cce 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2021,7 +2021,6 @@ public void EntryPointConvert() { "Transforms.ColumnTypeConverter", "Transforms.ColumnTypeConverter", - "Transforms.ColumnTypeConverter", }, new[] { @@ -2037,7 +2036,7 @@ public void EntryPointConvert() { 'Name': 'Feat', 'Source': 'FT', - 'Type': 'I1' + 'Type': 'R4' }, { 'Name': 'Key1', @@ -2047,18 +2046,11 @@ public void EntryPointConvert() ]", @"'Column': [ { - 'Name': 'Ints', + 'Name': 'Doubles', 'Source': 'Feat' } ], - 'Type': 'I4'", - @"'Column': [ - { - 'Name': 'Floats', - 'Source': 'Ints' - } - ], - 'Type': 'Num'", + 'Type': 'R8'", }); } diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 14a7f473f7..23921c639f 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -299,13 +299,9 @@ public class ConversionSimpleClass public class ConversionNullalbeClass { - public int? fInt; public uint? fuInt; - public short? fShort; public ushort? fuShort; - public sbyte? fsByte; public byte? fByte; - public long? fLong; public ulong? fuLong; public float? fFloat; public double? fDouble; @@ -438,46 +434,34 @@ public void RoundTripConversionWithBasicTypes() { new ConversionNullalbeClass() { - fInt = int.MaxValue - 1, fuInt = uint.MaxValue - 1, fBool = true, - fsByte = sbyte.MaxValue - 1, fByte = byte.MaxValue - 1, fDouble = double.MaxValue - 1, fFloat = float.MaxValue - 1, - fLong = long.MaxValue - 1, fuLong = ulong.MaxValue - 1, - fShort = short.MaxValue - 1, fuShort = ushort.MaxValue - 1, fString = "ha" }, new ConversionNullalbeClass() { - fInt = int.MaxValue, fuInt = uint.MaxValue, fBool = true, - fsByte = sbyte.MaxValue, fByte = byte.MaxValue, fDouble = double.MaxValue, fFloat = float.MaxValue, - fLong = long.MaxValue, fuLong = ulong.MaxValue, - fShort = short.MaxValue, fuShort = ushort.MaxValue, fString = "ooh" }, new ConversionNullalbeClass() { - fInt = int.MinValue + 1, fuInt = uint.MinValue, fBool = false, - fsByte = sbyte.MinValue + 1, fByte = byte.MinValue, fDouble = double.MinValue + 1, fFloat = float.MinValue + 1, - fLong = long.MinValue + 1, fuLong = ulong.MinValue, - fShort = short.MinValue + 1, fuShort = ushort.MinValue, fString = "" }, @@ -542,38 +526,6 @@ public void ConversionExceptionsBehavior() } } - public class ConversionLossMinValueClass - { - public int? fInt; - public long? fLong; - public short? fShort; - public sbyte? fSByte; - } - - [Fact] - public void ConversionMinValueToNullBehavior() - { - using (var env = new TlcEnvironment()) - { - - var data = new List - { - new ConversionLossMinValueClass() { fSByte = null, fInt = null, fLong = null, fShort = null }, - new ConversionLossMinValueClass() { fSByte = sbyte.MinValue, fInt = int.MinValue, fLong = long.MinValue, fShort = short.MinValue } - }; - foreach (var field in typeof(ConversionLossMinValueClass).GetFields()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); - while (enumerator.MoveNext()) - { - Assert.True(enumerator.Current.fInt == null && enumerator.Current.fLong == null && - enumerator.Current.fSByte == null && enumerator.Current.fShort == null); - } - } - } - } - public class ConversionLossMinValueClassProperties { private int? _fInt; @@ -774,13 +726,9 @@ public class ClassWithArrays public class ClassWithNullableArrays { public string[] fString; - public int?[] fInt; public uint?[] fuInt; - public short?[] fShort; public ushort?[] fuShort; - public sbyte?[] fsByte; public byte?[] fByte; - public long?[] fLong; public ulong?[] fuLong; public float?[] fFloat; public double?[] fDouble; @@ -816,20 +764,16 @@ public void RoundTripConversionWithArrays() { new ClassWithNullableArrays() { - fInt = new int?[3] { null, -1, 1 }, fFloat = new float?[3] { -0.99f, null, 0.99f }, fString = new string[2] { null, "" }, fBool = new bool?[3] { true, null, false }, fByte = new byte?[4] { 0, 125, null, 255 }, fDouble = new double?[3] { -1, null, 1 }, - fLong = new long?[] { null, -1, 1 }, - fsByte = new sbyte?[3] { -127, 127, null }, - fShort = new short?[3] { 0, null, 32767 }, fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, fuShort = new ushort?[3] { 0, null, ushort.MaxValue } }, - new ClassWithNullableArrays() { fInt = new int?[3] { -2, 1, 0 }, fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, + new ClassWithNullableArrays() { fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, new ClassWithNullableArrays() }; @@ -885,26 +829,18 @@ public class ClassWithArrayProperties public class ClassWithNullableArrayProperties { private string[] _fString; - private int?[] _fInt; private uint?[] _fuInt; - private short?[] _fShort; private ushort?[] _fuShort; - private sbyte?[] _fsByte; private byte?[] _fByte; - private long?[] _fLong; private ulong?[] _fuLong; private float?[] _fFloat; private double?[] _fDouble; private bool?[] _fBool; public string[] StringProp { get { return _fString; } set { _fString = value; } } - public int?[] IntProp { get { return _fInt; } set { _fInt = value; } } public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } - public short?[] ShortProp { get { return _fShort; } set { _fShort = value; } } public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } - public sbyte?[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } } - public long?[] LongProp { get { return _fLong; } set { _fLong = value; } } public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } } public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } } @@ -940,20 +876,16 @@ public void RoundTripConversionWithArrayPropertiess() { new ClassWithNullableArrayProperties() { - IntProp = new int?[3] { null, -1, 1 }, SingleProp = new float?[3] { -0.99f, null, 0.99f }, StringProp = new string[2] { null, "" }, BoolProp = new bool?[3] { true, null, false }, ByteProp = new byte?[4] { 0, 125, null, 255 }, DoubleProp = new double?[3] { -1, null, 1 }, - LongProp = new long?[] { null, -1, 1 }, - SByteProp = new sbyte?[3] { -127, 127, null }, - ShortProp = new short?[3] { 0, null, 32767 }, UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue }, ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 }, UShortProp = new ushort?[3] { 0, null, ushort.MaxValue } }, - new ClassWithNullableArrayProperties() { IntProp = new int?[3] { -2, 1, 0 }, SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } }, + new ClassWithNullableArrayProperties() { SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } }, new ClassWithNullableArrayProperties() }; From a6e5c75a3deaf7b10cacb005b9479e10aa203eaf Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 16 Aug 2018 14:49:20 -0700 Subject: [PATCH 10/37] PR feedback. --- .../DataViewConstructionUtils.cs | 14 ++--- src/Microsoft.ML.Api/TypedCursor.cs | 16 ++--- .../CommandLine/CmdParser.cs | 14 ++--- src/Microsoft.ML.Core/Data/DateTime.cs | 2 +- src/Microsoft.ML.Core/Utilities/BinFinder.cs | 2 +- src/Microsoft.ML.Core/Utilities/Stream.cs | 6 +- .../DataLoadSave/Binary/CodecFactory.cs | 4 +- .../DataLoadSave/Binary/Codecs.cs | 4 +- .../DataLoadSave/Binary/UnsafeTypeOps.cs | 58 +++++++++---------- .../Evaluators/AnomalyDetectionEvaluator.cs | 6 +- .../Transforms/ConcatTransform.cs | 8 +-- .../Transforms/DropSlotsTransform.cs | 20 +++---- .../Transforms/GenerateNumberTransform.cs | 4 +- .../Transforms/KeyToVectorTransform.cs | 8 +-- src/Microsoft.ML.FastTree/Dataset/Dataset.cs | 2 +- .../Utils/PseudorandomFunction.cs | 2 +- .../RecipeInference.cs | 4 +- .../Standard/ModelStatistics.cs | 2 +- src/Microsoft.ML.Sweeper/Parameters.cs | 8 +-- .../MutualInformationFeatureSelection.cs | 10 ++-- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 38 ++++++------ .../UnitTests/CoreBaseTestClass.cs | 16 ++--- .../UnitTests/DvTypes.cs | 2 +- .../TestTransposer.cs | 28 ++++----- .../DataPipe/TestDataPipeBase.cs | 16 ++--- .../TestSparseDataView.cs | 4 +- 26 files changed, 149 insertions(+), 149 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 98c0128102..7a801d2815 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -134,22 +134,22 @@ private Delegate CreateGetter(int index) else if (outputType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(bool)) { @@ -204,13 +204,13 @@ private Delegate CreateGetter(int index) { // int -> int Ch.Assert(colType == NumberType.I4); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(short)) { // short -> short Ch.Assert(colType == NumberType.I2); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(long)) { @@ -222,7 +222,7 @@ private Delegate CreateGetter(int index) { // sbyte -> sbyte Ch.Assert(colType == NumberType.I1); - return CreateConvertingGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 437236e63c..004c7bbaa5 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -292,22 +292,22 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit else if (fieldType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); } else if (fieldType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } else if (fieldType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); } // VBuffer -> T[] @@ -353,25 +353,25 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(short)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(long)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); + return CreateConvertingActionSetter(input, index, poke, x => x); } else if (fieldType == typeof(sbyte)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); + return CreateConvertingActionSetter(input, index, poke, x => x); } // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index 191321213f..650b2b2ab8 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -517,23 +517,23 @@ public static int GetConsoleWindowWidth() private struct Coord { - internal Int16 X; - internal Int16 Y; + internal short X; + internal short Y; } private struct SmallRect { - internal Int16 Left; - internal Int16 Top; - internal Int16 Right; - internal Int16 Bottom; + internal short Left; + internal short Top; + internal short Right; + internal short Bottom; } private struct ConsoleScreenBufferInfo { internal Coord DwSize; internal Coord DwCursorPosition; - internal Int16 WAttributes; + internal short WAttributes; internal SmallRect SrWindow; internal Coord DwMaximumWindowSize; } diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs index d11be2a494..53689eb0ae 100644 --- a/src/Microsoft.ML.Core/Data/DateTime.cs +++ b/src/Microsoft.ML.Core/Data/DateTime.cs @@ -242,7 +242,7 @@ private static DvDateTime ValidateDate(DvDateTime dateTime, ref DvInt2 offset) Contracts.Assert(MinMinutesOffset <= offset.RawValue && offset.RawValue <= MaxMinutesOffset); var offsetTicks = offset.RawValue * TicksPerMinute; // This operation cannot overflow because offset should have already been validated to be within - // 14 hours and the DateTime instance is more than that distance from the boundaries of Int64. + // 14 hours and the DateTime instance is more than that distance from the boundaries of long. long utcTicks = dateTime.Ticks.RawValue - offsetTicks; var dvdt = new DvDateTime(utcTicks); if (dvdt.IsNA) diff --git a/src/Microsoft.ML.Core/Utilities/BinFinder.cs b/src/Microsoft.ML.Core/Utilities/BinFinder.cs index 7d348af1a0..b3aa759fd3 100644 --- a/src/Microsoft.ML.Core/Utilities/BinFinder.cs +++ b/src/Microsoft.ML.Core/Utilities/BinFinder.cs @@ -525,7 +525,7 @@ private void UpdatePeg(Peg peg) namespace Microsoft.ML.Runtime.Internal.Utilities { - // Reasonable choices are Double and System.Int64. + // Reasonable choices are Double and System.long. using EnergyType = System.Double; // Uses dynamic programming. diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 832d8abfdf..2cb5e0baca 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -427,7 +427,7 @@ public static void WriteBitArray(this BinaryWriter writer, BitArray arr) } } - public static long WriteSByteStream(this BinaryWriter writer, IEnumerable e) + public static long WriteSByteStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -471,7 +471,7 @@ public static long WriteUIntStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteShortStream(this BinaryWriter writer, IEnumerable e) + public static long WriteShortStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -493,7 +493,7 @@ public static long WriteUShortStream(this BinaryWriter writer, IEnumerable e) + public static long WriteLongStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs index 91470f5b6b..5793a0b129 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs @@ -46,11 +46,11 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null) // Register the current codecs. RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index f840773872..bea1900fe2 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -639,7 +639,7 @@ public override void Commit() public override long GetCommitLengthEstimate() { - return _numWritten * sizeof(Int64); + return _numWritten * sizeof(long); } } @@ -740,7 +740,7 @@ public override void Commit() public override long GetCommitLengthEstimate() { - return (long)_offsets.Count * (sizeof(Int64) + sizeof(Int16)); + return (long)_offsets.Count * (sizeof(long) + sizeof(short)); } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs index 9930ce2974..fa1c4dc8e6 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs @@ -32,13 +32,13 @@ internal static class UnsafeTypeOpsFactory static UnsafeTypeOpsFactory() { _type2ops = new Dictionary(); - _type2ops[typeof(SByte)] = new SByteUnsafeTypeOps(); + _type2ops[typeof(sbyte)] = new SByteUnsafeTypeOps(); _type2ops[typeof(Byte)] = new ByteUnsafeTypeOps(); - _type2ops[typeof(Int16)] = new Int16UnsafeTypeOps(); + _type2ops[typeof(short)] = new Int16UnsafeTypeOps(); _type2ops[typeof(UInt16)] = new UInt16UnsafeTypeOps(); - _type2ops[typeof(Int32)] = new Int32UnsafeTypeOps(); + _type2ops[typeof(int)] = new Int32UnsafeTypeOps(); _type2ops[typeof(UInt32)] = new UInt32UnsafeTypeOps(); - _type2ops[typeof(Int64)] = new Int64UnsafeTypeOps(); + _type2ops[typeof(long)] = new Int64UnsafeTypeOps(); _type2ops[typeof(UInt64)] = new UInt64UnsafeTypeOps(); _type2ops[typeof(Single)] = new SingleUnsafeTypeOps(); _type2ops[typeof(Double)] = new DoubleUnsafeTypeOps(); @@ -51,16 +51,16 @@ public static UnsafeTypeOps Get() return (UnsafeTypeOps)_type2ops[typeof(T)]; } - private sealed class SByteUnsafeTypeOps : UnsafeTypeOps + private sealed class SByteUnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(SByte); } } - public override unsafe void Apply(SByte[] array, Action func) + public override int Size { get { return sizeof(sbyte); } } + public override unsafe void Apply(sbyte[] array, Action func) { - fixed (SByte* pArray = array) + fixed (sbyte* pArray = array) func(new IntPtr(pArray)); } - public override void Write(SByte a, BinaryWriter writer) { writer.Write(a); } - public override SByte Read(BinaryReader reader) { return reader.ReadSByte(); } + public override void Write(sbyte a, BinaryWriter writer) { writer.Write(a); } + public override sbyte Read(BinaryReader reader) { return reader.ReadSByte(); } } private sealed class ByteUnsafeTypeOps : UnsafeTypeOps @@ -75,16 +75,16 @@ public override unsafe void Apply(Byte[] array, Action func) public override Byte Read(BinaryReader reader) { return reader.ReadByte(); } } - private sealed class Int16UnsafeTypeOps : UnsafeTypeOps + private sealed class Int16UnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int16); } } - public override unsafe void Apply(Int16[] array, Action func) + public override int Size { get { return sizeof(short); } } + public override unsafe void Apply(short[] array, Action func) { - fixed (Int16* pArray = array) + fixed (short* pArray = array) func(new IntPtr(pArray)); } - public override void Write(Int16 a, BinaryWriter writer) { writer.Write(a); } - public override Int16 Read(BinaryReader reader) { return reader.ReadInt16(); } + public override void Write(short a, BinaryWriter writer) { writer.Write(a); } + public override short Read(BinaryReader reader) { return reader.ReadInt16(); } } private sealed class UInt16UnsafeTypeOps : UnsafeTypeOps @@ -99,16 +99,16 @@ public override unsafe void Apply(UInt16[] array, Action func) public override UInt16 Read(BinaryReader reader) { return reader.ReadUInt16(); } } - private sealed class Int32UnsafeTypeOps : UnsafeTypeOps + private sealed class Int32UnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int32); } } - public override unsafe void Apply(Int32[] array, Action func) + public override int Size { get { return sizeof(int); } } + public override unsafe void Apply(int[] array, Action func) { - fixed (Int32* pArray = array) + fixed (int* pArray = array) func(new IntPtr(pArray)); } - public override void Write(Int32 a, BinaryWriter writer) { writer.Write(a); } - public override Int32 Read(BinaryReader reader) { return reader.ReadInt32(); } + public override void Write(int a, BinaryWriter writer) { writer.Write(a); } + public override int Read(BinaryReader reader) { return reader.ReadInt32(); } } private sealed class UInt32UnsafeTypeOps : UnsafeTypeOps @@ -123,16 +123,16 @@ public override unsafe void Apply(UInt32[] array, Action func) public override UInt32 Read(BinaryReader reader) { return reader.ReadUInt32(); } } - private sealed class Int64UnsafeTypeOps : UnsafeTypeOps + private sealed class Int64UnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int64); } } - public override unsafe void Apply(Int64[] array, Action func) + public override int Size { get { return sizeof(long); } } + public override unsafe void Apply(long[] array, Action func) { - fixed (Int64* pArray = array) + fixed (long* pArray = array) func(new IntPtr(pArray)); } - public override void Write(Int64 a, BinaryWriter writer) { writer.Write(a); } - public override Int64 Read(BinaryReader reader) { return reader.ReadInt64(); } + public override void Write(long a, BinaryWriter writer) { writer.Write(a); } + public override long Read(BinaryReader reader) { return reader.ReadInt64(); } } private sealed class UInt64UnsafeTypeOps : UnsafeTypeOps @@ -173,7 +173,7 @@ public override unsafe void Apply(Double[] array, Action func) private sealed class DvTimeSpanUnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int64); } } + public override int Size { get { return sizeof(long); } } public override unsafe void Apply(DvTimeSpan[] array, Action func) { fixed (DvTimeSpan* pArray = array) diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 253619249e..d5f599fab1 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -136,7 +136,7 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A var thresholdAtK = new List(); var thresholdAtP = new List(); var thresholdAtNumAnomalies = new List(); - var numAnoms = new List(); + var numAnoms = new List(); var scores = new List(); var labels = new List(); @@ -678,11 +678,11 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary col == numAnomIndex || (hasStrat && col == stratCol))) { - var numAnomGetter = cursor.GetGetter(numAnomIndex); + var numAnomGetter = cursor.GetGetter(numAnomIndex); ValueGetter stratGetter = null; if (hasStrat) { diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index c4ca1a0039..e47d1780f4 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -404,7 +404,7 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal if (_typesCategoricals[iinfo] == null) throw MetadataUtils.ExceptGetMetadata(); - MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); + MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); break; case MetadataUtils.Kinds.IsNormalized: if (!_isNormalized[iinfo]) @@ -417,9 +417,9 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal } } - private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) { - List allValues = new List(); + List allValues = new List(); int slotCount = 0; for (int i = 0; i < Infos[iiinfo].SrcIndices.Length; i++) { @@ -440,7 +440,7 @@ private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) Contracts.Assert(allValues.Count > 0); - dst = new VBuffer(allValues.Count, allValues.ToArray()); + dst = new VBuffer(allValues.Count, allValues.ToArray()); } private void IsNormalized(int iinfo, ref DvBool dst) diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index dfe61cddbb..ef9791f7d3 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -393,14 +393,14 @@ private void ComputeType(ISchema input, int[] slotsMin, int[] slotsMax, int iinf { if (MetadataUtils.TryGetCategoricalFeatureIndices(Source.Schema, Infos[iinfo].Source, out categoricalRanges)) { - VBuffer dst = default(VBuffer); + VBuffer dst = default(VBuffer); GetCategoricalSlotRangesCore(iinfo, slotDropper.SlotsMin, slotDropper.SlotsMax, categoricalRanges, ref dst); // REVIEW: cache dst as opposed to caculating it again. if (dst.Length > 0) { Contracts.Assert(dst.Length % 2 == 0); - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, + bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.GetCategoricalType(dst.Length / 2), GetCategoricalSlotRanges); } } @@ -443,7 +443,7 @@ private void GetSlotNames(int iinfo, ref VBuffer dst) infoEx.SlotDropper.DropSlots(ref names, ref dst); } - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) { if (_exes[iinfo].CategoricalRanges != null) { @@ -452,7 +452,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) } } - private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer dst) + private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.Assert(slotsMax != null && slotsMin != null); @@ -467,9 +467,9 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots int previousDropSlotsIndex = 0; int droppedSlotsCount = 0; bool combine = false; - Int32 min = -1; - Int32 max = -1; - List newCategoricalSlotRanges = new List(); + int min = -1; + int max = -1; + List newCategoricalSlotRanges = new List(); // Six possible ways a drop slot range interacts with categorical slots range. // @@ -586,12 +586,12 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots Contracts.Assert(0 <= droppedSlotsCount && droppedSlotsCount <= slotsMax[slotsMax.Length - 1] + 1); if (newCategoricalSlotRanges.Count > 0) - dst = new VBuffer(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray()); + dst = new VBuffer(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray()); } private void CombineRanges( - Int32 minRange1, Int32 maxRange1, Int32 minRange2, Int32 maxRange2, - out Int32 newRangeMin, out Int32 newRangeMax) + int minRange1, int maxRange1, int minRange2, int maxRange2, + out int newRangeMin, out int newRangeMax) { Contracts.Assert(minRange2 >= 0 && maxRange2 >= 0); Contracts.Assert(minRange2 <= maxRange2); diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index 3d8823d5a1..3584531848 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -430,9 +430,9 @@ public ValueGetter GetGetter(int col) return fn; } - private ValueGetter MakeGetter() + private ValueGetter MakeGetter() { - return (ref Int64 value) => + return (ref long value) => { Ch.Check(IsGood); value = Input.Position; diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 0a1f4ce283..6671d0b033 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -319,7 +319,7 @@ private static void ComputeType(KeyToVectorTransform trans, ISchema input, int i if (!bag && info.TypeSrc.ValueCount > 0) { - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, + bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.GetCategoricalType(info.TypeSrc.ValueCount), trans.GetCategoricalSlotRanges); } @@ -334,7 +334,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _types[iinfo]; } - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); @@ -342,7 +342,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) Host.Assert(info.TypeSrc.ValueCount > 0); - Int32[] ranges = new Int32[info.TypeSrc.ValueCount * 2]; + int[] ranges = new int[info.TypeSrc.ValueCount * 2]; int size = info.TypeSrc.ItemType.KeyCount; ranges[0] = 0; @@ -353,7 +353,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) ranges[i + 1] = ranges[i] + size - 1; } - dst = new VBuffer(ranges.Length, ranges); + dst = new VBuffer(ranges.Length, ranges); } // Used for slot names when appropriate. diff --git a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs index f31ce73a94..b1b24bd4a1 100644 --- a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs +++ b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs @@ -609,7 +609,7 @@ public int[][] GetAssignments(double[] fraction, int randomSeed, out int[][] ass for (int i = 0; i < numParts; ++i) { cumulative += fraction[i]; - thresh[i] = (int)(cumulative * Int32.MaxValue); + thresh[i] = (int)(cumulative * int.MaxValue); if (fraction[i] == 0.0) thresh[i]--; } diff --git a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs index 43304b22f7..f94701856a 100644 --- a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs +++ b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs @@ -19,7 +19,7 @@ public sealed class PseudorandomFunction public PseudorandomFunction(Random rand) { - _data = _periodics.Select(x => Enumerable.Range(0, x).Select(y => rand.Next(-1, Int32.MaxValue) + 1).ToArray()).ToArray(); + _data = _periodics.Select(x => Enumerable.Range(0, x).Select(y => rand.Next(-1, int.MaxValue) + 1).ToArray()).ToArray(); } public int Apply(ulong seed) diff --git a/src/Microsoft.ML.PipelineInference/RecipeInference.cs b/src/Microsoft.ML.PipelineInference/RecipeInference.cs index 2d94a099d3..2ca74a5279 100644 --- a/src/Microsoft.ML.PipelineInference/RecipeInference.cs +++ b/src/Microsoft.ML.PipelineInference/RecipeInference.cs @@ -202,7 +202,7 @@ protected override IEnumerable ApplyCore(Type predictorType, TransformInference.SuggestedTransform[] transforms) { yield return - new SuggestedRecipe(ToString(), transforms, new SuggestedRecipe.SuggestedLearner[0], Int32.MinValue + 1); + new SuggestedRecipe(ToString(), transforms, new SuggestedRecipe.SuggestedLearner[0], int.MinValue + 1); } public override string ToString() => "Default transforms"; @@ -251,7 +251,7 @@ protected override IEnumerable ApplyCore(Type predictorType, } yield return - new SuggestedRecipe(ToString(), transforms, new[] { learner }, Int32.MaxValue); + new SuggestedRecipe(ToString(), transforms, new[] { learner }, int.MaxValue); } public override string ToString() => "Text classification optimized for speed and accuracy"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index dd2a117051..7b987d4d71 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -414,7 +414,7 @@ public void AddStatsColumns(List list, LinearBinaryPredictor parent, Ro _env.AssertValueOrNull(parent); _env.AssertValue(schema); - Int64 count = _trainingExampleCount; + long count = _trainingExampleCount; list.Add(RowColumnUtils.GetColumn("Count of training examples", NumberType.I8, ref count)); var dev = _deviance; list.Add(RowColumnUtils.GetColumn("Residual Deviance", NumberType.R4, ref dev)); diff --git a/src/Microsoft.ML.Sweeper/Parameters.cs b/src/Microsoft.ML.Sweeper/Parameters.cs index dd46374732..6f78bcf521 100644 --- a/src/Microsoft.ML.Sweeper/Parameters.cs +++ b/src/Microsoft.ML.Sweeper/Parameters.cs @@ -588,7 +588,7 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam } if (option.StartsWith("steps")) { - numSteps = Int32.Parse(option.Substring(option.IndexOf(':') + 1)); + numSteps = int.Parse(option.Substring(option.IndexOf(':') + 1)); optionsSpecified[1] = true; } if (option.StartsWith("inc")) @@ -613,9 +613,9 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam if (paramType == typeof(UInt16) || paramType == typeof(UInt32) || paramType == typeof(UInt64) - || paramType == typeof(Int16) - || paramType == typeof(Int32) - || paramType == typeof(Int64)) + || paramType == typeof(short) + || paramType == typeof(int) + || paramType == typeof(long)) { long min; long max; diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index dcd93d5039..58fc29b67a 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -407,7 +407,7 @@ private void GetLabels(Transposer trans, ColumnType labelType, int labelCol) // Note: NAs have their own separate bin. if (labelType == NumberType.I4) { - var tmp = default(VBuffer); + var tmp = default(VBuffer); trans.GetSingleSlotValue(labelCol, ref tmp); BinInts(ref tmp, ref labels, _numBins, out min, out lim); _numLabels = lim - min; @@ -486,7 +486,7 @@ private Single[] ComputeMutualInformation(Transposer trans, int col) if (type.ItemType == NumberType.I4) { return ComputeMutualInformation(trans, col, - (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => + (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => { BinInts(ref src, ref dst, _numBins, out min, out lim); }); @@ -676,7 +676,7 @@ private static ValueMapper, VBuffer> BinKeys(ColumnType colTy /// /// Maps Ints. /// - private void BinInts(ref VBuffer input, ref VBuffer output, + private void BinInts(ref VBuffer input, ref VBuffer output, int numBins, out int min, out int lim) { Contracts.Assert(_singles.Count == 0); @@ -685,8 +685,8 @@ private void BinInts(ref VBuffer input, ref VBuffer output, min = -1 - bounds.FindIndexSorted(0); lim = min + bounds.Length + 1; int offset = min; - ValueMapper mapper = - (ref Int32 src, ref int dst) => + ValueMapper mapper = + (ref int src, ref int dst) => dst = offset + 1 + bounds.FindIndexSorted((Single)src); mapper.MapVector(ref input, ref output); _singles.Clear(); diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index fd2ad40af5..aac007ca1d 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -1229,14 +1229,14 @@ private static class I4 { private const long MaxVal = int.MaxValue; - public sealed class MeanAggregatorOne : StatAggregator + public sealed class MeanAggregatorOne : StatAggregator { public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessRow(ref Int32? val) + protected override void ProcessRow(ref int? val) { Stat.Update(val, MaxVal); } @@ -1249,14 +1249,14 @@ public override object GetStat() } } - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots + public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessValue(ref Int32? val) + protected override void ProcessValue(ref int? val) { Stat.Update(val, MaxVal); } @@ -1269,14 +1269,14 @@ public override object GetStat() } } - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot + public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) : base(ch, type, cursor, col) { } - protected override void ProcessValue(ref Int32? val, int slot) + protected override void ProcessValue(ref int? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); Stat[slot].Update(val, MaxVal); @@ -1284,7 +1284,7 @@ protected override void ProcessValue(ref Int32? val, int slot) public override object GetStat() { - Int32[] stat = new Int32[Stat.Length]; + int[] stat = new int[Stat.Length]; for (int slot = 0; slot < stat.Length; slot++) { long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); @@ -1295,7 +1295,7 @@ public override object GetStat() } } - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne + public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1303,13 +1303,13 @@ public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnM Stat = (int)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref Int32? val) + protected override void ProcessValueMin(ref int? val) { if (val.HasValue && val < Stat) Stat = val.Value; } - protected override void ProcessValueMax(ref Int32? val) + protected override void ProcessValueMax(ref int? val) { if (val.HasValue && val > Stat) Stat = val.Value; @@ -1321,7 +1321,7 @@ public override object GetStat() } } - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots + public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) @@ -1329,13 +1329,13 @@ public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool Stat = (int)(ReturnMax ? -MaxVal : MaxVal); } - protected override void ProcessValueMin(ref Int32? val) + protected override void ProcessValueMin(ref int? val) { if (val.HasValue && val < Stat) Stat = val.Value; } - protected override void ProcessValueMax(ref Int32? val) + protected override void ProcessValueMax(ref int? val) { if (val.HasValue && val > Stat) Stat = val.Value; @@ -1346,14 +1346,14 @@ public override object GetStat() // If sparsity occurred, fold in a zero. if (ValueCount > (ulong)ValuesProcessed) { - var def = default(Int32?); + var def = default(int?); ProcValueDelegate(ref def); } return Stat; } } - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot + public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) @@ -1363,14 +1363,14 @@ public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, i Stat[i] = bound; } - protected override void ProcessValueMin(ref Int32? val, int slot) + protected override void ProcessValueMin(ref int? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); if (val.HasValue && val < Stat[slot]) Stat[slot] = val.Value; } - protected override void ProcessValueMax(ref Int32? val, int slot) + protected override void ProcessValueMax(ref int? val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); if (val.HasValue && val > Stat[slot]) @@ -1379,13 +1379,13 @@ protected override void ProcessValueMax(ref Int32? val, int slot) public override object GetStat() { - Int32[] stat = new Int32[Stat.Length]; + int[] stat = new int[Stat.Length]; // Account for defaults resulting from sparsity. for (int slot = 0; slot < Stat.Length; slot++) { if (GetValuesProcessed(slot) < RowCount) { - var def = default(Int32?); + var def = default(int?); ProcValueDelegate(ref def, slot); } stat[slot] = Stat[slot]; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs index 813e56e545..3fd6789acb 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs @@ -153,19 +153,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -196,19 +196,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs index d332ade37c..309f8402cb 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs @@ -16,7 +16,7 @@ public void TestComparableInt32() const int count = 100; var rand = RandomUtils.Create(42); - var values = new Int32?[2 * count]; + var values = new int?[2 * count]; for (int i = 0; i < count; i++) { var v = values[i] = rand.Next(); diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs index 35661bbb15..701adea39b 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs @@ -148,19 +148,19 @@ public void TransposerTest() ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env); // A is to check the splitting of a sparse-ish column. - var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (Int32)rgen.Next(), 50, 5, 10, 15); - dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. + var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (int)rgen.Next(), 50, 5, 10, 15); + dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. builder.AddColumn("A", NumberType.I4, dataA); // B is to check the splitting of a dense-ish column. builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49)); // C is to just have some column we do nothing with. - builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (Int16)1, 30, 3, 10, 24)); + builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (short)1, 30, 3, 10, 24)); // D is to check some column we don't have to split because it's sufficiently small. builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1)); // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (Int32)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (int)rgen.Next())); IDataView view = builder.GetDataView(); @@ -181,11 +181,11 @@ public void TransposerTest() } // Check the contents Assert.Null(trans.TransposeSchema.GetSlotType(2)); // C check to see that it's not transposable. - TransposeCheckHelper(view, 0, trans); // A check. + TransposeCheckHelper(view, 0, trans); // A check. TransposeCheckHelper(view, 1, trans); // B check. TransposeCheckHelper(view, 3, trans); // D check. TransposeCheckHelper(view, 4, trans); // E check. - TransposeCheckHelper(view, 5, trans); // F check. + TransposeCheckHelper(view, 5, trans); // F check. } // Force save. Recheck columns that would have previously been passthrough columns. @@ -200,7 +200,7 @@ public void TransposerTest() Assert.Null(trans.TransposeSchema.GetSlotType(2)); TransposeCheckHelper(view, 3, trans); // D check. TransposeCheckHelper(view, 4, trans); // E check. - TransposeCheckHelper(view, 5, trans); // F check. + TransposeCheckHelper(view, 5, trans); // F check. } } @@ -213,19 +213,19 @@ public void TransposerSaverLoaderTest() ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env); // A is to check the splitting of a sparse-ish column. - var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (Int32)rgen.Next(), 50, 5, 10, 15); - dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. + var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (int)rgen.Next(), 50, 5, 10, 15); + dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. builder.AddColumn("A", NumberType.I4, dataA); // B is to check the splitting of a dense-ish column. builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49)); // C is to just have some column we do nothing with. - builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (Int16)1, 30, 3, 10, 24)); + builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (short)1, 30, 3, 10, 24)); // D is to check some column we don't have to split because it's sufficiently small. builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1)); // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (Int32)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (int)rgen.Next())); IDataView view = builder.GetDataView(); @@ -240,12 +240,12 @@ public void TransposerSaverLoaderTest() // First check whether this as an IDataView yields the same values. CheckSameValues(view, loader); - TransposeCheckHelper(view, 0, loader); // A + TransposeCheckHelper(view, 0, loader); // A TransposeCheckHelper(view, 1, loader); // B - TransposeCheckHelper(view, 2, loader); // C + TransposeCheckHelper(view, 2, loader); // C TransposeCheckHelper(view, 3, loader); // D TransposeCheckHelper(view, 4, loader); // E - TransposeCheckHelper(view, 5, loader); // F + TransposeCheckHelper(view, 5, loader); // F Done(); } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index e53bc1e449..951f95051b 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -879,19 +879,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -922,19 +922,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: diff --git a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs index f3bf775921..37db8a26f4 100644 --- a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs +++ b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs @@ -35,7 +35,7 @@ private class SparseExample public void SparseDataView() { GenericSparseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); - GenericSparseDataView(new Int32[] { 1, 2, 3 }, new Int32[] { 1, 10, 100 }); + GenericSparseDataView(new int[] { 1, 2, 3 }, new int[] { 1, 10, 100 }); GenericSparseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); GenericSparseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); GenericSparseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, @@ -77,7 +77,7 @@ private void GenericSparseDataView(T[] v1, T[] v2) public void DenseDataView() { GenericDenseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); - GenericDenseDataView(new Int32[] { 1, 2, 3 }, new Int32[] { 1, 10, 100 }); + GenericDenseDataView(new int[] { 1, 2, 3 }, new int[] { 1, 10, 100 }); GenericDenseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); GenericDenseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); GenericDenseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, From 9aea10bf69eb39d85a84628edd1e5152954fa71a Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 16 Aug 2018 15:34:49 -0700 Subject: [PATCH 11/37] clean up. --- src/Microsoft.ML.Core/Utilities/BinFinder.cs | 2 +- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 514 ------------------ 2 files changed, 1 insertion(+), 515 deletions(-) diff --git a/src/Microsoft.ML.Core/Utilities/BinFinder.cs b/src/Microsoft.ML.Core/Utilities/BinFinder.cs index b3aa759fd3..7d348af1a0 100644 --- a/src/Microsoft.ML.Core/Utilities/BinFinder.cs +++ b/src/Microsoft.ML.Core/Utilities/BinFinder.cs @@ -525,7 +525,7 @@ private void UpdatePeg(Peg peg) namespace Microsoft.ML.Runtime.Internal.Utilities { - // Reasonable choices are Double and System.long. + // Reasonable choices are Double and System.Int64. using EnergyType = System.Double; // Uses dynamic programming. diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index aac007ca1d..acb7602c2d 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -881,520 +881,6 @@ public override object GetStat() } } - private static class I1 - { - // Utilizes MeanStatInt for the mean aggregators of all IX types, TS, and DT. - - private const long MaxVal = sbyte.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessRow(ref sbyte? val) - { - Stat.Update(val, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - return (sbyte)val; - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessValue(ref sbyte? val) - { - Stat.Update(val, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - return (sbyte)val; - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - } - - protected override void ProcessValue(ref sbyte? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val, MaxVal); - } - - public override object GetStat() - { - sbyte[] stat = new sbyte[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - stat[slot] = (sbyte)val; - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (sbyte)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref sbyte? val) - { - var raw = val; - if (raw.HasValue && raw < Stat) - Stat = raw.Value; - } - - protected override void ProcessValueMax(ref sbyte? val) - { - var raw = val; - if (raw > Stat) - Stat = raw.Value; - } - - public override object GetStat() - { - return (sbyte)Stat; - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (sbyte)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref sbyte? val) - { - if (val.HasValue && val < Stat) - Stat = val.Value; - } - - protected override void ProcessValueMax(ref sbyte? val) - { - if (val.HasValue && val > Stat) - Stat = val.Value; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - var def = default(sbyte?); - ProcValueDelegate(ref def); - } - return Stat; - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - sbyte bound = (sbyte)(ReturnMax ? -MaxVal : MaxVal); - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - } - - protected override void ProcessValueMin(ref sbyte? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - if (val.HasValue && val.Value < Stat[slot]) - Stat[slot] = val.Value; - } - - protected override void ProcessValueMax(ref sbyte? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - if (val.HasValue && val.Value > Stat[slot]) - Stat[slot] = val.Value; - } - - public override object GetStat() - { - sbyte[] stat = new sbyte[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(sbyte?); - ProcValueDelegate(ref def, slot); - } - stat[slot] = Stat[slot]; - } - return stat; - } - } - } - - private static class I2 - { - private const long MaxVal = short.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessRow(ref short? val) - { - Stat.Update(val, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - return (short)val; - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessValue(ref short? val) - { - Stat.Update(val, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - return (short)val; - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - } - - protected override void ProcessValue(ref short? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val, MaxVal); - } - - public override object GetStat() - { - short[] stat = new short[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - stat[slot] = (short)val; - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (short)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref short? val) - { - if (val.HasValue && val < Stat) - Stat = val.Value; - } - - protected override void ProcessValueMax(ref short? val) - { - if (val.HasValue && val > Stat) - Stat = val.Value; - } - - public override object GetStat() - { - return Stat; - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (short)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref short? val) - { - if (val.HasValue && val < Stat) - Stat = val.Value; - } - - protected override void ProcessValueMax(ref short? val) - { - if (val.HasValue && val > Stat) - Stat = val.Value; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - var def = default(short?); - ProcValueDelegate(ref def); - } - return Stat; - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - short bound = (short)(ReturnMax ? -MaxVal : MaxVal); - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - } - - protected override void ProcessValueMin(ref short? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - if (val.HasValue && val < Stat[slot]) - Stat[slot] = val.Value; - } - - protected override void ProcessValueMax(ref short? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - if (val.HasValue && val > Stat[slot]) - Stat[slot] = val.Value; - } - - public override object GetStat() - { - short[] stat = new short[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(short?); - ProcValueDelegate(ref def, slot); - } - stat[slot] = Stat[slot]; - } - return stat; - } - } - } - - private static class I4 - { - private const long MaxVal = int.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessRow(ref int? val) - { - Stat.Update(val, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - return (int)val; - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessValue(ref int? val) - { - Stat.Update(val, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - return (int)val; - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - } - - protected override void ProcessValue(ref int? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val, MaxVal); - } - - public override object GetStat() - { - int[] stat = new int[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal <= val && val <= MaxVal); - stat[slot] = (int)val; - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (int)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref int? val) - { - if (val.HasValue && val < Stat) - Stat = val.Value; - } - - protected override void ProcessValueMax(ref int? val) - { - if (val.HasValue && val > Stat) - Stat = val.Value; - } - - public override object GetStat() - { - return Stat; - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (int)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref int? val) - { - if (val.HasValue && val < Stat) - Stat = val.Value; - } - - protected override void ProcessValueMax(ref int? val) - { - if (val.HasValue && val > Stat) - Stat = val.Value; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - var def = default(int?); - ProcValueDelegate(ref def); - } - return Stat; - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - int bound = (int)(ReturnMax ? -MaxVal : MaxVal); - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - } - - protected override void ProcessValueMin(ref int? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - if (val.HasValue && val < Stat[slot]) - Stat[slot] = val.Value; - } - - protected override void ProcessValueMax(ref int? val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - if (val.HasValue && val > Stat[slot]) - Stat[slot] = val.Value; - } - - public override object GetStat() - { - int[] stat = new int[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(int?); - ProcValueDelegate(ref def, slot); - } - stat[slot] = Stat[slot]; - } - return stat; - } - } - } - private static class Long { private const long MaxVal = long.MaxValue; From 5d14c0a34af9a0316cf18aa74e258e3855b699eb Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 21 Aug 2018 13:03:30 -0700 Subject: [PATCH 12/37] PR feedback. --- .../DataViewConstructionUtils.cs | 44 -- src/Microsoft.ML.Core/Data/ColumnType.cs | 12 +- src/Microsoft.ML.Core/Data/DataKind.cs | 14 +- src/Microsoft.ML.Core/Data/DateTime.cs | 110 ++-- src/Microsoft.ML.Core/Data/DvInt1.cs | 264 --------- src/Microsoft.ML.Core/Data/DvInt2.cs | 263 --------- src/Microsoft.ML.Core/Data/DvInt4.cs | 456 ---------------- src/Microsoft.ML.Core/Data/DvInt8.cs | 511 ------------------ src/Microsoft.ML.Core/Utilities/Stream.cs | 72 ++- .../DataLoadSave/Binary/Codecs.cs | 28 +- .../DataLoadSave/Binary/UnsafeTypeOps.cs | 2 +- src/Microsoft.ML.Transforms/NAReplaceUtils.cs | 289 ---------- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 12 +- .../CollectionDataSourceTests.cs | 155 +----- .../LearningPipelineTests.cs | 32 +- .../Scenarios/Api/ApiScenariosTests.cs | 5 +- .../Scenarios/Api/SimpleTrainAndPredict.cs | 2 +- .../Scenarios/Api/TrainSaveModelAndPredict.cs | 2 +- 18 files changed, 126 insertions(+), 2147 deletions(-) delete mode 100644 src/Microsoft.ML.Core/Data/DvInt1.cs delete mode 100644 src/Microsoft.ML.Core/Data/DvInt2.cs delete mode 100644 src/Microsoft.ML.Core/Data/DvInt4.cs delete mode 100644 src/Microsoft.ML.Core/Data/DvInt8.cs diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 7a801d2815..9db8fc4f4a 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -131,26 +131,6 @@ private Delegate CreateGetter(int index) Ch.Assert(colType.ItemType.IsText); return CreateConvertingArrayGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); } - else if (outputType.GetElementType() == typeof(int)) - { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(long)) - { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(short)) - { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(sbyte)) - { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingArrayGetterDelegate(index, x => x); - } else if (outputType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); @@ -200,30 +180,6 @@ private Delegate CreateGetter(int index) Ch.Assert(colType.IsBool); return CreateConvertingGetterDelegate(index, x => x ?? DvBool.NA); } - else if (outputType == typeof(int)) - { - // int -> int - Ch.Assert(colType == NumberType.I4); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(short)) - { - // short -> short - Ch.Assert(colType == NumberType.I2); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(long)) - { - // long -> long - Ch.Assert(colType == NumberType.I8); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(sbyte)) - { - // sbyte -> sbyte - Ch.Assert(colType == NumberType.I1); - return CreateConvertingGetterDelegate(index, x => x); - } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType)); diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 12d0509352..96764d68f1 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -567,18 +567,14 @@ public static BoolType Instance get { if (_instance == null) - Interlocked.CompareExchange(ref _instance, new BoolType(DataKind.BL, "Bool"), null); + Interlocked.CompareExchange(ref _instance, new BoolType(), null); return _instance; } } - private readonly string _name; - - private BoolType(DataKind kind, string name) - : base(kind.ToType(), kind) + private BoolType() + : base(typeof(DvBool), DataKind.BL) { - Contracts.AssertNonEmpty(name); - _name = name; } public override bool Equals(ColumnType other) @@ -591,7 +587,7 @@ public override bool Equals(ColumnType other) public override string ToString() { - return _name; + return "Bool"; } } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index da0a4eaf7a..0a0707df48 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -187,27 +187,27 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) // REVIEW: Make this more efficient. Should we have a global dictionary? if (type == typeof(sbyte)) kind = DataKind.I1; - else if (type == typeof(byte) || type == typeof(byte?)) + else if (type == typeof(byte)) kind = DataKind.U1; else if (type == typeof(short)) kind = DataKind.I2; - else if (type == typeof(ushort) || type == typeof(ushort?)) + else if (type == typeof(ushort)) kind = DataKind.U2; else if (type == typeof(int)) kind = DataKind.I4; - else if (type == typeof(uint) || type == typeof(uint?)) + else if (type == typeof(uint)) kind = DataKind.U4; else if (type == typeof(long)) kind = DataKind.I8; - else if (type == typeof(ulong) || type == typeof(ulong?)) + else if (type == typeof(ulong)) kind = DataKind.U8; - else if (type == typeof(Single) || type == typeof(Single?)) + else if (type == typeof(Single)) kind = DataKind.R4; - else if (type == typeof(Double) || type == typeof(Double?)) + else if (type == typeof(Double)) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?)) + else if (type == typeof(DvBool)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs index 53689eb0ae..5bb2626d61 100644 --- a/src/Microsoft.ML.Core/Data/DateTime.cs +++ b/src/Microsoft.ML.Core/Data/DateTime.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Data public struct DvDateTime : IEquatable, IComparable { public const long MaxTicks = 3155378975999999999; - private readonly DvInt8 _ticks; + private readonly long _ticks; /// /// This ctor initializes _ticks to the value of sdt.Ticks, and ignores its DateTimeKind value. @@ -32,22 +32,19 @@ public DvDateTime(SysDateTime sdt) /// /// This ctor accepts any value for ticks, but produces an NA if ticks is out of the legal range. /// - public DvDateTime(DvInt8 ticks) + public DvDateTime(long ticks) { - if ((ulong)ticks.RawValue > MaxTicks) - _ticks = DvInt8.NA; - else - _ticks = ticks; + _ticks = ticks; AssertValid(); } [Conditional("DEBUG")] internal void AssertValid() { - Contracts.Assert((ulong)_ticks.RawValue <= MaxTicks || _ticks.IsNA); + Contracts.Assert((ulong)_ticks <= MaxTicks); } - public DvInt8 Ticks + public long Ticks { get { @@ -81,13 +78,13 @@ public bool IsNA get { AssertValid(); - return (ulong)_ticks.RawValue > MaxTicks; + return (ulong)_ticks > MaxTicks; } } public static DvDateTime NA { - get { return new DvDateTime(DvInt8.NA); } + get { return new DvDateTime(long.MinValue); } } public static explicit operator SysDateTime?(DvDateTime dvDt) @@ -124,12 +121,12 @@ internal SysDateTime GetSysDateTime() { AssertValid(); Contracts.Assert(!IsNA); - return new SysDateTime(_ticks.RawValue); + return new SysDateTime(_ticks); } public bool Equals(DvDateTime other) { - return _ticks.RawValue == other._ticks.RawValue; + return _ticks == other._ticks; } public override bool Equals(object obj) @@ -139,9 +136,9 @@ public override bool Equals(object obj) public int CompareTo(DvDateTime other) { - if (_ticks.RawValue == other._ticks.RawValue) + if (_ticks == other._ticks) return 0; - return _ticks.RawValue < other._ticks.RawValue ? -1 : 1; + return _ticks < other._ticks ? -1 : 1; } public override int GetHashCode() @@ -162,11 +159,11 @@ public struct DvDateTimeZone : IEquatable, IComparable /// The number of clock ticks in the date time portion /// The time zone offset in minutes - public DvDateTimeZone(DvInt8 ticks, DvInt2 offset) + public DvDateTimeZone(long ticks, short offset) { var dt = new DvDateTime(ticks); - if (dt.IsNA || offset.IsNA || MinMinutesOffset > offset.RawValue || offset.RawValue > MaxMinutesOffset) + if (MinMinutesOffset > offset || offset > MaxMinutesOffset) { _dateTime = DvDateTime.NA; - _offset = DvInt2.NA; + _offset = short.MinValue; } else { @@ -203,7 +200,6 @@ public DvDateTimeZone(SysDateTimeOffset dto) Contracts.Assert(success); _dateTime = ValidateDate(new DvDateTime(dto.DateTime), ref _offset); Contracts.Assert(!_dateTime.IsNA); - Contracts.Assert(!_offset.IsNA); AssertValid(); } @@ -217,7 +213,7 @@ public DvDateTimeZone(DvDateTime dt, DvTimeSpan offset) if (dt.IsNA || offset.IsNA || !TryValidateOffset(offset.Ticks, out _offset)) { _dateTime = DvDateTime.NA; - _offset = DvInt2.NA; + _offset = short.MinValue; } else _dateTime = ValidateDate(dt, ref _offset); @@ -233,20 +229,19 @@ public DvDateTimeZone(DvDateTime dt, DvTimeSpan offset) /// The offset. This value is assumed to be validated as a legal offset: /// a value in whole minutes, between -14 and 14 hours. /// The UTC DvDateTime representing the input clock time minus the offset - private static DvDateTime ValidateDate(DvDateTime dateTime, ref DvInt2 offset) + private static DvDateTime ValidateDate(DvDateTime dateTime, ref short offset) { Contracts.Assert(!dateTime.IsNA); - Contracts.Assert(!offset.IsNA); // Validate that both the UTC and clock times are legal. - Contracts.Assert(MinMinutesOffset <= offset.RawValue && offset.RawValue <= MaxMinutesOffset); - var offsetTicks = offset.RawValue * TicksPerMinute; + Contracts.Assert(MinMinutesOffset <= offset && offset <= MaxMinutesOffset); + var offsetTicks = offset * TicksPerMinute; // This operation cannot overflow because offset should have already been validated to be within - // 14 hours and the DateTime instance is more than that distance from the boundaries of long. - long utcTicks = dateTime.Ticks.RawValue - offsetTicks; + // 14 hours and the DateTime instance is more than that distance from the boundaries of Int64. + long utcTicks = dateTime.Ticks - offsetTicks; var dvdt = new DvDateTime(utcTicks); if (dvdt.IsNA) - offset = DvInt2.NA; + offset = short.MinValue; return dvdt; } @@ -257,23 +252,22 @@ private static DvDateTime ValidateDate(DvDateTime dateTime, ref DvInt2 offset) /// /// /// - private static bool TryValidateOffset(DvInt8 offsetTicks, out DvInt2 offset) + private static bool TryValidateOffset(long offsetTicks, out short offset) { - if (offsetTicks.IsNA || offsetTicks.RawValue % TicksPerMinute != 0) + if (offsetTicks % TicksPerMinute != 0) { - offset = DvInt2.NA; + offset = short.MinValue; return false; } - long mins = offsetTicks.RawValue / TicksPerMinute; + long mins = offsetTicks / TicksPerMinute; short res = (short)mins; if (res != mins || res > MaxMinutesOffset || res < MinMinutesOffset) { - offset = DvInt2.NA; + offset = short.MinValue; return false; } offset = res; - Contracts.Assert(!offset.IsNA); return true; } @@ -281,12 +275,10 @@ private static bool TryValidateOffset(DvInt8 offsetTicks, out DvInt2 offset) private void AssertValid() { _dateTime.AssertValid(); - if (_dateTime.IsNA) - Contracts.Assert(_offset.IsNA); - else + if (!_dateTime.IsNA) { - Contracts.Assert(MinMinutesOffset <= _offset.RawValue && _offset.RawValue <= MaxMinutesOffset); - Contracts.Assert((ulong)(_dateTime.Ticks.RawValue + _offset.RawValue * TicksPerMinute) + Contracts.Assert(MinMinutesOffset <= _offset && _offset <= MaxMinutesOffset); + Contracts.Assert((ulong)(_dateTime.Ticks + _offset * TicksPerMinute) <= (ulong)DvDateTime.MaxTicks); } } @@ -298,7 +290,7 @@ public DvDateTime ClockDateTime AssertValid(); if (_dateTime.IsNA) return DvDateTime.NA; - var res = new DvDateTime(_dateTime.Ticks.RawValue + _offset.RawValue * TicksPerMinute); + var res = new DvDateTime(_dateTime.Ticks + _offset * TicksPerMinute); Contracts.Assert(!res.IsNA); return res; } @@ -326,16 +318,14 @@ public DvTimeSpan Offset get { AssertValid(); - if (_offset.IsNA) - return DvTimeSpan.NA; - return new DvTimeSpan(_offset.RawValue * TicksPerMinute); + return new DvTimeSpan(_offset * TicksPerMinute); } } /// /// Gets the offset in minutes. /// - public DvInt2 OffsetMinutes + public short OffsetMinutes { get { @@ -392,7 +382,7 @@ public bool IsNA // and _offset = 0. public static DvDateTimeZone NA { - get { return new DvDateTimeZone(DvDateTime.NA, DvInt2.NA); } + get { return new DvDateTimeZone(DvDateTime.NA, short.MinValue); } } public static explicit operator SysDateTimeOffset?(DvDateTimeZone dvDto) @@ -427,7 +417,7 @@ private DateTimeOffset GetSysDateTimeOffset() { AssertValid(); Contracts.Assert(!IsNA); - return new SysDateTimeOffset(ClockDateTime.GetSysDateTime(), new TimeSpan(0, _offset.RawValue, 0)); + return new SysDateTimeOffset(ClockDateTime.GetSysDateTime(), new TimeSpan(0, _offset, 0)); } /// @@ -436,7 +426,7 @@ private DateTimeOffset GetSysDateTimeOffset() /// public bool Equals(DvDateTimeZone other) { - return _offset.RawValue == other._offset.RawValue && _dateTime.Equals(other._dateTime); + return _offset == other._offset && _dateTime.Equals(other._dateTime); } public override bool Equals(object obj) @@ -456,9 +446,9 @@ public int CompareTo(DvDateTimeZone other) int res = _dateTime.CompareTo(other._dateTime); if (res != 0) return res; - if (_offset.RawValue == other._offset.RawValue) + if (_offset == other._offset) return 0; - return _offset.RawValue < other._offset.RawValue ? -1 : 1; + return _offset < other._offset ? -1 : 1; } public override int GetHashCode() @@ -472,11 +462,11 @@ public override int GetHashCode() /// public struct DvTimeSpan : IEquatable, IComparable { - private readonly DvInt8 _ticks; + private readonly long _ticks; - public DvInt8 Ticks { get { return _ticks; } } + public long Ticks { get { return _ticks; } } - public DvTimeSpan(DvInt8 ticks) + public DvTimeSpan(long ticks) { _ticks = ticks; } @@ -488,24 +478,24 @@ public DvTimeSpan(SysTimeSpan sts) public DvTimeSpan(SysTimeSpan? sts) { - _ticks = sts != null ? sts.GetValueOrDefault().Ticks : DvInt8.NA; + _ticks = sts != null ? sts.GetValueOrDefault().Ticks : long.MinValue; } public bool IsNA { - get { return _ticks.IsNA; } + get { return false; } } public static DvTimeSpan NA { - get { return new DvTimeSpan(DvInt8.NA); } + get { return new DvTimeSpan(long.MinValue); } } public static explicit operator SysTimeSpan?(DvTimeSpan ts) { if (ts.IsNA) return null; - return new SysTimeSpan(ts._ticks.RawValue); + return new SysTimeSpan(ts._ticks); } public static implicit operator DvTimeSpan(SysTimeSpan sts) @@ -522,12 +512,12 @@ public override string ToString() { if (IsNA) return ""; - return new SysTimeSpan(_ticks.RawValue).ToString("c"); + return new SysTimeSpan(_ticks).ToString("c"); } public bool Equals(DvTimeSpan other) { - return _ticks.RawValue == other._ticks.RawValue; + return _ticks == other._ticks; } public override bool Equals(object obj) @@ -537,9 +527,9 @@ public override bool Equals(object obj) public int CompareTo(DvTimeSpan other) { - if (_ticks.RawValue == other._ticks.RawValue) + if (_ticks == other._ticks) return 0; - return _ticks.RawValue < other._ticks.RawValue ? -1 : 1; + return _ticks < other._ticks ? -1 : 1; } public override int GetHashCode() diff --git a/src/Microsoft.ML.Core/Data/DvInt1.cs b/src/Microsoft.ML.Core/Data/DvInt1.cs deleted file mode 100644 index ced2a4688d..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt1.cs +++ /dev/null @@ -1,264 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I2 = DvInt2; - using I4 = DvInt4; - using I8 = DvInt8; - using IX = DvInt1; - using R4 = Single; - using R8 = Double; - using RawI8 = Int64; - using RawIX = SByte; - - public struct DvInt1 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to sbyte"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I2 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I4 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I8 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - RawI8 nn = (RawI8)a; - Contracts.Assert(nn == n); - if (a >= 0) - Contracts.Assert(a - 1 < n & n <= a); - else - Contracts.Assert(a <= n & n < a + 1); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/DvInt2.cs b/src/Microsoft.ML.Core/Data/DvInt2.cs deleted file mode 100644 index 33599f6468..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt2.cs +++ /dev/null @@ -1,263 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I1 = DvInt1; - using I4 = DvInt4; - using I8 = DvInt8; - using IX = DvInt2; - using R4 = Single; - using R8 = Double; - using RawI8 = Int64; - using RawIX = Int16; - - public struct DvInt2 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to short"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I1 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I4 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I8 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - RawI8 nn = (RawI8)a; - Contracts.Assert(nn == n); - if (a >= 0) - Contracts.Assert(a - 1 < n & n <= a); - else - Contracts.Assert(a <= n & n < a + 1); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/DvInt4.cs b/src/Microsoft.ML.Core/Data/DvInt4.cs deleted file mode 100644 index 23c7e89242..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt4.cs +++ /dev/null @@ -1,456 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I1 = DvInt1; - using I2 = DvInt2; - using I8 = DvInt8; - using IX = DvInt4; - using R4 = Single; - using R8 = Double; - using RawI8 = Int64; - using RawIX = Int32; - - public struct DvInt4 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to int"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I1 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I2 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I8 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - RawI8 nn = (RawI8)a; - Contracts.Assert(nn == n); - if (a >= 0) - Contracts.Assert(a - 1 < n & n <= a); - else - Contracts.Assert(a <= n & n < a + 1); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a) - { - return -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator +(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a, IX b) - { - var av = a._value; - var bv = -b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator *(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - { - RawI8 res = (RawI8)av * bv; - if (-RawIX.MaxValue <= res && res <= RawIX.MaxValue) - return (RawIX)res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator /(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av / bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator %(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av % bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Abs(IX a) - { - // Can't use Math.Abs since it throws on the RawNA value. - return a._value >= 0 ? a._value : -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Sign(IX a) - { - var val = a._value; - var neg = -val; - // This works for NA since -RawNA == RawNA. - return val > neg ? +1 : val < neg ? -1 : val; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Min(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Min(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Max(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Max(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - /// - /// Raise a to the b power. Special cases: - /// * 1^NA => 1 - /// * NA^0 => 1 - /// - public static IX Pow(IX a, IX b) - { - var av = a.RawValue; - var bv = b.RawValue; - - if (av == 1) - return 1; - switch (bv) - { - case 0: - return 1; - case 1: - return av; - case 2: - return a * a; - case RawNA: - return RawNA; - } - if (av == -1) - return (bv & 1) == 0 ? 1 : -1; - if (bv < 0) - return RawNA; - if (av == RawNA) - return RawNA; - - // Since the abs of the base is at least two, the exponent must be less than 31. - if (bv >= 31) - return RawNA; - - bool neg = false; - if (av < 0) - { - av = -av; - neg = (bv & 1) != 0; - } - Contracts.Assert(av >= 2); - - // Since the exponent is at least three, the base must be <= 1290. - Contracts.Assert(bv >= 3); - if (av > 1290) - return RawNA; - - // REVIEW: Should we use a checked context and exception catching like I8 does? - ulong u = (ulong)(uint)av; - ulong result = 1; - for (; ; ) - { - if ((bv & 1) != 0 && (result *= u) > RawIX.MaxValue) - return RawNA; - bv >>= 1; - if (bv == 0) - break; - if ((u *= u) > RawIX.MaxValue) - return RawNA; - } - Contracts.Assert(result <= RawIX.MaxValue); - - var res = (RawIX)result; - if (neg) - res = -res; - return res; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/DvInt8.cs b/src/Microsoft.ML.Core/Data/DvInt8.cs deleted file mode 100644 index 3212e21fa6..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt8.cs +++ /dev/null @@ -1,511 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I1 = DvInt1; - using I2 = DvInt2; - using I4 = DvInt4; - using IX = DvInt8; - using R4 = Single; - using R8 = Double; - using RawIX = Int64; - - public struct DvInt8 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to long"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I1 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I2 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I4 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - // Note that an R8 cannot represent long.MaxValue exactly so y + 1.0 below might be the same as y. - R8 x = a; - R8 y = n; - if (a < 0) - { - x = -x; - y = -y; - } - Contracts.Assert(y <= x); - Contracts.Assert(x < y + 1.0 | y + 1.0 == y & x == y); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a) - { - return -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator +(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a, IX b) - { - var av = a._value; - var bv = -b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator *(IX a, IX b) - { - var av = a._value; - var bv = b._value; - bool neg = (av ^ bv) < 0; - if (av < 0) - { - if (av == RawNA) - return RawNA; - av = -av; - } - if (bv < 0) - { - if (bv == RawNA) - return RawNA; - bv = -bv; - } - - // Deal with the low 32 bits. - ulong lo1 = (ulong)av & 0x00000000FFFFFFFF; - ulong lo2 = (ulong)bv & 0x00000000FFFFFFFF; - RawIX res = (RawIX)(lo1 * lo2); - if (res < 0) - return RawNA; - - // Get the high 32 bits, including cross terms. - ulong hi1 = (ulong)av >> 32; - ulong hi2 = (ulong)bv >> 32; - if (hi1 != 0) - { - // If both high words are non-zero, overflow is guaranteed. - if (hi2 != 0) - return RawNA; - // Compute the cross term. - ulong tmp = hi1 * lo2; - if ((tmp & 0xFFFFFFFF80000000) != 0) - return RawNA; - res += (long)(tmp << 32); - if (res < 0) - return RawNA; - } - else if (hi2 != 0) - { - // Compute the cross term. - ulong tmp = hi2 * lo1; - if ((tmp & 0xFFFFFFFF80000000) != 0) - return RawNA; - res += (long)(tmp << 32); - if (res < 0) - return RawNA; - } - - // Adjust the sign. - if (neg) - res = -res; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator /(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av / bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator %(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av % bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Abs(IX a) - { - // Can't use Math.Abs since it throws on the RawNA value. - return a._value >= 0 ? a._value : -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Sign(IX a) - { - var val = a._value; - var neg = -val; - // This works for NA since -RawNA == RawNA. - return val > neg ? +1 : val < neg ? -1 : val; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Min(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Min(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Max(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Max(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - /// - /// Raise a to the b power. Special cases: - /// * 1^NA => 1 - /// * NA^0 => 1 - /// - public static IX Pow(IX a, IX b) - { - var av = a.RawValue; - var bv = b.RawValue; - - if (av == 1) - return 1; - switch (bv) - { - case 0: - return 1; - case 1: - return av; - case 2: - return a * a; - case RawNA: - return RawNA; - } - if (av == -1) - return (bv & 1) == 0 ? 1 : -1; - if (bv < 0) - return RawNA; - if (av == RawNA) - return RawNA; - - // Since the abs of the base is at least two, the exponent must be less than 63. - if (bv >= 63) - return RawNA; - - bool neg = false; - if (av < 0) - { - av = -av; - neg = (bv & 1) != 0; - } - Contracts.Assert(av >= 2); - - // Since the exponent is at least three, the base must be < 2^21. - Contracts.Assert(bv >= 3); - if (av >= (1L << 21)) - return RawNA; - - long res = 1; - long x = av; - // REVIEW: Is the catch too slow in the overflow case? - try - { - checked - { - for (; ; ) - { - if ((bv & 1) != 0) - res *= x; - bv >>= 1; - if (bv == 0) - break; - x *= x; - } - } - } - catch (OverflowException) - { - return RawNA; - } - Contracts.Assert(res > 0); - - if (neg) - res = -res; - return res; - } - } -} diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 2cb5e0baca..171d73ff65 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections; using System.Collections.Generic; @@ -178,7 +176,7 @@ public static void WriteBytesNoCount(this BinaryWriter writer, byte[] values, in /// /// Writes a length prefixed array of Floats. /// - public static void WriteFloatArray(this BinaryWriter writer, Float[] values) + public static void WriteFloatArray(this BinaryWriter writer, float[] values) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -197,7 +195,7 @@ public static void WriteFloatArray(this BinaryWriter writer, Float[] values) /// /// Writes a length prefixed array of Floats. /// - public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int count) + public static void WriteFloatArray(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -211,7 +209,7 @@ public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int /// /// Writes a specified number of floats starting at the specified index from an array. /// - public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int start, int count) + public static void WriteFloatArray(this BinaryWriter writer, float[] values, int start, int count) { Contracts.AssertValue(writer); Contracts.AssertValue(values); @@ -225,7 +223,7 @@ public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int /// /// Writes a length prefixed array of Floats. /// - public static void WriteFloatArray(this BinaryWriter writer, IEnumerable values, int count) + public static void WriteFloatArray(this BinaryWriter writer, IEnumerable values, int count) { Contracts.AssertValue(writer); Contracts.AssertValue(values); @@ -244,7 +242,7 @@ public static void WriteFloatArray(this BinaryWriter writer, IEnumerable /// /// Writes an array of Floats without the length prefix. /// - public static void WriteFloatsNoCount(this BinaryWriter writer, Float[] values, int count) + public static void WriteFloatsNoCount(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -257,7 +255,7 @@ public static void WriteFloatsNoCount(this BinaryWriter writer, Float[] values, /// /// Writes a length prefixed array of singles. /// - public static void WriteSingleArray(this BinaryWriter writer, Single[] values) + public static void WriteSingleArray(this BinaryWriter writer, float[] values) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -276,7 +274,7 @@ public static void WriteSingleArray(this BinaryWriter writer, Single[] values) /// /// Writes a length prefixed array of singles. /// - public static void WriteSingleArray(this BinaryWriter writer, Single[] values, int count) + public static void WriteSingleArray(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -290,7 +288,7 @@ public static void WriteSingleArray(this BinaryWriter writer, Single[] values, i /// /// Writes an array of singles without the length prefix. /// - public static void WriteSinglesNoCount(this BinaryWriter writer, Single[] values, int count) + public static void WriteSinglesNoCount(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -303,7 +301,7 @@ public static void WriteSinglesNoCount(this BinaryWriter writer, Single[] values /// /// Writes a length prefixed array of doubles. /// - public static void WriteDoubleArray(this BinaryWriter writer, Double[] values) + public static void WriteDoubleArray(this BinaryWriter writer, double[] values) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -315,14 +313,14 @@ public static void WriteDoubleArray(this BinaryWriter writer, Double[] values) } writer.Write(values.Length); - foreach (Double val in values) + foreach (double val in values) writer.Write(val); } /// /// Writes a length prefixed array of doubles. /// - public static void WriteDoubleArray(this BinaryWriter writer, Double[] values, int count) + public static void WriteDoubleArray(this BinaryWriter writer, double[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -336,7 +334,7 @@ public static void WriteDoubleArray(this BinaryWriter writer, Double[] values, i /// /// Writes an array of doubles without the length prefix. /// - public static void WriteDoublesNoCount(this BinaryWriter writer, Double[] values, int count) + public static void WriteDoublesNoCount(this BinaryWriter writer, double[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -438,7 +436,7 @@ public static long WriteSByteStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteByteStream(this BinaryWriter writer, IEnumerable e) + public static long WriteByteStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -460,7 +458,7 @@ public static long WriteIntStream(this BinaryWriter writer, IEnumerable e) return c; } - public static long WriteUIntStream(this BinaryWriter writer, IEnumerable e) + public static long WriteUIntStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -482,7 +480,7 @@ public static long WriteShortStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteUShortStream(this BinaryWriter writer, IEnumerable e) + public static long WriteUShortStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -504,7 +502,7 @@ public static long WriteLongStream(this BinaryWriter writer, IEnumerable e return c; } - public static long WriteULongStream(this BinaryWriter writer, IEnumerable e) + public static long WriteULongStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -515,7 +513,7 @@ public static long WriteULongStream(this BinaryWriter writer, IEnumerable e) + public static long WriteSingleStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -526,7 +524,7 @@ public static long WriteSingleStream(this BinaryWriter writer, IEnumerable e) + public static long WriteDoubleStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -606,12 +604,12 @@ public static bool ReadBoolByte(this BinaryReader reader) return b != 0; } - public static Float ReadFloat(this BinaryReader reader) + public static float ReadFloat(this BinaryReader reader) { return reader.ReadSingle(); } - public static Float[] ReadFloatArray(this BinaryReader reader) + public static float[] ReadFloatArray(this BinaryReader reader) { Contracts.AssertValue(reader); @@ -620,16 +618,16 @@ public static Float[] ReadFloatArray(this BinaryReader reader) return ReadFloatArray(reader, size); } - public static Float[] ReadFloatArray(this BinaryReader reader, int size) + public static float[] ReadFloatArray(this BinaryReader reader, int size) { Contracts.AssertValue(reader); Contracts.Assert(size >= 0); if (size == 0) return null; - var values = new Float[size]; + var values = new float[size]; - long bufferSizeInBytes = (long)size * sizeof(Float); + long bufferSizeInBytes = (long)size * sizeof(float); if (bufferSizeInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < size; i++) @@ -649,14 +647,14 @@ public static Float[] ReadFloatArray(this BinaryReader reader, int size) return values; } - public static void ReadFloatArray(this BinaryReader reader, Float[] array, int start, int count) + public static void ReadFloatArray(this BinaryReader reader, float[] array, int start, int count) { Contracts.AssertValue(reader); Contracts.AssertValue(array); Contracts.Assert(0 <= start && start < array.Length); Contracts.Assert(0 < count && count <= array.Length - start); - long bufferReadLengthInBytes = (long)count * sizeof(Float); + long bufferReadLengthInBytes = (long)count * sizeof(float); if (bufferReadLengthInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < count; i++) @@ -668,15 +666,15 @@ public static void ReadFloatArray(this BinaryReader reader, Float[] array, int s { fixed (void* dst = array) { - long bufferBeginOffsetInBytes = (long)start * sizeof(Float); - long bufferSizeInBytes = ((long)array.Length - start) * sizeof(Float); + long bufferBeginOffsetInBytes = (long)start * sizeof(float); + long bufferSizeInBytes = ((long)array.Length - start) * sizeof(float); ReadBytes(reader, (byte*)dst + bufferBeginOffsetInBytes, bufferSizeInBytes, bufferReadLengthInBytes); } } } } - public static Single[] ReadSingleArray(this BinaryReader reader) + public static float[] ReadSingleArray(this BinaryReader reader) { Contracts.AssertValue(reader); int size = reader.ReadInt32(); @@ -684,15 +682,15 @@ public static Single[] ReadSingleArray(this BinaryReader reader) return ReadSingleArray(reader, size); } - public static Single[] ReadSingleArray(this BinaryReader reader, int size) + public static float[] ReadSingleArray(this BinaryReader reader, int size) { Contracts.AssertValue(reader); Contracts.Assert(size >= 0); if (size == 0) return null; - var values = new Single[size]; + var values = new float[size]; - long bufferSizeInBytes = (long)size * sizeof(Single); + long bufferSizeInBytes = (long)size * sizeof(float); if (bufferSizeInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < size; i++) @@ -712,7 +710,7 @@ public static Single[] ReadSingleArray(this BinaryReader reader, int size) return values; } - public static Double[] ReadDoubleArray(this BinaryReader reader) + public static double[] ReadDoubleArray(this BinaryReader reader) { Contracts.AssertValue(reader); @@ -721,15 +719,15 @@ public static Double[] ReadDoubleArray(this BinaryReader reader) return ReadDoubleArray(reader, size); } - public static Double[] ReadDoubleArray(this BinaryReader reader, int size) + public static double[] ReadDoubleArray(this BinaryReader reader, int size) { Contracts.AssertValue(reader); Contracts.Assert(size >= 0); if (size == 0) return null; - var values = new Double[size]; + var values = new double[size]; - long bufferSizeInBytes = (long)size * sizeof(Double); + long bufferSizeInBytes = (long)size * sizeof(double); if (bufferSizeInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < size; i++) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index bea1900fe2..f0d17da56e 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -626,8 +626,8 @@ public Writer(DateTimeCodec codec, Stream stream) public override void Write(ref DvDateTime value) { - var ticks = value.Ticks.RawValue; - Contracts.Assert(ticks == DvInt8.RawNA || (ulong)ticks <= DvDateTime.MaxTicks); + var ticks = value.Ticks; + Contracts.Assert((ulong)ticks <= DvDateTime.MaxTicks); Writer.Write(ticks); _numWritten++; } @@ -658,7 +658,7 @@ public override void MoveNext() { Contracts.Assert(_remaining > 0, "already consumed all values"); var value = Reader.ReadInt64(); - Contracts.CheckDecode(value == DvInt8.RawNA || (ulong)value <= DvDateTime.MaxTicks); + Contracts.CheckDecode((ulong)value <= DvDateTime.MaxTicks); _value = new DvDateTime(value); _remaining--; } @@ -711,20 +711,14 @@ public override void Write(ref DvDateTimeZone value) var ticks = value.ClockDateTime.Ticks; var offset = value.OffsetMinutes; - _ticks.Add(ticks.RawValue); - if (ticks.IsNA) - { - Contracts.Assert(offset.IsNA); - _offsets.Add(0); - } - else - { + _ticks.Add(ticks); + Contracts.Assert( - offset.RawValue >= DvDateTimeZone.MinMinutesOffset && - offset.RawValue <= DvDateTimeZone.MaxMinutesOffset); - Contracts.Assert(0 <= ticks.RawValue && ticks.RawValue <= DvDateTime.MaxTicks); - _offsets.Add(offset.RawValue); - } + offset >= DvDateTimeZone.MinMinutesOffset && + offset <= DvDateTimeZone.MaxMinutesOffset); + Contracts.Assert(0 <= ticks && ticks <= DvDateTime.MaxTicks); + _offsets.Add(offset); + } public override void Commit() @@ -773,7 +767,7 @@ public Reader(DateTimeZoneCodec codec, Stream stream, int items) for (int i = 0; i < _entries; i++) { _ticks[i] = Reader.ReadInt64(); - Contracts.CheckDecode(_ticks[i] == DvInt8.RawNA || (ulong)_ticks[i] <= DvDateTime.MaxTicks); + Contracts.CheckDecode((ulong)_ticks[i] <= DvDateTime.MaxTicks); } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs index fa1c4dc8e6..7a17b84ac1 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs @@ -180,7 +180,7 @@ public override unsafe void Apply(DvTimeSpan[] array, Action func) func(new IntPtr(pArray)); } - public override void Write(DvTimeSpan a, BinaryWriter writer) { writer.Write(a.Ticks.RawValue); } + public override void Write(DvTimeSpan a, BinaryWriter writer) { writer.Write(a.Ticks); } public override DvTimeSpan Read(BinaryReader reader) { return new DvTimeSpan(reader.ReadInt64()); } } diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index acb7602c2d..40e9fc6d61 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -26,10 +26,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, return new R4.MeanAggregatorOne(ch, cursor, col); case DataKind.R8: return new R8.MeanAggregatorOne(ch, cursor, col); - case DataKind.TS: - return new Long.MeanAggregatorOne(ch, type, cursor, col); - case DataKind.DT: - return new Long.MeanAggregatorOne(ch, type, cursor, col); default: break; } @@ -42,10 +38,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: return new R8.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.TS: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.DT: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); default: break; } @@ -66,10 +58,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, return new R4.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R8: return new R8.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.TS: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.DT: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); default: break; } @@ -82,10 +70,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, return new R4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: return new R8.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.TS: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.DT: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); default: break; } @@ -102,10 +86,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, return new R4.MeanAggregatorAcrossSlots(ch, cursor, col); case DataKind.R8: return new R8.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.TS: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); - case DataKind.DT: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); default: break; } @@ -118,10 +98,6 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, return new R4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: return new R8.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.TS: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.DT: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); default: break; } @@ -880,270 +856,5 @@ public override object GetStat() } } } - - private static class Long - { - private const long MaxVal = long.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - // Converts between TItem and long. - private Converter _converter; - - public MeanAggregatorOne(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - _converter = CreateConverter(type); - } - - protected override void ProcessRow(ref TItem val) - { - Stat.Update(_converter.ToLong(val), MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - return _converter.FromLong(val); - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - private Converter _converter; - - public MeanAggregatorAcrossSlots(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - _converter = CreateConverter(type); - } - - protected override void ProcessValue(ref TItem val) - { - Stat.Update(_converter.ToLong(val), MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - return _converter.FromLong(val); - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - private Converter _converter; - - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - _converter = CreateConverter(type); - } - - protected override void ProcessValue(ref TItem val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(_converter.ToLong(val), MaxVal); - } - - public override object GetStat() - { - TItem[] stat = new TItem[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - stat[slot] = _converter.FromLong(val); - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - private Converter _converter; - - public MinMaxAggregatorOne(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = ReturnMax ? -MaxVal : MaxVal; - _converter = CreateConverter(type); - } - - protected override void ProcessValueMin(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw.HasValue && raw < Stat) - Stat = raw.Value; - } - - protected override void ProcessValueMax(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw.HasValue && raw > Stat) - Stat = raw.Value; - } - - public override object GetStat() - { - return _converter.FromLong(Stat); - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - private Converter _converter; - - public MinMaxAggregatorAcrossSlots(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = ReturnMax ? -MaxVal : MaxVal; - _converter = CreateConverter(type); - } - - protected override void ProcessValueMin(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw.HasValue && raw < Stat) - Stat = raw.Value; - } - - protected override void ProcessValueMax(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw.HasValue && raw > Stat) - Stat = raw.Value; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - TItem def = default(TItem); - ProcValueDelegate(ref def); - } - return _converter.FromLong(Stat); - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - private Converter _converter; - - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - long bound = ReturnMax ? -MaxVal : MaxVal; - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - - _converter = CreateConverter(type); - } - - protected override void ProcessValueMin(ref TItem val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = _converter.ToLong(val); - if (raw.HasValue && raw < Stat[slot]) - Stat[slot] = raw.Value; - } - - protected override void ProcessValueMax(ref TItem val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = _converter.ToLong(val); - if (raw.HasValue && raw > Stat[slot]) - Stat[slot] = raw.Value; - } - - public override object GetStat() - { - TItem[] stat = new TItem[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(TItem); - ProcValueDelegate(ref def, slot); - } - stat[slot] = _converter.FromLong(Stat[slot]); - } - return stat; - } - } - - private static Converter CreateConverter(ColumnType type) - { - Contracts.AssertValue(type); - Contracts.Assert(typeof(TItem) == type.ItemType.RawType); - Converter converter; - if (type.ItemType.IsTimeSpan) - converter = new TSConverter(); - else if (type.ItemType.IsDateTime) - converter = new DTConverter(); - else - { - Contracts.Assert(type.ItemType.RawKind == DataKind.I8); - converter = new I8Converter(); - } - return (Converter)converter; - } - - /// - /// The base class for conversions from types to long. - /// - private abstract class Converter - { - } - - private abstract class Converter : Converter - { - public abstract long? ToLong(T val); - public abstract T FromLong(long? val); - } - - private sealed class I8Converter : Converter - { - public override long? ToLong(long? val) - { - return val; - } - - public override long? FromLong(long? val) - { - Contracts.Assert(val.HasValue); - return val.Value; - } - } - - private sealed class TSConverter : Converter - { - public override long? ToLong(DvTimeSpan val) - { - return val.Ticks.RawValue; - } - - public override DvTimeSpan FromLong(long? val) - { - Contracts.Assert(val.HasValue); - return new DvTimeSpan(val); - } - } - - private sealed class DTConverter : Converter - { - public override long? ToLong(DvDateTime val) - { - return val.Ticks.RawValue; - } - - public override DvDateTime FromLong(long? val) - { - Contracts.Assert(0 <= val && val <= DvDateTime.MaxTicks); - return new DvDateTime(val); - } - } - } } } \ No newline at end of file diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 570c1e0722..a097e5f2fe 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -72,7 +72,7 @@ module SmokeTest1 = type SentimentPrediction() = [] - val mutable Sentiment : bool + val mutable Sentiment : Microsoft.ML.Runtime.Data.DvBool [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -125,7 +125,7 @@ module SmokeTest1 = |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] - Assert.Equal(predictionResults, [ false; true; true ]) + Assert.Equal(predictionResults, [ Microsoft.ML.Runtime.Data.DvBool.False; Microsoft.ML.Runtime.Data.DvBool.True; Microsoft.ML.Runtime.Data.DvBool.True ]) module SmokeTest2 = @@ -140,7 +140,7 @@ module SmokeTest2 = [] type SentimentPrediction = { [] - Sentiment : bool } + Sentiment : Microsoft.ML.Runtime.Data.DvBool } [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -193,7 +193,7 @@ module SmokeTest2 = |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] - Assert.Equal(predictionResults, [ false; true; true ]) + Assert.Equal(predictionResults, [ Microsoft.ML.Runtime.Data.DvBool.False; Microsoft.ML.Runtime.Data.DvBool.True; Microsoft.ML.Runtime.Data.DvBool.True ]) module SmokeTest3 = @@ -206,7 +206,7 @@ module SmokeTest3 = type SentimentPrediction() = [] - member val Sentiment = false with get, set + member val Sentiment = Microsoft.ML.Runtime.Data.DvBool.False with get, set [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -259,5 +259,5 @@ module SmokeTest3 = |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] - Assert.Equal(predictionResults, [ false; true; true ]) + Assert.Equal(predictionResults, [ Microsoft.ML.Runtime.Data.DvBool.False; Microsoft.ML.Runtime.Data.DvBool.True; Microsoft.ML.Runtime.Data.DvBool.True ]) diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 23921c639f..546f937637 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -293,19 +293,7 @@ public class ConversionSimpleClass public ulong fuLong; public float fFloat; public double fDouble; - public bool fBool; - public string fString; - } - - public class ConversionNullalbeClass - { - public uint? fuInt; - public ushort? fuShort; - public byte? fByte; - public ulong? fuLong; - public float? fFloat; - public double? fDouble; - public bool? fBool; + public DvBool fBool; public string fString; } @@ -430,44 +418,6 @@ public void RoundTripConversionWithBasicTypes() new ConversionSimpleClass() }; - var dataNullable = new List - { - new ConversionNullalbeClass() - { - fuInt = uint.MaxValue - 1, - fBool = true, - fByte = byte.MaxValue - 1, - fDouble = double.MaxValue - 1, - fFloat = float.MaxValue - 1, - fuLong = ulong.MaxValue - 1, - fuShort = ushort.MaxValue - 1, - fString = "ha" - }, - new ConversionNullalbeClass() - { - fuInt = uint.MaxValue, - fBool = true, - fByte = byte.MaxValue, - fDouble = double.MaxValue, - fFloat = float.MaxValue, - fuLong = ulong.MaxValue, - fuShort = ushort.MaxValue, - fString = "ooh" - }, - new ConversionNullalbeClass() - { - fuInt = uint.MinValue, - fBool = false, - fByte = byte.MinValue, - fDouble = double.MinValue + 1, - fFloat = float.MinValue + 1, - fuLong = ulong.MinValue, - fuShort = ushort.MinValue, - fString = "" - }, - new ConversionNullalbeClass() - }; - using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); @@ -478,15 +428,6 @@ public void RoundTripConversionWithBasicTypes() Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - - dataView = ComponentCreation.CreateDataView(env, dataNullable); - var enumeratorNullable = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalNullableEnumerator = dataNullable.GetEnumerator(); - while (enumeratorNullable.MoveNext() && originalNullableEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullableEnumerator.Current)); - } - Assert.True(!enumeratorNullable.MoveNext() && !originalNullableEnumerator.MoveNext()); } } @@ -720,19 +661,7 @@ public class ClassWithArrays public ulong[] fuLong; public float[] fFloat; public double[] fDouble; - public bool[] fBool; - } - - public class ClassWithNullableArrays - { - public string[] fString; - public uint?[] fuInt; - public ushort?[] fuShort; - public byte?[] fByte; - public ulong?[] fuLong; - public float?[] fFloat; - public double?[] fDouble; - public bool?[] fBool; + public DvBool[] fBool; } [Fact] @@ -746,7 +675,7 @@ public void RoundTripConversionWithArrays() fInt = new int[3] { 0, 1, 2 }, fFloat = new float[3] { -0.99f, 0f, 0.99f }, fString = new string[2] { "hola", "lola" }, - fBool = new bool[2] { true, false }, + fBool = new DvBool[2] { true, false }, fByte = new byte[3] { 0, 124, 255 }, fDouble = new double[3] { -1, 0, 1 }, fLong = new long[] { 0, 1, 2 }, @@ -760,23 +689,6 @@ public void RoundTripConversionWithArrays() new ClassWithArrays() }; - var nullableData = new List - { - new ClassWithNullableArrays() - { - fFloat = new float?[3] { -0.99f, null, 0.99f }, - fString = new string[2] { null, "" }, - fBool = new bool?[3] { true, null, false }, - fByte = new byte?[4] { 0, 125, null, 255 }, - fDouble = new double?[3] { -1, null, 1 }, - fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, - fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, - fuShort = new ushort?[3] { 0, null, ushort.MaxValue } - }, - new ClassWithNullableArrays() { fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, - new ClassWithNullableArrays() - }; - using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); @@ -787,15 +699,6 @@ public void RoundTripConversionWithArrays() Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - - var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); - var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); - var originalNullalbleEnumerator = nullableData.GetEnumerator(); - while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); - } - Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } public class ClassWithArrayProperties @@ -811,7 +714,7 @@ public class ClassWithArrayProperties private ulong[] _fuLong; private float[] _fFloat; private double[] _fDouble; - private bool[] _fBool; + private DvBool[] _fBool; public string[] StringProp { get { return _fString; } set { _fString = value; } } public int[] IntProp { get { return _fInt; } set { _fInt = value; } } public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } @@ -823,28 +726,7 @@ public class ClassWithArrayProperties public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } } public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } } - public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } } - } - - public class ClassWithNullableArrayProperties - { - private string[] _fString; - private uint?[] _fuInt; - private ushort?[] _fuShort; - private byte?[] _fByte; - private ulong?[] _fuLong; - private float?[] _fFloat; - private double?[] _fDouble; - private bool?[] _fBool; - - public string[] StringProp { get { return _fString; } set { _fString = value; } } - public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } - public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } - public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } } - public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } - public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } } - public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } } - public bool?[] BoolProp { get { return _fBool; } set { _fBool = value; } } + public DvBool[] BoolProp { get { return _fBool; } set { _fBool = value; } } } [Fact] @@ -858,7 +740,7 @@ public void RoundTripConversionWithArrayPropertiess() IntProp = new int[3] { 0, 1, 2 }, FloatProp = new float[3] { -0.99f, 0f, 0.99f }, StringProp = new string[2] { "hola", "lola" }, - BoolProp = new bool[2] { true, false }, + BoolProp = new DvBool[2] { true, false }, ByteProp = new byte[3] { 0, 124, 255 }, DobuleProp = new double[3] { -1, 0, 1 }, LongProp = new long[] { 0, 1, 2 }, @@ -872,22 +754,6 @@ public void RoundTripConversionWithArrayPropertiess() new ClassWithArrayProperties() }; - var nullableData = new List - { - new ClassWithNullableArrayProperties() - { - SingleProp = new float?[3] { -0.99f, null, 0.99f }, - StringProp = new string[2] { null, "" }, - BoolProp = new bool?[3] { true, null, false }, - ByteProp = new byte?[4] { 0, 125, null, 255 }, - DoubleProp = new double?[3] { -1, null, 1 }, - UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue }, - ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 }, - UShortProp = new ushort?[3] { 0, null, ushort.MaxValue } - }, - new ClassWithNullableArrayProperties() { SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } }, - new ClassWithNullableArrayProperties() - }; using (var env = new TlcEnvironment()) { @@ -899,15 +765,6 @@ public void RoundTripConversionWithArrayPropertiess() Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - - var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); - var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); - var originalNullalbleEnumerator = nullableData.GetEnumerator(); - while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); - } - Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } } diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index f19e3285d7..458cc01e71 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -119,7 +119,7 @@ public class BooleanLabelData public float[] Features; [ColumnName("Label")] - public bool Label; + public DvBool Label; } [Fact] @@ -137,36 +137,6 @@ public void BooleanLabelPipeline() var model = pipeline.Train(); } - public class NullableBooleanLabelData - { - [ColumnName("Features")] - [VectorType(2)] - public float[] Features; - - [ColumnName("Label")] - public bool? Label; - } - - [Fact] - public void NullableBooleanLabelPipeline() - { - var data = new NullableBooleanLabelData[2]; - data[0] = new NullableBooleanLabelData - { - Features = new float[] { 0.0f, 1.0f }, - Label = null - }; - data[1] = new NullableBooleanLabelData - { - Features = new float[] { 1.0f, 0.0f }, - Label = false - }; - var pipeline = new LearningPipeline(); - pipeline.Add(CollectionDataSource.Create(data)); - pipeline.Add(new FastForestBinaryClassifier()); - var model = pipeline.Train(); - } - [Fact] public void AppendPipeline() { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs index a8b7aaa7ad..47273bcb3f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; using Microsoft.ML.TestFramework; using Xunit.Abstractions; @@ -43,14 +44,14 @@ public class IrisPrediction public class SentimentData { [ColumnName("Label")] - public bool Sentiment; + public DvBool Sentiment; public string SentimentText; } public class SentimentPrediction { [ColumnName("PredictedLabel")] - public bool Sentiment; + public DvBool Sentiment; public float Score; } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs index 246281777c..db5ef3e779 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs @@ -55,7 +55,7 @@ public void SimpleTrainAndPredict() var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); + Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs index 51ae621c1b..c026efecfa 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs @@ -62,7 +62,7 @@ public void TrainSaveModelAndPredict() var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); + Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); } } } From 789182a4f5221f151110745332abff30eeb5d9af Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 21 Aug 2018 15:09:40 -0700 Subject: [PATCH 13/37] merge master and fix tests. --- .../Scenarios/PipelineApi/CrossValidation.cs | 2 +- .../Scenarios/PipelineApi/PipelineApiScenarioTests.cs | 4 ++-- .../Scenarios/PipelineApi/SimpleTrainAndPredict.cs | 2 +- .../Scenarios/PipelineApi/TrainSaveModelAndPredict.cs | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs index 6cc6630ed9..ceb4d8e973 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs @@ -34,7 +34,7 @@ void CrossValidation() var cv = new CrossValidator().CrossValidate(pipeline); var metrics = cv.BinaryClassificationMetrics[0]; var singlePrediction = cv.PredictorModels[0].Predict(new SentimentData() { SentimentText = "Not big fan of this." }); - Assert.True(singlePrediction.Sentiment); + Assert.True(singlePrediction.Sentiment.IsTrue); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs index 6b96929db7..82bb22848d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs @@ -47,7 +47,7 @@ public class IrisPrediction public class SentimentData { [Column("0", name: "Label")] - public bool Sentiment; + public Runtime.Data.DvBool Sentiment; [Column("1")] public string SentimentText; } @@ -55,7 +55,7 @@ public class SentimentData public class SentimentPrediction { [ColumnName("PredictedLabel")] - public bool Sentiment; + public Runtime.Data.DvBool Sentiment; public float Score; } diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs index 0bf201e328..56a039b7d2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs @@ -35,7 +35,7 @@ void SimpleTrainAndPredict() pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); var model = pipeline.Train(); var singlePrediction = model.Predict(new SentimentData() { SentimentText = "Not big fan of this." }); - Assert.True(singlePrediction.Sentiment); + Assert.True(singlePrediction.Sentiment.IsTrue); } private static TextFeaturizer MakeSentimentTextTransform() diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs index 7e935dcb90..bf090b510e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs @@ -35,7 +35,7 @@ public async void TrainSaveModelAndPredict() await model.WriteAsync(modelName); var loadedModel = await PredictionModel.ReadAsync(modelName); var singlePrediction = loadedModel.Predict(new SentimentData() { SentimentText = "Not big fan of this." }); - Assert.True(singlePrediction.Sentiment); + Assert.True(singlePrediction.Sentiment.IsTrue); } } From ef54c6084a7bee50dfc39972044a0249e5f3b150 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 21 Aug 2018 15:22:35 -0700 Subject: [PATCH 14/37] merge master and fix tests. --- .../Scenarios/Api/Estimators/SimpleTrainAndPredict.cs | 2 +- .../Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 76c12e6068..197d0f0ad7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -45,7 +45,7 @@ public void New_SimpleTrainAndPredict() var prediction = engine.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); + Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index f5b5b98c90..a0d79f2fec 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -57,7 +57,7 @@ public void New_TrainSaveModelAndPredict() var prediction = engine.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); + Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); } } } From b3ba06d806957b64110f12531c17dc8244ee2cf0 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 23 Aug 2018 13:22:41 -0700 Subject: [PATCH 15/37] PR feedback. --- src/Microsoft.ML.Api/TypedCursor.cs | 20 ----- src/Microsoft.ML.Data/Data/Conversion.cs | 78 +++++++++---------- .../DataLoadSave/Text/TextLoaderParser.cs | 10 ++- .../MulticlassClassifierEvaluator.cs | 2 +- src/Microsoft.ML.Parquet/ParquetLoader.cs | 18 ++--- .../UnitTests/DvTypes.cs | 26 ------- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 12 +-- .../TestTransposer.cs | 2 +- .../Scenarios/Api/ApiScenariosTests.cs | 5 +- .../Api/Estimators/SimpleTrainAndPredict.cs | 2 +- .../Estimators/TrainSaveModelAndPredict.cs | 2 +- .../Scenarios/Api/SimpleTrainAndPredict.cs | 2 +- .../Scenarios/Api/TrainSaveModelAndPredict.cs | 2 +- .../Scenarios/PipelineApi/CrossValidation.cs | 2 +- .../PipelineApi/PipelineApiScenarioTests.cs | 4 +- .../PipelineApi/SimpleTrainAndPredict.cs | 2 +- .../PipelineApi/TrainSaveModelAndPredict.cs | 2 +- 17 files changed, 72 insertions(+), 119 deletions(-) diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 004c7bbaa5..6751d312c8 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -289,26 +289,6 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(colType.ItemType.IsBool); return CreateConvertingVBufferSetter(input, index, poke, peek, x => (bool?)x); } - else if (fieldType.GetElementType() == typeof(int)) - { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); - } - else if (fieldType.GetElementType() == typeof(short)) - { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } - else if (fieldType.GetElementType() == typeof(long)) - { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } - else if (fieldType.GetElementType() == typeof(sbyte)) - { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => x); - } // VBuffer -> T[] if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 16641275cd..9bfa547a64 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -1277,12 +1277,11 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) /// public bool TryParse(ref TX src, out I1 dst) { - long? res; - bool f = TryParseSigned(I1.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); - Contracts.Assert((I1)res == res); + TryParseSigned(I1.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); + Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); dst = (I1)res; - return f; + return true; } /// @@ -1291,12 +1290,11 @@ public bool TryParse(ref TX src, out I1 dst) /// public bool TryParse(ref TX src, out I2 dst) { - long? res; - bool f = TryParseSigned(I2.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); - Contracts.Assert(!res.HasValue || (I2)res == res); + TryParseSigned(I2.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); + Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); dst = (I2)res; - return f; + return true; } /// @@ -1305,12 +1303,11 @@ public bool TryParse(ref TX src, out I2 dst) /// public bool TryParse(ref TX src, out I4 dst) { - long? res; - bool f = TryParseSigned(I4.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); - Contracts.Assert(!res.HasValue || (I4)res == res); + TryParseSigned(I4.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to int32."); + Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); dst = (I4)res; - return f; + return true; } /// @@ -1319,11 +1316,10 @@ public bool TryParse(ref TX src, out I4 dst) /// public bool TryParse(ref TX src, out I8 dst) { - long? res; - bool f = TryParseSigned(I8.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); + TryParseSigned(I8.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); dst = (I8)res; - return f; + return true; } /// @@ -1365,7 +1361,7 @@ private bool TryParseNonNegative(string text, int ich, int lim, out long result) /// When it returns false, result is set to the NA value. The result can be NA on true return, /// since some representations of NA are not considered parse failure. /// - private bool TryParseSigned(long max, ref TX span, out long? result) + private void TryParseSigned(long max, ref TX span, out long? result) { Contracts.Assert(max > 0); Contracts.Assert((max & (max + 1)) == 0); @@ -1376,7 +1372,7 @@ private bool TryParseSigned(long max, ref TX span, out long? result) result = null; else result = 0; - return true; + return; } int ichMin; @@ -1388,34 +1384,34 @@ private bool TryParseSigned(long max, ref TX span, out long? result) { if (span.Length == 1 || !TryParseNonNegative(text, ichMin + 1, ichLim, out val) || - val > max) + val > (max + 1)) { result = null; - return false; + return; } Contracts.Assert(val >= 0); result = -(long)val; - Contracts.Assert(long.MinValue < result && result <= 0); - return true; + Contracts.Assert(long.MinValue <= result && result <= 0); + return; } if (!TryParseNonNegative(text, ichMin, ichLim, out val)) { // Check for acceptable NA forms: ? NaN NA and N/A. result = null; - return IsStdMissing(ref span); + return; } Contracts.Assert(val >= 0); if (val > max) { result = null; - return false; + return; } result = (long)val; Contracts.Assert(0 <= result && result <= long.MaxValue); - return true; + return; } /// @@ -1505,36 +1501,32 @@ public bool TryParse(ref TX src, out DZ dst) // These map unparsable and overflow values to "NA", which is null. private I1 ParseI1(ref TX src) { - long? res; - bool f = TryParseSigned(I1.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); - Contracts.Assert((I1)res == res); + TryParseSigned(I1.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); + Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); return (I1)res; } private I2 ParseI2(ref TX src) { - long? res; - bool f = TryParseSigned(I2.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); - Contracts.Assert((I2)res == res); + TryParseSigned(I2.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); + Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); return (I2)res; } private I4 ParseI4(ref TX src) { - long? res; - bool f = TryParseSigned(I4.MaxValue, ref src, out res); - Contracts.Check(f && res.HasValue); - Contracts.Assert((I4)res == res); + TryParseSigned(I4.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to int."); + Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); return (I4)res; } private I8 ParseI8(ref TX src) { - long? res; - bool f = TryParseSigned(I8.MaxValue, ref src, out res); - Contracts.Assert(f || !res.HasValue); + TryParseSigned(I8.MaxValue, ref src, out long? res); + Contracts.Assert(res.HasValue, "Value could not be parsed from text to long."); return res.Value; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index b0b3ce63b9..cc2c2564dc 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -994,7 +994,15 @@ public int GatherFields(DvText lineSpan, string path = null, long line = 0) // Note that Convert throws exception the text is unparsable. int csrc = default; - Conversion.Conversions.Instance.Convert(ref spanT, ref csrc); + try + { + Conversions.Instance.Convert(ref spanT, ref csrc); + } + catch + { + Contracts.Assert(csrc == default); + } + if (csrc <= 0) { _stats.LogBadFmt(ref scan, "Bad dimensionality or ambiguous sparse item. Use sparse=- for non-sparse file, and/or quote the value."); diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index a4a83fed28..ff38c5cba1 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -997,7 +997,7 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa var labelType = perInst.Schema.GetColumnType(labelCol); if (labelType.IsKey && (!perInst.Schema.HasKeyNames(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4)) { - perInst = LambdaColumnMapper.Create(Host, "ConvertToLong", perInst, schema.Label.Name, + perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name, schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8, (ref uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min); } diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 2fe3f0ebe7..8600238b76 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -499,21 +499,21 @@ private Delegate CreateGetterDelegate(int col) case DataType.Byte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.SignedByte: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedByte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Short: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedShort: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int16: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedInt16: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int32: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int64: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int96: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.ByteArray: @@ -678,17 +678,17 @@ public ParquetConversions(IChannel channel) public void Conv(ref byte[] src, ref VBuffer dst) => dst = src != null ? new VBuffer(src.Length, src) : new VBuffer(0, new byte[0]); - public void Conv(ref sbyte src, ref sbyte dst) => dst = src; + public void Conv(ref sbyte? src, ref sbyte dst) => dst = (sbyte)src; public void Conv(ref byte src, ref byte dst) => dst = src; - public void Conv(ref short src, ref short dst) => dst = src; + public void Conv(ref short? src, ref short dst) => dst = (short)src; public void Conv(ref ushort src, ref ushort dst) => dst = src; - public void Conv(ref int src, ref int dst) => dst = src; + public void Conv(ref int? src, ref int dst) => dst = (int)src; - public void Conv(ref long src, ref long dst) => dst = src; + public void Conv(ref long? src, ref long dst) => dst = (long)src; public void Conv(ref float? src, ref Single dst) => dst = src ?? Single.NaN; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs index 309f8402cb..5df434e7ae 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs @@ -10,32 +10,6 @@ namespace Microsoft.ML.Runtime.RunTests { public sealed class DvTypeTests { - [Fact] - public void TestComparableInt32() - { - const int count = 100; - - var rand = RandomUtils.Create(42); - var values = new int?[2 * count]; - for (int i = 0; i < count; i++) - { - var v = values[i] = rand.Next(); - values[values.Length - i - 1] = v; - } - - // Assign two NA's at random. - int iv1 = rand.Next(values.Length); - int iv2 = rand.Next(values.Length - 1); - if (iv2 >= iv1) - iv2++; - values[iv1] = null; - values[iv2] = null; - Array.Sort(values); - - Assert.True(!values[0].HasValue); - Assert.True(!values[1].HasValue); - Assert.True(values[2].HasValue); - } [Fact] public void TestComparableDvText() diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 7aad3318be..1a6d94c5fb 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -72,7 +72,7 @@ module SmokeTest1 = type SentimentPrediction() = [] - val mutable Sentiment : Microsoft.ML.Runtime.Data.DvBool + val mutable Sentiment : bool [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -125,7 +125,7 @@ module SmokeTest1 = |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] - Assert.Equal(predictionResults, [ Microsoft.ML.Runtime.Data.DvBool.False; Microsoft.ML.Runtime.Data.DvBool.True; Microsoft.ML.Runtime.Data.DvBool.True ]) + Assert.Equal(predictionResults, [ false; true; true ]) module SmokeTest2 = @@ -140,7 +140,7 @@ module SmokeTest2 = [] type SentimentPrediction = { [] - Sentiment : Microsoft.ML.Runtime.Data.DvBool } + Sentiment : bool } [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -193,7 +193,7 @@ module SmokeTest2 = |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] - Assert.Equal(predictionResults, [ Microsoft.ML.Runtime.Data.DvBool.False; Microsoft.ML.Runtime.Data.DvBool.True; Microsoft.ML.Runtime.Data.DvBool.True ]) + Assert.Equal(predictionResults, [ false; true; true ]) module SmokeTest3 = @@ -206,7 +206,7 @@ module SmokeTest3 = type SentimentPrediction() = [] - member val Sentiment = Microsoft.ML.Runtime.Data.DvBool.False with get, set + member val Sentiment = false with get, set [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -259,5 +259,5 @@ module SmokeTest3 = |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] - Assert.Equal(predictionResults, [ Microsoft.ML.Runtime.Data.DvBool.False; Microsoft.ML.Runtime.Data.DvBool.True; Microsoft.ML.Runtime.Data.DvBool.True ]) + Assert.Equal(predictionResults, [ false; true; true ]) diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs index 701adea39b..c4e4ac1049 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs @@ -160,7 +160,7 @@ public void TransposerTest() // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (int)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => rgen.Next())); IDataView view = builder.GetDataView(); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs index 47273bcb3f..a8b7aaa7ad 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.TestFramework; using Xunit.Abstractions; @@ -44,14 +43,14 @@ public class IrisPrediction public class SentimentData { [ColumnName("Label")] - public DvBool Sentiment; + public bool Sentiment; public string SentimentText; } public class SentimentPrediction { [ColumnName("PredictedLabel")] - public DvBool Sentiment; + public bool Sentiment; public float Score; } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 197d0f0ad7..76c12e6068 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -45,7 +45,7 @@ public void New_SimpleTrainAndPredict() var prediction = engine.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index a0d79f2fec..f5b5b98c90 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -57,7 +57,7 @@ public void New_TrainSaveModelAndPredict() var prediction = engine.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs index ad645fbfeb..b05135692a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs @@ -55,7 +55,7 @@ public void SimpleTrainAndPredict() var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs index c026efecfa..51ae621c1b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs @@ -62,7 +62,7 @@ public void TrainSaveModelAndPredict() var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment.IsTrue && prediction.Score > 1 || input.Sentiment.IsFalse && prediction.Score < -1); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs index ceb4d8e973..6cc6630ed9 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/CrossValidation.cs @@ -34,7 +34,7 @@ void CrossValidation() var cv = new CrossValidator().CrossValidate(pipeline); var metrics = cv.BinaryClassificationMetrics[0]; var singlePrediction = cv.PredictorModels[0].Predict(new SentimentData() { SentimentText = "Not big fan of this." }); - Assert.True(singlePrediction.Sentiment.IsTrue); + Assert.True(singlePrediction.Sentiment); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs index 82bb22848d..6b96929db7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/PipelineApiScenarioTests.cs @@ -47,7 +47,7 @@ public class IrisPrediction public class SentimentData { [Column("0", name: "Label")] - public Runtime.Data.DvBool Sentiment; + public bool Sentiment; [Column("1")] public string SentimentText; } @@ -55,7 +55,7 @@ public class SentimentData public class SentimentPrediction { [ColumnName("PredictedLabel")] - public Runtime.Data.DvBool Sentiment; + public bool Sentiment; public float Score; } diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs index 56a039b7d2..0bf201e328 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs @@ -35,7 +35,7 @@ void SimpleTrainAndPredict() pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); var model = pipeline.Train(); var singlePrediction = model.Predict(new SentimentData() { SentimentText = "Not big fan of this." }); - Assert.True(singlePrediction.Sentiment.IsTrue); + Assert.True(singlePrediction.Sentiment); } private static TextFeaturizer MakeSentimentTextTransform() diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs index bf090b510e..7e935dcb90 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs @@ -35,7 +35,7 @@ public async void TrainSaveModelAndPredict() await model.WriteAsync(modelName); var loadedModel = await PredictionModel.ReadAsync(modelName); var singlePrediction = loadedModel.Predict(new SentimentData() { SentimentText = "Not big fan of this." }); - Assert.True(singlePrediction.Sentiment.IsTrue); + Assert.True(singlePrediction.Sentiment); } } From 25d1a52f54c8857dba2edd96f3afe9c599107dea Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 23 Aug 2018 14:26:03 -0700 Subject: [PATCH 16/37] PR feedback. --- src/Microsoft.ML.Core/Data/DataKind.cs | 2 +- src/Microsoft.ML.Core/Data/DateTime.cs | 25 ++++++++++++------- .../DataLoadSave/Binary/Codecs.cs | 16 ++++++++---- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 0a0707df48..254c894734 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -207,7 +207,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool)) + else if (type == typeof(DvBool) || type == typeof(bool)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs index 5bb2626d61..3ae8b9627a 100644 --- a/src/Microsoft.ML.Core/Data/DateTime.cs +++ b/src/Microsoft.ML.Core/Data/DateTime.cs @@ -34,7 +34,10 @@ public DvDateTime(SysDateTime sdt) /// public DvDateTime(long ticks) { - _ticks = ticks; + if ((ulong)ticks > MaxTicks) + _ticks = long.MinValue; + else + _ticks = ticks; AssertValid(); } @@ -82,10 +85,7 @@ public bool IsNA } } - public static DvDateTime NA - { - get { return new DvDateTime(long.MinValue); } - } + public static DvDateTime NA => new DvDateTime(long.MinValue); public static explicit operator SysDateTime?(DvDateTime dvDt) { @@ -180,7 +180,7 @@ private DvDateTimeZone(DvDateTime dt, short offset) public DvDateTimeZone(long ticks, short offset) { var dt = new DvDateTime(ticks); - if (MinMinutesOffset > offset || offset > MaxMinutesOffset) + if (dt.IsNA || offset == short.MinValue || MinMinutesOffset > offset || offset > MaxMinutesOffset) { _dateTime = DvDateTime.NA; _offset = short.MinValue; @@ -200,6 +200,7 @@ public DvDateTimeZone(SysDateTimeOffset dto) Contracts.Assert(success); _dateTime = ValidateDate(new DvDateTime(dto.DateTime), ref _offset); Contracts.Assert(!_dateTime.IsNA); + Contracts.Assert(_offset != short.MinValue); AssertValid(); } @@ -254,7 +255,7 @@ private static DvDateTime ValidateDate(DvDateTime dateTime, ref short offset) /// private static bool TryValidateOffset(long offsetTicks, out short offset) { - if (offsetTicks % TicksPerMinute != 0) + if (offsetTicks == short.MinValue || offsetTicks % TicksPerMinute != 0) { offset = short.MinValue; return false; @@ -268,6 +269,7 @@ private static bool TryValidateOffset(long offsetTicks, out short offset) return false; } offset = res; + Contracts.Assert(offset != short.MinValue); return true; } @@ -275,7 +277,10 @@ private static bool TryValidateOffset(long offsetTicks, out short offset) private void AssertValid() { _dateTime.AssertValid(); - if (!_dateTime.IsNA) + _dateTime.AssertValid(); + if (_dateTime.IsNA) + Contracts.Assert(_offset == short.MinValue); + else { Contracts.Assert(MinMinutesOffset <= _offset && _offset <= MaxMinutesOffset); Contracts.Assert((ulong)(_dateTime.Ticks + _offset * TicksPerMinute) @@ -318,6 +323,8 @@ public DvTimeSpan Offset get { AssertValid(); + if (_offset == short.MinValue) + return DvTimeSpan.NA; return new DvTimeSpan(_offset * TicksPerMinute); } } @@ -483,7 +490,7 @@ public DvTimeSpan(SysTimeSpan? sts) public bool IsNA { - get { return false; } + get { return _ticks == long.MinValue; } } public static DvTimeSpan NA diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index f0d17da56e..6029027f36 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -627,7 +627,7 @@ public Writer(DateTimeCodec codec, Stream stream) public override void Write(ref DvDateTime value) { var ticks = value.Ticks; - Contracts.Assert((ulong)ticks <= DvDateTime.MaxTicks); + Contracts.Assert(ticks == long.MinValue || (ulong)ticks <= DvDateTime.MaxTicks); Writer.Write(ticks); _numWritten++; } @@ -658,7 +658,7 @@ public override void MoveNext() { Contracts.Assert(_remaining > 0, "already consumed all values"); var value = Reader.ReadInt64(); - Contracts.CheckDecode((ulong)value <= DvDateTime.MaxTicks); + Contracts.CheckDecode(value == long.MinValue || (ulong)value <= DvDateTime.MaxTicks); _value = new DvDateTime(value); _remaining--; } @@ -712,13 +712,19 @@ public override void Write(ref DvDateTimeZone value) var offset = value.OffsetMinutes; _ticks.Add(ticks); - + if (ticks == long.MinValue) + { + Contracts.Assert(offset == short.MinValue); + _offsets.Add(0); + } + else + { Contracts.Assert( offset >= DvDateTimeZone.MinMinutesOffset && offset <= DvDateTimeZone.MaxMinutesOffset); Contracts.Assert(0 <= ticks && ticks <= DvDateTime.MaxTicks); _offsets.Add(offset); - + } } public override void Commit() @@ -767,7 +773,7 @@ public Reader(DateTimeZoneCodec codec, Stream stream, int items) for (int i = 0; i < _entries; i++) { _ticks[i] = Reader.ReadInt64(); - Contracts.CheckDecode((ulong)_ticks[i] <= DvDateTime.MaxTicks); + Contracts.CheckDecode(_ticks[i] == long.MinValue || (ulong)_ticks[i] <= DvDateTime.MaxTicks); } } From 1853471de72bc7ab83f0d5ecf286c33b50d2ba4b Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 23 Aug 2018 14:31:48 -0700 Subject: [PATCH 17/37] PR feedback. --- src/Microsoft.ML.Core/Data/DateTime.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs index 3ae8b9627a..d9f3930579 100644 --- a/src/Microsoft.ML.Core/Data/DateTime.cs +++ b/src/Microsoft.ML.Core/Data/DateTime.cs @@ -255,7 +255,7 @@ private static DvDateTime ValidateDate(DvDateTime dateTime, ref short offset) /// private static bool TryValidateOffset(long offsetTicks, out short offset) { - if (offsetTicks == short.MinValue || offsetTicks % TicksPerMinute != 0) + if (offsetTicks == long.MinValue || offsetTicks % TicksPerMinute != 0) { offset = short.MinValue; return false; From 66fa9be4605c3a122f9d01b953526033371fe616 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 23 Aug 2018 14:37:46 -0700 Subject: [PATCH 18/37] PR feedback. --- src/Microsoft.ML.Api/TypedCursor.cs | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 6751d312c8..f9aa05ab76 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -329,30 +329,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(peek == null); return CreateConvertingActionSetter(input, index, poke, x => (bool?)x); } - else if (fieldType == typeof(int)) - { - Ch.Assert(colType == NumberType.I4); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } - else if (fieldType == typeof(short)) - { - Ch.Assert(colType == NumberType.I2); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } - else if (fieldType == typeof(long)) - { - Ch.Assert(colType == NumberType.I8); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } - else if (fieldType == typeof(sbyte)) - { - Ch.Assert(colType == NumberType.I1); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => x); - } + // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType)); From 0c8fb8c5440655706f7168fd0030e6b669f78e1e Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 25 Aug 2018 19:48:21 -0700 Subject: [PATCH 19/37] Tests. --- src/Microsoft.ML.Data/Data/Conversion.cs | 38 +-- .../UnitTests/DataTypes.cs | 255 ++++++++++++++++++ test/Microsoft.ML.Tests/TextLoaderTests.cs | 128 ++++++++- 3 files changed, 401 insertions(+), 20 deletions(-) create mode 100644 test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 9bfa547a64..84bee650ea 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -1273,10 +1273,11 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the default value. /// public bool TryParse(ref TX src, out I1 dst) { + dst = default; TryParseSigned(I1.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); @@ -1286,10 +1287,11 @@ public bool TryParse(ref TX src, out I1 dst) /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the default value. /// public bool TryParse(ref TX src, out I2 dst) { + dst = default; TryParseSigned(I2.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); @@ -1299,10 +1301,11 @@ public bool TryParse(ref TX src, out I2 dst) /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the defualt value. /// public bool TryParse(ref TX src, out I4 dst) { + dst = default; TryParseSigned(I4.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to int32."); Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); @@ -1312,10 +1315,11 @@ public bool TryParse(ref TX src, out I4 dst) /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the default value. /// public bool TryParse(ref TX src, out I8 dst) { + dst = default; TryParseSigned(I8.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); dst = (I8)res; @@ -1368,23 +1372,19 @@ private void TryParseSigned(long max, ref TX span, out long? result) if (!span.HasChars) { - if (span.IsNA) - result = null; - else - result = 0; + result = default(long); return; } int ichMin; int ichLim; string text = span.GetRawUnderlyingBufferInfo(out ichMin, out ichLim); - - long val; + ulong val; if (span[0] == '-') { if (span.Length == 1 || - !TryParseNonNegative(text, ichMin + 1, ichLim, out val) || - val > (max + 1)) + !TryParseCore(text, ichMin + 1, ichLim, out val) || + (val > ((ulong)max + 1))) { result = null; return; @@ -1395,21 +1395,21 @@ private void TryParseSigned(long max, ref TX span, out long? result) return; } - if (!TryParseNonNegative(text, ichMin, ichLim, out val)) + long sVal; + if (!TryParseNonNegative(text, ichMin, ichLim, out sVal)) { - // Check for acceptable NA forms: ? NaN NA and N/A. result = null; return; } - Contracts.Assert(val >= 0); - if (val > max) + Contracts.Assert(sVal >= 0); + if (sVal > max) { result = null; return; } - result = (long)val; + result = (long)sVal; Contracts.Assert(0 <= result && result <= long.MaxValue); return; } @@ -1498,7 +1498,7 @@ public bool TryParse(ref TX src, out DZ dst) return IsStdMissing(ref src); } - // These map unparsable and overflow values to "NA", which is null. + // These throw an exception for unparsable and overflow values. private I1 ParseI1(ref TX src) { TryParseSigned(I1.MaxValue, ref src, out long? res); @@ -1526,7 +1526,7 @@ private I4 ParseI4(ref TX src) private I8 ParseI8(ref TX src) { TryParseSigned(I8.MaxValue, ref src, out long? res); - Contracts.Assert(res.HasValue, "Value could not be parsed from text to long."); + Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); return res.Value; } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs new file mode 100644 index 0000000000..2258ea6f57 --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -0,0 +1,255 @@ +using System; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.Conversion; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Runtime.RunTests +{ + public class DataTypesTest : TestDataViewBase + { + public DataTypesTest(ITestOutputHelper helper) + : base(helper) + { + } + + private readonly static Conversions _conv = Conversions.Instance; + + [Fact] + public void TXToI1() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. sbyte.MinValue in text to sbyte. + sbyte minValue = sbyte.MinValue; + sbyte maxValue = sbyte.MaxValue; + DvText src = new DvText(minValue.ToString()); + sbyte dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. sbyte.MaxValue in text to sbyte. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. ERROR condition: sbyte.MinValue - 1 in text to sbyte. + src = new DvText((sbyte.MinValue - 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //4. ERROR condition: sbyte.MaxValue + 1 in text to sbyte. + src = new DvText((sbyte.MaxValue + 1).ToString()); + dst = 0; + error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //5. Missing value as empty string in text to sbyte. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + } + + [Fact] + public void TXToI2() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. short.MinValue in text to short. + short minValue = short.MinValue; + short maxValue = short.MaxValue; + DvText src = new DvText(minValue.ToString()); + short dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. short.MaxValue in text to short. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. ERROR condition: short.MinValue - 1 in text to short. + src = new DvText((minValue - 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //4. ERROR condition: short.MaxValue + 1 in text to short. + src = new DvText((maxValue + 1).ToString()); + dst = 0; + error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //5. Missing value as empty string in text to short. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + } + + [Fact] + public void TXToI4() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. int.MinValue in text to int. + int minValue = int.MinValue; + int maxValue = int.MaxValue; + DvText src = new DvText(minValue.ToString()); + int dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. int.MaxValue in text to int. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. ERROR condition: int.MinValue - 1 in text to int. + src = new DvText(((long)minValue - 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //4. ERROR condition: int.MaxValue + 1 in text to int. + src = new DvText(((long)maxValue + 1).ToString()); + dst = 0; + error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //5. Missing value as empty string in text to int. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + } + + [Fact] + public void TXToI8() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. long.MinValue in text to long. + var minValue = long.MinValue; + var maxValue = long.MaxValue; + DvText src = new DvText(minValue.ToString()); + var dst = default(long); + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. long.MaxValue in text to long. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. long.MinValue - 1 in text to long. + src = new DvText(((long)minValue - 1).ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, (long)minValue - 1); + + //4. ERROR condition: long.MaxValue + 1 in text to long. + src = new DvText(((ulong)maxValue + 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch + { + error = true; + } + + Assert.True(error); + + //5. Missing value as empty string in text to long. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + } + + public ValueMapper GetMapper() + { + Assert.True(typeof(TSrc).TryGetDataKind(out DataKind srcDataKind)); + Assert.True(typeof(TDst).TryGetDataKind(out DataKind dstDataKind)); + + return Conversions.Instance.GetStandardConversion( + TextType.Instance, NumberType.FromKind(dstDataKind), out bool identity); + } + } +} + + diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 50a7e55975..7e484bd386 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -7,13 +7,139 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using System; +using System.IO; using Xunit; using Xunit.Abstractions; namespace Microsoft.ML.EntryPoints.Tests { + public class TextLoaderTestPipe : TestDataPipeBase + { + public TextLoaderTestPipe(ITestOutputHelper output) + : base(output) + { + + } + + [Fact] + public void TestTextLoaderDataTypes() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + "127,-32768,-2147483648,-9223372036854775808", + "-128,32767,2147483647,9223372036854775807", + ",,," + }); + + try + { + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt1:I1:0 col=DvInt2:I2:1 col=DvInt4:I4:2 col=DvInt8:I8:3 sep=comma}", + }, logCurs: true); + + using (var cursor = data.GetRowCursor((a => true))) + { + var col1 = cursor.GetGetter(0); + var col2 = cursor.GetGetter(1); + var col3 = cursor.GetGetter(2); + var col4 = cursor.GetGetter(3); + + Assert.True(cursor.MoveNext()); + + sbyte[] sByteTargets = new sbyte[] { 127, -128, default}; + short[] shortTargets = new short[] { -32768, 32767, default }; + int[] intTargets = new int[] { -2147483648, 2147483647, default }; + long[] longTargets = new long[] { -9223372036854775808, 9223372036854775807, default }; + + int i = 0; + for (; i < sByteTargets.Length; i++) + { + sbyte sbyteValue = -1; + col1(ref sbyteValue); + Assert.Equal(sByteTargets[i], sbyteValue); + + short shortValue = -1; + col2(ref shortValue); + Assert.Equal(shortTargets[i], shortValue); + + int intValue = -1; + col3(ref intValue); + Assert.Equal(intTargets[i], intValue); + + long longValue = -1; + col4(ref longValue); + Assert.Equal(longTargets[i], longValue); + + if (i < sByteTargets.Length - 1) + Assert.True(cursor.MoveNext()); + else + Assert.False(cursor.MoveNext()); + } + + Assert.Equal(i, sByteTargets.Length); + } + } + catch + { + Assert.True(false, "Test failed."); + } + } + + [Fact] + public void TestTextLoaderInvalidLongMin() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + "-9223372036854775809" + + }); + + try + { + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt8:I8:0 sep=comma}", + }, logCurs: true); + } + catch(Exception ex) + { + Assert.Equal("Value could not be parsed from text to long.", ex.Message); + return; + } + + Assert.True(false, "Test failed."); + } + + [Fact] + public void TestTextLoaderInvalidLongMax() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + "9223372036854775808" + + }); + + try + { + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt8:I8:0 sep=comma}", + }, logCurs: true); + } + catch (Exception ex) + { + Assert.Equal("Value could not be parsed from text to long.", ex.Message); + return; + } + + Assert.True(false, "Test failed."); + } + } + public class TextLoaderTests : BaseTestClass { public TextLoaderTests(ITestOutputHelper output) @@ -21,7 +147,7 @@ public TextLoaderTests(ITestOutputHelper output) { } - + [Fact] public void ConstructorDoesntThrow() { From 9c128a3106be7d1e1262622f16419231ef69892e Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 25 Aug 2018 19:58:13 -0700 Subject: [PATCH 20/37] cleanup. --- test/Microsoft.ML.Tests/TextLoaderTests.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 7e484bd386..5b576d5c30 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -29,8 +29,8 @@ public void TestTextLoaderDataTypes() { string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); File.WriteAllLines(pathData, new string[] { - "127,-32768,-2147483648,-9223372036854775808", - "-128,32767,2147483647,9223372036854775807", + string.Format("{0},{1},{2},{3}", sbyte.MinValue, short.MinValue, int.MinValue, long.MinValue), + string.Format("{0},{1},{2},{3}", sbyte.MaxValue, short.MaxValue, int.MaxValue, long.MaxValue), ",,," }); @@ -50,10 +50,10 @@ public void TestTextLoaderDataTypes() Assert.True(cursor.MoveNext()); - sbyte[] sByteTargets = new sbyte[] { 127, -128, default}; - short[] shortTargets = new short[] { -32768, 32767, default }; - int[] intTargets = new int[] { -2147483648, 2147483647, default }; - long[] longTargets = new long[] { -9223372036854775808, 9223372036854775807, default }; + sbyte[] sByteTargets = new sbyte[] { sbyte.MinValue, sbyte.MaxValue, default}; + short[] shortTargets = new short[] { short.MinValue, short.MaxValue, default }; + int[] intTargets = new int[] { int.MinValue, int.MaxValue, default }; + long[] longTargets = new long[] { long.MinValue, long.MaxValue, default }; int i = 0; for (; i < sByteTargets.Length; i++) From 9c4eebf0dcb7c2e6c09dfbf3d10e5e2efd05b86a Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 25 Aug 2018 20:03:48 -0700 Subject: [PATCH 21/37] cleanup. --- .../UnitTests/DataTypes.cs | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index 2258ea6f57..cdab734c32 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -47,8 +47,9 @@ public void TXToI1() { mapper(ref src, ref dst); } - catch + catch(Exception ex) { + Assert.Equal("Value could not be parsed from text to sbyte.", ex.Message); error = true; } @@ -62,8 +63,9 @@ public void TXToI1() { mapper(ref src, ref dst); } - catch + catch(Exception ex) { + Assert.Equal("Value could not be parsed from text to sbyte.", ex.Message); error = true; } @@ -105,8 +107,9 @@ public void TXToI2() { mapper(ref src, ref dst); } - catch + catch(Exception ex) { + Assert.Equal("Value could not be parsed from text to short.", ex.Message); error = true; } @@ -120,8 +123,9 @@ public void TXToI2() { mapper(ref src, ref dst); } - catch + catch (Exception ex) { + Assert.Equal("Value could not be parsed from text to short.", ex.Message); error = true; } @@ -163,8 +167,9 @@ public void TXToI4() { mapper(ref src, ref dst); } - catch + catch (Exception ex) { + Assert.Equal("Value could not be parsed from text to int.", ex.Message); error = true; } @@ -178,8 +183,9 @@ public void TXToI4() { mapper(ref src, ref dst); } - catch + catch (Exception ex) { + Assert.Equal("Value could not be parsed from text to int.", ex.Message); error = true; } @@ -227,8 +233,9 @@ public void TXToI8() { mapper(ref src, ref dst); } - catch + catch (Exception ex) { + Assert.Equal("Value could not be parsed from text to long.", ex.Message); error = true; } From d7ba332762e9d7396292df7b1c674c4c29428a55 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 27 Aug 2018 12:51:02 -0700 Subject: [PATCH 22/37] Add IDV test for backward compatiblity with DvTypes. --- test/Microsoft.ML.TestFramework/TestCommandBase.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index 18b4abf0cf..018d6087ab 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -2028,5 +2028,14 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithValidationAnd Assert.True(outputPath.CheckEqualityNormalized()); Done(); } + + [Fact] + public void DataTypes() + { + string idvPath = GetDataPath("datatypes.idv"); + OutputPath textOutputPath = CreateOutputPath("datatypes.txt"); + TestCore("savedata", idvPath, "loader=binary", "saver=text", textOutputPath.Arg("dout")); + Assert.True(textOutputPath.CheckEquality()); + } } } From 63e45e75351eff48f7428d67c93c939d93014f8f Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 27 Aug 2018 12:51:32 -0700 Subject: [PATCH 23/37] baseline and idv file. --- .../Command/Datatypes-datatypes.txt | 18 ++++++++++++++++++ test/data/datatypes.idv | Bin 0 -> 804 bytes 2 files changed, 18 insertions(+) create mode 100644 test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt create mode 100644 test/data/datatypes.idv diff --git a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt new file mode 100644 index 0000000000..020a3e4ec7 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt @@ -0,0 +1,18 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=bl:BL:0 +#@ col=i1:I1:1 +#@ col=i2:I2:2 +#@ col=i4:I4:3 +#@ col=i8:I8:4 +#@ col=ts:TS:5 +#@ col=dto:DZ:6 +#@ col=dt:DT:7 +#@ col=tx:TX:8 +#@ } +bl i1 i2 i4 i8 ts dto dt tx +0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo +1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz + "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz + diff --git a/test/data/datatypes.idv b/test/data/datatypes.idv new file mode 100644 index 0000000000000000000000000000000000000000..d801f99dff9644b6c46c1cf7c5e9122501624ab0 GIT binary patch literal 804 zcmZ?v^bdFWCK#*&_&~`L-`y)c@XA-@_B)50U(xy^7(*l1t8Xk@~}|0aK8j s%FhEi>=v3Trj!ycgd(7VSLg~#DmX$?D@vfY{6dpva0+u_@bPs907N-kl>h($ literal 0 HcmV?d00001 From 681b3ef1e56b7d1385a83181c0f20ab3a7140863 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 27 Aug 2018 16:38:19 -0700 Subject: [PATCH 24/37] PR feedback. --- .../Command/Datatypes-datatypes.txt | 3 ++- .../Command/Datatypes-datatypes.txt | 19 ++++++++++++++++++ test/data/datatypes.idv | Bin 804 -> 809 bytes 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt diff --git a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt index 020a3e4ec7..e9936a9901 100644 --- a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt @@ -14,5 +14,6 @@ bl i1 i2 i4 i8 ts dto dt tx 0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz - "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz + "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" +9 8: diff --git a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt new file mode 100644 index 0000000000..e9936a9901 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt @@ -0,0 +1,19 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=bl:BL:0 +#@ col=i1:I1:1 +#@ col=i2:I2:2 +#@ col=i4:I4:3 +#@ col=i8:I8:4 +#@ col=ts:TS:5 +#@ col=dto:DZ:6 +#@ col=dt:DT:7 +#@ col=tx:TX:8 +#@ } +bl i1 i2 i4 i8 ts dto dt tx +0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo +1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz + "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" +9 8: + diff --git a/test/data/datatypes.idv b/test/data/datatypes.idv index d801f99dff9644b6c46c1cf7c5e9122501624ab0..155b537c1fed505be261a4d52ad6d87a592adae0 100644 GIT binary patch literal 809 zcmZ?v^>fx`8+^27Z9^S`TRh(2oO6$`2s+;1`x|Y`5-wE z-vG)N1hS2QI1oucNIn6|7Y4HBftV>Nhs~wTDL+4lfswTyDhU(^1Ex$v)?lZ~l2o8v zKZcwUt7l$`p&3waK8BnLNY2;@D7PL%&H^N7W&)Jkk1khI%n_2An;Kk@mMRB(i(R+K<(`Hv>e;1uS<;N$BI E0MkTZ9RL6T literal 804 zcmZ?v^bdFWCK#*&_&~`L-`y)c@XA-@_B)50U(xy^7(*l1t8Xk@~}|0aK8j s%FhEi>=v3Trj!ycgd(7VSLg~#DmX$?D@vfY{6dpva0+u_@bPs907N-kl>h($ From 758051369601a10db84fb35c428e380fb99dcb11 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 27 Aug 2018 18:15:48 -0700 Subject: [PATCH 25/37] disable test for Linux. --- test/Microsoft.ML.TestFramework/TestCommandBase.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index 018d6087ab..07732d1f1f 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft.ML.Runtime.Command; @@ -2032,6 +2033,10 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithValidationAnd [Fact] public void DataTypes() { + //Skip for linux because DATE/TIME format is different. + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + return; + string idvPath = GetDataPath("datatypes.idv"); OutputPath textOutputPath = CreateOutputPath("datatypes.txt"); TestCore("savedata", idvPath, "loader=binary", "saver=text", textOutputPath.Arg("dout")); From 8f8e5440ddad90cc9fbb3dd9465f1bcbc669aeb5 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 27 Aug 2018 21:18:45 -0700 Subject: [PATCH 26/37] PR feedback. --- src/Microsoft.ML.Data/Data/Conversion.cs | 1 + .../SingleDebug/Command/DataTypes-out.txt | 2 +- .../Command/Datatypes-datatypes.txt | 3 +- .../SingleRelease/Command/DataTypes-out.txt | 2 +- .../Command/Datatypes-datatypes.txt | 3 +- .../UnitTests/DataTypes.cs | 62 ++++++++++++- .../CopyColumnEstimatorTests.cs | 20 ++-- test/Microsoft.ML.Tests/TextLoaderTests.cs | 87 ++++++++---------- test/data/datatypes.idv | Bin 809 -> 801 bytes 9 files changed, 116 insertions(+), 64 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 84bee650ea..c6ce03fed6 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -1369,6 +1369,7 @@ private void TryParseSigned(long max, ref TX span, out long? result) { Contracts.Assert(max > 0); Contracts.Assert((max & (max + 1)) == 0); + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to signed numbers."); if (!span.HasChars) { diff --git a/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt b/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt index a2aaab4439..0ec588d71d 100644 --- a/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt +++ b/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt @@ -1 +1 @@ -Wrote 5 rows of length 9 +Wrote 4 rows of length 9 diff --git a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt index e7d128e400..278c1c9ade 100644 --- a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt @@ -14,6 +14,5 @@ bl i1 i2 i4 i8 ts dto dt tx 0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz - "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" + -128 -32768 -2147483648 -9223372036854775808 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" 9 0:0 - diff --git a/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt b/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt index a2aaab4439..0ec588d71d 100644 --- a/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt +++ b/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt @@ -1 +1 @@ -Wrote 5 rows of length 9 +Wrote 4 rows of length 9 diff --git a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt index e7d128e400..278c1c9ade 100644 --- a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt @@ -14,6 +14,5 @@ bl i1 i2 i4 i8 ts dto dt tx 0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz - "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" + -128 -32768 -2147483648 -9223372036854775808 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" 9 0:0 - diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index cdab734c32..866834e46d 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -71,11 +71,26 @@ public void TXToI1() Assert.True(error); - //5. Missing value as empty string in text to sbyte. + //5. Empty string in text to sbyte. src = default; dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); + + //6. Missing value as empty string in text to sbyte. + src = DvText.NA; + dst = -1; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + error = true; + } + + Assert.True(error); } [Fact] @@ -136,6 +151,21 @@ public void TXToI2() dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); + + //6. Empty string in text to sbyte. + src = DvText.NA; + dst = -1; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + error = true; + } + + Assert.True(error); } [Fact] @@ -196,6 +226,21 @@ public void TXToI4() dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); + + //6. Empty string in text to sbyte. + src = DvText.NA; + dst = -1; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + error = true; + } + + Assert.True(error); } [Fact] @@ -246,6 +291,21 @@ public void TXToI8() dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); + + //6. Empty string in text to sbyte. + src = DvText.NA; + dst = -1; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + error = true; + } + + Assert.True(error); } public ValueMapper GetMapper() diff --git a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs index 6b0a2adc38..f4bcab1660 100644 --- a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs @@ -171,16 +171,16 @@ private void ValidateCopyColumnTransformer(IDataView result) { using (var cursor = result.GetRowCursor(x => true)) { - DvInt4 avalue = 0; - DvInt4 bvalue = 0; - DvInt4 dvalue = 0; - DvInt4 evalue = 0; - DvInt4 fvalue = 0; - var aGetter = cursor.GetGetter(0); - var bGetter = cursor.GetGetter(1); - var dGetter = cursor.GetGetter(3); - var eGetter = cursor.GetGetter(4); - var fGetter = cursor.GetGetter(5); + int avalue = 0; + int bvalue = 0; + int dvalue = 0; + int evalue = 0; + int fvalue = 0; + var aGetter = cursor.GetGetter(0); + var bGetter = cursor.GetGetter(1); + var dGetter = cursor.GetGetter(3); + var eGetter = cursor.GetGetter(4); + var fGetter = cursor.GetGetter(5); while (cursor.MoveNext()) { aGetter(ref avalue); diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 5b576d5c30..b6cabef1a7 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -34,58 +34,51 @@ public void TestTextLoaderDataTypes() ",,," }); - try - { - var data = TestCore(pathData, true, - new[] { - "loader=Text{col=DvInt1:I1:0 col=DvInt2:I2:1 col=DvInt4:I4:2 col=DvInt8:I8:3 sep=comma}", - }, logCurs: true); - - using (var cursor = data.GetRowCursor((a => true))) - { - var col1 = cursor.GetGetter(0); - var col2 = cursor.GetGetter(1); - var col3 = cursor.GetGetter(2); - var col4 = cursor.GetGetter(3); + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt1:I1:0 col=DvInt2:I2:1 col=DvInt4:I4:2 col=DvInt8:I8:3 sep=comma}", + }, logCurs: true); - Assert.True(cursor.MoveNext()); + using (var cursor = data.GetRowCursor((a => true))) + { + var col1 = cursor.GetGetter(0); + var col2 = cursor.GetGetter(1); + var col3 = cursor.GetGetter(2); + var col4 = cursor.GetGetter(3); - sbyte[] sByteTargets = new sbyte[] { sbyte.MinValue, sbyte.MaxValue, default}; - short[] shortTargets = new short[] { short.MinValue, short.MaxValue, default }; - int[] intTargets = new int[] { int.MinValue, int.MaxValue, default }; - long[] longTargets = new long[] { long.MinValue, long.MaxValue, default }; + Assert.True(cursor.MoveNext()); - int i = 0; - for (; i < sByteTargets.Length; i++) - { - sbyte sbyteValue = -1; - col1(ref sbyteValue); - Assert.Equal(sByteTargets[i], sbyteValue); - - short shortValue = -1; - col2(ref shortValue); - Assert.Equal(shortTargets[i], shortValue); - - int intValue = -1; - col3(ref intValue); - Assert.Equal(intTargets[i], intValue); - - long longValue = -1; - col4(ref longValue); - Assert.Equal(longTargets[i], longValue); - - if (i < sByteTargets.Length - 1) - Assert.True(cursor.MoveNext()); - else - Assert.False(cursor.MoveNext()); - } + sbyte[] sByteTargets = new sbyte[] { sbyte.MinValue, sbyte.MaxValue, default}; + short[] shortTargets = new short[] { short.MinValue, short.MaxValue, default }; + int[] intTargets = new int[] { int.MinValue, int.MaxValue, default }; + long[] longTargets = new long[] { long.MinValue, long.MaxValue, default }; - Assert.Equal(i, sByteTargets.Length); + int i = 0; + for (; i < sByteTargets.Length; i++) + { + sbyte sbyteValue = -1; + col1(ref sbyteValue); + Assert.Equal(sByteTargets[i], sbyteValue); + + short shortValue = -1; + col2(ref shortValue); + Assert.Equal(shortTargets[i], shortValue); + + int intValue = -1; + col3(ref intValue); + Assert.Equal(intTargets[i], intValue); + + long longValue = -1; + col4(ref longValue); + Assert.Equal(longTargets[i], longValue); + + if (i < sByteTargets.Length - 1) + Assert.True(cursor.MoveNext()); + else + Assert.False(cursor.MoveNext()); } - } - catch - { - Assert.True(false, "Test failed."); + + Assert.Equal(i, sByteTargets.Length); } } diff --git a/test/data/datatypes.idv b/test/data/datatypes.idv index 15f12b97484b09de15d051c78a494ed29357198b..e2c7a8543b03bd9327f623c400e3030cfd5634ff 100644 GIT binary patch literal 801 zcmZ?v^NKR;AYEQRb zrZwl;fdvbJX2hSrap1y%13>)=JF1tP{$pZiX5(c9nk~=sN>cK`f=0%vX~4i^1Y#gy z17c9nvZ3(2(?wg3=kK>7SYwgM1KLHWW^z5|pGQmX;PQBXcezXA~J zBB_@K@|luy*j&n-^7C^T7+K1pl0b1VV9GRP4R)$5Nd?NaW5^k?dghfFngQjeW5}6+ zu7=L&m6cVq@dMhn;TB+CsftQVhun4b1K7*2}c! zJUg%;(2*IaRi5XSq~w7Gjf_*<8rgsfc2qAn{l~=4%+|{URG*rVkieE?#*x=>IrHaF z(W#*fz%T~`HXz9Y#EMWpJCF?uayBTR2gv3EVj(D>AIKH~VqYj<0Laz=VJIIY2jbg6 z`GP>U5fH~B=?BT@K>5Nzwmc9sCFQWWlsV<+=P)p`) Date: Mon, 27 Aug 2018 21:28:15 -0700 Subject: [PATCH 27/37] add more tests. --- .../UnitTests/DataTypes.cs | 8 +++--- test/Microsoft.ML.Tests/TextLoaderTests.cs | 26 ++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index 866834e46d..d8c33b4775 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -77,7 +77,7 @@ public void TXToI1() mapper(ref src, ref dst); Assert.Equal(default, dst); - //6. Missing value as empty string in text to sbyte. + //6. Missing value in text to sbyte. src = DvText.NA; dst = -1; try @@ -146,7 +146,7 @@ public void TXToI2() Assert.True(error); - //5. Missing value as empty string in text to short. + //5. Missing value in text to short. src = default; dst = -1; mapper(ref src, ref dst); @@ -221,7 +221,7 @@ public void TXToI4() Assert.True(error); - //5. Missing value as empty string in text to int. + //5. Missing value in text to int. src = default; dst = -1; mapper(ref src, ref dst); @@ -286,7 +286,7 @@ public void TXToI8() Assert.True(error); - //5. Missing value as empty string in text to long. + //5. Missing value in text to long. src = default; dst = -1; mapper(ref src, ref dst); diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index b6cabef1a7..11f82003fb 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -31,7 +31,7 @@ public void TestTextLoaderDataTypes() File.WriteAllLines(pathData, new string[] { string.Format("{0},{1},{2},{3}", sbyte.MinValue, short.MinValue, int.MinValue, long.MinValue), string.Format("{0},{1},{2},{3}", sbyte.MaxValue, short.MaxValue, int.MaxValue, long.MaxValue), - ",,," + "\"\",\"\",\"\",\"\"" }); var data = TestCore(pathData, true, @@ -82,6 +82,30 @@ public void TestTextLoaderDataTypes() } } + [Fact] + public void TestTextLoaderDataTypesMissing() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + ",,," + }); + + try + { + TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt1:I1:0 col=DvInt2:I2:1 col=DvInt4:I4:2 col=DvInt8:I8:3 sep=comma}", + }, logCurs: true); + } + catch(Exception ex) + { + Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + return; + } + + Assert.True(false, "Test failed."); + } + [Fact] public void TestTextLoaderInvalidLongMin() { From c844b4ed9b5e96d8378acd68509cb0475b505950 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 27 Aug 2018 23:16:17 -0700 Subject: [PATCH 28/37] PR feedback. --- src/Microsoft.ML.Core/Data/DateTime.cs | 2 +- src/Microsoft.ML.Data/Data/Conversion.cs | 37 +++++++++++++++++- .../SingleDebug/Command/DataTypes-out.txt | 2 +- .../Command/Datatypes-datatypes.txt | 1 + .../SingleRelease/Command/DataTypes-out.txt | 2 +- .../Command/Datatypes-datatypes.txt | 1 + test/Microsoft.ML.Tests/TextLoaderTests.cs | 24 ------------ test/data/datatypes.idv | Bin 801 -> 809 bytes 8 files changed, 41 insertions(+), 28 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs index d9f3930579..1e90a80f81 100644 --- a/src/Microsoft.ML.Core/Data/DateTime.cs +++ b/src/Microsoft.ML.Core/Data/DateTime.cs @@ -44,7 +44,7 @@ public DvDateTime(long ticks) [Conditional("DEBUG")] internal void AssertValid() { - Contracts.Assert((ulong)_ticks <= MaxTicks); + Contracts.Assert((ulong)_ticks <= MaxTicks || _ticks == long.MinValue); } public long Ticks diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index c6ce03fed6..59c7addc7f 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -1278,6 +1278,12 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) public bool TryParse(ref TX src, out I1 dst) { dst = default; + if (src.IsNA) + { + dst = I1.MinValue; + return true; + } + TryParseSigned(I1.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); @@ -1292,6 +1298,12 @@ public bool TryParse(ref TX src, out I1 dst) public bool TryParse(ref TX src, out I2 dst) { dst = default; + if (src.IsNA) + { + dst = I2.MinValue; + return true; + } + TryParseSigned(I2.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); @@ -1306,6 +1318,12 @@ public bool TryParse(ref TX src, out I2 dst) public bool TryParse(ref TX src, out I4 dst) { dst = default; + if (src.IsNA) + { + dst = I4.MinValue; + return true; + } + TryParseSigned(I4.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to int32."); Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); @@ -1320,6 +1338,12 @@ public bool TryParse(ref TX src, out I4 dst) public bool TryParse(ref TX src, out I8 dst) { dst = default; + if (src.IsNA) + { + dst = I8.MinValue; + return true; + } + TryParseSigned(I8.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); dst = (I8)res; @@ -1369,7 +1393,6 @@ private void TryParseSigned(long max, ref TX span, out long? result) { Contracts.Assert(max > 0); Contracts.Assert((max & (max + 1)) == 0); - Contracts.Check(!span.IsNA, "Missing text value cannot be converted to signed numbers."); if (!span.HasChars) { @@ -1502,6 +1525,9 @@ public bool TryParse(ref TX src, out DZ dst) // These throw an exception for unparsable and overflow values. private I1 ParseI1(ref TX src) { + if (src.IsNA) + return I1.MinValue; + TryParseSigned(I1.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); @@ -1510,6 +1536,9 @@ private I1 ParseI1(ref TX src) private I2 ParseI2(ref TX src) { + if (src.IsNA) + return I2.MinValue; + TryParseSigned(I2.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); @@ -1518,6 +1547,9 @@ private I2 ParseI2(ref TX src) private I4 ParseI4(ref TX src) { + if (src.IsNA) + return I4.MinValue; + TryParseSigned(I4.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to int."); Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); @@ -1526,6 +1558,9 @@ private I4 ParseI4(ref TX src) private I8 ParseI8(ref TX src) { + if (src.IsNA) + return I8.MinValue; + TryParseSigned(I8.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); return res.Value; diff --git a/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt b/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt index 0ec588d71d..a2aaab4439 100644 --- a/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt +++ b/test/BaselineOutput/SingleDebug/Command/DataTypes-out.txt @@ -1 +1 @@ -Wrote 4 rows of length 9 +Wrote 5 rows of length 9 diff --git a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt index 278c1c9ade..8815a6b6ee 100644 --- a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt @@ -16,3 +16,4 @@ bl i1 i2 i4 i8 ts dto dt tx 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz -128 -32768 -2147483648 -9223372036854775808 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" 9 0:0 + -128 -32768 -2147483648 -9223372036854775808 diff --git a/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt b/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt index 0ec588d71d..a2aaab4439 100644 --- a/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt +++ b/test/BaselineOutput/SingleRelease/Command/DataTypes-out.txt @@ -1 +1 @@ -Wrote 4 rows of length 9 +Wrote 5 rows of length 9 diff --git a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt index 278c1c9ade..8815a6b6ee 100644 --- a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt @@ -16,3 +16,4 @@ bl i1 i2 i4 i8 ts dto dt tx 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz -128 -32768 -2147483648 -9223372036854775808 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" 9 0:0 + -128 -32768 -2147483648 -9223372036854775808 diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 11f82003fb..8e4c298f82 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -82,30 +82,6 @@ public void TestTextLoaderDataTypes() } } - [Fact] - public void TestTextLoaderDataTypesMissing() - { - string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); - File.WriteAllLines(pathData, new string[] { - ",,," - }); - - try - { - TestCore(pathData, true, - new[] { - "loader=Text{col=DvInt1:I1:0 col=DvInt2:I2:1 col=DvInt4:I4:2 col=DvInt8:I8:3 sep=comma}", - }, logCurs: true); - } - catch(Exception ex) - { - Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); - return; - } - - Assert.True(false, "Test failed."); - } - [Fact] public void TestTextLoaderInvalidLongMin() { diff --git a/test/data/datatypes.idv b/test/data/datatypes.idv index e2c7a8543b03bd9327f623c400e3030cfd5634ff..288f75e0ea49618068ef55082e4df88f199b298a 100644 GIT binary patch literal 809 zcmZ?v^~4~P(DA9EdsVrT}Gn~x!90+KT}0?Mt&kh1{EnVA6P_M^*{6mx`R=B5T0B<2C7&!b5* zrNKR;AYEQRb zrZwl;fdvbJX2hSrap1y%13>)=JF1tP{$pZiX5(c9nk~=sN>cK`f=0%vX~4i^1Y#gy z17c9nvZ3(2(?wg3=kK>7SYwgM1KLHWW^z5|pGQmX;PQBXcezXA~J zBB_@K@|luy*j&n-^7C^T7+K1pl0b1VV9GRP4R)$5Nd?NaW5^k?dghfFngQjeW5}6+ z Date: Tue, 28 Aug 2018 09:29:52 -0700 Subject: [PATCH 29/37] add test for parquet loader. --- .../SavePipe/TestParquetNull-Data.txt | 9 ++++++++ .../SavePipe/TestParquetNull-Schema.txt | 4 ++++ .../TestParquetPrimitiveDataTypes-Data.txt | 21 ++++++++++++++++++ .../TestParquetPrimitiveDataTypes-Schema.txt | 16 +++++++++++++ .../SavePipe/TestParquetNull-Data.txt | 9 ++++++++ .../SavePipe/TestParquetNull-Schema.txt | 4 ++++ .../TestParquetPrimitiveDataTypes-Data.txt | 21 ++++++++++++++++++ .../TestParquetPrimitiveDataTypes-Schema.txt | 16 +++++++++++++ .../DataPipe/TestDataPipe.cs | 20 +++++++++++++++++ .../Microsoft.ML.TestFramework.csproj | 1 + .../TestInitialization.cs | 1 + test/data/Parquet/alltypes.parquet | Bin 0 -> 1419 bytes test/data/Parquet/test-null.parquet | Bin 0 -> 349 bytes 13 files changed, 122 insertions(+) create mode 100644 test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt create mode 100644 test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt create mode 100644 test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt create mode 100644 test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt create mode 100644 test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt create mode 100644 test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt create mode 100644 test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt create mode 100644 test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt create mode 100644 test/data/Parquet/alltypes.parquet create mode 100644 test/data/Parquet/test-null.parquet diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt new file mode 100644 index 0000000000..c7049cd12a --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt @@ -0,0 +1,9 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=foo:I4:0 +#@ col=bar:I4:1 +#@ } +foo bar +1 2 +1 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt new file mode 100644 index 0000000000..8fa619c171 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt @@ -0,0 +1,4 @@ +---- ParquetLoader ---- +2 columns: + foo: I4 + bar: I4 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt new file mode 100644 index 0000000000..bbdd0c18ab --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt @@ -0,0 +1,21 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=Id:I4:0 +#@ col=Timestamp:DZ:1 +#@ col=Message:TX:2 +#@ col=Data:TX:3 +#@ col=IsDeleted:BL:4 +#@ col=Amount:R4:5 +#@ col=TotalAmount:R8:6 +#@ col=Counter:I8:7 +#@ col=Amount2:R8:8 +#@ col=Flag:U1:9 +#@ col=Flag2:I1:10 +#@ col=Flag3:I2:11 +#@ col=Flag4:U2:12 +#@ col=Flag5:TS:13 +#@ } +Id Timestamp Message Data IsDeleted Amount TotalAmount Counter Amount2 Flag Flag2 Flag3 Flag4 Flag5 +1 "2000-01-01T01:01:01.0000000+00:00" Record1 SomeData3 0 125.4 400 300000 3331313 3 -3 -600 600 "3100.00:00:00.1000000" +1 "2000-12-31T23:59:59.9990000+00:00" Record2 SomeData4 0 126.4 500 400000 4331313 4 -4 -700 700 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt new file mode 100644 index 0000000000..213ef605e6 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt @@ -0,0 +1,16 @@ +---- ParquetLoader ---- +14 columns: + Id: I4 + Timestamp: DateTimeZone + Message: Text + Data: Text + IsDeleted: Bool + Amount: R4 + TotalAmount: R8 + Counter: I8 + Amount2: R8 + Flag: U1 + Flag2: I1 + Flag3: I2 + Flag4: U2 + Flag5: TimeSpan diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt new file mode 100644 index 0000000000..c7049cd12a --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt @@ -0,0 +1,9 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=foo:I4:0 +#@ col=bar:I4:1 +#@ } +foo bar +1 2 +1 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt new file mode 100644 index 0000000000..8fa619c171 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt @@ -0,0 +1,4 @@ +---- ParquetLoader ---- +2 columns: + foo: I4 + bar: I4 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt new file mode 100644 index 0000000000..bbdd0c18ab --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt @@ -0,0 +1,21 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=Id:I4:0 +#@ col=Timestamp:DZ:1 +#@ col=Message:TX:2 +#@ col=Data:TX:3 +#@ col=IsDeleted:BL:4 +#@ col=Amount:R4:5 +#@ col=TotalAmount:R8:6 +#@ col=Counter:I8:7 +#@ col=Amount2:R8:8 +#@ col=Flag:U1:9 +#@ col=Flag2:I1:10 +#@ col=Flag3:I2:11 +#@ col=Flag4:U2:12 +#@ col=Flag5:TS:13 +#@ } +Id Timestamp Message Data IsDeleted Amount TotalAmount Counter Amount2 Flag Flag2 Flag3 Flag4 Flag5 +1 "2000-01-01T01:01:01.0000000+00:00" Record1 SomeData3 0 125.4 400 300000 3331313 3 -3 -600 600 "3100.00:00:00.1000000" +1 "2000-12-31T23:59:59.9990000+00:00" Record2 SomeData4 0 126.4 500 400000 4331313 4 -4 -700 700 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt new file mode 100644 index 0000000000..213ef605e6 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Schema.txt @@ -0,0 +1,16 @@ +---- ParquetLoader ---- +14 columns: + Id: I4 + Timestamp: DateTimeZone + Message: Text + Data: Text + IsDeleted: Bool + Amount: R4 + TotalAmount: R8 + Counter: I8 + Amount2: R8 + Flag: U1 + Flag2: I1 + Flag3: I2 + Flag4: U2 + Flag5: TimeSpan diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index c598879795..ba0749318b 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -145,4 +145,24 @@ public void TestLdaTransformEmptyDocumentException() Assert.True(false, "The LDA transform does not throw expected error on empty documents."); } } + + public sealed partial class TestDataPipe : TestDataPipeBase + { + + [Fact] + public void TestParquetPrimitiveDataTypes() + { + string pathData = GetDataPath(@"..\data\Parquet", "alltypes.parquet"); + TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }); + Done(); + } + + [Fact] + public void TestParquetNull() + { + string pathData = GetDataPath(@"..\data\Parquet", "test-null.parquet"); + TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + Done(); + } + } } diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index a80b2df83e..454d1d7a31 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -9,6 +9,7 @@ + diff --git a/test/Microsoft.ML.TestFramework/TestInitialization.cs b/test/Microsoft.ML.TestFramework/TestInitialization.cs index ebe0eb0a79..50d54c9f21 100644 --- a/test/Microsoft.ML.TestFramework/TestInitialization.cs +++ b/test/Microsoft.ML.TestFramework/TestInitialization.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Xunit; using Xunit.Abstractions; namespace Microsoft.ML.Runtime.RunTests diff --git a/test/data/Parquet/alltypes.parquet b/test/data/Parquet/alltypes.parquet new file mode 100644 index 0000000000000000000000000000000000000000..e47e6b6e5587d5fb458fee85eab343b29e55af16 GIT binary patch literal 1419 zcmaKsPe>GD6u_V3=kB_RDfNcg8ft@PDYI?cU16aPnpR4twn+yKnNg=QT>tG@2ZIRG zFfBre5W-WGMG^D|%#&G!9y*kjvlH4Omt8*)AF(6KWSfC-}Cur zcm3T`_p32+9Rwavp7olE9Ol?30$iEn-lmvV7DjBXJssa~OneHRzX|UK$|b7-`B-Ty z*U6Cj=FcOn?hpkzzv{>M(}Q!Vf&io5)+5$#ytvmkwIk5qnZh>Q`ZyAXzijDsmboo2 z7TfA|vV3ZA?6*dj(-N83Fztf%*N&H4NepxClTM_7)Y%{7%-XAI3kyz&ghnbY&4mHm zYSq6czwLte?twyP)>}>Nk*8iTxO3lm*%AYoZ(H$@BFn`Eqf1Lm51k4W{fs;?Muzh& zt}VtErLDDH)ppi5YfbbQuREp26chnD2?VaBo(7j5(kmNcYFtxdEr1ykCpZtQ@wjqQ zm1uDFm=!6SLNlzZCGjG)QPosRby4gnh&8u2X@HsNMG)3hHfu_w`A%}yar0&26v)jmk-hH-j`i2uhohqLAcGio%khFDB4phMOnl0M9(4XV&Kuk@Ms^ zT;S1hNUFVZwOU<6G1Y@xD~_bh=FVQ zHX?+FQ$>Wt8p~ZWoUME*VRE{dh%U>qI%fz$zKR$L( OkUtDF)b$Gb4F3ax_YjN# literal 0 HcmV?d00001 diff --git a/test/data/Parquet/test-null.parquet b/test/data/Parquet/test-null.parquet new file mode 100644 index 0000000000000000000000000000000000000000..a4c8a943b312b4b0d238e5b4dac44accbe799e04 GIT binary patch literal 349 zcmZXQu};G<6h)oHR1p$P*s>)L9v}j#iDJbu8M?4BRvj1^knF~lQUwgA<;(apegHdk zh{WKl>-Wy{yYl*BBM?9g@g;)6EKtq}OaK6Xih~1}G7Xw#u^dm}t`M0I#6E#hag_7x zX&{M;Qo4GC-nj#U;c}CbZ0Fp`8SMdvl+k68vp}{vSkMVTDTeq3n<8KR%Tw%UYvO+7f>txfkjW$~&yGClcwS9L7enNzG_}E|F7&hPl literal 0 HcmV?d00001 From f05e2f1a196799c4371765bf1f58608c77afac01 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 28 Aug 2018 09:51:36 -0700 Subject: [PATCH 30/37] Update parquet tests. --- .../DataPipe/TestDataPipe.cs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index ba0749318b..7809c9ea52 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -161,7 +161,19 @@ public void TestParquetPrimitiveDataTypes() public void TestParquetNull() { string pathData = GetDataPath(@"..\data\Parquet", "test-null.parquet"); - TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + bool exception = false; + try + { + TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + } + catch(Exception ex) + { + Assert.Equal("Nullable object must have a value.", ex.Message); + exception = true; + } + + Assert.True(exception, "Test failed because control reached here without an expected exception for nullable values."); + Done(); } } From 33b7150c7b953ca6346a22d08e1fec36ab7642c3 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 28 Aug 2018 10:13:19 -0700 Subject: [PATCH 31/37] PR feedback. --- .../DataLoadSave/Binary/BinaryLoader.cs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 7bc0a8d2ad..329a12d20d 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -731,6 +731,12 @@ public void GetMetadata(string kind, int col, ref TValue value) /// private const ulong ReaderVersion = MissingTextVersion; + /// + /// The first version that removes DvTypes and uses .NET standard + /// data types. + /// + private const ulong StandardDataTypesVersion = 0x0001000100010006; + /// /// The first version of the format that accomodated DvText.NA. /// @@ -1090,10 +1096,10 @@ private unsafe Header InitHeader() throw _host.Except("Cannot read version {0} data, earliest that can be handled is {1}", Header.VersionToString(header.CompatibleVersion), Header.VersionToString(MetadataVersion)); } - if (header.CompatibleVersion > ReaderVersion) + if (header.CompatibleVersion > StandardDataTypesVersion) { throw _host.Except("Cannot read version {0} data, latest that can be handled is {1}", - Header.VersionToString(header.CompatibleVersion), Header.VersionToString(ReaderVersion)); + Header.VersionToString(header.CompatibleVersion), Header.VersionToString(StandardDataTypesVersion)); } _host.CheckDecode(header.RowCount >= 0, "Row count cannot be negative"); From d0287023b9cb502993e9bd43a8dfe488b6239d1d Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 28 Aug 2018 12:58:27 -0700 Subject: [PATCH 32/37] fix build. --- test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 7809c9ea52..06cd4a71e7 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -152,7 +152,7 @@ public sealed partial class TestDataPipe : TestDataPipeBase [Fact] public void TestParquetPrimitiveDataTypes() { - string pathData = GetDataPath(@"..\data\Parquet", "alltypes.parquet"); + string pathData = GetDataPath(@"Parquet", "alltypes.parquet"); TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }); Done(); } @@ -160,7 +160,7 @@ public void TestParquetPrimitiveDataTypes() [Fact] public void TestParquetNull() { - string pathData = GetDataPath(@"..\data\Parquet", "test-null.parquet"); + string pathData = GetDataPath(@"Parquet", "test-null.parquet"); bool exception = false; try { From a7cd3d83c2b21dafabd50352c70c6c15e88e44d1 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 30 Aug 2018 12:01:03 -0700 Subject: [PATCH 33/37] PR feedback. --- src/Microsoft.ML.Data/Data/Conversion.cs | 66 +++++++------------ .../DataLoadSave/Binary/BinaryLoader.cs | 6 +- .../DataLoadSave/Binary/Header.cs | 5 +- .../UnitTests/DataTypes.cs | 16 ++--- .../UnitTests/TestEntryPoints.cs | 2 +- test/Microsoft.ML.Tests/TextLoaderTests.cs | 1 - 6 files changed, 38 insertions(+), 58 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 59c7addc7f..2fe524827f 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -1035,6 +1035,7 @@ public void Convert(ref BL src, ref SB dst) /// public bool TryParse(ref TX src, out U1 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref src, out res) || res > U1.MaxValue) { @@ -1050,6 +1051,7 @@ public bool TryParse(ref TX src, out U1 dst) /// public bool TryParse(ref TX src, out U2 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref src, out res) || res > U2.MaxValue) { @@ -1065,6 +1067,7 @@ public bool TryParse(ref TX src, out U2 dst) /// public bool TryParse(ref TX src, out U4 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref src, out res) || res > U4.MaxValue) { @@ -1080,12 +1083,7 @@ public bool TryParse(ref TX src, out U4 dst) /// public bool TryParse(ref TX src, out U8 dst) { - if (src.IsNA) - { - dst = 0; - return false; - } - + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); int ichMin; int ichLim; string text = src.GetRawUnderlyingBufferInfo(out ichMin, out ichLim); @@ -1203,6 +1201,8 @@ private bool IsStdMissing(ref TX src) /// public bool TryParseKey(ref TX src, U8 min, U8 max, out U8 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); + Contracts.Check(!IsStdMissing(ref src), "Missing text value cannot be converted to unsigned integer type."); Contracts.Assert(min <= max); // This simply ensures we don't have min == 0 and max == U8.MaxValue. This is illegal since @@ -1227,7 +1227,7 @@ public bool TryParseKey(ref TX src, U8 min, U8 max, out U8 dst) { dst = 0; // Return true only for standard forms for NA. - return IsStdMissing(ref src); + return false; } if (min > uu || uu > max) @@ -1277,13 +1277,9 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) /// public bool TryParse(ref TX src, out I1 dst) { - dst = default; - if (src.IsNA) - { - dst = I1.MinValue; - return true; - } + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + dst = default; TryParseSigned(I1.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); @@ -1297,13 +1293,9 @@ public bool TryParse(ref TX src, out I1 dst) /// public bool TryParse(ref TX src, out I2 dst) { - dst = default; - if (src.IsNA) - { - dst = I2.MinValue; - return true; - } + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + dst = default; TryParseSigned(I2.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); @@ -1317,13 +1309,9 @@ public bool TryParse(ref TX src, out I2 dst) /// public bool TryParse(ref TX src, out I4 dst) { - dst = default; - if (src.IsNA) - { - dst = I4.MinValue; - return true; - } + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + dst = default; TryParseSigned(I4.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to int32."); Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); @@ -1337,13 +1325,9 @@ public bool TryParse(ref TX src, out I4 dst) /// public bool TryParse(ref TX src, out I8 dst) { - dst = default; - if (src.IsNA) - { - dst = I8.MinValue; - return true; - } + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + dst = default; TryParseSigned(I8.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); dst = (I8)res; @@ -1525,9 +1509,7 @@ public bool TryParse(ref TX src, out DZ dst) // These throw an exception for unparsable and overflow values. private I1 ParseI1(ref TX src) { - if (src.IsNA) - return I1.MinValue; - + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); TryParseSigned(I1.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); @@ -1536,9 +1518,7 @@ private I1 ParseI1(ref TX src) private I2 ParseI2(ref TX src) { - if (src.IsNA) - return I2.MinValue; - + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); TryParseSigned(I2.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); @@ -1547,9 +1527,7 @@ private I2 ParseI2(ref TX src) private I4 ParseI4(ref TX src) { - if (src.IsNA) - return I4.MinValue; - + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); TryParseSigned(I4.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to int."); Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); @@ -1558,9 +1536,7 @@ private I4 ParseI4(ref TX src) private I8 ParseI8(ref TX src) { - if (src.IsNA) - return I8.MinValue; - + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); TryParseSigned(I8.MaxValue, ref src, out long? res); Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); return res.Value; @@ -1571,6 +1547,7 @@ private I8 ParseI8(ref TX src) // unsigned integer types. private U1 ParseU1(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; @@ -1581,6 +1558,7 @@ private U1 ParseU1(ref TX span) private U2 ParseU2(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; @@ -1591,6 +1569,7 @@ private U2 ParseU2(ref TX span) private U4 ParseU4(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; @@ -1601,6 +1580,7 @@ private U4 ParseU4(ref TX span) private U8 ParseU8(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 329a12d20d..582212738a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -729,7 +729,7 @@ public void GetMetadata(string kind, int col, ref TValue value) /// /// Upper inclusive bound of versions this reader can read. /// - private const ulong ReaderVersion = MissingTextVersion; + private const ulong ReaderVersion = StandardDataTypesVersion; /// /// The first version that removes DvTypes and uses .NET standard @@ -1096,10 +1096,10 @@ private unsafe Header InitHeader() throw _host.Except("Cannot read version {0} data, earliest that can be handled is {1}", Header.VersionToString(header.CompatibleVersion), Header.VersionToString(MetadataVersion)); } - if (header.CompatibleVersion > StandardDataTypesVersion) + if (header.CompatibleVersion > ReaderVersion) { throw _host.Except("Cannot read version {0} data, latest that can be handled is {1}", - Header.VersionToString(header.CompatibleVersion), Header.VersionToString(StandardDataTypesVersion)); + Header.VersionToString(header.CompatibleVersion), Header.VersionToString(ReaderVersion)); } _host.CheckDecode(header.RowCount >= 0, "Row count cannot be negative"); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs index 36186cf7af..b552ab6523 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs @@ -34,8 +34,9 @@ public struct Header //public const ulong WriterVersion = 0x0001000100010002; // Codec changes. //public const ulong WriterVersion = 0x0001000100010003; // Slot names. //public const ulong WriterVersion = 0x0001000100010004; // Column metadata. - public const ulong WriterVersion = 0x0001000100010005; // "NA" DvText support. - public const ulong CanBeReadByVersion = 0x0001000100010005; + //public const ulong WriterVersion = 0x0001000100010005; // "NA" DvText support. + public const ulong WriterVersion = 0x0001000100010006; // Replace DvTypes with .NET Standard data types. + public const ulong CanBeReadByVersion = 0x0001000100010006; internal static string VersionToString(ulong v) { diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index d8c33b4775..51543d5d98 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -19,7 +19,7 @@ public DataTypesTest(ITestOutputHelper helper) private readonly static Conversions _conv = Conversions.Instance; [Fact] - public void TXToI1() + public void TXToSByte() { var mapper = GetMapper(); @@ -86,7 +86,7 @@ public void TXToI1() } catch (Exception ex) { - Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); error = true; } @@ -94,7 +94,7 @@ public void TXToI1() } [Fact] - public void TXToI2() + public void TXToShort() { var mapper = GetMapper(); @@ -161,7 +161,7 @@ public void TXToI2() } catch (Exception ex) { - Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); error = true; } @@ -169,7 +169,7 @@ public void TXToI2() } [Fact] - public void TXToI4() + public void TXToInt() { var mapper = GetMapper(); @@ -236,7 +236,7 @@ public void TXToI4() } catch (Exception ex) { - Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); error = true; } @@ -244,7 +244,7 @@ public void TXToI4() } [Fact] - public void TXToI8() + public void TXToLong() { var mapper = GetMapper(); @@ -301,7 +301,7 @@ public void TXToI8() } catch (Exception ex) { - Assert.Equal("Missing text value cannot be converted to signed numbers.", ex.Message); + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); error = true; } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 468df84446..37a0e7772c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -66,7 +66,7 @@ private IDataView GetBreastCancerDataviewWithTextColumns() { new TextLoader.Column("Label", type: null, 0), new TextLoader.Column("F1", DataKind.Text, 1), - new TextLoader.Column("F2", DataKind.I4, 2), + new TextLoader.Column("F2", DataKind.R4, 2), new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) }) } }, diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 8e4c298f82..82a2b6192d 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -113,7 +113,6 @@ public void TestTextLoaderInvalidLongMax() string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); File.WriteAllLines(pathData, new string[] { "9223372036854775808" - }); try From ad37fb56cf59bf2df701b57144bd6b10498f71a5 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 30 Aug 2018 12:25:32 -0700 Subject: [PATCH 34/37] cleanup. --- src/Microsoft.ML.Data/Data/Conversion.cs | 2 +- .../UnitTests/DataTypes.cs | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 2fe524827f..e39a0242e6 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -1196,7 +1196,7 @@ private bool IsStdMissing(ref TX src) /// Utility to assist in parsing key-type values. The min and max values define /// the legal input value bounds. The output dst value is "normalized" so min is /// mapped to 1, max is mapped to 1 + (max - min). - /// Missing values are mapped to zero with a true return. + /// Exception is thrown for missing values. /// Unparsable or out of range values are mapped to zero with a false return. /// public bool TryParseKey(ref TX src, U8 min, U8 max, out U8 dst) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index 51543d5d98..94a0459a35 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -146,15 +146,16 @@ public void TXToShort() Assert.True(error); - //5. Missing value in text to short. + //5. Empty value in text to short. src = default; dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); - //6. Empty string in text to sbyte. + //6. Missing string in text to sbyte. src = DvText.NA; dst = -1; + error = false; try { mapper(ref src, ref dst); @@ -221,15 +222,16 @@ public void TXToInt() Assert.True(error); - //5. Missing value in text to int. + //5. Empty value in text to int. src = default; dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); - //6. Empty string in text to sbyte. + //6. Missing string in text to sbyte. src = DvText.NA; dst = -1; + error = false; try { mapper(ref src, ref dst); @@ -286,13 +288,14 @@ public void TXToLong() Assert.True(error); - //5. Missing value in text to long. + //5. Empty value in text to long. src = default; dst = -1; mapper(ref src, ref dst); Assert.Equal(default, dst); - //6. Empty string in text to sbyte. + //6. Missing string in text to sbyte. + error = false; src = DvText.NA; dst = -1; try From 078cd8cbe7c96ace3651a070454c403a098b8675 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 30 Aug 2018 20:40:10 -0700 Subject: [PATCH 35/37] resolve merge conflict. --- .../DataPipe/TestDataPipeBase.cs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index d189e627f5..7070e556b7 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -983,19 +983,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -1029,19 +1029,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: From cd7de9f3533f5e46b04d4e04133240b1da290f6b Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 3 Sep 2018 17:06:14 -0700 Subject: [PATCH 36/37] PR feedback. --- .../SavePipe/TestParquetNull-Data.txt | 9 ----- .../SavePipe/TestParquetNull-Schema.txt | 4 --- .../TestParquetPrimitiveDataTypes-Data.txt | 2 +- .../SavePipe/TestParquetNull-Data.txt | 9 ----- .../SavePipe/TestParquetNull-Schema.txt | 4 --- .../TestParquetPrimitiveDataTypes-Data.txt | 2 +- .../DataPipe/Parquet.cs | 14 +++++++- .../DataPipe/TestDataPipe.cs | 32 ------------------ test/data/datatypes.idv | Bin 809 -> 809 bytes 9 files changed, 15 insertions(+), 61 deletions(-) delete mode 100644 test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt delete mode 100644 test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt delete mode 100644 test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt delete mode 100644 test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt deleted file mode 100644 index c7049cd12a..0000000000 --- a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt +++ /dev/null @@ -1,9 +0,0 @@ -#@ TextLoader{ -#@ header+ -#@ sep=tab -#@ col=foo:I4:0 -#@ col=bar:I4:1 -#@ } -foo bar -1 2 -1 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt deleted file mode 100644 index 8fa619c171..0000000000 --- a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt +++ /dev/null @@ -1,4 +0,0 @@ ----- ParquetLoader ---- -2 columns: - foo: I4 - bar: I4 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt index af1e19e1cc..85a3d35b4b 100644 --- a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt +++ b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt @@ -11,5 +11,5 @@ #@ col=string:TX:7 #@ } sbyte short int long bool DateTimeOffset Interval string - 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" +-128 -32768 -2147483648 -9223372036854775808 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" 127 32767 2147483647 9223372036854775807 0 "2018-09-01T19:53:18.3110000+00:00" "31.00:00:00.0010000" """""" diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt deleted file mode 100644 index c7049cd12a..0000000000 --- a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt +++ /dev/null @@ -1,9 +0,0 @@ -#@ TextLoader{ -#@ header+ -#@ sep=tab -#@ col=foo:I4:0 -#@ col=bar:I4:1 -#@ } -foo bar -1 2 -1 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt deleted file mode 100644 index 8fa619c171..0000000000 --- a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt +++ /dev/null @@ -1,4 +0,0 @@ ----- ParquetLoader ---- -2 columns: - foo: I4 - bar: I4 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt index af1e19e1cc..85a3d35b4b 100644 --- a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt +++ b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt @@ -11,5 +11,5 @@ #@ col=string:TX:7 #@ } sbyte short int long bool DateTimeOffset Interval string - 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" +-128 -32768 -2147483648 -9223372036854775808 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" 127 32767 2147483647 9223372036854775807 0 "2018-09-01T19:53:18.3110000+00:00" "31.00:00:00.0010000" """""" diff --git a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs index f5be433b3e..ace98c93f9 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs @@ -33,7 +33,19 @@ public void TestParquetPrimitiveDataTypes() public void TestParquetNull() { string pathData = GetDataPath(@"Parquet", "test-null.parquet"); - TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + bool exception = false; + try + { + TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + } + catch (Exception ex) + { + Assert.Equal("Nullable object must have a value.", ex.Message); + exception = true; + } + + Assert.True(exception, "Test failed because control reached here without an expected exception for nullable values."); + Done(); } } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 1adfd71f1b..3abff3e560 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -243,36 +243,4 @@ public void TestLdaTransformEmptyDocumentException() Assert.True(false, "The LDA transform does not throw expected error on empty documents."); } } - - public sealed partial class TestDataPipe : TestDataPipeBase - { - - [Fact] - public void TestParquetPrimitiveDataTypes() - { - string pathData = GetDataPath(@"Parquet", "alltypes.parquet"); - TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }); - Done(); - } - - [Fact] - public void TestParquetNull() - { - string pathData = GetDataPath(@"Parquet", "test-null.parquet"); - bool exception = false; - try - { - TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); - } - catch(Exception ex) - { - Assert.Equal("Nullable object must have a value.", ex.Message); - exception = true; - } - - Assert.True(exception, "Test failed because control reached here without an expected exception for nullable values."); - - Done(); - } - } } diff --git a/test/data/datatypes.idv b/test/data/datatypes.idv index 288f75e0ea49618068ef55082e4df88f199b298a..15f12b97484b09de15d051c78a494ed29357198b 100644 GIT binary patch delta 216 zcmZ3&d*86U|>kj*db%o zIkBcK`f=0%vZH;U|1v{#j zoBm^BXJ+eVlA4^&sFJA2$N&cHK#~=R*`RzLAe#$_g`j+XAX@~8eW82-AX@{3p?r`W Uh;IYs3j*0jKpeaI1EVD)0FvM^SO5S3 delta 216 zcmZ3kswA>8GJpX)kYoko1Sp>e$mRlKUMQa*$QA)&87N->$kqU2TPPnS V2jc5M`GP>U5fBG%{=jI-2mqLIFj@cr From 9d4f4d6627b703efa7a18cf4fbd6824bad0aa21c Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 7 Sep 2018 11:38:46 -0700 Subject: [PATCH 37/37] merge master. --- test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 603d9dcd2f..e29407cb40 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -165,7 +165,7 @@ public void AssertStaticKeys() var col1 = RowColumnUtils.GetColumn("stay", new KeyType(DataKind.U4, 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1)); // Next the case where those values are ints. - var metaValues2 = new VBuffer(3, new DvInt4[] { 1, 2, 3, 4 }); + var metaValues2 = new VBuffer(3, new int[] { 1, 2, 3, 4 }); var meta2 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2); var value2 = new VBuffer(2, 0, null, null); var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(DataKind.U1, 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2));