Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx export functionality for LpNormNormalizingTransformer #4161

Merged
merged 18 commits into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 130 additions & 1 deletion src/Microsoft.ML.Transforms/GcnTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -313,11 +314,13 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _srcTypes;
private readonly int[] _srcCols;
private readonly DataViewType[] _types;
private readonly LpNormNormalizingEstimatorBase.NormFunction[] _norms;
private readonly bool[] _ensureZeroMeans;
private readonly LpNormNormalizingTransformer _parent;

public Mapper(LpNormNormalizingTransformer parent, DataViewSchema inputSchema)
Expand All @@ -327,12 +330,16 @@ public Mapper(LpNormNormalizingTransformer parent, DataViewSchema inputSchema)
_types = new DataViewType[_parent.ColumnPairs.Length];
_srcTypes = new DataViewType[_parent.ColumnPairs.Length];
_srcCols = new int[_parent.ColumnPairs.Length];
_norms = new LpNormNormalizingEstimatorBase.NormFunction[_parent.ColumnPairs.Length];
_ensureZeroMeans = new bool[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]);
var srcCol = inputSchema[_srcCols[i]];
_srcTypes[i] = srcCol.Type;
_types[i] = srcCol.Type;
_norms[i] = _parent._columns[i].Norm;
_ensureZeroMeans[i] = _parent._columns[i].EnsureZeroMean;
}
}

Expand Down Expand Up @@ -594,6 +601,128 @@ private static float Mean(ReadOnlySpan<float> src, int length)
return 0;
return CpuMathUtils.Sum(src) / length;
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

for (int iinfo = 0; iinfo < _srcCols.Length; ++iinfo)
{
string inputColumnName = InputSchema[_srcCols[iinfo]].Name;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(inputColumnName, false);
continue;
}

if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(_srcTypes[iinfo], inputColumnName)))
{
ctx.RemoveColumn(inputColumnName, true);
}
}
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
string opType;

if ((_norms[iinfo] != LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation) && (_ensureZeroMeans[iinfo] == false))
{
string strNorm;
if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L1)
strNorm = "L1";
else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L2)
strNorm = "L2";
else
strNorm = "MAX";
opType = "Normalizer";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("norm", strNorm);
return true;
}

opType = "ReduceMean";
string meanOfInput = ctx.AddIntermediateVariable(_types[iinfo], "MeanOfInput", true);
var meanNode = ctx.CreateNode(opType, srcVariableName, meanOfInput, ctx.GetNodeName(opType), "");
meanNode.AddAttribute("axes", new long[] { 1 });

opType = "Sub";
string inputMinusMean = ctx.AddIntermediateVariable(_types[iinfo], "InputMinusMean");
var subtractNode = ctx.CreateNode(opType, new[] { srcVariableName, meanOfInput }, new[] { inputMinusMean }, ctx.GetNodeName(opType), "");

if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L1)
{
opType = "Abs";
string absOfInput = ctx.AddIntermediateVariable(_types[iinfo], "AbsOfInput");
var absNode = ctx.CreateNode(opType, inputMinusMean, absOfInput, ctx.GetNodeName(opType), "");

opType = "ReduceSum";
string sumOfAbsOfInput = ctx.AddIntermediateVariable(_types[iinfo], "SumOfAbsOfInput", true);
var sumOfAbsNode = ctx.CreateNode(opType, absOfInput, sumOfAbsOfInput, ctx.GetNodeName(opType), "");
sumOfAbsNode.AddAttribute("axes", new long[] { 1 });

opType = "Div";
var l1Node = ctx.CreateNode(opType, new[] { inputMinusMean, sumOfAbsOfInput }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
}
else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L2)
{
opType = "Pow";
string two = ctx.AddInitializer(2.0f);
string squareOfInput = ctx.AddIntermediateVariable(_types[iinfo], "SquareOfInput", true);
var squareNode = ctx.CreateNode(opType, new[] { inputMinusMean, two }, new[] { squareOfInput }, ctx.GetNodeName(opType), "");

opType = "ReduceSum";
string sumOfSquares = ctx.AddIntermediateVariable(_types[iinfo], "SumOfSquares", true);
var sumOfSquaresNode = ctx.CreateNode(opType, squareOfInput, sumOfSquares, ctx.GetNodeName(opType), "");
sumOfSquaresNode.AddAttribute("axes", new long[] { 1 });

opType = "Sqrt";
string squareRoot = ctx.AddIntermediateVariable(_types[iinfo], "SquareRoot", true);
var squareRootNode = ctx.CreateNode(opType, sumOfSquares, squareRoot, ctx.GetNodeName(opType), "");

opType = "Div";
var l2Node = ctx.CreateNode(opType, new[] { inputMinusMean, squareRoot }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
}
else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.Infinity)
{
opType = "ReduceMax";
string maxOfInput = ctx.AddIntermediateVariable(_types[iinfo], "MaxOfInput", true);
var maxNode = ctx.CreateNode(opType, inputMinusMean, maxOfInput, ctx.GetNodeName(opType), "");
maxNode.AddAttribute("axes", new long[] { 1 });

opType = "Div";
var lMaxNode = ctx.CreateNode(opType, new[] { inputMinusMean, maxOfInput }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
}
else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation)
{
// first calculate the standard deviation
opType = "Pow";
string two = ctx.AddInitializer(2.0f);
string squareOfInputMinusMean = ctx.AddIntermediateVariable(_types[iinfo], "SquareOfInputMinusMean", true);
var squareOfInputMinusMeanNode = ctx.CreateNode(opType, new[] { inputMinusMean, two }, new[] { squareOfInputMinusMean }, ctx.GetNodeName(opType), "");

opType = "ReduceMean";
string average = ctx.AddIntermediateVariable(_types[iinfo], "SumOfSquares", true);
var sumOfSquaresNode = ctx.CreateNode(opType, squareOfInputMinusMean, average, ctx.GetNodeName(opType), "");
sumOfSquaresNode.AddAttribute("axes", new long[] { 1 });

opType = "Sqrt";
string stdDev = ctx.AddIntermediateVariable(_types[iinfo], "SquareRoot", true);
var stdDevNode = ctx.CreateNode(opType, average, stdDev, ctx.GetNodeName(opType), "");

opType = "Div";
string input = _ensureZeroMeans[iinfo] ? inputMinusMean : srcVariableName;
var lStdDevNode = ctx.CreateNode(opType, new[] {input, stdDev }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
}
else
{
Contracts.Assert(false);
return false;
}
return true;
}
}
}

Expand Down
62 changes: 62 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,68 @@ public void KmeansOnnxConversionTest()
Done();
}

private class DataPoint
{
[VectorType(3)]
public float[] Features { get; set; }
}

[Fact]
void LpNormOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);

var samples = new List<DataPoint>()
{
new DataPoint() { Features = new float[3] {0.01f, 0.02f, 0.03f} },
new DataPoint() { Features = new float[3] {0.04f, 0.05f, 0.06f} },
new DataPoint() { Features = new float[3] {0.07f, 0.08f, 0.09f} },
new DataPoint() { Features = new float[3] {0.10f, 0.11f, 0.12f} },
new DataPoint() { Features = new float[3] {0.13f, 0.14f, 0.15f} }
};
var dataView = mlContext.Data.LoadFromEnumerable(samples);

LpNormNormalizingEstimatorBase.NormFunction[] norms =
{
LpNormNormalizingEstimatorBase.NormFunction.L1,
LpNormNormalizingEstimatorBase.NormFunction.L2,
LpNormNormalizingEstimatorBase.NormFunction.Infinity,
LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation
};

bool[] ensureZeroMeans = { true, false};
foreach (var ensureZeroMean in ensureZeroMeans)
{
foreach (var norm in norms)
{
var pipe = mlContext.Transforms.NormalizeLpNorm(nameof(DataPoint.Features), norm:norm, ensureZeroMean: ensureZeroMean);

var model = pipe.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = $"LpNorm-{norm.ToString()}-{ensureZeroMean}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

SaveOnnxModel(onnxModel, onnxModelPath, null);

// Compare results produced by ML.NET and ONNX's runtime.
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RuntimeInformation.IsOSPlatform(OSPlatform.Windows) [](start = 24, length = 51)

Do you need this condition? If its a linux will results match?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is that the test will run only on Windows. The results should still match. It appears that OnnxRuntime doesn't support Linux and Mac yet.

{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4VectorColumns(nameof(DataPoint.Features), outputNames[0], transformedData, onnxResult, 3);
}
}
}

Done();
}

[Fact]
void CommandLineOnnxConversionTest()
{
Expand Down