diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs index 30670cd96e..d956532921 100644 --- a/src/Microsoft.ML.PCA/PcaTransform.cs +++ b/src/Microsoft.ML.PCA/PcaTransform.cs @@ -1,12 +1,12 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; +using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -15,48 +15,48 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Transforms; -[assembly: LoadableClass(PcaTransform.Summary, typeof(PcaTransform), typeof(PcaTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(PcaTransform.Summary, typeof(IDataTransform), typeof(PcaTransform), typeof(PcaTransform.Arguments), typeof(SignatureDataTransform), PcaTransform.UserName, PcaTransform.LoaderSignature, PcaTransform.ShortName)] -[assembly: LoadableClass(PcaTransform.Summary, typeof(PcaTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(PcaTransform.Summary, typeof(IDataTransform), typeof(PcaTransform), null, typeof(SignatureLoadDataTransform), + PcaTransform.UserName, PcaTransform.LoaderSignature)] + +[assembly: LoadableClass(PcaTransform.Summary, typeof(PcaTransform), null, typeof(SignatureLoadModel), + PcaTransform.UserName, PcaTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(PcaTransform), null, typeof(SignatureLoadRowMapper), PcaTransform.UserName, PcaTransform.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(PcaTransform), null, typeof(SignatureEntryPointModule), PcaTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Transforms { /// - public sealed class PcaTransform : OneToOneTransformBase + public sealed class PcaTransform : OneToOneTransformerBase { - internal static class Defaults - { - public const string WeightColumn = null; - public const int Rank = 20; - public const int Oversampling = 20; - public const bool Center = true; - public const int Seed = 0; - } - public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; [Argument(ArgumentType.Multiple, HelpText = "The name of the weight column", ShortName = "weight", Purpose = SpecialPurpose.ColumnName)] - public string WeightColumn = Defaults.WeightColumn; + public string WeightColumn = PcaEstimator.Defaults.WeightColumn; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k")] - public int Rank = Defaults.Rank; + public int Rank = PcaEstimator.Defaults.Rank; [Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", ShortName = "over")] - public int Oversampling = Defaults.Oversampling; + public int Oversampling = PcaEstimator.Defaults.Oversampling; [Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean")] - public bool Center = Defaults.Center; + public bool Center = PcaEstimator.Defaults.Center; [Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation")] - public int Seed = Defaults.Seed; + public int Seed = PcaEstimator.Defaults.Seed; } public class Column : OneToOneColumn @@ -98,22 +98,64 @@ public bool TryUnparse(StringBuilder sb) } } + public sealed class ColumnInfo + { + public readonly string Input; + public readonly string Output; + public readonly string WeightColumn; + public readonly int Rank; + public readonly int Oversampling; + public readonly bool Center; + public readonly int? Seed; + + /// + /// Describes how the transformer handles one column pair. + /// + /// The column to apply PCA to. + /// The output column that contains PCA values. + /// The name of the weight column. + /// The number of components in the PCA. + /// Oversampling parameter for randomized PCA training. + /// If enabled, data is centered to be zero mean. + /// The seed for random number generation. + public ColumnInfo(string input, + string output, + string weightColumn = PcaEstimator.Defaults.WeightColumn, + int rank = PcaEstimator.Defaults.Rank, + int overSampling = PcaEstimator.Defaults.Oversampling, + bool center = PcaEstimator.Defaults.Center, + int? seed = null) + { + Input = input; + Output = output; + WeightColumn = weightColumn; + Rank = rank; + Oversampling = overSampling; + Center = center; + Seed = seed; + Contracts.CheckParam(Oversampling >= 0, nameof(Oversampling), "Oversampling must be non-negative."); + Contracts.CheckParam(Rank > 0, nameof(Rank), "Rank must be positive."); + } + } + private sealed class TransformInfo { public readonly int Dimension; public readonly int Rank; - public Float[][] Eigenvectors; - public Float[] MeanProjected; + public float[][] Eigenvectors; + public float[] MeanProjected; - public TransformInfo(Column item, Arguments args, int d) + public ColumnType OutputType => new VectorType(NumberType.Float, Rank); + + public TransformInfo(int rank, int dim) { - Dimension = d; - Rank = item.Rank ?? args.Rank; - Contracts.CheckUserArg(0 < Rank && Rank <= Dimension, nameof(item.Rank), "Rank must be positive, and at most the dimension of untransformed data"); + Dimension = dim; + Rank = rank; + Contracts.CheckParam(0 < Rank && Rank <= Dimension, nameof(Rank), "Rank must be positive, and at most the dimension of untransformed data"); } - public TransformInfo(ModelLoadContext ctx, int colValueCount) + public TransformInfo(ModelLoadContext ctx) { Contracts.AssertValue(ctx); @@ -121,17 +163,15 @@ public TransformInfo(ModelLoadContext ctx, int colValueCount) // int: Dimension // int: Rank // for i=0,..,Rank-1: - // Float[]: the i'th eigenvector + // float[]: the i'th eigenvector // int: the size of MeanProjected (0 if it is null) - // Float[]: MeanProjected + // float[]: MeanProjected Dimension = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(Dimension == colValueCount); - Rank = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < Rank && Rank <= Dimension); - Eigenvectors = new Float[Rank][]; + Eigenvectors = new float[Rank][]; for (int i = 0; i < Rank; i++) { Eigenvectors[i] = ctx.Reader.ReadFloatArray(Dimension); @@ -150,9 +190,9 @@ public void Save(ModelSaveContext ctx) // int: Dimension // int: Rank // for i=0,..,Rank-1: - // Float[]: the i'th eigenvector + // float[]: the i'th eigenvector // int: the size of MeanProjected (0 if it is null) - // Float[]: MeanProjected + // float[]: MeanProjected Contracts.Assert(0 < Rank && Rank <= Dimension); ctx.Writer.Write(Dimension); @@ -166,7 +206,7 @@ public void Save(ModelSaveContext ctx) ctx.Writer.WriteFloatArray(MeanProjected); } - internal void ProjectMean(Float[] mean) + public void ProjectMean(float[] mean) { Contracts.AssertValue(Eigenvectors); if (mean == null) @@ -175,7 +215,7 @@ internal void ProjectMean(Float[] mean) return; } - MeanProjected = new Float[Rank]; + MeanProjected = new float[Rank]; for (var i = 0; i < Rank; ++i) MeanProjected[i] = VectorUtils.DotProduct(Eigenvectors[i], mean); } @@ -190,62 +230,41 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "PCA FUNC", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Got rid of writing float size in model context + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(PcaTransform).Assembly.FullName); } - // These are parallel to Infos. - private readonly ColumnType[] _types; + private readonly int _numColumns; + private readonly Mapper.ColumnSchemaInfo[] _schemaInfos; private readonly TransformInfo[] _transformInfos; - private readonly int[] _oversampling; - private readonly bool[] _center; - private readonly int[] _weightColumnIndex; - private const string RegistrationName = "Pca"; - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public PcaTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, - input, TestIsFloatItem) + internal PcaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(PcaTransform)), GetColumnPairs(columns)) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); + Host.AssertNonEmpty(ColumnPairs); + _numColumns = columns.Length; + _transformInfos = new TransformInfo[_numColumns]; + _schemaInfos = new Mapper.ColumnSchemaInfo[_numColumns]; - _transformInfos = new TransformInfo[args.Column.Length]; - _oversampling = new int[args.Column.Length]; - _center = new bool[args.Column.Length]; - _weightColumnIndex = new int[args.Column.Length]; - for (int i = 0; i < _transformInfos.Length; i++) + for (int i = 0; i < _numColumns; i++) { - Host.Check(Infos[i].TypeSrc.VectorSize > 1, "Pca transform can only be applied to columns with known dimensionality greater than 1"); - _transformInfos[i] = new TransformInfo(args.Column[i], args, Infos[i].TypeSrc.ValueCount); - _center[i] = args.Column[i].Center ?? args.Center; - _oversampling[i] = args.Column[i].Oversampling ?? args.Oversampling; - Host.CheckUserArg(_oversampling[i] >= 0, nameof(args.Oversampling), "Oversampling must be non-negative"); - _weightColumnIndex[i] = -1; - var weightColumn = args.Column[i].WeightColumn ?? args.WeightColumn; - if (weightColumn != null) - { - if (!Source.Schema.TryGetColumnIndex(weightColumn, out _weightColumnIndex[i])) - throw Host.Except("weight column '{0}' does not exist", weightColumn); - var type = Source.Schema.GetColumnType(_weightColumnIndex[i]); - Host.CheckUserArg(type == NumberType.Float, nameof(args.WeightColumn)); - } + var colInfo = columns[i]; + var sInfo = _schemaInfos[i] = new Mapper.ColumnSchemaInfo(ColumnPairs[i], input.Schema, colInfo.WeightColumn); + ValidatePcaInput(Host, colInfo.Input, sInfo.InputType); + _transformInfos[i] = new TransformInfo(colInfo.Rank, sInfo.InputType.ValueCount); } - Train(args, _transformInfos, input); - - _types = InitColumnTypes(); + Train(columns, _transformInfos, input); } - private PcaTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsFloatItem) + private PcaTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { Host.AssertValue(ctx); @@ -253,27 +272,53 @@ private PcaTransform(IHost host, ModelLoadContext ctx, IDataView input) // // // transformInfos - Host.AssertNonEmpty(Infos); - _transformInfos = new TransformInfo[Infos.Length]; - for (int i = 0; i < Infos.Length; i++) - _transformInfos[i] = new TransformInfo(ctx, Infos[i].TypeSrc.ValueCount); - _types = InitColumnTypes(); + Host.AssertNonEmpty(ColumnPairs); + _numColumns = ColumnPairs.Length; + _transformInfos = new TransformInfo[_numColumns]; + for (int i = 0; i < _numColumns; i++) + _transformInfos[i] = new TransformInfo(ctx); } - public static PcaTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + env.CheckValue(args.Column, nameof(args.Column)); + var cols = args.Column.Select(item => new ColumnInfo( + item.Source, + item.Name, + item.WeightColumn, + item.Rank ?? args.Rank, + item.Oversampling ?? args.Oversampling, + item.Center ?? args.Center, + item.Seed ?? args.Seed)).ToArray(); + return new PcaTransform(env, input, cols).MakeDataTransform(input); + } - // *** Binary format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - h.CheckDecode(cbFloat == sizeof(Float)); - return h.Apply("Loading Model", ch => new PcaTransform(h, ctx, input)); + // Factory method for SignatureLoadModel. + private static PcaTransform Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(nameof(PcaTransform)); + + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + if (ctx.Header.ModelVerWritten == 0x00010001) + { + int cbFloat = ctx.Reader.ReadInt32(); + env.CheckDecode(cbFloat == sizeof(float)); + } + return new PcaTransform(host, ctx); } public override void Save(ModelSaveContext ctx) @@ -283,54 +328,56 @@ public override void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); // *** Binary format *** - // int: sizeof(Float) // // transformInfos - ctx.Writer.Write(sizeof(Float)); - SaveBase(ctx); + SaveColumns(ctx); for (int i = 0; i < _transformInfos.Length; i++) _transformInfos[i].Save(ctx); } - - private void Train(Arguments args, TransformInfo[] transformInfos, IDataView trainingData) + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { - var y = new Float[transformInfos.Length][][]; - var omega = new Float[transformInfos.Length][][]; - var mean = new Float[transformInfos.Length][]; + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); + } - var oversampledRank = new int[transformInfos.Length]; + private void Train(ColumnInfo[] columns, TransformInfo[] transformInfos, IDataView trainingData) + { + var y = new float[_numColumns][][]; + var omega = new float[_numColumns][][]; + var mean = new float[_numColumns][]; + var oversampledRank = new int[_numColumns]; var rnd = Host.Rand; Double totalMemoryUsageEstimate = 0; - for (int iinfo = 0; iinfo < transformInfos.Length; iinfo++) + for (int iinfo = 0; iinfo < _numColumns; iinfo++) { - oversampledRank[iinfo] = Math.Min(transformInfos[iinfo].Rank + _oversampling[iinfo], transformInfos[iinfo].Dimension); + oversampledRank[iinfo] = Math.Min(transformInfos[iinfo].Rank + columns[iinfo].Oversampling, transformInfos[iinfo].Dimension); //exact: (size of the 2 big matrices + other minor allocations) / (2^30) - Double colMemoryUsageEstimate = 2.0 * transformInfos[iinfo].Dimension * oversampledRank[iinfo] * sizeof(Float) / 1e9; + Double colMemoryUsageEstimate = 2.0 * transformInfos[iinfo].Dimension * oversampledRank[iinfo] * sizeof(float) / 1e9; totalMemoryUsageEstimate += colMemoryUsageEstimate; if (colMemoryUsageEstimate > 2) { using (var ch = Host.Start("Memory usage")) { ch.Info("Estimate memory usage for transforming column {1}: {0:G2} GB. If running out of memory, reduce rank and oversampling factor.", - colMemoryUsageEstimate, Infos[iinfo].Name); + colMemoryUsageEstimate, ColumnPairs[iinfo].input); } } - y[iinfo] = new Float[oversampledRank[iinfo]][]; - omega[iinfo] = new Float[oversampledRank[iinfo]][]; + y[iinfo] = new float[oversampledRank[iinfo]][]; + omega[iinfo] = new float[oversampledRank[iinfo]][]; for (int i = 0; i < oversampledRank[iinfo]; i++) { - y[iinfo][i] = new Float[transformInfos[iinfo].Dimension]; - omega[iinfo][i] = new Float[transformInfos[iinfo].Dimension]; + y[iinfo][i] = new float[transformInfos[iinfo].Dimension]; + omega[iinfo][i] = new float[transformInfos[iinfo].Dimension]; for (int j = 0; j < transformInfos[iinfo].Dimension; j++) { - omega[iinfo][i][j] = (Float)Stats.SampleFromGaussian(rnd); + omega[iinfo][i][j] = (float)Stats.SampleFromGaussian(rnd); } } - if (_center[iinfo]) - mean[iinfo] = new Float[transformInfos[iinfo].Dimension]; + if (columns[iinfo].Center) + mean[iinfo] = new float[transformInfos[iinfo].Dimension]; } if (totalMemoryUsageEstimate > 2) { @@ -365,15 +412,15 @@ private void Train(Arguments args, TransformInfo[] transformInfos, IDataView tra for (int iinfo = 0; iinfo < transformInfos.Length; iinfo++) { //Compute B2 = B' * B - var b2 = new Float[oversampledRank[iinfo] * oversampledRank[iinfo]]; + var b2 = new float[oversampledRank[iinfo] * oversampledRank[iinfo]]; for (var i = 0; i < oversampledRank[iinfo]; ++i) { for (var j = i; j < oversampledRank[iinfo]; ++j) b2[i * oversampledRank[iinfo] + j] = b2[j * oversampledRank[iinfo] + i] = VectorUtils.DotProduct(b[iinfo][i], b[iinfo][j]); } - Float[] smallEigenvalues; // eigenvectors and eigenvalues of the small matrix B2. - Float[] smallEigenvectors; + float[] smallEigenvalues; // eigenvectors and eigenvalues of the small matrix B2. + float[] smallEigenvectors; EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors); transformInfos[iinfo].Eigenvectors = PostProcess(b[iinfo], smallEigenvalues, smallEigenvectors, transformInfos[iinfo].Dimension, oversampledRank[iinfo]); @@ -384,9 +431,9 @@ private void Train(Arguments args, TransformInfo[] transformInfos, IDataView tra //Project the covariance matrix A on to Omega: Y <- A * Omega //A = X' * X / n, where X = data - mean //Note that the covariance matrix is not computed explicitly - private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, Float[][][] y, TransformInfo[] transformInfos) + private void Project(IDataView trainingData, float[][] mean, float[][][] omega, float[][][] y, TransformInfo[] transformInfos) { - Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == Infos.Length); + Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == _numColumns); for (int i = 0; i < omega.Length; i++) Contracts.Assert(omega[i].Length == y[i].Length); @@ -399,37 +446,35 @@ private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, bool[] center = Enumerable.Range(0, mean.Length).Select(i => mean[i] != null).ToArray(); - Double[] totalColWeight = new Double[Infos.Length]; + Double[] totalColWeight = new Double[_numColumns]; - bool[] activeColumns = new bool[Source.Schema.ColumnCount]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; + foreach (var sInfo in _schemaInfos) { - activeColumns[Infos[iinfo].Source] = true; - if (_weightColumnIndex[iinfo] >= 0) - activeColumns[_weightColumnIndex[iinfo]] = true; + activeColumns[sInfo.InputIndex] = true; + if (sInfo.WeightColumnIndex >= 0) + activeColumns[sInfo.WeightColumnIndex] = true; } + using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { - var weightGetters = new ValueGetter[Infos.Length]; - var columnGetters = new ValueGetter>[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + var weightGetters = new ValueGetter[_numColumns]; + var columnGetters = new ValueGetter>[_numColumns]; + for (int iinfo = 0; iinfo < _numColumns; iinfo++) { - if (_weightColumnIndex[iinfo] >= 0) - weightGetters[iinfo] = cursor.GetGetter(_weightColumnIndex[iinfo]); - columnGetters[iinfo] = cursor.GetGetter>(Infos[iinfo].Source); + var sInfo = _schemaInfos[iinfo]; + if (sInfo.WeightColumnIndex >= 0) + weightGetters[iinfo] = cursor.GetGetter(sInfo.WeightColumnIndex); + columnGetters[iinfo] = cursor.GetGetter>(sInfo.InputIndex); } - var features = default(VBuffer); + var features = default(VBuffer); while (cursor.MoveNext()) { - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + for (int iinfo = 0; iinfo < _numColumns; iinfo++) { - Contracts.Check(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsNumber, - "PCA transform can only be performed on numeric columns of dimension > 1"); - - Float weight = 1; - if (weightGetters[iinfo] != null) - weightGetters[iinfo](ref weight); + float weight = 1; + weightGetters[iinfo]?.Invoke(ref weight); columnGetters[iinfo](ref features); if (FloatUtils.IsFinite(weight) && weight >= 0 && (features.Count == 0 || FloatUtils.IsFinite(features.Values, features.Count))) @@ -445,15 +490,15 @@ private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, } } - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + for (int iinfo = 0; iinfo < _numColumns; iinfo++) { if (totalColWeight[iinfo] <= 0) - throw Host.Except("Empty data in column '{0}'", Source.Schema.GetColumnName(Infos[iinfo].Source)); + throw Host.Except("Empty data in column '{0}'", ColumnPairs[iinfo].input); } - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + for (int iinfo = 0; iinfo < _numColumns; iinfo++) { - var invn = (Float)(1 / totalColWeight[iinfo]); + var invn = (float)(1 / totalColWeight[iinfo]); for (var i = 0; i < omega[iinfo].Length; ++i) VectorUtils.ScaleBy(y[iinfo][i], invn); @@ -470,13 +515,13 @@ private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, //return Y * eigenvectors / eigenvalues // REVIEW: improve - private Float[][] PostProcess(Float[][] y, Float[] sigma, Float[] z, int d, int k) + private float[][] PostProcess(float[][] y, float[] sigma, float[] z, int d, int k) { - var pinv = new Float[k]; - var tmp = new Float[k]; + var pinv = new float[k]; + var tmp = new float[k]; for (int i = 0; i < k; i++) - pinv[i] = (Float)(1.0) / ((Float)(1e-6) + sigma[i]); + pinv[i] = (float)(1.0) / ((float)(1e-6) + sigma[i]); for (int i = 0; i < d; i++) { @@ -493,56 +538,109 @@ private Float[][] PostProcess(Float[][] y, Float[] sigma, Float[] z, int d, int return y; } - private ColumnType[] InitColumnTypes() + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, Schema.Create(schema)); + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Host.Assert(Infos.Length == _transformInfos.Length); - var types = new ColumnType[Infos.Length]; - for (int i = 0; i < _transformInfos.Length; i++) - types[i] = new VectorType(NumberType.Float, _transformInfos[i].Rank); - Metadata.Seal(); - return types; + ValidatePcaInput(Host, inputSchema.GetColumnName(srcCol), inputSchema.GetColumnType(srcCol)); } - protected override ColumnType GetColumnTypeCore(int iinfo) + internal static void ValidatePcaInput(IExceptionContext ectx, string name, ColumnType type) { - Host.Check(0 <= iinfo & iinfo < Utils.Size(_types)); - return _types[iinfo]; + string inputSchema; // just used for the excpections + + if (!(type.IsKnownSizeVector && type.VectorSize > 1 && type.ItemType.Equals(NumberType.R4))) + throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "vector of floats with fixed size greater than 1", type.ToString()); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private sealed class Mapper : MapperBase { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - var getSrc = GetSrcGetter>(input, iinfo); - var src = default(VBuffer); - var trInfo = _transformInfos[iinfo]; - ValueGetter> del = - (ref VBuffer dst) => + public sealed class ColumnSchemaInfo + { + public ColumnType InputType { get; } + public int InputIndex { get; } + public int WeightColumnIndex { get; } + + public ColumnSchemaInfo((string input, string output) columnPair, Schema schema, string weightColumn = null) { - getSrc(ref src); - TransformFeatures(Host, ref src, ref dst, trInfo); - }; - return del; - } + schema.TryGetColumnIndex(columnPair.input, out int inputIndex); + InputIndex = inputIndex; + InputType = schema[columnPair.input].Type; - private static void TransformFeatures(IExceptionContext ectx, ref VBuffer src, ref VBuffer dst, TransformInfo transformInfo) - { - ectx.Check(src.Length == transformInfo.Dimension); + var weightIndex = -1; + if (weightColumn != null) + { + if (!schema.TryGetColumnIndex(weightColumn, out weightIndex)) + throw Contracts.Except("Weight column '{0}' does not exist.", weightColumn); + Contracts.CheckParam(schema[weightIndex].Type == NumberType.Float, nameof(weightColumn)); + } + WeightColumnIndex = weightIndex; + } + } - var values = dst.Values; - if (Utils.Size(values) < transformInfo.Rank) - values = new Float[transformInfo.Rank]; + private readonly PcaTransform _parent; + private readonly int _numColumns; - for (int i = 0; i < transformInfo.Rank; i++) + public Mapper(PcaTransform parent, Schema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { - values[i] = VectorUtils.DotProductWithOffset(transformInfo.Eigenvectors[i], 0, ref src) - - (transformInfo.MeanProjected == null ? 0 : transformInfo.MeanProjected[i]); + _parent = parent; + _numColumns = parent._numColumns; + for (int i = 0; i < _numColumns; i++) + { + var colPair = _parent.ColumnPairs[i]; + var colSchemaInfo = new ColumnSchemaInfo(colPair, inputSchema); + ValidatePcaInput(Host, colPair.input, colSchemaInfo.InputType); + if (colSchemaInfo.InputType.VectorSize != _parent._transformInfos[i].Dimension) + { + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input, + new VectorType(NumberType.R4, _parent._transformInfos[i].Dimension).ToString(), colSchemaInfo.InputType.ToString()); + } + } } - dst = new VBuffer(transformInfo.Rank, values, dst.Indices); + public override Schema.Column[] GetOutputColumns() + { + var result = new Schema.Column[_numColumns]; + for (int i = 0; i < _numColumns; i++) + result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._transformInfos[i].OutputType, null); + return result; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _numColumns); + disposer = null; + + var srcGetter = input.GetGetter>(ColMapNewToOld[iinfo]); + var src = default(VBuffer); + + ValueGetter> dstGetter = (ref VBuffer dst) => + { + srcGetter(ref src); + TransformFeatures(Host, ref src, ref dst, _parent._transformInfos[iinfo]); + }; + + return dstGetter; + } + + private static void TransformFeatures(IExceptionContext ectx, ref VBuffer src, ref VBuffer dst, TransformInfo transformInfo) + { + ectx.Check(src.Length == transformInfo.Dimension); + + var values = dst.Values; + if (Utils.Size(values) < transformInfo.Rank) + values = new float[transformInfo.Rank]; + + for (int i = 0; i < transformInfo.Rank; i++) + { + values[i] = VectorUtils.DotProductWithOffset(transformInfo.Eigenvectors[i], 0, ref src) - + (transformInfo.MeanProjected == null ? 0 : transformInfo.MeanProjected[i]); + } + + dst = new VBuffer(transformInfo.Rank, values, dst.Indices); + } } [TlcModule.EntryPoint(Name = "Transforms.PcaCalculator", @@ -554,7 +652,7 @@ private static void TransformFeatures(IExceptionContext ectx, ref VBuffer public static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "Pca", input); - var view = new PcaTransform(h, input, input.Data); + var view = PcaTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), @@ -562,4 +660,125 @@ public static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Argu }; } } + + public sealed class PcaEstimator : IEstimator + { + internal static class Defaults + { + public const string WeightColumn = null; + public const int Rank = 20; + public const int Oversampling = 20; + public const bool Center = true; + public const int Seed = 0; + } + + private readonly IHost _host; + private readonly PcaTransform.ColumnInfo[] _columns; + + /// Convinence constructor for simple one column case. + /// + /// The environment. + /// Input column to apply PCA on. + /// Output column. Null means is replaced. + /// The name of the weight column. + /// The number of components in the PCA. + /// Oversampling parameter for randomized PCA training. + /// If enabled, data is centered to be zero mean. + /// The seed for random number generation. + public PcaEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + string weightColumn = Defaults.WeightColumn, int rank = Defaults.Rank, + int overSampling = Defaults.Oversampling, bool center = Defaults.Center, + int? seed = null) + : this(env, new PcaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, weightColumn, rank, overSampling, center, seed)) + { + } + + public PcaEstimator(IHostEnvironment env, params PcaTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(PcaEstimator)); + _columns = columns; + } + + public PcaTransform Fit(IDataView input) => new PcaTransform(_host, input, _columns); + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + + if (col.Kind != SchemaShape.Column.VectorKind.Vector || !col.ItemType.Equals(NumberType.R4)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, + SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + + return new SchemaShape(result.Values); + } + } + + public static class PcaEstimatorExtensions + { + private sealed class OutPipelineColumn : Vector + { + public readonly Vector Input; + + public OutPipelineColumn(Vector input, string weightColumn, int rank, + int overSampling, bool center, int? seed = null) + : base(new Reconciler(weightColumn, rank, overSampling, center, seed), input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + private readonly PcaTransform.ColumnInfo _colInfo; + + public Reconciler(string weightColumn, int rank, int overSampling, bool center, int? seed = null) + { + _colInfo = new PcaTransform.ColumnInfo( + null, null, weightColumn, rank, overSampling, center, seed); + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); + var outCol = (OutPipelineColumn)toOutput[0]; + var inputColName = inputNames[outCol.Input]; + var outputColName = outputNames[outCol]; + return new PcaEstimator(env, inputColName, outputColName, + _colInfo.WeightColumn, _colInfo.Rank, _colInfo.Oversampling, + _colInfo.Center, _colInfo.Seed); + } + } + + /// + /// Replaces the input vector with its projection to the principal component subspace, + /// which can significantly reduce size of vector. + /// + /// + /// The column to apply PCA to. + /// The name of the weight column. + /// The number of components in the PCA. + /// Oversampling parameter for randomized PCA training. + /// If enabled, data is centered to be zero mean. + /// The seed for random number generation + /// Vector containing the principal components. + public static Vector ToPrincipalComponents(this Vector input, + string weightColumn = PcaEstimator.Defaults.WeightColumn, + int rank = PcaEstimator.Defaults.Rank, + int overSampling = PcaEstimator.Defaults.Oversampling, + bool center = PcaEstimator.Defaults.Center, + int? seed = null) => new OutPipelineColumn(input, weightColumn, rank, overSampling, center, seed); + } } diff --git a/src/Microsoft.ML.PCA/WrappedPcaTransform.cs b/src/Microsoft.ML.PCA/WrappedPcaTransform.cs deleted file mode 100644 index 1f082e193e..0000000000 --- a/src/Microsoft.ML.PCA/WrappedPcaTransform.cs +++ /dev/null @@ -1,116 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Core.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.StaticPipe; -using Microsoft.ML.StaticPipe.Runtime; -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Microsoft.ML.Runtime.Data -{ - /// - public sealed class PcaEstimator : TrainedWrapperEstimatorBase - { - private readonly PcaTransform.Arguments _args; - - /// - /// The environment. - /// Input column to apply PCA on. - /// Output column. Null means is replaced. - /// The number of components in the PCA. - /// A delegate to apply all the advanced arguments to the algorithm. - public PcaEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, - int rank = PcaTransform.Defaults.Rank, - Action advancedSettings = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, rank, advancedSettings) - { - } - - /// - /// The environment. - /// Pairs of columns to run the PCA on. - /// The number of components in the PCA. - /// A delegate to apply all the advanced arguments to the algorithm. - public PcaEstimator(IHostEnvironment env, (string input, string output)[] columns, - int rank = PcaTransform.Defaults.Rank, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(PcaEstimator))) - { - foreach (var (input, output) in columns) - { - Host.CheckUserArg(Utils.Size(input) > 0, nameof(input)); - Host.CheckValue(output, nameof(input)); - } - - _args = new PcaTransform.Arguments(); - _args.Column = columns.Select(x => new PcaTransform.Column { Source = x.input, Name = x.output }).ToArray(); - _args.Rank = rank; - - advancedSettings?.Invoke(_args); - } - - public override TransformWrapper Fit(IDataView input) - { - return new TransformWrapper(Host, new PcaTransform(Host, _args, input)); - } - } - - /// - /// Extensions for statically typed . - /// - public static class PcaEstimatorExtensions - { - private sealed class OutPipelineColumn : Vector - { - public readonly Vector Input; - - public OutPipelineColumn(Vector input, int rank, Action advancedSettings) - : base(new Reconciler(null, rank, advancedSettings), input) - { - Input = input; - } - } - - private sealed class Reconciler : EstimatorReconciler - { - private readonly int _rank; - private readonly Action _advancedSettings; - - public Reconciler(PipelineColumn weightColumn, int rank, Action advancedSettings) - { - _rank = rank; - _advancedSettings = advancedSettings; - } - - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - Contracts.Assert(toOutput.Length == 1); - - var pairs = new List<(string input, string output)>(); - foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); - - return new PcaEstimator(env, pairs.ToArray(), _rank, _advancedSettings); - } - } - - /// Replace current vector with its principal components. Can significantly reduce size of vector. - /// - /// The column to apply PCA to. - /// The number of components in the PCA. - /// A delegate to apply all the advanced arguments to the algorithm. - public static Vector ToPrincipalComponents(this Vector input, - int rank = PcaTransform.Defaults.Rank, - Action advancedSettings = null) => new OutPipelineColumn(input, rank, advancedSettings); - } -} diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 7095eeaaed..bb0ca58a38 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -119,7 +119,7 @@ Transforms.ModelCombiner Combines a sequence of TransformModels into a single mo Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Runtime.Transforms.TextAnalytics NGramTransform Microsoft.ML.Runtime.Data.NgramTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.NoOperation Does nothing. Microsoft.ML.Runtime.Data.NopTransform Nop Microsoft.ML.Runtime.Data.NopTransform+NopInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Runtime.DataPipe.OptionalColumnTransform MakeOptional Microsoft.ML.Runtime.DataPipe.OptionalColumnTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput -Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Runtime.Data.PcaTransform Calculate Microsoft.ML.Runtime.Data.PcaTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput +Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.PcaTransform Calculate Microsoft.ML.Transforms.PcaTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Runtime.Data.RandomNumberGenerator Generate Microsoft.ML.Runtime.Data.GenerateNumberTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.Runtime.EntryPoints.SelectRows FilterByRange Microsoft.ML.Runtime.Data.RangeFilter+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 6d4cae0995..7234173e9f 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -856,5 +856,25 @@ public void TextNormalizeStatic() type = schema.GetColumnType(numbers); Assert.True(!type.IsVector && type.ItemType.IsText); } + + [Fact] + public void TestPcaStatic() + { + var env = new ConsoleEnvironment(seed: 1); + var dataSource = GetDataPath("generated_regression_dataset.csv"); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + separator: ';', hasHeader: true); + var data = reader.Read(dataSource); + var est = reader.MakeNewEstimator() + .Append(r => (r.label, pca: r.features.ToPrincipalComponents(rank: 5))); + var tdata = est.Fit(data).Transform(data); + var schema = tdata.AsDynamic.Schema; + + Assert.True(schema.TryGetColumnIndex("pca", out int pca)); + var type = schema[pca].Type; + Assert.True(type.IsVector && type.ItemType.RawKind == DataKind.R4); + Assert.True(type.VectorSize == 5); + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs index 5561b4df7d..8f1089e0dc 100644 --- a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.IO; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Transforms; -using System.IO; using Xunit; using Xunit.Abstractions; @@ -14,43 +14,57 @@ namespace Microsoft.ML.Tests.Transformers { public sealed class PcaTests : TestDataPipeBase { + private readonly ConsoleEnvironment _env; + private readonly string _dataSource; + private readonly TextSaver _saver; + public PcaTests(ITestOutputHelper helper) : base(helper) { + _env = new ConsoleEnvironment(seed: 1); + _dataSource = GetDataPath("generated_regression_dataset.csv"); + _saver = new TextSaver(_env, new TextSaver.Arguments { Silent = true, OutputHeader = false }); } [Fact] public void PcaWorkout() { - var env = new ConsoleEnvironment(seed: 1, conc: 1); - string dataSource = GetDataPath("generated_regression_dataset.csv"); - var data = TextLoader.CreateReader(env, - c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + var data = TextLoader.CreateReader(_env, + c => (label: c.LoadFloat(11), weight: c.LoadFloat(0), features: c.LoadFloat(1, 10)), separator: ';', hasHeader: true) - .Read(dataSource); + .Read(_dataSource); - var invalidData = TextLoader.CreateReader(env, - c => (label: c.LoadFloat(11), features: c.LoadText(0, 10)), + var invalidData = TextLoader.CreateReader(_env, + c => (label: c.LoadFloat(11), weight: c.LoadFloat(0), features: c.LoadText(1, 10)), separator: ';', hasHeader: true) - .Read(dataSource); + .Read(_dataSource); + + var est = new PcaEstimator(_env, "features", "pca", rank: 4, seed: 10); + TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); + + var estNonDefaultArgs = new PcaEstimator(_env, "features", "pca", rank: 3, weightColumn: "weight", overSampling: 2, center: false); + TestEstimatorCore(estNonDefaultArgs, data.AsDynamic, invalidInput: invalidData.AsDynamic); - var est = new PcaEstimator(env, "features", "pca", rank: 5, advancedSettings: s => { - s.Seed = 1; - }); + Done(); + } - // The following call fails because of the following issue - // https://github.com/dotnet/machinelearning/issues/969 - // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); + [Fact] + public void TestPcaEstimator() + { + var data = TextLoader.CreateReader(_env, + c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + separator: ';', hasHeader: true) + .Read(_dataSource); + var est = new PcaEstimator(_env, "features", "pca", rank: 5, seed: 1); var outputPath = GetOutputPath("PCA", "pca.tsv"); - using (var ch = env.Start("save")) + using (var ch = _env.Start("save")) { - var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false }); - IDataView savedData = TakeFilter.Create(env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(env, savedData, "pca"); + IDataView savedData = TakeFilter.Create(_env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); + savedData = new ChooseColumnsTransform(_env, savedData, "pca"); using (var fs = File.Create(outputPath)) - DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + DataSaverUtils.SaveDataView(ch, _saver, savedData, fs, keepHidden: true); } CheckEquality("PCA", "pca.tsv");