Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Field-aware factorization machine to estimator #912

Merged
merged 13 commits into from
Sep 20, 2018
13 changes: 7 additions & 6 deletions src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.Runtime
{
public interface IPredictionTransformer<out TModel> : ITransformer
Copy link
Contributor

@Zruty0 Zruty0 Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPredictionTransformer [](start = 21, length = 22)

I don't think we even need an interface for FFM, since for the time being it's the only trainer that accepts multiple feature columns. #Closed

Copy link
Member Author

@sfilipi sfilipi Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to inquire about all trainers, it is useful to have them extend one interface. #Closed

where TModel : IPredictor
{
TModel Model { get; }
}

public interface IClassicPredictionTransformer<out TModel> : IPredictionTransformer<TModel>
Copy link
Contributor

@Zruty0 Zruty0 Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classic [](start = 22, length = 7)

'Classic' sounds a bit wacky, even though it was my suggestion. Maybe 'SingleInput' ?.. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classic is weird.


In reply to: 218633462 [](ancestors = 218633462)

Copy link
Contributor

@TomFinley TomFinley Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IClassicPredictionTransformer [](start = 21, length = 29)

I sort of feel like an interface called IClassicPredictionTransformer needs some XML comment on it. All public interfaces should, but especially one with the word "classic" in the name.

I am choosing to interpret this interface as it gives me classic coke whenever I use it. #Resolved

where TModel : IPredictor
{
string FeatureColumn { get; }

ColumnType FeatureColumnType { get; }

TModel Model { get; }
}
}
119 changes: 77 additions & 42 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.IO;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
Expand All @@ -20,39 +21,33 @@

namespace Microsoft.ML.Runtime.Data
{
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, ICanSaveModel

public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>
Copy link
Contributor

@TomFinley TomFinley Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No documentation? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is there? Are you looking at iteration 10?


In reply to: 218905403 [](ancestors = 218905403)

where TModel : class, IPredictor
{
private const string DirModel = "Model";
private const string DirTransSchema = "TrainSchema";
/// <summary>
/// The model.
/// </summary>
public TModel Model { get; }

protected const string DirModel = "Model";
protected const string DirTransSchema = "TrainSchema";
protected readonly IHost Host;
protected readonly ISchemaBindableMapper BindableMapper;
protected readonly ISchema TrainSchema;
protected ISchemaBindableMapper BindableMapper;
protected ISchema TrainSchema;

public string FeatureColumn { get; }

public ColumnType FeatureColumnType { get; }

public TModel Model { get; }

public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
Host.CheckValue(trainSchema, nameof(trainSchema));

Model = model;
FeatureColumn = featureColumn;
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
FeatureColumnType = trainSchema.GetColumnType(col);

TrainSchema = trainSchema;
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}

internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
protected PredictionTransformerBase(IHost host, ModelLoadContext ctx)

{
Host = host;

Expand All @@ -74,16 +69,72 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
ms.Position = 0;
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
TrainSchema = loader.Schema;
}

public abstract ISchema GetOutputSchema(ISchema inputSchema);

/// <summary>
/// Transforms the input data.
/// </summary>
/// <param name="input"></param>
/// <returns></returns>
public abstract IDataView Transform(IDataView input);

protected void SaveModel(ModelSaveContext ctx)
{
// *** Binary format ***
Copy link
Contributor

@Zruty0 Zruty0 Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*** Binary format *** [](start = 15, length = 21)

whenever you save or load, need *** Binary format *** #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?


In reply to: 218633828 [](ancestors = 218633828)

// model: prediction model.
Copy link
Contributor

@TomFinley TomFinley Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model: prediction model. [](start = 15, length = 24)

Technically the model isn't part of this format, since you're not writing it to the stream, you're writing it somewhere else entirely, but that's OK. Consider fixing if you have to change the code anyway. #Resolved

Copy link
Member Author

@sfilipi sfilipi Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tfinley@gmail.com fixing it == remove the comment?


In reply to: 218818528 [](ancestors = 218818528)

// stream: empty data view that contains train schema.
// id of string: feature column.

ctx.SaveModel(Model, DirModel);
ctx.SaveBinaryStream(DirTransSchema, writer =>
{
using (var ch = Host.Start("Saving train schema"))
{
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
}
});
}
}

public abstract class ClassicPredictionTransformerBase<TModel> : PredictionTransformerBase<TModel>, IClassicPredictionTransformer<TModel>, ICanSaveModel
Copy link
Contributor

@TomFinley TomFinley Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess by "classic" this just means a prediction transformer base that takes a single features column as its input. Classic is a bit of a funny word, but then again SinlgeFeaturesPredictionTransformerBase might be a bit of a mouthful and itself potentially confusing. #Resolved

where TModel : class, IPredictor
{
/// <summary>
/// The name of the feature column used by the prediction transformer.
/// </summary>
public string FeatureColumn { get; }

/// <summary>
/// The type of the prediction transformer
/// </summary>
public ColumnType FeatureColumnType { get; }

public ClassicPredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
:base(host, model, trainSchema)
{
FeatureColumn = featureColumn;
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
FeatureColumnType = trainSchema.GetColumnType(col);

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}

internal ClassicPredictionTransformerBase(IHost host, ModelLoadContext ctx)
:base(host, ctx)
{
FeatureColumn = ctx.LoadString();
if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
FeatureColumnType = TrainSchema.GetColumnType(col);

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
}

public ISchema GetOutputSchema(ISchema inputSchema)
public override ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

Expand All @@ -95,8 +146,6 @@ public ISchema GetOutputSchema(ISchema inputSchema)
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}

public abstract IDataView Transform(IDataView input);

public void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Expand All @@ -106,26 +155,12 @@ public void Save(ModelSaveContext ctx)

protected virtual void SaveCore(ModelSaveContext ctx)
{
// *** Binary format ***
// model: prediction model.
// stream: empty data view that contains train schema.
// id of string: feature column.

ctx.SaveModel(Model, DirModel);
ctx.SaveBinaryStream(DirTransSchema, writer =>
{
using (var ch = Host.Start("Saving train schema"))
{
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
}
});

SaveModel(ctx);
ctx.SaveString(FeatureColumn);
}
}

public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
public sealed class BinaryPredictionTransformer<TModel> : ClassicPredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<float>
{
private readonly BinaryClassifierScorer _scorer;
Expand Down Expand Up @@ -194,7 +229,7 @@ private static VersionInfo GetVersionInfo()
}
}

public sealed class MulticlassPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
public sealed class MulticlassPredictionTransformer<TModel> : ClassicPredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<VBuffer<float>>
{
private readonly MultiClassClassifierScorer _scorer;
Expand Down Expand Up @@ -255,7 +290,7 @@ private static VersionInfo GetVersionInfo()
}
}

public sealed class RegressionPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
public sealed class RegressionPredictionTransformer<TModel> : ClassicPredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<float>
{
private readonly GenericScorer _scorer;
Expand Down Expand Up @@ -324,4 +359,4 @@ internal static class RegressionPredictionTransformer
public static RegressionPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
=> new RegressionPredictionTransformer<IPredictorProducing<float>>(env, ctx);
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Training/ITrainerEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace Microsoft.ML.Runtime.Training
{
public interface ITrainerEstimator<out TTransformer, out TPredictor>: IEstimator<TTransformer>
where TTransformer: IPredictionTransformer<TPredictor>
where TTransformer: IClassicPredictionTransformer<TPredictor>
Copy link
Contributor

@Zruty0 Zruty0 Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IClassicPredictionTransformer [](start = 28, length = 29)

I believe this change is incorrect, isn't it? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, maybe i named in reverse, but this is the old interface.


In reply to: 218633926 [](ancestors = 218633926)

where TPredictor: IPredictor
{
TrainerInfo Info { get; }
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.Training
/// It produces a 'prediction transformer'.
/// </summary>
public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
where TTransformer : IPredictionTransformer<TModel>
where TTransformer : IClassicPredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
Expand Down
50 changes: 50 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Training;

namespace Microsoft.ML.Core.Prediction
{
/// <summary>
/// Holds information relevant to trainers. It is passed to the constructor of the<see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/>
/// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model.
/// This holds at least a training set, as well as optioonally a predictor.
/// </summary>
public class TrainerEstimatorContext
{
/// <summary>
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
/// a trainer that does not support validation sets should not be considered an error condition. It
/// should simply be ignored in that case.
/// </summary>
public IDataView ValidationSet { get; }

/// <summary>
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> implementor
/// does not support incremental training, then it can ignore it similarly to how one would ignore
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
/// </summary>
public IPredictor InitialPredictor { get; }

/// <summary>
/// Initializes a new instance of <see cref="TrainerEstimatorContext"/>, given a training set and optional other arguments.
/// </summary>
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null)
{
Contracts.CheckValueOrNull(validationSet);
Contracts.CheckValueOrNull(initialPredictor);

ValidationSet = validationSet;
InitialPredictor = initialPredictor;
}
}
}
Loading