Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed memory leaks from OnnxTransformer #5518

Merged
merged 6 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public static void Example()
//Create the pipeline using onnx file.
var onnxModelPath = "your_path_to_sample_onnx_conversion_1.onnx";
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet);
//Make sure to either use the 'using' clause or explicitly dispose the returned onnxTransformer to prevent memory leaks
using var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet);

//Inference the testset
var output = transformer.Transform(trainTestOriginalData.TestSet);
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private protected abstract class MapperBase : IRowMapper
{
protected readonly IHost Host;
protected readonly DataViewSchema InputSchema;
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns;
protected readonly Lazy<DataViewSchema.DetachedColumn[]> OutputColumns;
private readonly RowToRowTransformerBase _parent;

protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
Expand All @@ -68,21 +68,21 @@ protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformer
Host = host;
InputSchema = inputSchema;
_parent = parent;
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
OutputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
}

protected abstract DataViewSchema.DetachedColumn[] GetOutputColumnsCore();

DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value;
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => OutputColumns.Value;

Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
public virtual Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
// REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema
// of the input row.
// It still has to be the same schema, but because we may make a transition from lazy to eager schema, the reference-equality
// is no longer always possible. So, we relax the assert as below.
Contracts.Assert(input.Schema == InputSchema);
int n = _outputColumns.Value.Length;
int n = OutputColumns.Value.Length;
var result = new Delegate[n];
var disposers = new Action[n];
for (int i = 0; i < n; i++)
Expand Down
74 changes: 56 additions & 18 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,10 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
=> throw new NotImplementedException("This should never be called!");

private Delegate CreateGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher)
{
disposer = null;
Host.AssertValue(input);

var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
Expand All @@ -495,26 +497,59 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
var elemRawType = vectorType.ItemType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
if (vectorType.ItemType is TextDataViewType)
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
else
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
else
{
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
}

public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
Contracts.Assert(input.Schema == InputSchema);

OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher();

int n = OutputColumns.Value.Length;
var result = new Delegate[n];
for (int i = 0; i < n; i++)
{
if (!activeOutput(i))
continue;
result[i] = CreateGetter(input, i, activeOutput, outputCacher);
}
disposer = () =>
{
outputCacher.Dispose();
};
return result;
}

private class OnnxRuntimeOutputCacher
private sealed class OnnxRuntimeOutputCacher : IDisposable
{
public long Position;
public Dictionary<string, NamedOnnxValue> Outputs;
public Dictionary<string, DisposableNamedOnnxValue> Outputs;
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues;

public OnnxRuntimeOutputCacher()
{
Position = -1;
Outputs = new Dictionary<string, NamedOnnxValue>();
Outputs = new Dictionary<string, DisposableNamedOnnxValue>();
}

private bool _isDisposed;

public void Dispose()
{
if (_isDisposed)
return;
OutputOnnxValues?.Dispose();
_isDisposed = true;
}
}

Expand All @@ -529,21 +564,22 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue());
}

var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputNamedOnnxValues.Count > 0);
outputCache.OutputOnnxValues?.Dispose();
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0);

foreach (var outputNameOnnxValue in outputNamedOnnxValues)
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues)
{
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue;
}
outputCache.Position = position;
}
}

private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
Expand All @@ -558,10 +594,11 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
return valueGetter;
}

private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();

ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
Expand All @@ -580,14 +617,15 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
return valueGetter;
}

private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCache = new OnnxRuntimeOutputCacher();

ValueGetter<T> valueGetter = (ref T dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster;
dst = (T)caster(namedOnnxValue);
Expand Down
71 changes: 40 additions & 31 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,40 +198,49 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
_session = new InferenceSession(modelFile);
}

// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
// doesn't expose full type information via its C# APIs.
ModelFile = modelFile;
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(modelFile))
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
var inputTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Input)
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);

var initializerTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Initializer)
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);

var outputTypePool = new Dictionary<string, DataViewType>();
// Build casters which maps NamedOnnxValue to .NET objects.
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
foreach (var valueInfo in model.Graph.Output)
try
{
outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
}
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
// doesn't expose full type information via its C# APIs.
ModelFile = modelFile;
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(modelFile))
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
var inputTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Input)
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);

var initializerTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Initializer)
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);

var outputTypePool = new Dictionary<string, DataViewType>();
// Build casters which maps NamedOnnxValue to .NET objects.
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
foreach (var valueInfo in model.Graph.Output)
{
outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
}

var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);
var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

Graph = model.Graph;
Graph = model.Graph;
}
catch
{
_session.Dispose();
_session = null;
throw;
}
}

private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata,
Expand Down Expand Up @@ -350,7 +359,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu
/// </summary>
/// <param name="inputNamedOnnxValues">The NamedOnnxValues to score.</param>
/// <returns>Resulting output NamedOnnxValues list.</returns>
public IReadOnlyCollection<NamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
{
return _session.Run(inputNamedOnnxValues);
}
Expand Down
4 changes: 2 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ public void RemoveVariablesInPipelineTest()
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label", featureColumnName: "Features", numberOfLeaves: 2, numberOfTrees: 1, minimumExampleCountPerLeaf: 2));

var model = pipeline.Fit(data);
using var model = pipeline.Fit(data);
var transformedData = model.Transform(data);

var onnxConversionContext = new OnnxContextImpl(mlContext, "A Simple Pipeline", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
Expand Down Expand Up @@ -2029,7 +2029,7 @@ private void TestPipeline<TLastTransformer, TRow>(EstimatorChain<TLastTransforme
private void TestPipeline<TLastTransformer>(EstimatorChain<TLastTransformer> pipeline, IDataView dataView, string onnxFileName, ColumnComparison[] columnsToCompare, string onnxTxtName = null, string onnxTxtSubDir = null)
where TLastTransformer : class, ITransformer
{
var model = pipeline.Fit(dataView);
using var model = pipeline.Fit(dataView);
Copy link
Member

@antoniovs1029 antoniovs1029 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to update the samples inside the docs folder as well. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


In reply to: 533779757 [](ancestors = 533779757)

var transformedData = model.Transform(dataView);
var onnxModel = ML.Model.ConvertToOnnxProtobuf(model, dataView);

Expand Down