-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Field-aware factorization machine to estimator #912
Changes from 7 commits
b53d09e
d7c942d
6358777
0e31686
67f41a3
d4f5413
26691e3
5890b11
174e75d
65a6296
492a890
e86cfb9
3d4858d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,22 +2,23 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Internal.Calibration; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
public interface IPredictionTransformer<out TModel> : ITransformer | ||
where TModel : IPredictor | ||
{ | ||
TModel Model { get; } | ||
} | ||
|
||
public interface IClassicPredictionTransformer<out TModel> : IPredictionTransformer<TModel> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
'Classic' sounds a bit wacky, even though it was my suggestion. Maybe 'SingleInput' ?.. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I sort of feel like an interface called I am choosing to interpret this interface as it gives me classic coke whenever I use it. #Resolved |
||
where TModel : IPredictor | ||
{ | ||
string FeatureColumn { get; } | ||
|
||
ColumnType FeatureColumnType { get; } | ||
|
||
TModel Model { get; } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
using System; | ||
using System.IO; | ||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.IO; | ||
|
@@ -20,39 +21,33 @@ | |
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, ICanSaveModel | ||
|
||
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No documentation? #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
where TModel : class, IPredictor | ||
{ | ||
private const string DirModel = "Model"; | ||
private const string DirTransSchema = "TrainSchema"; | ||
/// <summary> | ||
/// The model. | ||
/// </summary> | ||
public TModel Model { get; } | ||
|
||
protected const string DirModel = "Model"; | ||
protected const string DirTransSchema = "TrainSchema"; | ||
protected readonly IHost Host; | ||
protected readonly ISchemaBindableMapper BindableMapper; | ||
protected readonly ISchema TrainSchema; | ||
protected ISchemaBindableMapper BindableMapper; | ||
protected ISchema TrainSchema; | ||
|
||
public string FeatureColumn { get; } | ||
|
||
public ColumnType FeatureColumnType { get; } | ||
|
||
public TModel Model { get; } | ||
|
||
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) | ||
protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema) | ||
{ | ||
Contracts.CheckValue(host, nameof(host)); | ||
Host = host; | ||
Host.CheckValue(trainSchema, nameof(trainSchema)); | ||
|
||
Model = model; | ||
FeatureColumn = featureColumn; | ||
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) | ||
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); | ||
FeatureColumnType = trainSchema.GetColumnType(col); | ||
|
||
TrainSchema = trainSchema; | ||
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); | ||
} | ||
|
||
internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) | ||
protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) | ||
|
||
{ | ||
Host = host; | ||
|
||
|
@@ -74,16 +69,72 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) | |
ms.Position = 0; | ||
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); | ||
TrainSchema = loader.Schema; | ||
} | ||
|
||
public abstract ISchema GetOutputSchema(ISchema inputSchema); | ||
|
||
/// <summary> | ||
/// Transforms the input data. | ||
/// </summary> | ||
/// <param name="input"></param> | ||
/// <returns></returns> | ||
public abstract IDataView Transform(IDataView input); | ||
|
||
protected void SaveModel(ModelSaveContext ctx) | ||
{ | ||
// *** Binary format *** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
whenever you save or load, need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
// model: prediction model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Technically the model isn't part of this format, since you're not writing it to the stream, you're writing it somewhere else entirely, but that's OK. Consider fixing if you have to change the code anyway. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tfinley@gmail.com fixing it == remove the comment? In reply to: 218818528 [](ancestors = 218818528) |
||
// stream: empty data view that contains train schema. | ||
// id of string: feature column. | ||
|
||
ctx.SaveModel(Model, DirModel); | ||
ctx.SaveBinaryStream(DirTransSchema, writer => | ||
{ | ||
using (var ch = Host.Start("Saving train schema")) | ||
{ | ||
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); | ||
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); | ||
} | ||
}); | ||
} | ||
} | ||
|
||
public abstract class ClassicPredictionTransformerBase<TModel> : PredictionTransformerBase<TModel>, IClassicPredictionTransformer<TModel>, ICanSaveModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess by "classic" this just means a prediction transformer base that takes a single features column as its input. Classic is a bit of a funny word, but then again |
||
where TModel : class, IPredictor | ||
{ | ||
/// <summary> | ||
/// The name of the feature column used by the prediction transformer. | ||
/// </summary> | ||
public string FeatureColumn { get; } | ||
|
||
/// <summary> | ||
/// The type of the prediction transformer | ||
/// </summary> | ||
public ColumnType FeatureColumnType { get; } | ||
|
||
public ClassicPredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) | ||
:base(host, model, trainSchema) | ||
{ | ||
FeatureColumn = featureColumn; | ||
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) | ||
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); | ||
FeatureColumnType = trainSchema.GetColumnType(col); | ||
|
||
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); | ||
} | ||
|
||
internal ClassicPredictionTransformerBase(IHost host, ModelLoadContext ctx) | ||
:base(host, ctx) | ||
{ | ||
FeatureColumn = ctx.LoadString(); | ||
if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) | ||
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); | ||
FeatureColumnType = TrainSchema.GetColumnType(col); | ||
|
||
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); | ||
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); | ||
} | ||
|
||
public ISchema GetOutputSchema(ISchema inputSchema) | ||
public override ISchema GetOutputSchema(ISchema inputSchema) | ||
{ | ||
Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
||
|
@@ -95,8 +146,6 @@ public ISchema GetOutputSchema(ISchema inputSchema) | |
return Transform(new EmptyDataView(Host, inputSchema)).Schema; | ||
} | ||
|
||
public abstract IDataView Transform(IDataView input); | ||
|
||
public void Save(ModelSaveContext ctx) | ||
{ | ||
Host.CheckValue(ctx, nameof(ctx)); | ||
|
@@ -106,26 +155,12 @@ public void Save(ModelSaveContext ctx) | |
|
||
protected virtual void SaveCore(ModelSaveContext ctx) | ||
{ | ||
// *** Binary format *** | ||
// model: prediction model. | ||
// stream: empty data view that contains train schema. | ||
// id of string: feature column. | ||
|
||
ctx.SaveModel(Model, DirModel); | ||
ctx.SaveBinaryStream(DirTransSchema, writer => | ||
{ | ||
using (var ch = Host.Start("Saving train schema")) | ||
{ | ||
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); | ||
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); | ||
} | ||
}); | ||
|
||
SaveModel(ctx); | ||
ctx.SaveString(FeatureColumn); | ||
} | ||
} | ||
|
||
public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerBase<TModel> | ||
public sealed class BinaryPredictionTransformer<TModel> : ClassicPredictionTransformerBase<TModel> | ||
where TModel : class, IPredictorProducing<float> | ||
{ | ||
private readonly BinaryClassifierScorer _scorer; | ||
|
@@ -194,7 +229,7 @@ private static VersionInfo GetVersionInfo() | |
} | ||
} | ||
|
||
public sealed class MulticlassPredictionTransformer<TModel> : PredictionTransformerBase<TModel> | ||
public sealed class MulticlassPredictionTransformer<TModel> : ClassicPredictionTransformerBase<TModel> | ||
where TModel : class, IPredictorProducing<VBuffer<float>> | ||
{ | ||
private readonly MultiClassClassifierScorer _scorer; | ||
|
@@ -255,7 +290,7 @@ private static VersionInfo GetVersionInfo() | |
} | ||
} | ||
|
||
public sealed class RegressionPredictionTransformer<TModel> : PredictionTransformerBase<TModel> | ||
public sealed class RegressionPredictionTransformer<TModel> : ClassicPredictionTransformerBase<TModel> | ||
where TModel : class, IPredictorProducing<float> | ||
{ | ||
private readonly GenericScorer _scorer; | ||
|
@@ -324,4 +359,4 @@ internal static class RegressionPredictionTransformer | |
public static RegressionPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx) | ||
=> new RegressionPredictionTransformer<IPredictorProducing<float>>(env, ctx); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
namespace Microsoft.ML.Runtime.Training | ||
{ | ||
public interface ITrainerEstimator<out TTransformer, out TPredictor>: IEstimator<TTransformer> | ||
where TTransformer: IPredictionTransformer<TPredictor> | ||
where TTransformer: IClassicPredictionTransformer<TPredictor> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I believe this change is incorrect, isn't it? #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, maybe i named in reverse, but this is the old interface. In reply to: 218633926 [](ancestors = 218633926) |
||
where TPredictor: IPredictor | ||
{ | ||
TrainerInfo Info { get; } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Training; | ||
|
||
namespace Microsoft.ML.Core.Prediction | ||
{ | ||
/// <summary> | ||
/// Holds information relevant to trainers. It is passed to the constructor of the<see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> | ||
/// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model. | ||
/// This holds at least a training set, as well as optioonally a predictor. | ||
/// </summary> | ||
public class TrainerEstimatorContext | ||
{ | ||
/// <summary> | ||
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into | ||
/// a trainer that does not support validation sets should not be considered an error condition. It | ||
/// should simply be ignored in that case. | ||
/// </summary> | ||
public IDataView ValidationSet { get; } | ||
|
||
/// <summary> | ||
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> implementor | ||
/// does not support incremental training, then it can ignore it similarly to how one would ignore | ||
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there | ||
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception. | ||
/// </summary> | ||
public IPredictor InitialPredictor { get; } | ||
|
||
/// <summary> | ||
/// Initializes a new instance of <see cref="TrainerEstimatorContext"/>, given a training set and optional other arguments. | ||
/// </summary> | ||
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param> | ||
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param> | ||
public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null) | ||
{ | ||
Contracts.CheckValueOrNull(validationSet); | ||
Contracts.CheckValueOrNull(initialPredictor); | ||
|
||
ValidationSet = validationSet; | ||
InitialPredictor = initialPredictor; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we even need an interface for FFM, since for the time being it's the only trainer that accepts multiple feature columns. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to inquire about all trainers, it is useful to have them extend one interface. #Closed