Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train FieldAwareFactorizationMachines without providing arguments #2931

Merged
merged 6 commits into from
Mar 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using System;
using System.Linq;
using Microsoft.ML.Data;

namespace Microsoft.ML.Samples.Dynamic
{
public static class FFMBinaryClassificationWithoutArguments
{
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
var mlContext = new MLContext();

// Download and featurize the dataset.
var dataviews = SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext);
var trainData = dataviews[0];
var testData = dataviews[1];

// ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to
// expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially
// helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a
// cache step in a pipeline is also possible, please see the construction of pipeline below.
trainData = mlContext.Data.Cache(trainData);

// Step 2: Pipeline
// Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and
// the "Features" column as the features column.
var pipeline = mlContext.Transforms.CopyColumns("Label", "Sentiment")
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine());

// Fit the model.
var model = pipeline.Fit(trainData);

// Let's get the model parameters from the model.
var modelParams = model.LastTransformer.Model;

// Let's inspect the model parameters.
var featureCount = modelParams.FeatureCount;
var fieldCount = modelParams.FieldCount;
var latentDim = modelParams.LatentDimension;
var linearWeights = modelParams.GetLinearWeights();
var latentWeights = modelParams.GetLatentWeights();

Console.WriteLine("The feature count is: " + featureCount);
Console.WriteLine("The number of fields is: " + fieldCount);
Console.WriteLine("The latent dimension is: " + latentDim);
Console.WriteLine("The linear weights of some of the features are: " +
string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} ")));
Console.WriteLine("The weights of some of the latent features are: " +
string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} ")));

// Expected Output:
// The feature count is: 9374
// The number of fields is: 1
// The latent dimension is: 20
// The linear weights of some of the features are: 0.0188 0.0000 -0.0048 -0.0184 0.0000 0.0031 0.0914 0.0112 -0.0152 0.0110
// The weights of some of the latent features are: 0.0631 0.0041 -0.0333 0.0694 0.1330 0.0790 0.1168 -0.0848 0.0431 0.0411

// Evaluate how the model is doing on the test data.
var dataWithPredictions = model.Transform(testData);

var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment");
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);

// Expected output:
// Accuracy: 0.61
// AUC: 0.72
// F1 Score: 0.59
// Negative Precision: 0.60
// Negative Recall: 0.67
// Positive Precision: 0.63
// Positive Recall: 0.56
// Log Loss: 1.21
// Log Loss Reduction: -21.20
// Entropy: 1.00
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,32 @@ namespace Microsoft.ML
/// </summary>
public static class FactorizationMachineExtensions
{
/// <summary>
/// Predict a target using a field-aware factorization machine algorithm.
Copy link
Member

@wschin wschin Mar 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add Note that because there is only one feature column, the underlying model is equivalent to standard factorization machine. #Resolved

/// </summary>
/// <remarks>
/// Note that because there is only one feature column, the underlying model is equivalent to standard factorization machine.
/// </remarks>
/// <param name="catalog">The binary classification catalog trainer object.</param>
/// <param name="featureColumnName">The name of the feature column.</param>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FieldAwareFactorizationMachineWithoutArguments](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs)]
/// ]]></format>
Copy link
Member

@sfilipi sfilipi Mar 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this sample illustrate this API, or one of the other ones below? This first one and the second are failry similar, but the third API is a bit different; idk if it will get confusing. #Resolved

/// </example>
public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string featureColumnName = DefaultColumnNames.Features,
Copy link
Member

@wschin wschin Mar 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test please. #Resolved

string labelColumnName = DefaultColumnNames.Label,
string exampleWeightColumnName = null)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new FieldAwareFactorizationMachineTrainer(env, new string[] { featureColumnName }, labelColumnName, exampleWeightColumnName);
}

/// <summary>
/// Predict a target using a field-aware factorization machine algorithm.
/// </summary>
Expand Down
21 changes: 21 additions & 0 deletions test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,27 @@ namespace Microsoft.ML.Tests.TrainerEstimators
{
public partial class TrainerEstimators : TestDataPipeBase
{
[Fact]
public void FfmBinaryClassificationWithoutArguments()
{
var mlContext = new MLContext(seed: 0);
var data = DatasetUtils.GenerateFfmSamples(500);
var dataView = mlContext.Data.LoadFromEnumerable(data);

var pipeline = mlContext.Transforms.CopyColumns(DefaultColumnNames.Features, nameof(DatasetUtils.FfmExample.Field0))
.Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine());

var model = pipeline.Fit(dataView);
var prediction = model.Transform(dataView);

var metrics = mlContext.BinaryClassification.Evaluate(prediction);

// Run a sanity check against a few of the metrics.
Assert.InRange(metrics.Accuracy, 0.6, 1);
Assert.InRange(metrics.AreaUnderRocCurve, 0.7, 1);
Assert.InRange(metrics.AreaUnderPrecisionRecallCurve, 0.65, 1);
}

[Fact]
public void FfmBinaryClassificationWithAdvancedArguments()
{
Expand Down